dlc2action.project.project

Project interface

   1#
   2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
   3#
   4# This project and all its files are licensed under GNU AGPLv3 or later version. 
   5# A copy is included in dlc2action/LICENSE.AGPL.
   6#
   7"""
   8Project interface
   9"""
  10
  11import gc
  12import os
  13import pickle
  14import shutil
  15import time
  16import warnings
  17from abc import abstractmethod
  18from collections import defaultdict
  19from collections.abc import Iterable, Mapping
  20from copy import deepcopy
  21from email.policy import default
  22from itertools import product
  23from pathlib import Path
  24from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
  25
  26import cv2
  27import numpy as np
  28import optuna
  29import pandas as pd
  30import plotly
  31import torch
  32from matplotlib import cm
  33from matplotlib import pyplot as plt
  34from matplotlib import rc
  35from numpy import ndarray
  36from ruamel.yaml import YAML
  37from ruamel.yaml.comments import CommentedMap, CommentedSet
  38from tqdm import tqdm
  39
  40from dlc2action import __version__, options
  41from dlc2action.data.dataset import BehaviorDataset
  42from dlc2action.project.meta import (
  43    DecisionThresholds,
  44    Run,
  45    SavedRuns,
  46    SavedStores,
  47    Searches,
  48    Suggestions,
  49)
  50from dlc2action.task.task_dispatcher import TaskDispatcher
  51from dlc2action.utils import apply_threshold, binarize_data, load_pickle
  52
  53
  54class Project:
  55    """A class to create and maintain the project files + keep track of experiments."""
  56
  57    def __init__(
  58        self,
  59        name: str,
  60        data_type: str = None,
  61        annotation_type: str = "none",
  62        projects_path: str = None,
  63        data_path: Union[str, List] = None,
  64        annotation_path: Union[str, List] = None,
  65        copy: bool = False,
  66    ) -> None:
  67        """Initialize the class.
  68
  69        Parameters
  70        ----------
  71        name : str
  72            name of the project
  73        data_type : str, optional
  74            data type (run Project.data_types() to see available options; has to be provided if the project is being
  75            created)
  76        annotation_type : str, default 'none'
  77            annotation type (run Project.annotation_types() to see available options)
  78        projects_path : str, optional
  79            path to the projects folder (is filled with ~/DLC2Action by default)
  80        data_path : str, optional
  81            path to the folder containing input files for the project (has to be provided if the project is being
  82            created)
  83        annotation_path : str, optional
  84            path to the folder containing annotation files for the project
  85        copy : bool, default False
  86            if True, the files from annotation_path and data_path will be copied to the projects folder;
  87            otherwise they will be moved
  88
  89        """
  90        if projects_path is None:
  91            projects_path = os.path.join(str(Path.home()), "DLC2Action")
  92        if not os.path.exists(projects_path):
  93            os.mkdir(projects_path)
  94        self.project_path = os.path.join(projects_path, name)
  95        self.name = name
  96        self.data_type = data_type
  97        self.annotation_type = annotation_type
  98        self.data_path = data_path
  99        self.annotation_path = annotation_path
 100        if not os.path.exists(self.project_path):
 101            if data_type is None:
 102                raise ValueError(
 103                    "The data_type parameter is necessary when creating a new project!"
 104                )
 105            self._initialize_project(
 106                data_type, annotation_type, data_path, annotation_path, copy
 107            )
 108        else:
 109            self.annotation_type, self.data_type = self._read_types()
 110            if data_type != self.data_type and data_type is not None:
 111                raise ValueError(
 112                    f"The project has already been initialized with data_type={self.data_type}!"
 113                )
 114            if annotation_type != self.annotation_type and annotation_type != "none":
 115                raise ValueError(
 116                    f"The project has already been initialized with annotation_type={self.annotation_type}!"
 117                )
 118            self.annotation_path, data_path = self._read_paths()
 119            if self.data_path is None:
 120                self.data_path = data_path
 121            # if data_path != self.data_path and data_path is not None:
 122            #     raise ValueError(
 123            #         f"The project has already been initialized with data_path={self.data_path}!"
 124            #     )
 125            if annotation_path != self.annotation_path and annotation_path is not None:
 126                raise ValueError(
 127                    f"The project has already been initialized with annotation_path={self.annotation_path}!"
 128                )
 129        self._update_configs()
 130
 131    def _make_prediction(
 132        self,
 133        prediction_name: str,
 134        episode_names: List,
 135        load_epochs: Union[List[int], int] = None,
 136        parameters_update: Dict = None,
 137        data_path: str = None,
 138        file_paths: Set = None,
 139        mode: str = "all",
 140        augment_n: int = 0,
 141        evaluate: bool = False,
 142        task: TaskDispatcher = None,
 143        embedding: bool = False,
 144        annotation_type: str = "none",
 145    ) -> Tuple[TaskDispatcher, Dict, str, torch.Tensor]:
 146        """Generate a prediction.
 147        Parameters
 148        ----------
 149        prediction_name : str
 150            name of the prediction
 151            episode_names : List
 152            names of the episodes to use for the prediction
 153            load_epochs : Union[List[int],int], optional
 154            epochs to load for each episode; if a single integer is provided, it will be used for all episodes;
 155            if None, the last epochs will be used
 156            parameters_update : Dict, optional
 157            dictionary with parameters to update the task parameters
 158            data_path : str, optional
 159            path to the data folder; if None, the data_path from the project will be used
 160            file_paths : Set, optional
 161            set of file paths to use for the prediction; if None, the data_path will be used
 162            mode : str, default "all
 163            mode of the prediction; can be "train", "val", "test" or "all
 164            augment_n : int, default 0
 165            number of augmentations to apply to the data; if 0, no augmentations are applied
 166            evaluate : bool, default False
 167            if True, the prediction will be evaluated and the results will be saved to the episode meta file
 168            task : TaskDispatcher, optional
 169            task object to use for the prediction; if None, a new task object will be created
 170            embedding : bool, default False
 171            if True, the prediction will be returned as an embedding
 172            annotation_type : str, default "none
 173            type of the annotation to use for the prediction; if "none", the annotation will not be used
 174        Returns
 175        -------
 176        task : TaskDispatcher
 177            task object used for the prediction
 178        parameters : Dict
 179            parameters used for the prediction
 180        mode : str
 181            mode of the prediction
 182        prediction : torch.Tensor
 183            prediction tensor of shape (num_videos, num_behaviors, num_frames)
 184        inference_time : str
 185            time taken for the prediction in the format "HH:MM:SS"
 186        behavior_dict : Dict
 187            dictionary with behavior names and their indices
 188        """
 189
 190        names = []
 191        for episode_name in episode_names:
 192            names += self._episodes().get_runs(episode_name)
 193        if len(names) == 0:
 194            warnings.warn(f"None of the episodes {episode_names} exist!")
 195            names = [None]
 196        if load_epochs is None:
 197            load_epochs = [None for _ in names]
 198        elif isinstance(load_epochs, int):
 199            load_epochs = [load_epochs for _ in names]
 200        assert len(load_epochs) == len(
 201            names
 202        ), f"Length of load_epochs ({len(load_epochs)}) must match the number of episodes ({len(names)})!"
 203        prediction = None
 204        decision_thresholds = None
 205        time_total = 0
 206        behavior_dicts = [
 207            self.get_behavior_dictionary(episode_name) for episode_name in names
 208        ]
 209
 210        if not all(
 211            [
 212                set(d.values()) == set(behavior_dicts[0].values())
 213                for d in behavior_dicts[1:]
 214            ]
 215        ):
 216            raise ValueError(
 217                f"Episodes {episode_names} have different sets of behaviors!"
 218            )
 219        behaviors = list(behavior_dicts[0].values())
 220
 221        for episode_name, load_epoch, behavior_dict in zip(
 222            names, load_epochs, behavior_dicts
 223        ):
 224            print(f"episode {episode_name}")
 225            task, parameters, data_mode = self._make_task_prediction(
 226                prediction_name=prediction_name,
 227                load_episode=episode_name,
 228                parameters_update=parameters_update,
 229                load_epoch=load_epoch,
 230                data_path=data_path,
 231                mode=mode,
 232                file_paths=file_paths,
 233                task=task,
 234                decision_thresholds=decision_thresholds,
 235                annotation_type=annotation_type,
 236            )
 237            # data_mode = "train" if mode == "all" else mode
 238            time_start = time.time()
 239            new_pred = task.predict(
 240                data_mode,
 241                raw_output=True,
 242                apply_primary_function=True,
 243                augment_n=augment_n,
 244                embedding=embedding,
 245            )
 246            indices = [
 247                behaviors.index(behavior_dict[i]) for i in range(new_pred.shape[1])
 248            ]
 249            new_pred = new_pred[:, indices, :]
 250            time_end = time.time()
 251            time_total += time_end - time_start
 252            if evaluate:
 253                _, metrics = task.evaluate_prediction(
 254                    new_pred, data=data_mode, indices=indices
 255                )
 256                if mode == "val":
 257                    self._update_episode_metrics(episode_name, metrics)
 258            if prediction is None:
 259                prediction = new_pred
 260            else:
 261                prediction += new_pred
 262            print("\n")
 263        hours = int(time_total // 3600)
 264        time_total -= hours * 3600
 265        minutes = int(time_total // 60)
 266        time_total -= minutes * 60
 267        seconds = int(time_total)
 268        inference_time = f"{hours}:{minutes:02}:{seconds:02}"
 269        prediction /= len(names)
 270        return (
 271            task,
 272            parameters,
 273            data_mode,
 274            prediction,
 275            inference_time,
 276            behavior_dicts[0],
 277        )
 278
 279    def _make_task_prediction(
 280        self,
 281        prediction_name: str,
 282        load_episode: str = None,
 283        parameters_update: Dict = None,
 284        load_epoch: int = None,
 285        data_path: str = None,
 286        annotation_path: str = None,
 287        mode: str = "val",
 288        file_paths: Set = None,
 289        decision_thresholds: List = None,
 290        task: TaskDispatcher = None,
 291        annotation_type: str = "none",
 292    ) -> Tuple[TaskDispatcher, Dict, str]:
 293        """Make a `TaskDispatcher` object that will be used to generate a prediction."""
 294        if parameters_update is None:
 295            parameters_update = {}
 296        parameters_update_second = {}
 297        if mode == "all" or data_path is not None or file_paths is not None:
 298            parameters_update_second["training"] = {
 299                "val_frac": 0,
 300                "test_frac": 0,
 301                "partition_method": "random",
 302                "save_split": False,
 303                "split_path": None,
 304            }
 305            mode = "train"
 306        if decision_thresholds is not None:
 307            if (
 308                len(decision_thresholds)
 309                == self._episode(load_episode).get_num_classes()
 310            ):
 311                parameters_update_second["general"] = {
 312                    "threshold_value": decision_thresholds
 313                }
 314            else:
 315                raise ValueError(
 316                    f"The length of the decision thresholds {decision_thresholds} "
 317                    f"must be equal to the length of the behaviors dictionary "
 318                    f"{self._episode(load_episode).get_behaviors_dict()}"
 319                )
 320        data_param_update = {}
 321        if data_path is not None:
 322            data_param_update = {"data_path": data_path}
 323            if annotation_path is None:
 324                data_param_update["annotation_path"] = data_path
 325        if annotation_path is not None:
 326            data_param_update["annotation_path"] = annotation_path
 327        if file_paths is not None:
 328            data_param_update = {"data_path": None, "file_paths": file_paths}
 329        parameters_update = self._update(parameters_update, {"data": data_param_update})
 330        if data_path is not None or file_paths is not None:
 331            general_update = {
 332                "annotation_type": annotation_type,
 333                "only_load_annotated": False,
 334            }
 335        else:
 336            general_update = {}
 337        parameters_update = self._update(parameters_update, {"general": general_update})
 338        task, parameters = self._make_task(
 339            episode_name=prediction_name,
 340            load_episode=load_episode,
 341            parameters_update=parameters_update,
 342            parameters_update_second=parameters_update_second,
 343            load_epoch=load_epoch,
 344            purpose="prediction",
 345            task=task,
 346        )
 347        return task, parameters, mode
 348
 349    def _make_task_training(
 350        self,
 351        episode_name: str,
 352        load_episode: str = None,
 353        parameters_update: Dict = None,
 354        load_epoch: int = None,
 355        load_search: str = None,
 356        load_parameters: list = None,
 357        round_to_binary: list = None,
 358        load_strict: bool = True,
 359        continuing: bool = False,
 360        task: TaskDispatcher = None,
 361        mask_name: str = None,
 362        throwaway: bool = False,
 363    ) -> Tuple[TaskDispatcher, Dict, str]:
 364        """Make a `TaskDispatcher` object that will be used to generate a prediction."""
 365        if parameters_update is None:
 366            parameters_update = {}
 367        if continuing:
 368            purpose = "continuing"
 369        else:
 370            purpose = "training"
 371        if mask_name is not None:
 372            mask_name = os.path.join(self._mask_path(), f"{mask_name}.pickle")
 373        parameters_update_second = {"data": {"real_lens": mask_name}}
 374        if throwaway:
 375            parameters_update = self._update(
 376                parameters_update, {"training": {"normalize": False, "device": "cpu"}}
 377            )
 378        return self._make_task(
 379            episode_name,
 380            load_episode,
 381            parameters_update,
 382            parameters_update_second,
 383            load_epoch,
 384            load_search,
 385            load_parameters,
 386            round_to_binary,
 387            purpose,
 388            task,
 389            load_strict=load_strict,
 390        )
 391
 392    def _make_parameters(
 393        self,
 394        episode_name: str,
 395        load_episode: str = None,
 396        parameters_update: Dict = None,
 397        parameters_update_second: Dict = None,
 398        load_epoch: int = None,
 399        load_search: str = None,
 400        load_parameters: list = None,
 401        round_to_binary: list = None,
 402        purpose: str = "train",
 403        load_strict: bool = True,
 404    ):
 405        """Construct a parameters dictionary."""
 406        if parameters_update is None:
 407            parameters_update = {}
 408        pars_update = deepcopy(parameters_update)
 409        if parameters_update_second is None:
 410            parameters_update_second = {}
 411        if (
 412            purpose == "prediction"
 413            and "model" in pars_update.keys()
 414            and pars_update["general"]["model_name"] != "motionbert"
 415        ):
 416            raise ValueError("Cannot change model parameters after training!")
 417        if purpose in ["continuing", "prediction"] and load_episode is not None:
 418            read_parameters = self._read_parameters()
 419            parameters = self._episodes().load_parameters(load_episode)
 420            parameters["metrics"] = self._update(
 421                read_parameters["metrics"], parameters["metrics"]
 422            )
 423            parameters["ssl"] = self._update(
 424                read_parameters["ssl"], parameters.get("ssl", {})
 425            )
 426        else:
 427            parameters = self._read_parameters()
 428        if "model" in pars_update:
 429            model_params = pars_update.pop("model")
 430        else:
 431            model_params = None
 432        if "features" in pars_update:
 433            feat_params = pars_update.pop("features")
 434        else:
 435            feat_params = None
 436        if "augmentations" in pars_update:
 437            aug_params = pars_update.pop("augmentations")
 438        else:
 439            aug_params = None
 440        parameters = self._update(parameters, pars_update)
 441        if pars_update.get("general", {}).get("model_name") is not None:
 442            model_name = parameters["general"]["model_name"]
 443            parameters["model"] = self._open_yaml(
 444                os.path.join(self.project_path, "config", "model", f"{model_name}.yaml")
 445            )
 446        if pars_update.get("general", {}).get("feature_extraction") is not None:
 447            feat_name = parameters["general"]["feature_extraction"]
 448            parameters["features"] = self._open_yaml(
 449                os.path.join(
 450                    self.project_path, "config", "features", f"{feat_name}.yaml"
 451                )
 452            )
 453            aug_name = options.extractor_to_transformer[
 454                parameters["general"]["feature_extraction"]
 455            ]
 456            parameters["augmentations"] = self._open_yaml(
 457                os.path.join(
 458                    self.project_path, "config", "augmentations", f"{aug_name}.yaml"
 459                )
 460            )
 461        if model_params is not None:
 462            parameters["model"] = self._update(parameters["model"], model_params)
 463        if feat_params is not None:
 464            parameters["features"] = self._update(parameters["features"], feat_params)
 465        if aug_params is not None:
 466            parameters["augmentations"] = self._update(
 467                parameters["augmentations"], aug_params
 468            )
 469        if load_search is not None:
 470            parameters = self._update_with_search(
 471                parameters, load_search, load_parameters, round_to_binary
 472            )
 473        parameters = self._fill(
 474            parameters,
 475            episode_name,
 476            load_episode,
 477            load_epoch=load_epoch,
 478            load_strict=load_strict,
 479            only_load_model=(purpose != "continuing"),
 480            continuing=(purpose in ["prediction", "continuing"]),
 481            enforce_split_parameters=(purpose == "prediction"),
 482        )
 483        parameters = self._update(parameters, parameters_update_second)
 484        return parameters
 485
 486    def _make_task(
 487        self,
 488        episode_name: str,
 489        load_episode: str = None,
 490        parameters_update: Dict = None,
 491        parameters_update_second: Dict = None,
 492        load_epoch: int = None,
 493        load_search: str = None,
 494        load_parameters: list = None,
 495        round_to_binary: list = None,
 496        purpose: str = "train",
 497        task: TaskDispatcher = None,
 498        load_strict: bool = True,
 499    ) -> Tuple[TaskDispatcher, Union[CommentedMap, dict]]:
 500        """Make a `TaskDispatcher` object.
 501
 502        The task parameters are read from the config files and then updated with the
 503        parameters_update dictionary. The model can be either initialized from scratch or loaded from one of the
 504        previous experiments. All parameters and results are saved in the meta files and can be accessed with the
 505        list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the
 506        same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same
 507        data parameters are used.
 508
 509        Parameters
 510        ----------
 511        episode_name : str
 512            the name of the episode
 513        load_episode : str, optional
 514            the (previously run) episode name to load the model from
 515        parameters_update : dict, optional
 516            the dictionary used to update the parameters from the config
 517        parameters_update_second : dict, optional
 518            the dictionary used to update the parameters after the automatic fill-out
 519        load_epoch : int, optional
 520            the epoch to load (if load_episodes is not None); if not provided, the last epoch is used
 521        load_search : str, optional
 522            the hyperparameter search result to load
 523        load_parameters : list, optional
 524            a list of string names of the parameters to load from load_search (if not provided, all parameters
 525            are loaded)
 526        round_to_binary : list, optional
 527            a list of string names of the loaded parameters that should be rounded to the nearest power of two
 528        purpose : {"train", "continuing", "prediction"}
 529            the purpose of the task object (`"train"` for training from scratch, `"continuing"` for continuing
 530            the training of an interrupted episode, `"prediction"` for generating a prediction)
 531        task : TaskDispatcher, optional
 532            a pre-existing task; if provided, the method will update the task instead of creating a new one
 533            (this might save time, mainly on dataset loading)
 534
 535        Returns
 536        -------
 537        task : TaskDispatcher
 538            the `TaskDispatcher` instance
 539        parameters : dict
 540            the parameters dictionary that describes the task
 541
 542        """
 543        parameters = self._make_parameters(
 544            episode_name,
 545            load_episode,
 546            parameters_update,
 547            parameters_update_second,
 548            load_epoch,
 549            load_search,
 550            load_parameters,
 551            round_to_binary,
 552            purpose,
 553            load_strict=load_strict,
 554        )
 555        if task is None:
 556            task = TaskDispatcher(parameters)
 557        else:
 558            task.update_task(parameters)
 559        self._save_stores(parameters)
 560        return task, parameters
 561
 562    def get_decision_thresholds(
 563        self,
 564        episode_names: List,
 565        metric_name: str = "f1",
 566        parameters_update: Dict = None,
 567        load_epochs: List = None,
 568        remove_saved_features: bool = False,
 569    ) -> Tuple[List, List, TaskDispatcher]:
 570        """Compute optimal decision thresholds or load them if they have been computed before.
 571
 572        Parameters
 573        ----------
 574        episode_names : List
 575            a list of episode names
 576        metric_name : {"f1", "segmental_f1", "semisegmental_f1", "f_beta", "segmental_f_beta"}
 577            the metric to optimize
 578        parameters_update : dict, optional
 579            the parameter update dictionary
 580        load_epochs : list, optional
 581            a list of epochs to load (by default last are loaded)
 582        remove_saved_features : bool, default False
 583            if `True`, the dataset will be deleted after the computation
 584
 585        Returns
 586        -------
 587        thresholds : list
 588            a list of float decision threshold values
 589        classes : list
 590            the label names corresponding to the values
 591        task : TaskDispatcher | None
 592            the task used in computation
 593
 594        """
 595        parameters = self._make_parameters(
 596            "_",
 597            episode_names[0],
 598            parameters_update,
 599            {},
 600            load_epochs[0],
 601            purpose="prediction",
 602        )
 603        thresholds = self._thresholds().find_thresholds(
 604            episode_names,
 605            load_epochs,
 606            metric_name,
 607            metric_parameters=parameters["metrics"][metric_name],
 608        )
 609        task = None
 610        behaviors = list(self._episode(episode_names[0]).get_behaviors_dict().values())
 611        return thresholds, behaviors, task
 612
 613    def run_episode(
 614        self,
 615        episode_name: str,
 616        load_episode: str = None,
 617        parameters_update: Dict = None,
 618        task: TaskDispatcher = None,
 619        load_epoch: int = None,
 620        load_search: str = None,
 621        load_parameters: list = None,
 622        round_to_binary: list = None,
 623        load_strict: bool = True,
 624        n_seeds: int = 1,
 625        force: bool = False,
 626        suppress_name_check: bool = False,
 627        remove_saved_features: bool = False,
 628        mask_name: str = None,
 629        autostop_metric: str = None,
 630        autostop_interval: int = 50,
 631        autostop_threshold: float = 0.001,
 632        loading_bar: bool = False,
 633        trial: Tuple = None,
 634    ) -> TaskDispatcher:
 635        """Run an episode.
 636
 637        The task parameters are read from the config files and then updated with the
 638        parameters_update dictionary. The model can be either initialized from scratch or loaded from one of the
 639        previous experiments. All parameters and results are saved in the meta files and can be accessed with the
 640        list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the
 641        same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same
 642        data parameters are used.
 643
 644        You can use the autostop parameters to finish training when the parameters are not improving. It will be
 645        stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than
 646        the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the
 647        current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared.
 648
 649        Parameters
 650        ----------
 651        episode_name : str
 652            the episode name
 653        load_episode : str, optional
 654            the (previously run) episode name to load the model from; if the episode has multiple runs,
 655            the new episode will have the same number of runs, each starting with one of the pre-trained models
 656        parameters_update : dict, optional
 657            the dictionary used to update the parameters from the config files
 658        task : TaskDispatcher, optional
 659            a pre-existing `TaskDispatcher` object (if provided, the method will update it instead of creating
 660            a new instance)
 661        load_epoch : int, optional
 662            the epoch to load (if load_episodes is not None); if not provided, the last epoch is used
 663        load_search : str, optional
 664            the hyperparameter search result to load
 665        load_parameters : list, optional
 666            a list of string names of the parameters to load from load_search (if not provided, all parameters
 667            are loaded)
 668        round_to_binary : list, optional
 669            a list of string names of the loaded parameters that should be rounded to the nearest power of two
 670        load_strict : bool, default True
 671            if `False`, matching weights will be loaded from `load_episode` and differences in parameter name lists and
 672            weight shapes will be ignored; otherwise mismatches will prompt a `RuntimeError`
 673        n_seeds : int, default 1
 674            the number of runs to perform; if `n_seeds > 1`, the episodes will be named `episode_name#run_index`, e.g.
 675            `test_episode#0` and `test_episode#1`
 676        force : bool, default False
 677            if `True` and an episode with name `episode_name` already exists, it will be overwritten (use with caution!)
 678        suppress_name_check : bool, default False
 679            if `True`, episode names with a double colon are allowed (please don't use this option unless you understand
 680            why they are usually forbidden)
 681        remove_saved_features : bool, default False
 682            if `True`, the dataset will be deleted after training
 683        mask_name : str, optional
 684            the name of the real_lens to apply
 685        autostop_metric : str, optional
 686            the autostop metric (can be any one of the tracked metrics of `'loss'`)
 687        autostop_interval : int, default 50
 688            the number of epochs to average the autostop metric over
 689        autostop_threshold : float, default 0.001
 690            the autostop difference threshold
 691        loading_bar : bool, default False
 692            if `True`, a loading bar will be displayed
 693        trial : tuple, optional
 694            a tuple of (trial, metric) for hyperparameter search
 695
 696        Returns
 697        -------
 698        TaskDispatcher
 699            the `TaskDispatcher` object
 700
 701        """
 702
 703        import gc
 704
 705        gc.collect()
 706        if torch.cuda.is_available():
 707            torch.cuda.empty_cache()
 708
 709        if type(n_seeds) is not int or n_seeds < 1:
 710            raise ValueError(
 711                f"The n_seeds parameter has to be an integer larger than 0; got {n_seeds}"
 712            )
 713        if n_seeds > 1 and mask_name is not None:
 714            raise ValueError("Cannot apply a real_lens with n_seeds > 1")
 715        self._check_episode_validity(
 716            episode_name, allow_doublecolon=suppress_name_check, force=force
 717        )
 718        load_runs = self._episodes().get_runs(load_episode)
 719        if len(load_runs) > 1:
 720            task = self.run_episodes(
 721                episode_names=[
 722                    f'{episode_name}#{run.split("#")[-1]}' for run in load_runs
 723                ],
 724                load_episodes=load_runs,
 725                parameters_updates=[parameters_update for _ in load_runs],
 726                load_epochs=[load_epoch for _ in load_runs],
 727                load_searches=[load_search for _ in load_runs],
 728                load_parameters=[load_parameters for _ in load_runs],
 729                round_to_binary=[round_to_binary for _ in load_runs],
 730                load_strict=[load_strict for _ in load_runs],
 731                suppress_name_check=True,
 732                force=force,
 733                remove_saved_features=False,
 734            )
 735            if remove_saved_features:
 736                self._remove_stores(
 737                    {
 738                        "general": task.general_parameters,
 739                        "data": task.data_parameters,
 740                        "features": task.feature_parameters,
 741                    }
 742                )
 743            if n_seeds > 1:
 744                warnings.warn(
 745                    f"The n_seeds parameter is disregarded since load_episode={load_episode} has multiple runs"
 746                )
 747        elif n_seeds > 1:
 748
 749            self.run_episodes(
 750                episode_names=[f"{episode_name}#{i}" for i in range(n_seeds)],
 751                load_episodes=[load_episode for _ in range(n_seeds)],
 752                parameters_updates=[parameters_update for _ in range(n_seeds)],
 753                load_epochs=[load_epoch for _ in range(n_seeds)],
 754                load_searches=[load_search for _ in range(n_seeds)],
 755                load_parameters=[load_parameters for _ in range(n_seeds)],
 756                round_to_binary=[round_to_binary for _ in range(n_seeds)],
 757                load_strict=[load_strict for _ in range(n_seeds)],
 758                suppress_name_check=True,
 759                force=force,
 760                remove_saved_features=remove_saved_features,
 761            )
 762        else:
 763            print(f"TRAINING {episode_name}")
 764            try:
 765                task, parameters = self._make_task_training(
 766                    episode_name,
 767                    load_episode,
 768                    parameters_update,
 769                    load_epoch,
 770                    load_search,
 771                    load_parameters,
 772                    round_to_binary,
 773                    continuing=False,
 774                    task=task,
 775                    mask_name=mask_name,
 776                    load_strict=load_strict,
 777                )
 778                self._save_episode(
 779                    episode_name,
 780                    parameters,
 781                    task.behaviors_dict(),
 782                    norm_stats=task.get_normalization_stats(),
 783                )
 784                time_start = time.time()
 785                if trial is not None:
 786                    trial, metric = trial
 787                else:
 788                    trial, metric = None, None
 789                logs = task.train(
 790                    autostop_metric=autostop_metric,
 791                    autostop_interval=autostop_interval,
 792                    autostop_threshold=autostop_threshold,
 793                    loading_bar=loading_bar,
 794                    trial=trial,
 795                    optimized_metric=metric,
 796                )
 797                time_end = time.time()
 798                time_total = time_end - time_start
 799                hours = int(time_total // 3600)
 800                time_total -= hours * 3600
 801                minutes = int(time_total // 60)
 802                time_total -= minutes * 60
 803                seconds = int(time_total)
 804                training_time = f"{hours}:{minutes:02}:{seconds:02}"
 805                self._update_episode_results(episode_name, logs, training_time)
 806                if remove_saved_features:
 807                    self._remove_stores(parameters)
 808                print("\n")
 809                return task
 810
 811            except Exception as e:
 812                if isinstance(e, optuna.exceptions.TrialPruned):
 813                    raise e
 814                else:
 815                    # if str(e) != f"The {episode_name} episode name is already in use!":
 816                    #     self.remove_episode(episode_name)
 817                    raise RuntimeError(f"Episode {episode_name} could not run")
 818
 819    def run_episodes(
 820        self,
 821        episode_names: List,
 822        load_episodes: List = None,
 823        parameters_updates: List = None,
 824        load_epochs: List = None,
 825        load_searches: List = None,
 826        load_parameters: List = None,
 827        round_to_binary: List = None,
 828        load_strict: List = None,
 829        force: bool = False,
 830        suppress_name_check: bool = False,
 831        remove_saved_features: bool = False,
 832    ) -> TaskDispatcher:
 833        """Run multiple episodes in sequence (and re-use previously loaded information).
 834
 835        For each episode, the task parameters are read from the config files and then updated with the
 836        parameter_update dictionary. The model can be either initialized from scratch or loaded from one of the
 837        previous experiments. All parameters and results are saved in the meta files and can be accessed with the
 838        list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the
 839        same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same
 840        data parameters are used.
 841
 842        Parameters
 843        ----------
 844        episode_names : list
 845            a list of strings of episode names
 846        load_episodes : list, optional
 847            a list of strings of (previously run) episode names to load the model from; if the episode has multiple runs,
 848            the new episode will have the same number of runs, each starting with one of the pre-trained models
 849        parameters_updates : list, optional
 850            a list of dictionaries used to update the parameters from the config
 851        load_epochs : list, optional
 852            a list of integers used to specify the epoch to load (if load_episodes is not None)
 853        load_searches : list, optional
 854            a list of strings of hyperparameter search results to load
 855        load_parameters : list, optional
 856            a list of lists of string names of the parameters to load from the searches
 857        round_to_binary : list, optional
 858            a list of string names of the loaded parameters that should be rounded to the nearest power of two
 859        load_strict : list, optional
 860            a list of boolean values specifying weight loading policy: if `False`, matching weights will be loaded from
 861            the corresponding episode and differences in parameter name lists and
 862            weight shapes will be ignored; otherwise mismatches will prompt a `RuntimeError` (by default `True` for
 863            every episode)
 864        force : bool, default False
 865            if `True` and an episode name is already taken, it will be overwritten (use with caution!)
 866        suppress_name_check : bool, default False
 867            if `True`, episode names with a double colon are allowed (please don't use this option unless you understand
 868            why they are usually forbidden)
 869        remove_saved_features : bool, default False
 870            if `True`, the dataset will be deleted after training
 871
 872        Returns
 873        -------
 874        TaskDispatcher
 875            the task dispatcher object
 876
 877        """
 878        task = None
 879        if load_searches is None:
 880            load_searches = [None for _ in episode_names]
 881        if load_episodes is None:
 882            load_episodes = [None for _ in episode_names]
 883        if parameters_updates is None:
 884            parameters_updates = [None for _ in episode_names]
 885        if load_parameters is None:
 886            load_parameters = [None for _ in episode_names]
 887        if load_epochs is None:
 888            load_epochs = [None for _ in episode_names]
 889        if load_strict is None:
 890            load_strict = [True for _ in episode_names]
 891        for (
 892            parameters_update,
 893            episode_name,
 894            load_episode,
 895            load_epoch,
 896            load_search,
 897            load_parameters_list,
 898            load_strict_value,
 899        ) in zip(
 900            parameters_updates,
 901            episode_names,
 902            load_episodes,
 903            load_epochs,
 904            load_searches,
 905            load_parameters,
 906            load_strict,
 907        ):
 908            task = self.run_episode(
 909                episode_name,
 910                load_episode,
 911                parameters_update,
 912                task,
 913                load_epoch,
 914                load_search,
 915                load_parameters_list,
 916                round_to_binary,
 917                load_strict_value,
 918                suppress_name_check=suppress_name_check,
 919                force=force,
 920                remove_saved_features=remove_saved_features,
 921            )
 922        return task
 923
 924    def continue_episode(
 925        self,
 926        episode_name: str,
 927        num_epochs: int = None,
 928        task: TaskDispatcher = None,
 929        n_seeds: int = 1,
 930        remove_saved_features: bool = False,
 931        device: str = "cuda",
 932        num_cpus: int = None,
 933    ) -> TaskDispatcher:
 934        """Load an older episode and continue running from the latest checkpoint.
 935
 936        All parameters as well as the model and optimizer state dictionaries are loaded from the episode.
 937
 938        Parameters
 939        ----------
 940        episode_name : str
 941            the name of the episode to continue
 942        num_epochs : int, optional
 943            the new number of epochs
 944        task : TaskDispatcher, optional
 945            a pre-existing task; if provided, the method will update the task instead of creating a new one
 946            (this might save time, mainly on dataset loading)
 947        n_seeds : int, default 1
 948            the number of runs to perform; if `n_seeds > 1`, the episodes will be named `episode_name#run_index`, e.g.
 949            `test_episode#0` and `test_episode#1`
 950        remove_saved_features : bool, default False
 951            if `True`, pre-computed features will be deleted after the run
 952        device : str, default "cuda"
 953            the torch device to use
 954        num_cpus : int, optional
 955            the number of CPUs to use for data loading; if `None`, the number of available CPUs will be used
 956
 957        Returns
 958        -------
 959        TaskDispatcher
 960            the task dispatcher
 961
 962        """
 963        runs = self._episodes().get_runs(episode_name)
 964        for run in runs:
 965            print(f"TRAINING {run}")
 966            if num_epochs is None and not self._episode(run).unfinished():
 967                continue
 968            parameters_update = {
 969                "training": {
 970                    "num_epochs": num_epochs,
 971                    "device": device,
 972                },
 973                "general": {"num_cpus": num_cpus},
 974            }
 975            task, parameters = self._make_task_training(
 976                run,
 977                load_episode=run,
 978                parameters_update=parameters_update,
 979                continuing=True,
 980                task=task,
 981            )
 982            time_start = time.time()
 983            logs = task.train()
 984            time_end = time.time()
 985            old_time = self._training_time(run)
 986            if not np.isnan(old_time):
 987                time_end += old_time
 988                time_total = time_end - time_start
 989                hours = int(time_total // 3600)
 990                time_total -= hours * 3600
 991                minutes = int(time_total // 60)
 992                time_total -= minutes * 60
 993                seconds = int(time_total)
 994                training_time = f"{hours}:{minutes:02}:{seconds:02}"
 995            else:
 996                training_time = np.nan
 997            self._save_episode(
 998                run,
 999                parameters,
1000                task.behaviors_dict(),
1001                suppress_validation=True,
1002                training_time=training_time,
1003                norm_stats=task.get_normalization_stats(),
1004            )
1005            self._update_episode_results(run, logs)
1006            print("\n")
1007        if len(runs) < n_seeds:
1008            for i in range(len(runs), n_seeds):
1009                self.run_episode(
1010                    f"{episode_name}#{i}",
1011                    parameters_update=self._episodes().load_parameters(runs[0]),
1012                    task=task,
1013                    suppress_name_check=True,
1014                )
1015        if remove_saved_features:
1016            self._remove_stores(parameters)
1017        return task
1018
1019    def run_default_hyperparameter_search(
1020        self,
1021        search_name: str,
1022        model_name: str,
1023        metric: str = "f1",
1024        best_n: int = 3,
1025        direction: str = "maximize",
1026        load_episode: str = None,
1027        load_epoch: int = None,
1028        load_strict: bool = True,
1029        prune: bool = True,
1030        force: bool = False,
1031        remove_saved_features: bool = False,
1032        overlap: float = 0,
1033        num_epochs: int = 50,
1034        test_frac: float = None,
1035        n_trials=150,
1036        batch_size=32,
1037    ):
1038        """Run an optuna hyperparameter search with default parameters for a model.
1039
1040        For the vast majority of cases, optimizing the default parameters should be enough.
1041        Check out `dlc2action.options.model_hyperparameters` for the lists of parameters.
1042        There are also options to set overlap, test fraction and number of epochs parameters for the search without
1043        modifying the project config files. However, if you want something more complex, look into
1044        `Project.run_hyperparameter_search`.
1045
1046        The task parameters are read from the config files and updated with the parameters_update dictionary.
1047        The model can be either initialized from scratch or loaded from a previously run episode.
1048        For each trial, the objective metric is averaged over a few best epochs.
1049
1050        Parameters
1051        ----------
1052        search_name : str
1053            the name of the search to store it in the meta files and load in run_episode
1054        model_name : str
1055            the name
1056        metric : str
1057            the metric to maximize/minimize (see direction); if the metric has an `"average"` parameter and it is set to
1058            `"none"` in the config files, it will be reset to `"macro"` for the search
1059        best_n : int, default 1
1060            the number of epochs to average the metric; if 0, the last value is taken
1061        direction : {'maximize', 'minimize'}
1062            optimization direction
1063        load_episode : str, optional
1064            the name of the episode to load the model from
1065        load_epoch : int, optional
1066            the epoch to load the model from (if not provided, the last checkpoint is used)
1067        load_strict : bool, default True
1068            if `True`, the model will be loaded only if the parameters match exactly
1069        prune : bool, default False
1070            if `True`, experiments where the optimized metric is improving too slowly will be terminated
1071            (with optuna HyperBand pruner)
1072        force : bool, default False
1073            if `True`, existing searches with the same name will be overwritten
1074        remove_saved_features : bool, default False
1075            if `True`, pre-computed features will be deleted after each run (if the data parameters change)
1076        overlap : float, default 0
1077            the overlap to use for the search
1078        num_epochs : int, default 50
1079            the number of epochs to use for the search
1080        test_frac : float, optional
1081            the test fraction to use for the search
1082        n_trials : int, default 150
1083            the number of trials to run
1084        batch_size : int, default 32
1085            the batch size to use for the search
1086
1087        Returns
1088        -------
1089        best_parameters : dict
1090            a dictionary of best parameters
1091
1092        """
1093        if model_name not in options.model_hyperparameters:
1094            raise ValueError(
1095                f"There is no default search space for {model_name}! Please choose from {options.model_hyperparameters.keys()} or try project.run_hyperparameter_search()"
1096            )
1097        pars = {
1098            "general": {"overlap": overlap, "model_name": model_name},
1099            "training": {"num_epochs": num_epochs, "batch_size": batch_size},
1100        }
1101        if test_frac is not None:
1102            pars["training"]["test_frac"] = test_frac
1103        if not metric.split("_")[-1].isnumeric():
1104            project_pars = self._read_parameters()
1105            if project_pars["metrics"][metric].get("average") == "none":
1106                pars["metrics"] = {metric: {"average": "macro"}}
1107        return self.run_hyperparameter_search(
1108            search_name=search_name,
1109            search_space=options.model_hyperparameters[model_name],
1110            metric=metric,
1111            n_trials=n_trials,
1112            best_n=best_n,
1113            parameters_update=pars,
1114            direction=direction,
1115            load_episode=load_episode,
1116            load_epoch=load_epoch,
1117            load_strict=load_strict,
1118            prune=prune,
1119            force=force,
1120            remove_saved_features=remove_saved_features,
1121        )
1122
1123    def run_hyperparameter_search(
1124        self,
1125        search_name: str,
1126        search_space: Dict,
1127        metric: str = "f1",
1128        n_trials: int = 20,
1129        best_n: int = 1,
1130        parameters_update: Dict = None,
1131        direction: str = "maximize",
1132        load_episode: str = None,
1133        load_epoch: int = None,
1134        load_strict: bool = True,
1135        prune: bool = False,
1136        force: bool = False,
1137        remove_saved_features: bool = False,
1138        make_plots: bool = True,
1139    ) -> Dict:
1140        """Run an optuna hyperparameter search.
1141
1142        For a simpler function that fits most use cases, check out `Project.run_default_hyperparameter_search()`.
1143
1144        To use a default search space with this method, import `dlc2action.options.model_hyperparameters`. It is
1145        a dictionary where keys are model names and values are default search spaces.
1146
1147        The task parameters are read from the config files and updated with the parameters_update dictionary.
1148        The model can be either initialized from scratch or loaded from a previously run episode.
1149        For each trial, the objective metric is averaged over a few best epochs.
1150
1151        Parameters
1152        ----------
1153        search_name : str
1154            the name of the search to store it in the meta files and load in run_episode
1155        search_space : dict
1156            a dictionary representing the search space; of this general structure:
1157            {'group/param_name': ('float/int/float_log/int_log', start, end),
1158            'group/param_name': ('categorical', [choices])}, e.g.
1159            {'data/overlap': ('int', 5, 100), 'training/lr': ('float_log', 1e-4, 1e-2),
1160            'data/feature_extraction': ('categorical', ['kinematic', 'bones'])};
1161        metric : str, default f1
1162            the metric to maximize/minimize (see direction)
1163        n_trials : int, default 20
1164            the number of optimization trials to run
1165        best_n : int, default 1
1166            the number of epochs to average the metric; if 0, the last value is taken
1167        parameters_update : dict, optional
1168            the parameters update dictionary
1169        direction : {'maximize', 'minimize'}
1170            optimization direction
1171        load_episode : str, optional
1172            the name of the episode to load the model from
1173        load_epoch : int, optional
1174            the epoch to load the model from (if not provided, the last checkpoint is used)
1175        load_strict : bool, default True
1176            if `True`, the model will be loaded only if the parameters match exactly
1177        prune : bool, default False
1178            if `True`, experiments where the optimized metric is improving too slowly will be terminated
1179            (with optuna HyperBand pruner)
1180        force : bool, default False
1181            if `True`, existing searches with the same name will be overwritten
1182        remove_saved_features : bool, default False
1183            if `True`, pre-computed features will be deleted after each run (if the data parameters change)
1184
1185        Returns
1186        -------
1187        dict
1188            a dictionary of best parameters
1189
1190        """
1191        self._check_search_validity(search_name, force=force)
1192        print(f"SEARCH {search_name}")
1193        self.remove_episode(f"_{search_name}")
1194        if parameters_update is None:
1195            parameters_update = {}
1196        parameters_update = self._update(
1197            parameters_update, {"general": {"metric_functions": {metric}}}
1198        )
1199        parameters = self._make_parameters(
1200            f"_{search_name}",
1201            load_episode,
1202            parameters_update,
1203            parameters_update_second={"training": {"model_save_path": None}},
1204            load_epoch=load_epoch,
1205            load_strict=load_strict,
1206        )
1207        task = None
1208
1209        if prune:
1210            pruner = optuna.pruners.HyperbandPruner()
1211        else:
1212            pruner = optuna.pruners.NopPruner()
1213        study = optuna.create_study(direction=direction, pruner=pruner)
1214        runner = _Runner(
1215            search_space=search_space,
1216            load_episode=load_episode,
1217            load_epoch=load_epoch,
1218            metric=metric,
1219            average=best_n,
1220            task=task,
1221            remove_saved_features=remove_saved_features,
1222            project=self,
1223            search_name=search_name,
1224        )
1225        study.optimize(lambda trial: runner.run(trial, parameters), n_trials=n_trials)
1226        if make_plots:
1227            search_path = self._search_path(search_name)
1228            os.mkdir(search_path)
1229            fig = optuna.visualization.plot_contour(study)
1230            plotly.offline.plot(
1231                fig, filename=os.path.join(search_path, f"{search_name}_contour.html")
1232            )
1233            fig = optuna.visualization.plot_param_importances(study)
1234            plotly.offline.plot(
1235                fig,
1236                filename=os.path.join(search_path, f"{search_name}_importances.html"),
1237            )
1238        best_params = study.best_params
1239        best_value = study.best_value
1240        if best_value == 0 or best_value == float("inf"):
1241            raise ValueError(
1242                f"Best metric value is {best_value}, check your partition method and make sure that all behaviors are present in the validation set!"
1243            )
1244        self._save_search(
1245            search_name,
1246            parameters,
1247            n_trials,
1248            best_params,
1249            best_value,
1250            metric,
1251            search_space,
1252        )
1253        self.remove_episode(f"_{search_name}")
1254        runner.clean()
1255        print(f"best parameters: {best_params}")
1256        print("\n")
1257        return best_params
1258
1259    def run_prediction(
1260        self,
1261        prediction_name: str,
1262        episode_names: List,
1263        load_epochs: List = None,
1264        parameters_update: Dict = None,
1265        augment_n: int = 10,
1266        data_path: str = None,
1267        mode: str = "all",
1268        file_paths: Set = None,
1269        remove_saved_features: bool = False,
1270        frame_number_map_file: str = None,
1271        force: bool = False,
1272        embedding: bool = False,
1273    ) -> None:
1274        """Load models from previously run episodes to generate a prediction.
1275
1276        The probabilities predicted by the models are averaged.
1277        Unless `submission` is `True`, the prediction results are saved as a pickled dictionary in the project_name/results/predictions folder
1278        under the {episode_name}_{load_epoch}.pickle name. The file is a nested dictionary where the first-level
1279        keys are the video ids, the second-level keys are the clip ids (like individual names) and the values
1280        are the prediction arrays.
1281
1282        Parameters
1283        ----------
1284        prediction_name : str
1285            the name of the prediction
1286        episode_names : list
1287            a list of string episode names to load the models from
1288        load_epochs : list or int, optional
1289            a list of integer epoch indices to load the model from; if None, the last ones are used, if int the same epoch is used for all episodes
1290        parameters_update : dict, optional
1291            a dictionary of parameter updates
1292        augment_n : int, default 10
1293            the number of augmentations to average over
1294        data_path : str, optional
1295            the data path to run the prediction for
1296        mode : {'all', 'test', 'val', 'train'}
1297            the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
1298        file_paths : set, optional
1299            a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction
1300            for
1301        remove_saved_features : bool, default False
1302            if `True`, pre-computed features will be deleted
1303        submission : bool, default False
1304            if `True`, a MABe-22 style submission file is generated
1305        frame_number_map_file : str, optional
1306            path to the frame number map file
1307        force : bool, default False
1308            if `True`, existing prediction with this name will be overwritten
1309        embedding : bool, default False
1310            if `True`, the prediction is made for the embedding task
1311
1312        """
1313        self._check_prediction_validity(prediction_name, force=force)
1314        print(f"PREDICTION {prediction_name}")
1315        task, parameters, mode, prediction, inference_time, behavior_dict = (
1316            self._make_prediction(
1317                prediction_name,
1318                episode_names,
1319                load_epochs,
1320                parameters_update,
1321                data_path,
1322                file_paths,
1323                mode,
1324                augment_n,
1325                evaluate=False,
1326                embedding=embedding,
1327            )
1328        )
1329        predicted = task.dataset(mode).generate_full_length_prediction(prediction)
1330
1331        if remove_saved_features:
1332            self._remove_stores(parameters)
1333
1334        self._save_prediction(
1335            prediction_name,
1336            predicted,
1337            parameters,
1338            task,
1339            mode,
1340            embedding,
1341            inference_time,
1342            behavior_dict,
1343        )
1344        print("\n")
1345
1346    def evaluate_prediction(
1347        self,
1348        prediction_name: str,
1349        parameters_update: Dict = None,
1350        data_path: str = None,
1351        annotation_path: str = None,
1352        file_paths: Set = None,
1353        mode: str = None,
1354        remove_saved_features: bool = False,
1355        annotation_type: str = "none",
1356        num_classes: int = None,  # Set when using data_path
1357    ) -> Tuple[float, dict]:
1358        """Make predictions and evaluate them
1359        inputs:
1360            prediction_name (str): the name of the prediction
1361            parameters_update (dict): a dictionary of parameter updates
1362            data_path (str): the data path to run the prediction for
1363            annotation_path (str): the annotation path to run the prediction for
1364            file_paths (set): a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction for
1365            mode (str): the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
1366            remove_saved_features (bool): if `True`, pre-computed features will be deleted
1367            annotation_type (str): the type of annotation to use for evaluation
1368            num_classes (int): the number of classes in the dataset, must be set with data_path
1369        outputs:
1370            results (dict): a dictionary of average values of metric functions
1371        """
1372
1373        prediction_path = os.path.join(
1374            self.project_path, "results", "predictions", f"{prediction_name}"
1375        )
1376        prediction_dict = {}
1377        for prediction_file_path in [
1378            os.path.join(prediction_path, i) for i in os.listdir(prediction_path)
1379        ]:
1380            with open(os.path.join(prediction_file_path), "rb") as f:
1381                prediction = pickle.load(f)
1382            video_id = os.path.basename(prediction_file_path).split(
1383                "_" + prediction_name
1384            )[0]
1385            prediction_dict[video_id] = prediction
1386        if parameters_update is None:
1387            parameters_update = {}
1388        parameters_update = self._update(
1389            self._predictions().load_parameters(prediction_name), parameters_update
1390        )
1391        parameters_update.pop("model")
1392        if not data_path is None:
1393            assert (
1394                not num_classes is None
1395            ), "num_classes must be provided if data_path is provided"
1396            parameters_update["general"]["num_classes"] = num_classes + int(
1397                parameters_update["general"]["exclusive"]
1398            )
1399        task, parameters, mode = self._make_task_prediction(
1400            "_",
1401            load_episode=None,
1402            parameters_update=parameters_update,
1403            data_path=data_path,
1404            annotation_path=annotation_path,
1405            file_paths=file_paths,
1406            mode=mode,
1407            annotation_type=annotation_type,
1408        )
1409        results = task.evaluate_prediction(prediction_dict, data=mode)
1410        if remove_saved_features:
1411            self._remove_stores(parameters)
1412        results = Project._reformat_results(
1413            results[1],
1414            task.behaviors_dict(),
1415            exclusive=task.general_parameters["exclusive"],
1416        )
1417        return results
1418
1419    def evaluate(
1420        self,
1421        episode_names: List,
1422        load_epochs: List = None,
1423        augment_n: int = 0,
1424        data_path: str = None,
1425        file_paths: Set = None,
1426        mode: str = None,
1427        parameters_update: Dict = None,
1428        multiple_episode_policy: str = "average",
1429        remove_saved_features: bool = False,
1430        skip_updating_meta: bool = True,
1431        annotation_type: str = "none",
1432    ) -> Dict:
1433        """Load one or several models from previously run episodes to make an evaluation.
1434
1435        By default it will run on the test (or validation, if there is no test) subset of the project dataset.
1436
1437        Parameters
1438        ----------
1439        episode_names : list
1440            a list of string episode names to load the models from
1441        load_epochs : list, optional
1442            a list of integer epoch indices to load the model from; if None, the last ones are used
1443        augment_n : int, default 0
1444            the number of augmentations to average over
1445        data_path : str, optional
1446            the data path to run the prediction for
1447        file_paths : set, optional
1448            a set of files to run the prediction for
1449        mode : {'test', 'val', 'train', 'all'}
1450            the subset of the data to make the prediction for (forced to 'all' if data_path is not None;
1451            by default 'test' if test subset is not empty and 'val' otherwise)
1452        parameters_update : dict, optional
1453            a dictionary with parameter updates (cannot change model parameters)
1454        multiple_episode_policy : {'average', 'statistics'}
1455            the policy to use when multiple episodes are provided
1456        remove_saved_features : bool, default False
1457            if `True`, the dataset will be deleted
1458        skip_updating_meta : bool, default True
1459            if `True`, the meta file will not be updated with the computed metrics
1460
1461        Returns
1462        -------
1463        metric : dict
1464            a dictionary of average values of metric functions
1465
1466        """
1467        names = []
1468        for episode_name in episode_names:
1469            names += self._episodes().get_runs(episode_name)
1470        if len(set(episode_names)) == 1:
1471            print(f"EVALUATION {episode_names[0]}")
1472        else:
1473            print(f"EVALUATION {episode_names}")
1474        if len(names) > 1:
1475            evaluate = True
1476        else:
1477            evaluate = False
1478        if multiple_episode_policy == "average":
1479            task, parameters, mode, prediction, inference_time, behavior_dict = (
1480                self._make_prediction(
1481                    "_",
1482                    episode_names,
1483                    load_epochs,
1484                    parameters_update,
1485                    mode=mode,
1486                    data_path=data_path,
1487                    file_paths=file_paths,
1488                    augment_n=augment_n,
1489                    evaluate=evaluate,
1490                    annotation_type=annotation_type,
1491                )
1492            )
1493            print("EVALUATE PREDICTION:")
1494            indices = [
1495                list(behavior_dict.keys()).index(i) for i in range(len(behavior_dict))
1496            ]
1497            _, results = task.evaluate_prediction(
1498                prediction, data=mode, indices=indices
1499            )
1500            if len(names) == 1 and mode == "val" and not skip_updating_meta:
1501                self._update_episode_metrics(names[0], results)
1502            results = Project._reformat_results(
1503                results,
1504                behavior_dict,
1505                exclusive=task.general_parameters["exclusive"],
1506            )
1507
1508        elif multiple_episode_policy == "statistics":
1509            values = defaultdict(lambda: [])
1510            task = None
1511            for name in names:
1512                (
1513                    task,
1514                    parameters,
1515                    mode,
1516                    prediction,
1517                    inference_time,
1518                    behavior_dict,
1519                ) = self._make_prediction(
1520                    "_",
1521                    [name],
1522                    load_epochs,
1523                    parameters_update,
1524                    mode=mode,
1525                    data_path=data_path,
1526                    file_paths=file_paths,
1527                    augment_n=augment_n,
1528                    evaluate=evaluate,
1529                    task=task,
1530                )
1531                _, metrics = task.evaluate_prediction(
1532                    prediction, data=mode, indices=list(behavior_dict.keys())
1533                )
1534                for name, value in metrics.items():
1535                    values[name].append(value)
1536                if mode == "val" and not skip_updating_meta:
1537                    self._update_episode_metrics(name, metrics)
1538            results = defaultdict(lambda: {})
1539            mean_string = ""
1540            std_string = ""
1541            for key, value_list in values.items():
1542                results[key]["mean"] = np.mean(value_list)
1543                results[key]["std"] = np.std(value_list)
1544                results[key]["all"] = value_list
1545                mean_string += f"{key} {np.mean(value_list):.3f}, "
1546                std_string += f"{key} {np.std(value_list):.3f}, "
1547            print("MEAN:")
1548            print(mean_string)
1549            print("STD:")
1550            print(std_string)
1551        else:
1552            raise ValueError(
1553                f"The {multiple_episode_policy} multiple episode policy is not recognized; please choose "
1554                f"from ['average', 'statistics']"
1555            )
1556        if len(names) > 0 and remove_saved_features:
1557            self._remove_stores(parameters)
1558        print(f"Inference time: {inference_time}")
1559        print("\n")
1560        return results
1561
1562    def run_suggestion(
1563        self,
1564        suggestions_name: str,
1565        error_episode: str = None,
1566        error_load_epoch: int = None,
1567        error_class: str = None,
1568        suggestions_prediction: str = None,
1569        suggestion_episodes: List = [None],
1570        suggestion_load_epoch: int = None,
1571        suggestion_classes: List = None,
1572        error_threshold: float = 0.5,
1573        error_threshold_diff: float = 0.1,
1574        error_hysteresis: bool = False,
1575        suggestion_threshold: Union[float, List] = 0.5,
1576        suggestion_threshold_diff: Union[float, List] = 0.1,
1577        suggestion_hysteresis: Union[bool, List] = True,
1578        min_frames_suggestion: int = 10,
1579        min_frames_al: int = 30,
1580        visibility_min_score: float = 0,
1581        visibility_min_frac: float = 0.7,
1582        augment_n: int = 0,
1583        exclude_classes: List = None,
1584        exclude_threshold: Union[float, List] = 0.6,
1585        exclude_threshold_diff: Union[float, List] = 0.1,
1586        exclude_hysteresis: Union[bool, List] = False,
1587        include_classes: List = None,
1588        include_threshold: Union[float, List] = 0.4,
1589        include_threshold_diff: Union[float, List] = 0.1,
1590        include_hysteresis: Union[bool, List] = False,
1591        data_path: str = None,
1592        file_paths: Set = None,
1593        parameters_update: Dict = None,
1594        mode: str = "all",
1595        force: bool = False,
1596        remove_saved_features: bool = False,
1597        cut_annotated: bool = False,
1598        background_threshold: float = None,
1599    ) -> None:
1600        """Create active learning and suggestion files.
1601
1602        Generate predictions with the error and suggestion model and use them to create
1603        suggestion files for the labeling interface. Those files will render as suggested labels
1604        at intervals with high pose estimation quality. Quality here is defined by probability of error
1605        (predicted by the error model) and visibility parameters.
1606
1607        If `error_episode` or `exclude_classes` is not `None`,
1608        an active learning file will be created as well (with frames with high predicted probability of classes
1609        from `exclude_classes` and/or errors excluded from the active learning intervals).
1610
1611        In all three steps (predicting errors, suggesting labels and excluding them from active learning intervals)
1612        you can apply one of three methods.
1613
1614        - **Simple threshold**
1615
1616            Set the `hysteresis` parameter (e.g. `error_hysteresis`) to `False` and the `threshold`
1617            parameter to $\alpha$.
1618            In this case if the probability of a label is predicted to be higher than $\alpha$ the frame will
1619            be considered labeled.
1620
1621        - **Hysteresis threshold**
1622
1623            Set the `hysteresis` parameter (e.g. `error_hysteresis`) to `True`, the `threshold`
1624            parameter to $\alpha$ and the `threshold_diff` parameter to $\beta$.
1625            Now intervals will be marked with a label if the probability of that label for all frames is higher
1626            than $\alpha - \beta$ and at least for one frame in that interval it is higher than $\alpha$.
1627
1628        - **Max hysteresis threshold**
1629
1630            Set the `hysteresis` parameter (e.g. `error_hysteresis`) to `True`, the `threshold`
1631            parameter to $\alpha$ and the `threshold_diff` parameter to `None`.
1632            With this combination intervals are marked with a label if that label is more likely than any other
1633            for all frames in this interval and at for at least one of those frames its probability is higher than
1634            $\alpha$.
1635
1636        Parameters
1637        ----------
1638        suggestions_name : str
1639            the name of the suggestions
1640        error_episode : str, optional
1641            the name of the episode where the error model should be loaded from
1642        error_load_epoch : int, optional
1643            the epoch the error model should be loaded from
1644        error_class : str, optional
1645            the name of the error class (in `error_episode`)
1646        suggestions_prediction : str, optional
1647            the name of the predictions that should be used for the suggestion model
1648        suggestion_episodes : list, optional
1649            the names of the episodes where the suggestion models should be loaded from
1650        suggestion_load_epoch : int, optional
1651            the epoch the suggestion model should be loaded from
1652        suggestion_classes : list, optional
1653            a list of string names of the classes that should be suggested (in `suggestion_episode`)
1654        error_threshold : float, default 0.5
1655            the hard threshold for error prediction
1656        error_threshold_diff : float, default 0.1
1657            the difference between soft and hard thresholds for error prediction (in case hysteresis is used)
1658        error_hysteresis : bool, default False
1659            if True, hysteresis is used for error prediction
1660        suggestion_threshold : float | list, default 0.5
1661            the hard threshold for class prediction (use a list to set different rules for different classes)
1662        suggestion_threshold_diff : float | list, default 0.1
1663            the difference between soft and hard thresholds for class prediction (in case hysteresis is used;
1664            use a list to set different rules for different classes)
1665        suggestion_hysteresis : bool | list, default True
1666            if True, hysteresis is used for class prediction (use a list to set different rules for different classes)
1667        min_frames_suggestion : int, default 10
1668            only actions longer than this number of frames will be suggested
1669        min_frames_al : int, default 30
1670            only active learning intervals longer than this number of frames will be suggested
1671        visibility_min_score : float, default 0
1672            the minimum visibility score for visibility filtering
1673        visibility_min_frac : float, default 0.7
1674            the minimum fraction of visible frames for visibility filtering
1675        augment_n : int, default 10
1676            the number of augmentations to average the predictions over
1677        exclude_classes : list, optional
1678            a list of string names of classes that should be excluded from the active learning intervals
1679        exclude_threshold : float | list, default 0.6
1680            the hard threshold for excluded class prediction (use a list to set different rules for different classes)
1681        exclude_threshold_diff : float | list, default 0.1
1682            the difference between soft and hard thresholds for excluded class prediction (in case hysteresis is used)
1683        exclude_hysteresis : bool | list, default False
1684            if True, hysteresis is used for excluded class prediction (use a list to set different rules for different classes)
1685        include_classes : list, optional
1686            a list of string names of classes that should be included into the active learning intervals
1687        include_threshold : float | list, default 0.6
1688            the hard threshold for included class prediction (use a list to set different rules for different classes)
1689        include_threshold_diff : float | list, default 0.1
1690            the difference between soft and hard thresholds for included class prediction (in case hysteresis is used)
1691        include_hysteresis : bool | list, default False
1692            if True, hysteresis is used for included class prediction (use a list to set different rules for different classes)
1693        data_path : str, optional
1694            the data path to run the prediction for
1695        file_paths : set, optional
1696            a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction
1697            for
1698        parameters_update : dict, optional
1699            the parameters update dictionary
1700        mode : {'all', 'test', 'val', 'train'}
1701            the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
1702        force : bool, default False
1703            if `True` and an episode with name `episode_name` already exists, it will be overwritten (use with caution!)
1704        remove_saved_features : bool, default False
1705            if `True`, the dataset will be deleted.
1706        cut_annotated : bool, default False
1707            if `True`, annotated frames will be cut from the suggestions
1708        background_threshold : float, default 0.5
1709            the threshold for background prediction
1710
1711        """
1712        self._check_suggestions_validity(suggestions_name, force=force)
1713        if any([x is None for x in suggestion_episodes]):
1714            suggestion_episodes = None
1715        if error_episode is None and (
1716            suggestion_episodes is None and suggestions_prediction is None
1717        ):
1718            raise ValueError(
1719                "Both error_episode and suggestion_episode parameters cannot be None at the same time"
1720            )
1721        print(f"SUGGESTION {suggestions_name}")
1722        task = None
1723        if suggestion_classes is None:
1724            suggestion_classes = []
1725        if exclude_classes is None:
1726            exclude_classes = []
1727        if include_classes is None:
1728            include_classes = []
1729        if isinstance(suggestion_threshold, list):
1730            if len(suggestion_threshold) != len(suggestion_classes):
1731                raise ValueError(
1732                    "The suggestion_threshold parameter has to be either a float value or a list of "
1733                    f"float values of the same length as suggestion_classes (got a list of length "
1734                    f"{len(suggestion_threshold)} for {len(suggestion_classes)} classes)"
1735                )
1736        else:
1737            suggestion_threshold = [suggestion_threshold for _ in suggestion_classes]
1738        if isinstance(suggestion_threshold_diff, list):
1739            if len(suggestion_threshold_diff) != len(suggestion_classes):
1740                raise ValueError(
1741                    "The suggestion_threshold_diff parameter has to be either a float value or a list of "
1742                    f"float values of the same length as suggestion_classes (got a list of length "
1743                    f"{len(suggestion_threshold)} for {len(suggestion_classes)} classes)"
1744                )
1745        else:
1746            suggestion_threshold_diff = [
1747                suggestion_threshold_diff for _ in suggestion_classes
1748            ]
1749        if isinstance(suggestion_hysteresis, list):
1750            if len(suggestion_hysteresis) != len(suggestion_classes):
1751                raise ValueError(
1752                    "The suggestion_threshold_diff parameter has to be either a float value or a list of "
1753                    f"float values of the same length as suggestion_classes (got a list of length "
1754                    f"{len(suggestion_hysteresis)} for {len(suggestion_classes)} classes)"
1755                )
1756        else:
1757            suggestion_hysteresis = [suggestion_hysteresis for _ in suggestion_classes]
1758        if isinstance(exclude_threshold, list):
1759            if len(exclude_threshold) != len(exclude_classes):
1760                raise ValueError(
1761                    "The exclude_threshold parameter has to be either a float value or a list of "
1762                    f"float values of the same length as exclude_classes (got a list of length "
1763                    f"{len(exclude_threshold)} for {len(exclude_classes)} classes)"
1764                )
1765        else:
1766            exclude_threshold = [exclude_threshold for _ in exclude_classes]
1767        if isinstance(exclude_threshold_diff, list):
1768            if len(exclude_threshold_diff) != len(exclude_classes):
1769                raise ValueError(
1770                    "The exclude_threshold_diff parameter has to be either a float value or a list of "
1771                    f"float values of the same length as exclude_classes (got a list of length "
1772                    f"{len(exclude_threshold_diff)} for {len(exclude_classes)} classes)"
1773                )
1774        else:
1775            exclude_threshold_diff = [exclude_threshold_diff for _ in exclude_classes]
1776        if isinstance(exclude_hysteresis, list):
1777            if len(exclude_hysteresis) != len(exclude_classes):
1778                raise ValueError(
1779                    "The suggestion_threshold_diff parameter has to be either a float value or a list of "
1780                    f"float values of the same length as suggestion_classes (got a list of length "
1781                    f"{len(exclude_hysteresis)} for {len(exclude_classes)} classes)"
1782                )
1783        else:
1784            exclude_hysteresis = [exclude_hysteresis for _ in exclude_classes]
1785        if isinstance(include_threshold, list):
1786            if len(include_threshold) != len(include_classes):
1787                raise ValueError(
1788                    "The exclude_threshold parameter has to be either a float value or a list of "
1789                    f"float values of the same length as exclude_classes (got a list of length "
1790                    f"{len(include_threshold)} for {len(include_classes)} classes)"
1791                )
1792        else:
1793            include_threshold = [include_threshold for _ in include_classes]
1794        if isinstance(include_threshold_diff, list):
1795            if len(include_threshold_diff) != len(include_classes):
1796                raise ValueError(
1797                    "The exclude_threshold_diff parameter has to be either a float value or a list of "
1798                    f"float values of the same length as exclude_classes (got a list of length "
1799                    f"{len(include_threshold_diff)} for {len(include_classes)} classes)"
1800                )
1801        else:
1802            include_threshold_diff = [include_threshold_diff for _ in include_classes]
1803        if isinstance(include_hysteresis, list):
1804            if len(include_hysteresis) != len(include_classes):
1805                raise ValueError(
1806                    "The suggestion_threshold_diff parameter has to be either a float value or a list of "
1807                    f"float values of the same length as suggestion_classes (got a list of length "
1808                    f"{len(include_hysteresis)} for {len(include_classes)} classes)"
1809                )
1810        else:
1811            include_hysteresis = [include_hysteresis for _ in include_classes]
1812        if (suggestion_episodes is None and suggestions_prediction is None) and len(
1813            exclude_classes
1814        ) > 0:
1815            raise ValueError(
1816                "In order to exclude classes from the active learning intervals you need to set the "
1817                "suggestion_episode parameter"
1818            )
1819
1820        task = None
1821        if error_episode is not None:
1822            task, parameters, mode = self._make_task_prediction(
1823                prediction_name=suggestions_name,
1824                load_episode=error_episode,
1825                parameters_update=parameters_update,
1826                load_epoch=error_load_epoch,
1827                data_path=data_path,
1828                mode=mode,
1829                file_paths=file_paths,
1830                task=task,
1831            )
1832            predicted_error = task.predict(
1833                data=mode,
1834                raw_output=True,
1835                apply_primary_function=True,
1836                augment_n=augment_n,
1837            )
1838        else:
1839            predicted_error = None
1840
1841        if suggestion_episodes is not None:
1842            (
1843                task,
1844                parameters,
1845                mode,
1846                predicted_classes,
1847                inference_time,
1848                behavior_dict,
1849            ) = self._make_prediction(
1850                prediction_name=suggestions_name,
1851                episode_names=suggestion_episodes,
1852                load_epochs=suggestion_load_epoch,
1853                parameters_update=parameters_update,
1854                data_path=data_path,
1855                file_paths=file_paths,
1856                mode=mode,
1857                task=task,
1858            )
1859        elif suggestions_prediction is not None:
1860            with open(
1861                os.path.join(
1862                    self.project_path,
1863                    "results",
1864                    "predictions",
1865                    f"{suggestions_prediction}.pickle",
1866                ),
1867                "rb",
1868            ) as f:
1869                predicted_classes = pickle.load(f)
1870            if parameters_update is None:
1871                parameters_update = {}
1872            parameters_update = self._update(
1873                self._predictions().load_parameters(suggestions_prediction),
1874                parameters_update,
1875            )
1876            parameters_update.pop("model")
1877            if suggestion_episodes is None:
1878                suggestion_episodes = [
1879                    os.path.basename(
1880                        os.path.dirname(
1881                            parameters_update["training"]["checkpoint_path"]
1882                        )
1883                    )
1884                ]
1885            task, parameters, mode = self._make_task_prediction(
1886                "_",
1887                load_episode=None,
1888                parameters_update=parameters_update,
1889                data_path=data_path,
1890                file_paths=file_paths,
1891                mode=mode,
1892            )
1893        else:
1894            predicted_classes = None
1895
1896        if len(suggestion_classes) > 0 and predicted_classes is not None:
1897            suggestions = self._make_suggestions(
1898                task,
1899                predicted_error,
1900                predicted_classes,
1901                suggestion_threshold,
1902                suggestion_threshold_diff,
1903                suggestion_hysteresis,
1904                suggestion_episodes,
1905                suggestion_classes,
1906                error_threshold,
1907                min_frames_suggestion,
1908                min_frames_al,
1909                visibility_min_score,
1910                visibility_min_frac,
1911                cut_annotated=cut_annotated,
1912            )
1913            videos = list(suggestions.keys())
1914            for v_id in videos:
1915                times_dict = defaultdict(lambda: defaultdict(lambda: []))
1916                clips = set()
1917                for c in suggestions[v_id]:
1918                    for start, end, ind in suggestions[v_id][c]:
1919                        times_dict[ind][c].append([start, end, 2])
1920                        clips.add(ind)
1921                clips = list(clips)
1922                times_dict = dict(times_dict)
1923                times = [
1924                    [times_dict[ind][c] for c in suggestion_classes] for ind in clips
1925                ]
1926                save_path = self._suggestion_path(v_id, suggestions_name)
1927                Path(os.path.dirname(save_path)).mkdir(parents=True, exist_ok=True)
1928                with open(save_path, "wb") as f:
1929                    pickle.dump((None, suggestion_classes, clips, times), f)
1930
1931        if (
1932            error_episode is not None
1933            or len(exclude_classes) > 0
1934            or len(include_classes) > 0
1935        ):
1936            al_points = self._make_al_points(
1937                task,
1938                predicted_error,
1939                predicted_classes,
1940                exclude_classes,
1941                exclude_threshold,
1942                exclude_threshold_diff,
1943                exclude_hysteresis,
1944                include_classes,
1945                include_threshold,
1946                include_threshold_diff,
1947                include_hysteresis,
1948                error_episode,
1949                error_class,
1950                suggestion_episodes,
1951                error_threshold,
1952                error_threshold_diff,
1953                error_hysteresis,
1954                min_frames_al,
1955                visibility_min_score,
1956                visibility_min_frac,
1957            )
1958        else:
1959            al_points = self._make_al_points_from_suggestions(
1960                suggestions_name,
1961                task,
1962                predicted_classes,
1963                background_threshold,
1964                visibility_min_score,
1965                visibility_min_frac,
1966                num_behaviors=len(task.behaviors_dict()),
1967            )
1968        save_path = self._al_points_path(suggestions_name)
1969        Path(os.path.dirname(save_path)).mkdir(parents=True, exist_ok=True)
1970        with open(save_path, "wb") as f:
1971            pickle.dump(al_points, f)
1972
1973        meta_parameters = {
1974            "error_episode": error_episode,
1975            "error_load_epoch": error_load_epoch,
1976            "error_class": error_class,
1977            "suggestion_episode": suggestion_episodes,
1978            "suggestion_load_epoch": suggestion_load_epoch,
1979            "suggestion_classes": suggestion_classes,
1980            "error_threshold": error_threshold,
1981            "error_threshold_diff": error_threshold_diff,
1982            "error_hysteresis": error_hysteresis,
1983            "suggestion_threshold": suggestion_threshold,
1984            "suggestion_threshold_diff": suggestion_threshold_diff,
1985            "suggestion_hysteresis": suggestion_hysteresis,
1986            "min_frames_suggestion": min_frames_suggestion,
1987            "min_frames_al": min_frames_al,
1988            "visibility_min_score": visibility_min_score,
1989            "visibility_min_frac": visibility_min_frac,
1990            "augment_n": augment_n,
1991            "exclude_classes": exclude_classes,
1992            "exclude_threshold": exclude_threshold,
1993            "exclude_threshold_diff": exclude_threshold_diff,
1994            "exclude_hysteresis": exclude_hysteresis,
1995        }
1996        self._save_suggestions(suggestions_name, {}, meta_parameters)
1997        if data_path is not None or file_paths is not None or remove_saved_features:
1998            self._remove_stores(parameters)
1999        print(f"\n")
2000
2001    def _generate_similarity_score(
2002        self,
2003        prediction_name: str,
2004        target_video_id: str,
2005        target_clip: str,
2006        target_start: int,
2007        target_end: int,
2008    ) -> Dict:
2009        with open(
2010            os.path.join(
2011                self.project_path,
2012                "results",
2013                "predictions",
2014                f"{prediction_name}.pickle",
2015            ),
2016            "rb",
2017        ) as f:
2018            prediction = pickle.load(f)
2019        target = prediction[target_video_id][target_clip][:, target_start:target_end]
2020        score_dict = defaultdict(lambda: {})
2021        for video_id in prediction:
2022            for clip_id in prediction[video_id]:
2023                score_dict[video_id][clip_id] = torch.cdist(
2024                    target.T, prediction[video_id][score_dict].T
2025                ).min(0)
2026        return score_dict
2027
2028    def _suggest_intervals_from_dict(self, score_dict, min_length, n_intervals) -> Dict:
2029        """Suggest intervals from a score dictionary.
2030
2031        Parameters
2032        ----------
2033        score_dict : dict
2034            a dictionary containing scores for intervals
2035        min_length : int
2036            minimum length of intervals to suggest
2037        n_intervals : int
2038            number of intervals to suggest
2039
2040        Returns
2041        -------
2042        intervals : dict
2043            a dictionary of suggested intervals
2044
2045        """
2046        interval_address = {}
2047        interval_value = {}
2048        s = 0
2049        n = 0
2050        for video_id, video_dict in score_dict.items():
2051            for clip_id, value in video_dict.items():
2052                s += value.mean()
2053                n += 1
2054        mean_value = s / n
2055        alpha = 1.75
2056        for it in range(10):
2057            id = 0
2058            interval_address = {}
2059            interval_value = {}
2060            for video_id, video_dict in score_dict.items():
2061                for clip_id, value in video_dict.items():
2062                    res_indices_start, res_indices_end = apply_threshold(
2063                        value,
2064                        threshold=(2 - alpha * (0.9**it)) * mean_value,
2065                        low=True,
2066                        error_mask=None,
2067                        min_frames=min_length,
2068                        smooth_interval=0,
2069                    )
2070                    for start, end in zip(res_indices_start, res_indices_end):
2071                        interval_address[id] = [video_id, clip_id, start, end]
2072                        interval_value[id] = score_dict[video_id][clip_id][
2073                            start:end
2074                        ].mean()
2075                        id += 1
2076            if len(interval_address) >= n_intervals:
2077                break
2078        if len(interval_address) < n_intervals:
2079            warnings.warn(
2080                f"Could not get {n_intervals} intervals from the data, saving the result with {len(interval_address)} intervals"
2081            )
2082        sorted_intervals = sorted(
2083            interval_value.items(), key=lambda x: x[1], reverse=True
2084        )
2085        output_intervals = [
2086            interval_address[x[0]]
2087            for x in sorted_intervals[: min(len(sorted_intervals), n_intervals)]
2088        ]
2089        output = defaultdict(lambda: [])
2090        for video_id, clip_id, start, end in output_intervals:
2091            output[video_id].append([start, end, clip_id])
2092        return output
2093
2094    def suggest_intervals_with_similarity(
2095        self,
2096        suggestions_name: str,
2097        prediction_name: str,
2098        target_video_id: str,
2099        target_clip: str,
2100        target_start: int,
2101        target_end: int,
2102        min_length: int = 60,
2103        n_intervals: int = 5,
2104        force: bool = False,
2105    ):
2106        """
2107        Suggest intervals based on similarity to a target interval.
2108
2109        Parameters
2110        ----------
2111        suggestions_name : str
2112            Name of the suggestion.
2113        prediction_name : str
2114            Name of the prediction to use.
2115        target_video_id : str
2116            Video id of the target interval.
2117        target_clip : str
2118            Clip id of the target interval.
2119        target_start : int
2120            Start frame of the target interval.
2121        target_end : int
2122            End frame of the target interval.
2123        min_length : int, default 60
2124            Minimum length of the suggested intervals.
2125        n_intervals : int, default 5
2126            Number of suggested intervals.
2127        force : bool, default False
2128            If True, the suggestion is overwritten if it already exists.
2129
2130        """
2131        self._check_suggestions_validity(suggestions_name, force=force)
2132        print(f"SUGGESTION {suggestions_name}")
2133        score_dict = self._generate_similarity_score(
2134            prediction_name, target_video_id, target_clip, target_start, target_end
2135        )
2136        intervals = self._suggest_intervals_from_dict(
2137            score_dict, min_length, n_intervals
2138        )
2139        suggestions_path = os.path.join(
2140            self.project_path,
2141            "results",
2142            "suggestions",
2143            suggestions_name,
2144        )
2145        if not os.path.exists(suggestions_path):
2146            os.mkdir(suggestions_path)
2147        with open(
2148            os.path.join(suggestions_path, f"{suggestions_name}_al_points.pickle"), "wb"
2149        ) as f:
2150            pickle.dump(intervals, f)
2151        meta_parameters = {
2152            "prediction_name": prediction_name,
2153            "min_frames_suggestion": min_length,
2154            "n_intervals": n_intervals,
2155            "target_clip": target_clip,
2156            "target_start": target_start,
2157            "target_end": target_end,
2158        }
2159        self._save_suggestions(suggestions_name, {}, meta_parameters)
2160        print("\n")
2161
2162    def suggest_intervals_with_uncertainty(
2163        self,
2164        suggestions_name: str,
2165        episode_names: List,
2166        load_epochs: List = None,
2167        classes: List = None,
2168        n_frames: int = 10000,
2169        method: str = "least_confidence",
2170        min_length: int = 60,
2171        augment_n: int = 0,
2172        data_path: str = None,
2173        file_paths: Set = None,
2174        parameters_update: Dict = None,
2175        mode: str = "all",
2176        force: bool = False,
2177        remove_saved_features: bool = False,
2178    ) -> None:
2179        """Generate an active learning file based on model uncertainty.
2180
2181        If you provide several episode names, the predicted probabilities will be averaged.
2182
2183        Parameters
2184        ----------
2185        suggestions_name : str
2186            the name of the suggestion
2187        episode_names : list
2188            a list of string episode names to load the models from
2189        load_epochs : list, optional
2190            a list of epoch indices to load the models from (if `None`, the last ones will be used)
2191        classes : list, optional
2192            a list of classes to look at (by default all)
2193        n_frames : int, default 10000
2194            the threshold total number of frames in the suggested intervals (in the end result it will most likely
2195            be slightly larger; it will only be smaller if the algorithm fails to find enough intervals
2196            with the set parameters)
2197        method : {"least_confidence", "entropy"}
2198            the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if
2199            `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`)
2200        min_length : int, default 60
2201            the minimum number of frames in one interval
2202        augment_n : int, default 0
2203            the number of augmentations to average the predictions over
2204        data_path : str, optional
2205            the path to a data folder (by default, the project data is used)
2206        file_paths : set, optional
2207            a list of file paths (by default, the project data is used)
2208        parameters_update : dict, optional
2209            a dictionary of parameter updates
2210        mode : {"test", "val", "train", "all"}
2211            the subset of the data to make the prediction for (forced to 'all' if `data_path` is not `None`;
2212            by default set to `'test'` if the test subset if not empty, or to `'val'` otherwise)
2213        force : bool, default False
2214            if `True`, existing suggestions with the same name will be overwritten
2215        remove_saved_features : bool, default False
2216            if `True`, the dataset will be deleted after the computation
2217
2218        """
2219        self._check_suggestions_validity(suggestions_name, force=force)
2220        print(f"SUGGESTION {suggestions_name}")
2221        task, parameters, mode, predicted, inference_time, behavior_dict = (
2222            self._make_prediction(
2223                suggestions_name,
2224                episode_names,
2225                load_epochs,
2226                parameters_update,
2227                data_path=data_path,
2228                file_paths=file_paths,
2229                mode=mode,
2230                augment_n=augment_n,
2231                evaluate=False,
2232            )
2233        )
2234        if classes is None:
2235            classes = behavior_dict.values()
2236        episode = self._episodes().get_runs(episode_names[0])[0]
2237        score_tensors = task.generate_uncertainty_score(
2238            classes,
2239            augment_n,
2240            method,
2241            predicted,
2242            self._episode(episode).get_behaviors_dict(),
2243        )
2244        intervals = self._suggest_intervals(
2245            task.dataset(mode), score_tensors, n_frames, min_length
2246        )
2247        for k, v in intervals.items():
2248            l = sum([x[1] - x[0] for x in v])
2249            print(f"{k}: {len(v)} ({l})")
2250        if remove_saved_features:
2251            self._remove_stores(parameters)
2252        suggestions_path = os.path.join(
2253            self.project_path,
2254            "results",
2255            "suggestions",
2256            suggestions_name,
2257        )
2258        if not os.path.exists(suggestions_path):
2259            os.mkdir(suggestions_path)
2260        with open(
2261            os.path.join(suggestions_path, f"{suggestions_name}_al_points.pickle"), "wb"
2262        ) as f:
2263            pickle.dump(intervals, f)
2264        meta_parameters = {
2265            "suggestion_episode": episode_names,
2266            "suggestion_load_epoch": load_epochs,
2267            "suggestion_classes": classes,
2268            "min_frames_suggestion": min_length,
2269            "augment_n": augment_n,
2270            "method": method,
2271            "num_frames": n_frames,
2272        }
2273        self._save_suggestions(suggestions_name, {}, meta_parameters)
2274        print("\n")
2275
2276    def suggest_intervals_with_bald(
2277        self,
2278        suggestions_name: str,
2279        episode_name: str,
2280        load_epoch: int = None,
2281        classes: List = None,
2282        n_frames: int = 10000,
2283        num_models: int = 10,
2284        kernel_size: int = 11,
2285        min_length: int = 60,
2286        augment_n: int = 0,
2287        data_path: str = None,
2288        file_paths: Set = None,
2289        parameters_update: Dict = None,
2290        mode: str = "all",
2291        force: bool = False,
2292        remove_saved_features: bool = False,
2293    ):
2294        """Generate an active learning file based on Bayesian Active Learning by Disagreement.
2295
2296        Parameters
2297        ----------
2298        suggestions_name : str
2299            the name of the suggestion
2300        episode_name : str
2301            the name of the episode to load the model from
2302        load_epoch : int, optional
2303            the index of the epoch to load the model from (if `None`, the last one will be used)
2304        classes : list, optional
2305            a list of classes to look at (by default all)
2306        n_frames : int, default 10000
2307            the threshold total number of frames in the suggested intervals (in the end result it will most likely
2308            be slightly larger; it will only be smaller if the algorithm fails to find enough intervals
2309            with the set parameters)
2310        num_models : int, default 10
2311            the number of dropout masks to apply
2312        kernel_size : int, default 11
2313            the size of the smoothing kernel applied to the discrete results
2314        min_length : int, default 60
2315            the minimum number of frames in one interval
2316        augment_n : int, default 0
2317            the number of augmentations to average the predictions over
2318        data_path : str, optional
2319            the path to a data folder (by default, the project data is used)
2320        file_paths : set, optional
2321            a list of file paths (by default, the project data is used)
2322        parameters_update : dict, optional
2323            a dictionary of parameter updates
2324        mode : {"test", "val", "train", "all"}
2325            the subset of the data to make the prediction for (forced to 'all' if `data_path` is not `None`;
2326            by default set to `'test'` if the test subset if not empty, or to `'val'` otherwise)
2327        force : bool, default False
2328            if `True`, existing suggestions with the same name will be overwritten
2329        remove_saved_features : bool, default False
2330            if `True`, the dataset will be deleted after the computation
2331
2332        """
2333        self._check_suggestions_validity(suggestions_name, force=force)
2334        print(f"SUGGESTION {suggestions_name}")
2335        task, parameters, mode = self._make_task_prediction(
2336            suggestions_name,
2337            episode_name,
2338            parameters_update,
2339            load_epoch,
2340            data_path=data_path,
2341            file_paths=file_paths,
2342            mode=mode,
2343        )
2344        if classes is None:
2345            classes = list(task.behaviors_dict().values())
2346        score_tensors = task.generate_bald_score(
2347            classes, augment_n, num_models, kernel_size
2348        )
2349        intervals = self._suggest_intervals(
2350            task.dataset(mode), score_tensors, n_frames, min_length
2351        )
2352        if remove_saved_features:
2353            self._remove_stores(parameters)
2354        suggestions_path = os.path.join(
2355            self.project_path,
2356            "results",
2357            "suggestions",
2358            suggestions_name,
2359        )
2360        if not os.path.exists(suggestions_path):
2361            os.mkdir(suggestions_path)
2362        with open(
2363            os.path.join(suggestions_path, f"{suggestions_name}_al_points.pickle"), "wb"
2364        ) as f:
2365            pickle.dump(intervals, f)
2366        meta_parameters = {
2367            "suggestion_episode": episode_name,
2368            "suggestion_load_epoch": load_epoch,
2369            "suggestion_classes": classes,
2370            "min_frames_suggestion": min_length,
2371            "augment_n": augment_n,
2372            "method": f"BALD:{num_models}",
2373            "num_frames": n_frames,
2374        }
2375        self._save_suggestions(suggestions_name, {}, meta_parameters)
2376        print("\n")
2377
2378    def list_episodes(
2379        self,
2380        episode_names: List = None,
2381        value_filter: str = "",
2382        display_parameters: List = None,
2383        print_results: bool = True,
2384    ) -> pd.DataFrame:
2385        """Get a filtered pandas dataframe with episode metadata.
2386
2387        Parameters
2388        ----------
2389        episode_names : list
2390            a list of strings of episode names
2391        value_filter : str
2392            a string of filters to apply; of this general structure:
2393            'group_name1/par_name1::(</>/<=/>=/=)value1,group_name2/par_name2::(</>/<=/>=/=)value2', e.g.
2394            'data/overlap::=50,results/recall::>0.5,data/feature_extraction::=kinematic,meta/training_time::>=00:00:10'
2395        display_parameters : list
2396            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
2397        print_results : bool, default True
2398            if True, the result will be printed to standard output
2399
2400        Returns
2401        -------
2402        pd.DataFrame
2403            the filtered dataframe
2404
2405        """
2406        episodes = self._episodes().list_episodes(
2407            episode_names, value_filter, display_parameters
2408        )
2409        if print_results:
2410            print("TRAINING EPISODES")
2411            print(episodes)
2412            print("\n")
2413        return episodes
2414
2415    def list_predictions(
2416        self,
2417        episode_names: List = None,
2418        value_filter: str = "",
2419        display_parameters: List = None,
2420        print_results: bool = True,
2421    ) -> pd.DataFrame:
2422        """Get a filtered pandas dataframe with prediction metadata.
2423
2424        Parameters
2425        ----------
2426        episode_names : list
2427            a list of strings of episode names
2428        value_filter : str
2429            a string of filters to apply; of this general structure:
2430            'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g.
2431            'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic'
2432        display_parameters : list
2433            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
2434        print_results : bool, default True
2435            if True, the result will be printed to standard output
2436
2437        Returns
2438        -------
2439        pd.DataFrame
2440            the filtered dataframe
2441
2442        """
2443        predictions = self._predictions().list_episodes(
2444            episode_names, value_filter, display_parameters
2445        )
2446        if print_results:
2447            print("PREDICTIONS")
2448            print(predictions)
2449            print("\n")
2450        return predictions
2451
2452    def list_suggestions(
2453        self,
2454        suggestions_names: List = None,
2455        value_filter: str = "",
2456        display_parameters: List = None,
2457        print_results: bool = True,
2458    ) -> pd.DataFrame:
2459        """Get a filtered pandas dataframe with prediction metadata.
2460
2461        Parameters
2462        ----------
2463        suggestions_names : list
2464            a list of strings of suggestion names
2465        value_filter : str
2466            a string of filters to apply; of this general structure:
2467            'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g.
2468            'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic'
2469        display_parameters : list
2470            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
2471        print_results : bool, default True
2472            if True, the result will be printed to standard output
2473
2474        Returns
2475        -------
2476        pd.DataFrame
2477            the filtered dataframe
2478
2479        """
2480        suggestions = self._suggestions().list_episodes(
2481            suggestions_names, value_filter, display_parameters
2482        )
2483        if print_results:
2484            print("SUGGESTIONS")
2485            print(suggestions)
2486            print("\n")
2487        return suggestions
2488
2489    def list_searches(
2490        self,
2491        search_names: List = None,
2492        value_filter: str = "",
2493        display_parameters: List = None,
2494        print_results: bool = True,
2495    ) -> pd.DataFrame:
2496        """Get a filtered pandas dataframe with hyperparameter search metadata.
2497
2498        Parameters
2499        ----------
2500        search_names : list
2501            a list of strings of search names
2502        value_filter : str
2503            a string of filters to apply; of this general structure:
2504            'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g.
2505            'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic'
2506        display_parameters : list
2507            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
2508        print_results : bool, default True
2509            if True, the result will be printed to standard output
2510
2511        Returns
2512        -------
2513        pd.DataFrame
2514            the filtered dataframe
2515
2516        """
2517        searches = self._searches().list_episodes(
2518            search_names, value_filter, display_parameters
2519        )
2520        if print_results:
2521            print("SEARCHES")
2522            print(searches)
2523            print("\n")
2524        return searches
2525
2526    def get_best_parameters(
2527        self,
2528        search_name: str,
2529        round_to_binary: List = None,
2530    ):
2531        """Get the best parameters found by a search.
2532
2533        Parameters
2534        ----------
2535        search_name : str
2536            the name of the search
2537        round_to_binary : list, default None
2538            a list of parameters to round to binary values
2539
2540        Returns
2541        -------
2542        best_params : dict
2543            a dictionary of the best parameters where the keys are in '{group}/{name}' format
2544
2545        """
2546        params, model = self._searches().get_best_params(
2547            search_name, round_to_binary=round_to_binary
2548        )
2549        params = self._update(params, {"general": {"model_name": model}})
2550        return params
2551
2552    def list_best_parameters(
2553        self, search_name: str, print_results: bool = True
2554    ) -> Dict:
2555        """Get the raw dictionary of best parameters found by a search.
2556
2557        Parameters
2558        ----------
2559        search_name : str
2560            the name of the search
2561        print_results : bool, default True
2562            if True, the result will be printed to standard output
2563
2564        Returns
2565        -------
2566        best_params : dict
2567            a dictionary of the best parameters where the keys are in '{group}/{name}' format
2568
2569        """
2570        params = self._searches().get_best_params_raw(search_name)
2571        if print_results:
2572            print(f"SEARCH RESULTS {search_name}")
2573            for k, v in params.items():
2574                print(f"{k}: {v}")
2575            print("\n")
2576        return params
2577
2578    def plot_episodes(
2579        self,
2580        episode_names: List,
2581        metrics: List | str,
2582        modes: List | str = None,
2583        title: str = None,
2584        episode_labels: List = None,
2585        save_path: str = None,
2586        add_hlines: List = None,
2587        epoch_limits: List = None,
2588        colors: List = None,
2589        add_highpoint_hlines: bool = False,
2590        remove_box: bool = False,
2591        font_size: float = None,
2592        linewidth: float = None,
2593        return_ax: bool = False,
2594    ) -> None:
2595        """Plot episode training curves.
2596
2597        Parameters
2598        ----------
2599        episode_names : list
2600            a list of episode names to plot; to plot to episodes in one line combine them in a list
2601            (e.g. ['episode1', ['episode2', 'episode3']] to plot episode2 and episode3 as one experiment)
2602        metrics : list
2603            a list of metric to plot
2604        modes : list, optional
2605            a list of modes to plot ('train' and/or 'val'; `['val']` by default)
2606        title : str, optional
2607            title for the plot
2608        episode_labels : list, optional
2609            a list of strings used to label the curves (has to be the same length as episode_names)
2610        save_path : str, optional
2611            the path to save the resulting plot
2612        add_hlines : list, optional
2613            a list of float values (or (value, label) tuples) to mark with horizontal lines
2614        epoch_limits : list, optional
2615            a list of (min, max) tuples to set the x-axis limits for each episode
2616        colors: list, optional
2617            a list of matplotlib colors
2618        add_highpoint_hlines : bool, default False
2619            if `True`, horizontal lines will be added at the highest value of each episode
2620        """
2621
2622        if isinstance(metrics, str):
2623            metrics = [metrics]
2624        if isinstance(modes, str):
2625            modes = [modes]
2626
2627        if font_size is not None:
2628            font = {"size": font_size}
2629            rc("font", **font)
2630        if modes is None:
2631            modes = ["val"]
2632        if add_hlines is None:
2633            add_hlines = []
2634        logs = []
2635        epochs = []
2636        labels = []
2637        if episode_labels is not None:
2638            assert len(episode_labels) == len(episode_names)
2639        for name_i, name in enumerate(episode_names):
2640            log_params = product(metrics, modes)
2641            for metric, mode in log_params:
2642                if episode_labels is not None:
2643                    label = episode_labels[name_i]
2644                else:
2645                    label = deepcopy(name)
2646                if len(modes) != 1:
2647                    label += f"_{mode}"
2648                if len(metrics) != 1:
2649                    label += f"_{metric}"
2650                labels.append(label)
2651                if isinstance(name, Iterable) and not isinstance(name, str):
2652                    epoch_list = defaultdict(lambda: [])
2653                    multi_logs = defaultdict(lambda: [])
2654                    for i, n in enumerate(name):
2655                        runs = self._episodes().get_runs(n)
2656                        if len(runs) > 1:
2657                            for run in runs:
2658                                if "::" in run:
2659                                    index = run.split("::")[-1]
2660                                else:
2661                                    index = run.split("#")[-1]
2662                                if multi_logs[index] == []:
2663                                    if multi_logs["null"] is None:
2664                                        raise RuntimeError(
2665                                            "The run indices are not consistent across episodes!"
2666                                        )
2667                                    else:
2668                                        multi_logs[index] += multi_logs["null"]
2669                                multi_logs[index] += list(
2670                                    self._episode(run).get_metric_log(mode, metric)
2671                                )
2672                                start = (
2673                                    0
2674                                    if len(epoch_list[index]) == 0
2675                                    else epoch_list[index][-1]
2676                                )
2677                                epoch_list[index] += [
2678                                    x + start
2679                                    for x in self._episode(run).get_epoch_list(mode)
2680                                ]
2681                            multi_logs["null"] = None
2682                        else:
2683                            if len(multi_logs.keys()) > 1:
2684                                raise RuntimeError(
2685                                    "Cannot plot a single-run episode after a multi-run episode!"
2686                                )
2687                            multi_logs["null"] += list(
2688                                self._episode(n).get_metric_log(mode, metric)
2689                            )
2690                            start = (
2691                                0
2692                                if len(epoch_list["null"]) == 0
2693                                else epoch_list["null"][-1]
2694                            )
2695                            epoch_list["null"] += [
2696                                x + start for x in self._episode(n).get_epoch_list(mode)
2697                            ]
2698                    if len(multi_logs.keys()) == 1:
2699                        log = multi_logs["null"]
2700                        epochs.append(epoch_list["null"])
2701                    else:
2702                        log = tuple([v for k, v in multi_logs.items() if k != "null"])
2703                        epochs.append(
2704                            tuple([v for k, v in epoch_list.items() if k != "null"])
2705                        )
2706                else:
2707                    runs = self._episodes().get_runs(name)
2708                    if len(runs) > 1:
2709                        log = []
2710                        for run in runs:
2711                            tracked_metrics = self._episode(run).get_metrics()
2712                            if metric in tracked_metrics:
2713                                log.append(
2714                                    list(
2715                                        self._episode(run).get_metric_log(mode, metric)
2716                                    )
2717                                )
2718                            else:
2719                                relevant = []
2720                                for m in tracked_metrics:
2721                                    m_split = m.split("_")
2722                                    if (
2723                                        "_".join(m_split[:-1]) == metric
2724                                        and m_split[-1].isnumeric()
2725                                    ):
2726                                        relevant.append(m)
2727                                if len(relevant) == 0:
2728                                    raise ValueError(
2729                                        f"The {metric} metric was not tracked at {run}"
2730                                    )
2731                                arr = 0
2732                                for m in relevant:
2733                                    arr += self._episode(run).get_metric_log(mode, m)
2734                                arr /= len(relevant)
2735                                log.append(list(arr))
2736                        log = tuple(log)
2737                        epochs.append(
2738                            tuple(
2739                                [
2740                                    self._episode(run).get_epoch_list(mode)
2741                                    for run in runs
2742                                ]
2743                            )
2744                        )
2745                    else:
2746                        tracked_metrics = self._episode(name).get_metrics()
2747                        if metric in tracked_metrics:
2748                            log = list(self._episode(name).get_metric_log(mode, metric))
2749                        else:
2750                            relevant = []
2751                            for m in tracked_metrics:
2752                                m_split = m.split("_")
2753                                if (
2754                                    "_".join(m_split[:-1]) == metric
2755                                    and m_split[-1].isnumeric()
2756                                ):
2757                                    relevant.append(m)
2758                            if len(relevant) == 0:
2759                                raise ValueError(
2760                                    f"The {metric} metric was not tracked at {name}"
2761                                )
2762                            arr = 0
2763                            for m in relevant:
2764                                arr += self._episode(name).get_metric_log(mode, m)
2765                            arr /= len(relevant)
2766                            log = list(arr)
2767                        epochs.append(self._episode(name).get_epoch_list(mode))
2768                logs.append(log)
2769        # if episode_labels is not None:
2770        #     print(f'{len(episode_labels)=}, {len(logs)=}')
2771        #     if len(episode_labels) != len(logs):
2772
2773        #         raise ValueError(
2774        #             f"The length of episode_labels ({len(episode_labels)}) has to be equal to the length of "
2775        #             f"curves ({len(logs)})!"
2776        #         )
2777        #     else:
2778        #         labels = episode_labels
2779        if colors is None:
2780            colors = cm.rainbow(np.linspace(0, 1, len(logs)))
2781        if len(colors) != len(logs):
2782            raise ValueError(
2783                "The length of colors has to be equal to the length of curves (metrics * modes * episode_names)!"
2784            )
2785        f, ax = plt.subplots()
2786        length = 0
2787        for log, label, color, epoch_list in zip(logs, labels, colors, epochs):
2788            if type(log) is list:
2789                if len(log) > length:
2790                    length = len(log)
2791                ax.plot(
2792                    epoch_list,
2793                    log,
2794                    label=label,
2795                    color=color,
2796                )
2797                if add_highpoint_hlines:
2798                    ax.axhline(np.max(log), linestyle="dashed", color=color)
2799            else:
2800                for l, xx in zip(log, epoch_list):
2801                    if len(l) > length:
2802                        length = len(l)
2803                    ax.plot(
2804                        xx,
2805                        l,
2806                        color=color,
2807                        alpha=0.2,
2808                    )
2809                if not all([len(x) == len(log[0]) for x in log]):
2810                    warnings.warn(
2811                        f"Got logs with unequal lengths in parallel runs for {label}"
2812                    )
2813                    log = list(log)
2814                    epoch_list = list(epoch_list)
2815                    for i, x in enumerate(epoch_list):
2816                        to_remove = []
2817                        for j, y in enumerate(x[1:]):
2818                            if y <= x[j - 1]:
2819                                y_ind = x.index(y)
2820                                to_remove += list(range(y_ind, j))
2821                        epoch_list[i] = [
2822                            y for j, y in enumerate(x) if j not in to_remove
2823                        ]
2824                        log[i] = [y for j, y in enumerate(log[i]) if j not in to_remove]
2825                    length = min([len(x) for x in log])
2826                    for i in range(len(log)):
2827                        log[i] = log[i][:length]
2828                        epoch_list[i] = epoch_list[i][:length]
2829                    if not all([x == epoch_list[0] for x in epoch_list]):
2830                        raise RuntimeError(
2831                            f"Got different epoch indices in parallel runs for {label}"
2832                        )
2833                mean = np.array(log).mean(0)
2834                ax.plot(
2835                    epoch_list[0],
2836                    mean,
2837                    label=label,
2838                    color=color,
2839                    linewidth=linewidth,
2840                )
2841                if add_highpoint_hlines:
2842                    ax.axhline(np.max(mean), linestyle="dashed", color=color)
2843        for x in add_hlines:
2844            label = None
2845            if isinstance(x, Iterable):
2846                x, label = x
2847            ax.axhline(x, label=label)
2848            ax.set_xlim((0, length))
2849
2850        ax.legend()
2851        ax.set_xlabel("epochs")
2852        if len(metrics) == 1:
2853            ax.set_ylabel(metrics[0])
2854        else:
2855            ax.set_ylabel("value")
2856        if title is None:
2857            if len(episode_names) == 1:
2858                title = episode_names[0]
2859            elif len(metrics) == 1:
2860                title = metrics[0]
2861        if epoch_limits is not None:
2862            ax.set_xlim(epoch_limits)
2863        if title is not None:
2864            ax.set_title(title)
2865        if remove_box:
2866            ax.box(False)
2867        if return_ax:
2868            return ax
2869        if save_path is not None:
2870            plt.savefig(save_path)
2871        plt.show()
2872
2873    def update_parameters(
2874        self,
2875        parameters_update: Dict = None,
2876        load_search: str = None,
2877        load_parameters: List = None,
2878        round_to_binary: List = None,
2879    ) -> None:
2880        """Update the parameters in the project config files.
2881
2882        Parameters
2883        ----------
2884        parameters_update : dict, optional
2885            a dictionary of parameter updates
2886        load_search : str, optional
2887            the name of hyperparameter search results to load to config
2888        load_parameters : list, optional
2889            a list of lists of string names of the parameters to load from the searches
2890        round_to_binary : list, optional
2891            a list of string names of the loaded parameters that should be rounded to the nearest power of two
2892
2893        """
2894        keys = [
2895            "general",
2896            "losses",
2897            "metrics",
2898            "ssl",
2899            "training",
2900            "data",
2901        ]
2902        parameters = self._read_parameters(catch_blanks=False)
2903        if parameters_update is not None:
2904            model_params = (
2905                parameters_update.pop("model") if "model" in parameters_update else None
2906            )
2907            feat_params = (
2908                parameters_update.pop("features")
2909                if "features" in parameters_update
2910                else None
2911            )
2912            aug_params = (
2913                parameters_update.pop("augmentations")
2914                if "augmentations" in parameters_update
2915                else None
2916            )
2917
2918            parameters = self._update(parameters, parameters_update)
2919            model_name = parameters["general"]["model_name"]
2920            parameters["model"] = self._open_yaml(
2921                os.path.join(self.project_path, "config", "model", f"{model_name}.yaml")
2922            )
2923            if model_params is not None:
2924                parameters["model"] = self._update(parameters["model"], model_params)
2925            feat_name = parameters["general"]["feature_extraction"]
2926            parameters["features"] = self._open_yaml(
2927                os.path.join(
2928                    self.project_path, "config", "features", f"{feat_name}.yaml"
2929                )
2930            )
2931            if feat_params is not None:
2932                parameters["features"] = self._update(
2933                    parameters["features"], feat_params
2934                )
2935            aug_name = options.extractor_to_transformer[
2936                parameters["general"]["feature_extraction"]
2937            ]
2938            parameters["augmentations"] = self._open_yaml(
2939                os.path.join(
2940                    self.project_path, "config", "augmentations", f"{aug_name}.yaml"
2941                )
2942            )
2943            if aug_params is not None:
2944                parameters["augmentations"] = self._update(
2945                    parameters["augmentations"], aug_params
2946                )
2947        if load_search is not None:
2948            parameters_update, model_name = self._searches().get_best_params(
2949                load_search, load_parameters, round_to_binary
2950            )
2951            parameters["general"]["model_name"] = model_name
2952            parameters["model"] = self._open_yaml(
2953                os.path.join(self.project_path, "config", "model", f"{model_name}.yaml")
2954            )
2955            parameters = self._update(parameters, parameters_update)
2956        for key in keys:
2957            with open(
2958                os.path.join(self.project_path, "config", f"{key}.yaml"),
2959                "w",
2960                encoding="utf-8",
2961            ) as f:
2962                YAML().dump(parameters[key], f)
2963        model_name = parameters["general"]["model_name"]
2964        model_path = os.path.join(
2965            self.project_path, "config", "model", f"{model_name}.yaml"
2966        )
2967        with open(model_path, "w", encoding="utf-8") as f:
2968            YAML().dump(parameters["model"], f)
2969        features_name = parameters["general"]["feature_extraction"]
2970        features_path = os.path.join(
2971            self.project_path, "config", "features", f"{features_name}.yaml"
2972        )
2973        with open(features_path, "w", encoding="utf-8") as f:
2974            YAML().dump(parameters["features"], f)
2975        aug_name = options.extractor_to_transformer[features_name]
2976        aug_path = os.path.join(
2977            self.project_path, "config", "augmentations", f"{aug_name}.yaml"
2978        )
2979        with open(aug_path, "w", encoding="utf-8") as f:
2980            YAML().dump(parameters["augmentations"], f)
2981
2982    def get_summary(
2983        self,
2984        episode_names: list,
2985        method: str = "last",
2986        average: int = 1,
2987        metrics: List = None,
2988        return_values: bool = False,
2989    ) -> Dict:
2990        """Get a summary of episode statistics.
2991
2992        If an episode has multiple runs, the statistics will be aggregated over all of them.
2993
2994        Parameters
2995        ----------
2996        episode_names : str
2997            the names of the episodes
2998        method : ["best", "last"]
2999            the method for choosing the epochs
3000        average : int, default 1
3001            the number of epochs to average over (for each run)
3002        metrics : list, optional
3003            a list of metrics
3004
3005        Returns
3006        -------
3007        statistics : dict
3008            a nested dictionary where first-level keys are metric names and second-level keys are 'mean' for the mean
3009            and 'std' for the standard deviation
3010
3011        """
3012        runs = []
3013        for episode_name in episode_names:
3014            runs_ep = self._episodes().get_runs(episode_name)
3015            if len(runs_ep) == 0:
3016                raise RuntimeError(
3017                    f"There is no {episode_name} episode in the project memory"
3018                )
3019            runs += runs_ep
3020        if metrics is None:
3021            metrics = self._episode(runs[0]).get_metrics()
3022
3023        values = {m: [] for m in metrics}
3024        for run in runs:
3025            for m in metrics:
3026                log = self._episode(run).get_metric_log(mode="val", metric_name=m)
3027                if method == "best":
3028                    log = sorted(log)
3029                    values[m] += list(log[-average:])
3030                elif method == "last":
3031                    if len(log) == 0:
3032                        episodes = self._episodes().data
3033                        if average == 1 and ("results", m) in episodes.columns:
3034                            values[m] += [episodes.loc[run, ("results", m)]]
3035                        else:
3036                            raise RuntimeError(f"Did not find {m} metric for {run} run")
3037                    values[m] += list(log[-average:])
3038                elif method.startswith("epoch"):
3039                    epoch = int(method[5:]) - 1
3040                    pars = self._episodes().load_parameters(run)
3041                    step = int(pars["training"]["validation_interval"])
3042                    values[m] += [log[epoch // step]]
3043                else:
3044                    raise ValueError(
3045                        f"The {method} method is not recognized! Please choose from ['last', 'best', 'epoch...']"
3046                    )
3047        statistics = defaultdict(lambda: {})
3048        for m, v in values.items():
3049            statistics[m]["mean"] = np.mean(v)
3050            statistics[m]["std"] = np.std(v)
3051        print(f"SUMMARY {episode_names}")
3052        for m, v in statistics.items():
3053            print(f'{m}: mean {v["mean"]:.3f}, std {v["std"]:.3f}')
3054        print("\n")
3055
3056        return (dict(statistics), values) if return_values else dict(statistics)
3057
3058    @staticmethod
3059    def remove_project(name: str, projects_path: str = None) -> None:
3060        """Remove all project files and experiment records and results.
3061
3062        Parameters
3063        ----------
3064        name : str
3065            the name of the project to remove
3066        projects_path : str, optional
3067            the path to the projects directory (by default the home DLC2Action directory)
3068
3069        """
3070        if projects_path is None:
3071            projects_path = os.path.join(str(Path.home()), "DLC2Action")
3072        project_path = os.path.join(projects_path, name)
3073        if os.path.exists(project_path):
3074            shutil.rmtree(project_path)
3075
3076    def remove_saved_features(
3077        self,
3078        dataset_names: List = None,
3079        exceptions: List = None,
3080        remove_active: bool = False,
3081    ) -> None:
3082        """Remove saved pre-computed dataset feature files.
3083
3084        By default, all features will be deleted.
3085        No essential information can get lost, storing them only saves time. Be careful with deleting datasets
3086        while training or inference is happening though.
3087
3088        Parameters
3089        ----------
3090        dataset_names : list, optional
3091            a list of dataset names to delete (by default all names are added)
3092        exceptions : list, optional
3093            a list of dataset names to not be deleted
3094        remove_active : bool, default False
3095            if `False`, datasets used by unfinished episodes will not be deleted
3096
3097        """
3098        print("Removing datasets...")
3099        if dataset_names is None:
3100            dataset_names = []
3101        if exceptions is None:
3102            exceptions = []
3103        if not remove_active:
3104            exceptions += self._episodes().get_active_datasets()
3105        dataset_path = os.path.join(self.project_path, "saved_datasets")
3106        if os.path.exists(dataset_path):
3107            if dataset_names == []:
3108                dataset_names = set([f.split(".")[0] for f in os.listdir(dataset_path)])
3109
3110            to_remove = [
3111                x
3112                for x in dataset_names
3113                if os.path.exists(os.path.join(dataset_path, x)) and x not in exceptions
3114            ]
3115            if len(to_remove) > 2:
3116                to_remove = tqdm(to_remove)
3117            for dataset in to_remove:
3118                shutil.rmtree(os.path.join(dataset_path, dataset))
3119            to_remove = [
3120                f"{x}.pickle"
3121                for x in dataset_names
3122                if os.path.exists(os.path.join(dataset_path, f"{x}.pickle"))
3123                and x not in exceptions
3124            ]
3125            for dataset in to_remove:
3126                os.remove(os.path.join(dataset_path, dataset))
3127            names = self._saved_datasets().dataset_names()
3128            self._saved_datasets().remove(names)
3129        print("\n")
3130
3131    def remove_extra_checkpoints(
3132        self, episode_names: List = None, exceptions: List = None
3133    ) -> None:
3134        """Remove intermediate model checkpoint files (only leave the files for the last epoch).
3135
3136        By default, all intermediate checkpoints will be deleted.
3137        Files in the model folder that are not associated with any record in the meta files are also deleted.
3138
3139        Parameters
3140        ----------
3141        episode_names : list, optional
3142            a list of episode names to clean (by default all names are added)
3143        exceptions : list, optional
3144            a list of episode names to not clean
3145
3146        """
3147        model_path = os.path.join(self.project_path, "results", "model")
3148        try:
3149            all_names = self._episodes().data.index
3150        except:
3151            all_names = os.listdir(model_path)
3152        if episode_names is None:
3153            episode_names = all_names
3154        if exceptions is None:
3155            exceptions = []
3156        to_remove = [x for x in episode_names if x not in exceptions]
3157        folders = os.listdir(model_path)
3158        for folder in folders:
3159            if folder not in all_names:
3160                shutil.rmtree(os.path.join(model_path, folder))
3161            elif folder in to_remove:
3162                files = os.listdir(os.path.join(model_path, folder))
3163                for file in sorted(files)[:-1]:
3164                    os.remove(os.path.join(model_path, folder, file))
3165
3166    def remove_search(self, search_name: str) -> None:
3167        """Remove a hyperparameter search record.
3168
3169        Parameters
3170        ----------
3171        search_name : str
3172            the name of the search to remove
3173
3174        """
3175        self._searches().remove_episode(search_name)
3176        graph_path = os.path.join(self.project_path, "results", "searches", search_name)
3177        if os.path.exists(graph_path):
3178            shutil.rmtree(graph_path)
3179
3180    def remove_suggestion(self, suggestion_name: str) -> None:
3181        """Remove a suggestion record.
3182
3183        Parameters
3184        ----------
3185        suggestion_name : str
3186            the name of the suggestion to remove
3187
3188        """
3189        self._suggestions().remove_episode(suggestion_name)
3190        suggestion_path = os.path.join(
3191            self.project_path, "results", "suggestions", suggestion_name
3192        )
3193        if os.path.exists(suggestion_path):
3194            shutil.rmtree(suggestion_path)
3195
3196    def remove_prediction(self, prediction_name: str) -> None:
3197        """Remove a prediction record.
3198
3199        Parameters
3200        ----------
3201        prediction_name : str
3202            the name of the prediction to remove
3203
3204        """
3205        self._predictions().remove_episode(prediction_name)
3206        prediction_path = self.prediction_path(prediction_name)
3207        if os.path.exists(prediction_path):
3208            shutil.rmtree(prediction_path)
3209
3210    def check_prediction_exists(self, prediction_name: str) -> str | None:
3211        """Check if a prediction exists.
3212
3213        Parameters
3214        ----------
3215        prediction_name : str
3216            the name of the prediction to check
3217
3218        Returns
3219        -------
3220        str | None
3221            the path to the prediction if it exists, `None` otherwise
3222
3223        """
3224        prediction_path = self.prediction_path(prediction_name)
3225        if os.path.exists(prediction_path):
3226            return prediction_path
3227        return None
3228
3229    def remove_episode(self, episode_name: str) -> None:
3230        """Remove all model, logs and metafile records related to an episode.
3231
3232        Parameters
3233        ----------
3234        episode_name : str
3235            the name of the episode to remove
3236
3237        """
3238        runs = self._episodes().get_runs(episode_name)
3239        runs.append(episode_name)
3240        for run in runs:
3241            self._episodes().remove_episode(run)
3242            model_path = os.path.join(self.project_path, "results", "model", run)
3243            if os.path.exists(model_path):
3244                shutil.rmtree(model_path)
3245            log_path = os.path.join(self.project_path, "results", "logs", f"{run}.txt")
3246            if os.path.exists(log_path):
3247                os.remove(log_path)
3248
3249    @abstractmethod
3250    def _reformat_results(res: dict, classes: dict, exclusive=False):
3251        """Add classes to micro metrics in results from evaluation"""
3252        results = deepcopy(res)
3253        for key in results.keys():
3254            if isinstance(results[key], list):
3255                if exclusive and len(classes) == len(results[key]) + 1:
3256                    other_ind = list(classes.keys())[
3257                        list(classes.values()).index("other")
3258                    ]
3259                    classes = {
3260                        (i if i < other_ind else i - 1): c
3261                        for i, c in classes.items()
3262                        if i != other_ind
3263                    }
3264                assert len(results[key]) == len(
3265                    classes
3266                ), f"Results for {key} have {len(results[key])} values, but {len(classes)} classes were provided!"
3267                results[key] = {
3268                    classes[i]: float(v) for i, v in enumerate(results[key])
3269                }
3270        return results
3271
3272    def prune_unfinished(self, exceptions: List = None) -> List:
3273        """Remove all interrupted episodes.
3274
3275        Remove all episodes that either don't have a log file or have less epochs in the log file than in
3276        the training parameters or have a model folder but not a record. Note that it can remove episodes that are
3277        currently running!
3278
3279        Parameters
3280        ----------
3281        exceptions : list
3282            the episodes to keep even if they are interrupted
3283
3284        Returns
3285        -------
3286        pruned : list
3287            a list of the episode names that were pruned
3288
3289        """
3290        if exceptions is None:
3291            exceptions = []
3292        unfinished = self._episodes().unfinished_episodes()
3293        unfinished = [x for x in unfinished if x not in exceptions]
3294        model_folders = os.listdir(os.path.join(self.project_path, "results", "model"))
3295        unfinished += [
3296            x for x in model_folders if x not in self._episodes().list_episodes().index
3297        ]
3298        print(f"PRUNING {unfinished}")
3299        for episode_name in unfinished:
3300            self.remove_episode(episode_name)
3301        print(f"\n")
3302        return unfinished
3303
3304    def prediction_path(self, prediction_name: str) -> str:
3305        """Get the path where prediction files are saved.
3306
3307        Parameters
3308        ----------
3309        prediction_name : str
3310            name of the prediction
3311
3312        Returns
3313        -------
3314        prediction_path : str
3315            the file path
3316
3317        """
3318        return os.path.join(
3319            self.project_path, "results", "predictions", f"{prediction_name}"
3320        )
3321
3322    def suggestion_path(self, suggestion_name: str) -> str:
3323        """Get the path where suggestion files are saved.
3324
3325        Parameters
3326        ----------
3327        suggestion_name : str
3328            name of the prediction
3329
3330        Returns
3331        -------
3332        suggestion_path : str
3333            the file path
3334
3335        """
3336        return os.path.join(
3337            self.project_path, "results", "suggestions", f"{suggestion_name}"
3338        )
3339
3340    @classmethod
3341    def print_data_types(cls):
3342        """Print available data types."""
3343        print("DATA TYPES:")
3344        for key, value in cls.data_types().items():
3345            print(f"{key}:")
3346            print(value.__doc__)
3347
3348    @classmethod
3349    def print_annotation_types(cls):
3350        """Print available annotation types."""
3351        print("ANNOTATION TYPES:")
3352        for key, value in cls.annotation_types():
3353            print(f"{key}:")
3354            print(value.__doc__)
3355
3356    @staticmethod
3357    def data_types() -> List:
3358        """Get available data types.
3359
3360        Returns
3361        -------
3362        data_types : list
3363            available data types
3364
3365        """
3366        return options.input_stores
3367
3368    @staticmethod
3369    def annotation_types() -> List:
3370        """Get available annotation types.
3371
3372        Returns
3373        -------
3374        list
3375            available annotation types
3376
3377        """
3378        return options.annotation_stores
3379
3380    def _save_mask(self, file: Dict, mask_name: str):
3381        """Save a mask file.
3382
3383        Parameters
3384        ----------
3385        file : dict
3386            the mask file data to save
3387        mask_name : str
3388            the name of the mask file
3389
3390        """
3391        if not os.path.exists(self._mask_path()):
3392            os.mkdir(self._mask_path())
3393        with open(os.path.join(self._mask_path(), mask_name + ".pickle"), "wb") as f:
3394            pickle.dump(file, f)
3395
3396    def _load_mask(self, mask_name: str) -> Dict:
3397        """Load a mask file.
3398
3399        Parameters
3400        ----------
3401        mask_name : str
3402            the name of the mask file to load
3403
3404        Returns
3405        -------
3406        mask : dict
3407            the loaded mask data
3408
3409        """
3410        with open(os.path.join(self._mask_path(), mask_name + ".pickle"), "rb") as f:
3411            data = pickle.load(f)
3412        return data
3413
3414    def _thresholds(self) -> DecisionThresholds:
3415        """Get the decision thresholds meta object.
3416
3417        Returns
3418        -------
3419        thresholds : DecisionThresholds
3420            the decision thresholds meta object
3421
3422        """
3423        return DecisionThresholds(self._thresholds_path())
3424
3425    def _episodes(self) -> SavedRuns:
3426        """Get the episodes meta object.
3427
3428        Returns
3429        -------
3430        episodes : SavedRuns
3431            the episodes meta object
3432
3433        """
3434        try:
3435            return SavedRuns(self._episodes_path(), self.project_path)
3436        except:
3437            self.load_metadata_backup()
3438            return SavedRuns(self._episodes_path(), self.project_path)
3439
3440    def _suggestions(self) -> Suggestions:
3441        """Get the suggestions meta object.
3442
3443        Returns
3444        -------
3445        suggestions : Suggestions
3446            the suggestions meta object
3447
3448        """
3449        try:
3450            return Suggestions(self._suggestions_path(), self.project_path)
3451        except:
3452            self.load_metadata_backup()
3453            return Suggestions(self._suggestions_path(), self.project_path)
3454
3455    def _predictions(self) -> SavedRuns:
3456        """Get the predictions meta object.
3457
3458        Returns
3459        -------
3460        predictions : SavedRuns
3461            the predictions meta object
3462
3463        """
3464        try:
3465            return SavedRuns(self._predictions_path(), self.project_path)
3466        except:
3467            self.load_metadata_backup()
3468            return SavedRuns(self._predictions_path(), self.project_path)
3469
3470    def _saved_datasets(self) -> SavedStores:
3471        """Get the datasets meta object.
3472
3473        Returns
3474        -------
3475        datasets : SavedStores
3476            the datasets meta object
3477
3478        """
3479        try:
3480            return SavedStores(self._saved_datasets_path())
3481        except:
3482            self.load_metadata_backup()
3483            return SavedStores(self._saved_datasets_path())
3484
3485    def _prediction(self, name: str) -> Run:
3486        """Get a prediction meta object.
3487
3488        Parameters
3489        ----------
3490        name : str
3491            episode name
3492
3493        Returns
3494        -------
3495        prediction : Run
3496            the prediction meta object
3497
3498        """
3499        try:
3500            return Run(name, self.project_path, meta_path=self._predictions_path())
3501        except:
3502            self.load_metadata_backup()
3503            return Run(name, self.project_path, meta_path=self._predictions_path())
3504
3505    def _episode(self, name: str) -> Run:
3506        """Get an episode meta object.
3507
3508        Parameters
3509        ----------
3510        name : str
3511            episode name
3512
3513        Returns
3514        -------
3515        episode : Run
3516            the episode meta object
3517
3518        """
3519        try:
3520            return Run(name, self.project_path, meta_path=self._episodes_path())
3521        except:
3522            self.load_metadata_backup()
3523            return Run(name, self.project_path, meta_path=self._episodes_path())
3524
3525    def _searches(self) -> Searches:
3526        """Get the hyperparameter search meta object.
3527
3528        Returns
3529        -------
3530        searches : Searches
3531            the searches meta object
3532
3533        """
3534        try:
3535            return Searches(self._searches_path(), self.project_path)
3536        except:
3537            self.load_metadata_backup()
3538            return Searches(self._searches_path(), self.project_path)
3539
3540    def _update_configs(self) -> None:
3541        """Update the project config files with newly added files and parameters.
3542
3543        This method updates the project configuration with the data path and copies
3544        any new configuration files from the original package to the project.
3545
3546        """
3547        self.update_parameters({"data": {"data_path": self.data_path}})
3548        folders = ["augmentations", "features", "model"]
3549        original_path = os.path.join(
3550            os.path.dirname(os.path.dirname(__file__)), "config"
3551        )
3552        project_path = os.path.join(self.project_path, "config")
3553        filenames = [x for x in os.listdir(original_path) if x.endswith("yaml")]
3554        for folder in folders:
3555            filenames += [
3556                os.path.join(folder, x)
3557                for x in os.listdir(os.path.join(original_path, folder))
3558            ]
3559        filenames.append(os.path.join("data", f"{self.data_type}.yaml"))
3560        if self.annotation_type != "none":
3561            filenames.append(os.path.join("annotation", f"{self.annotation_type}.yaml"))
3562        for file in filenames:
3563            filepath_original = os.path.join(original_path, file)
3564            if file.startswith("data") or file.startswith("annotation"):
3565                file = os.path.basename(file)
3566            filepath_project = os.path.join(project_path, file)
3567            if not os.path.exists(filepath_project):
3568                shutil.copy(filepath_original, filepath_project)
3569            else:
3570                original_pars = self._open_yaml(filepath_original)
3571                project_pars = self._open_yaml(filepath_project)
3572                to_remove = []
3573                for key, value in project_pars.items():
3574                    if key not in original_pars:
3575                        if key not in ["data_type", "annotation_type"]:
3576                            to_remove.append(key)
3577                for key in to_remove:
3578                    project_pars.pop(key)
3579                to_remove = []
3580                for key, value in original_pars.items():
3581                    if key in project_pars:
3582                        to_remove.append(key)
3583                for key in to_remove:
3584                    original_pars.pop(key)
3585                project_pars = self._update(project_pars, original_pars)
3586                with open(filepath_project, "w", encoding="utf-8") as f:
3587                    YAML().dump(project_pars, f)
3588
3589    def _update_project(self) -> None:
3590        """Update project files with the current version."""
3591        version_file = self._version_path()
3592        ok = True
3593        if not os.path.exists(version_file):
3594            ok = False
3595        else:
3596            with open(version_file, encoding="utf-8") as f:
3597                project_version = f.read()
3598            if project_version < __version__:
3599                ok = False
3600            elif project_version > __version__:
3601                warnings.warn(
3602                    f"The project expects a higher dlc2action version ({project_version}), please update!"
3603                )
3604        if not ok:
3605            project_config_path = os.path.join(self.project_path, "config")
3606            config_path = os.path.join(
3607                os.path.dirname(os.path.dirname(__path__)), "config"
3608            )
3609            episodes = self._episodes()
3610            folders = ["annotation", "augmentations", "data", "features", "model"]
3611
3612            project_annotation_configs = os.listdir(
3613                os.path.join(project_config_path, "annotation")
3614            )
3615            annotation_configs = os.listdir(os.path.join(config_path, "annotation"))
3616            for ann_config in annotation_configs:
3617                if ann_config not in project_annotation_configs:
3618                    shutil.copytree(
3619                        os.path.join(config_path, "annotation", ann_config),
3620                        os.path.join(project_config_path, "annotation", ann_config),
3621                        dirs_exist_ok=True,
3622                    )
3623                else:
3624                    project_pars = self._open_yaml(
3625                        os.path.join(project_config_path, "annotation", ann_config)
3626                    )
3627                    pars = self._open_yaml(
3628                        os.path.join(config_path, "annotation", ann_config)
3629                    )
3630                    new_keys = set(pars.keys()) - set(project_pars.keys())
3631                    for key in new_keys:
3632                        project_pars[key] = pars[key]
3633                        c = self._get_comment(pars.ca.items.get(key))
3634                        project_pars.yaml_add_eol_comment(c, key=key)
3635                        episodes.update(
3636                            condition=f"general/annotation_type::={ann_config}",
3637                            update={f"data/{key}": pars[key]},
3638                        )
3639
3640    def _initialize_project(
3641        self,
3642        data_type: str,
3643        annotation_type: str = None,
3644        data_path: str = None,
3645        annotation_path: str = None,
3646        copy: bool = True,
3647    ) -> None:
3648        """Initialize a new project."""
3649        if data_type not in self.data_types():
3650            raise ValueError(
3651                f"The {data_type} data type is not available. "
3652                f"Please choose from {self.data_types()}"
3653            )
3654        if annotation_type not in self.annotation_types():
3655            raise ValueError(
3656                f"The {annotation_type} annotation type is not available. "
3657                f"Please choose from {self.annotation_types()}"
3658            )
3659        os.mkdir(self.project_path)
3660        folders = ["results", "saved_datasets", "meta", "config"]
3661        for f in folders:
3662            os.mkdir(os.path.join(self.project_path, f))
3663        results_subfolders = [
3664            "model",
3665            "logs",
3666            "predictions",
3667            "splits",
3668            "searches",
3669            "suggestions",
3670        ]
3671        for sf in results_subfolders:
3672            os.mkdir(os.path.join(self.project_path, "results", sf))
3673        if data_path is not None:
3674            if copy:
3675                os.mkdir(os.path.join(self.project_path, "data"))
3676                shutil.copytree(
3677                    data_path,
3678                    os.path.join(self.project_path, "data"),
3679                    dirs_exist_ok=True,
3680                )
3681                data_path = os.path.join(self.project_path, "data")
3682        if annotation_path is not None:
3683            if copy:
3684                os.mkdir(os.path.join(self.project_path, "annotation"))
3685                shutil.copytree(
3686                    annotation_path,
3687                    os.path.join(self.project_path, "annotation"),
3688                    dirs_exist_ok=True,
3689                )
3690                annotation_path = os.path.join(self.project_path, "annotation")
3691        self._generate_config(
3692            data_type,
3693            annotation_type,
3694            data_path=data_path,
3695            annotation_path=annotation_path,
3696        )
3697        self._generate_meta()
3698
3699    def _read_types(self) -> Tuple[str, str]:
3700        """Get data type and annotation type from existing project files."""
3701        config_path = os.path.join(self.project_path, "config", "general.yaml")
3702        with open(config_path, encoding="utf-8") as f:
3703            pars = YAML().load(f)
3704        data_type = pars["data_type"]
3705        annotation_type = pars["annotation_type"]
3706        return annotation_type, data_type
3707
3708    def _read_paths(self) -> Tuple[str, str]:
3709        """Get data type and annotation type from existing project files."""
3710        config_path = os.path.join(self.project_path, "config", "data.yaml")
3711        with open(config_path, encoding="utf-8") as f:
3712            pars = YAML().load(f)
3713        data_path = pars["data_path"]
3714        annotation_path = pars["annotation_path"]
3715        return annotation_path, data_path
3716
3717    def _generate_config(
3718        self, data_type: str, annotation_type: str, data_path: str, annotation_path: str
3719    ) -> None:
3720        """Initialize the config files."""
3721        default_path = os.path.join(
3722            os.path.dirname(os.path.dirname(__file__)), "config"
3723        )
3724        config_path = os.path.join(self.project_path, "config")
3725        files = ["losses", "metrics", "ssl", "training"]
3726        for f in files:
3727            shutil.copy(os.path.join(default_path, f"{f}.yaml"), config_path)
3728        shutil.copytree(
3729            os.path.join(default_path, "model"), os.path.join(config_path, "model")
3730        )
3731        shutil.copytree(
3732            os.path.join(default_path, "features"),
3733            os.path.join(config_path, "features"),
3734        )
3735        shutil.copytree(
3736            os.path.join(default_path, "augmentations"),
3737            os.path.join(config_path, "augmentations"),
3738        )
3739        yaml = YAML()
3740        data_param_path = os.path.join(default_path, "data", f"{data_type}.yaml")
3741        if os.path.exists(data_param_path):
3742            with open(data_param_path, encoding="utf-8") as f:
3743                data_params = yaml.load(f)
3744        if data_params is None:
3745            data_params = {}
3746        if annotation_type is None:
3747            ann_params = {}
3748        else:
3749            ann_param_path = os.path.join(
3750                default_path, "annotation", f"{annotation_type}.yaml"
3751            )
3752            if os.path.exists(ann_param_path):
3753                ann_params = self._open_yaml(ann_param_path)
3754            elif annotation_type == "none":
3755                ann_params = {}
3756            else:
3757                raise ValueError(
3758                    f"The {annotation_type} data type is not available. "
3759                    f"Please choose from {BehaviorDataset.annotation_types()}"
3760                )
3761        if ann_params is None:
3762            ann_params = {}
3763        data_params = self._update(data_params, ann_params)
3764        data_params["data_path"] = data_path
3765        data_params["annotation_path"] = annotation_path
3766        with open(os.path.join(config_path, "data.yaml"), "w", encoding="utf-8") as f:
3767            yaml.dump(data_params, f)
3768        with open(os.path.join(default_path, "general.yaml"), encoding="utf-8") as f:
3769            general_params = yaml.load(f)
3770        general_params["data_type"] = data_type
3771        general_params["annotation_type"] = annotation_type
3772        with open(
3773            os.path.join(config_path, "general.yaml"), "w", encoding="utf-8"
3774        ) as f:
3775            yaml.dump(general_params, f)
3776
3777    def _generate_meta(self) -> None:
3778        """Initialize the meta files."""
3779        config_file = os.path.join(self.project_path, "config")
3780        meta_fields = ["time"]
3781        columns = [("meta", field) for field in meta_fields]
3782        episodes = pd.DataFrame(columns=pd.MultiIndex.from_tuples(columns))
3783        episodes.to_pickle(self._episodes_path())
3784        meta_fields = ["time", "objective"]
3785        result_fields = ["best_params", "best_value"]
3786        columns = [("meta", field) for field in meta_fields] + [
3787            ("results", field) for field in result_fields
3788        ]
3789        searches = pd.DataFrame(columns=pd.MultiIndex.from_tuples(columns))
3790        searches.to_pickle(self._searches_path())
3791        meta_fields = ["time"]
3792        columns = [("meta", field) for field in meta_fields]
3793        predictions = pd.DataFrame(columns=pd.MultiIndex.from_tuples(columns))
3794        predictions.to_pickle(self._predictions_path())
3795        suggestions = pd.DataFrame(columns=pd.MultiIndex.from_tuples(columns))
3796        suggestions.to_pickle(self._suggestions_path())
3797        with open(os.path.join(config_file, "data.yaml"), encoding="utf-8") as f:
3798            data_keys = list(YAML().load(f).keys())
3799        saved_data = pd.DataFrame(columns=data_keys)
3800        saved_data.to_pickle(self._saved_datasets_path())
3801        pd.DataFrame().to_pickle(self._thresholds_path())
3802        # with open(self._version_path()) as f:
3803        #     f.write(__version__)
3804
3805    def _open_yaml(self, path: str) -> CommentedMap:
3806        """Load a parameter dictionary from a .yaml file."""
3807        with open(path, encoding="utf-8") as f:
3808            data = YAML().load(f)
3809        if data is None:
3810            data = {}
3811        return data
3812
3813    def _compare(self, d: Dict, u: Dict, allow_diff: float = 1e-7):
3814        """Compare nested dictionaries with 'almost equal' condition."""
3815        ok = True
3816        if u.keys() != d.keys():
3817            ok = False
3818        else:
3819            for k, v in u.items():
3820                if isinstance(v, Mapping):
3821                    ok = self._compare(d[k], v, allow_diff=allow_diff)
3822                else:
3823                    if isinstance(v, float) or isinstance(d[k], float):
3824                        if not isinstance(d[k], float) and not isinstance(d[k], int):
3825                            ok = False
3826                        elif not isinstance(v, float) and not isinstance(v, int):
3827                            ok = False
3828                        elif np.abs(v - d[k]) > allow_diff:
3829                            ok = False
3830                    elif v != d[k]:
3831                        ok = False
3832        return ok
3833
3834    def _check_comment(self, comment_sequence: List) -> bool:
3835        """Check if a comment already exists in a ruamel.yaml comment sequence."""
3836        if comment_sequence is None:
3837            return False
3838        c = self._get_comment(comment_sequence)
3839        if c != "":
3840            return True
3841        else:
3842            return False
3843
3844    def _get_comment(self, comment_sequence: List, strip=True) -> str:
3845        """Get the comment string from a ruamel.yaml comment sequence."""
3846        if comment_sequence is None:
3847            return ""
3848        c = ""
3849        for cm in comment_sequence:
3850            if cm is not None:
3851                if isinstance(cm, Iterable):
3852                    for c in cm:
3853                        if c is not None:
3854                            c = c.value
3855                            break
3856                    break
3857                else:
3858                    c = cm.value
3859                    break
3860        if strip:
3861            c = c.strip()
3862        return c
3863
3864    def _update(self, d: Union[CommentedMap, Dict], u: Union[CommentedMap, Dict]):
3865        """Update a nested dictionary."""
3866        if "general" in u and "model_name" in u["general"] and "model" in d:
3867            model_name = u["general"]["model_name"]
3868            if d["general"]["model_name"] != model_name:
3869                d["model"] = self._open_yaml(
3870                    os.path.join(
3871                        self.project_path, "config", "model", f"{model_name}.yaml"
3872                    )
3873                )
3874        d_copied = deepcopy(d)
3875        for k, v in u.items():
3876            if (
3877                k in d_copied
3878                and isinstance(d_copied[k], list)
3879                and isinstance(v, Mapping)
3880                and all([isinstance(x, int) for x in v.keys()])
3881            ):
3882                for kk, vv in v.items():
3883                    d_copied[k][kk] = vv
3884            elif (
3885                isinstance(v, Mapping)
3886                and k in d_copied
3887                and isinstance(d_copied[k], Mapping)
3888            ):
3889                if d_copied[k] is None:
3890                    d_k = CommentedMap()
3891                else:
3892                    d_k = d_copied[k]
3893                d_copied[k] = self._update(d_k, v)
3894            else:
3895                d_copied[k] = v
3896                if isinstance(u, CommentedMap) and u.ca.items.get(k) is not None:
3897                    c = self._get_comment(u.ca.items.get(k), strip=False)
3898                    if isinstance(d_copied, CommentedMap) and not self._check_comment(
3899                        d_copied.ca.items.get(k)
3900                    ):
3901                        d_copied.yaml_add_eol_comment(c, key=k)
3902        return d_copied
3903
3904    def _update_with_search(
3905        self,
3906        d: Dict,
3907        search_name: str,
3908        load_parameters: list = None,
3909        round_to_binary: list = None,
3910    ):
3911        """Update a dictionary with best parameters from a hyperparameter search."""
3912        u, _ = self._searches().get_best_params(
3913            search_name, load_parameters, round_to_binary
3914        )
3915        return self._update(d, u)
3916
3917    def _read_parameters(self, catch_blanks=True) -> Dict:
3918        """Compose a parameter dictionary to create a task from the config files."""
3919        config_path = os.path.join(self.project_path, "config")
3920        keys = [
3921            "data",
3922            "general",
3923            "losses",
3924            "metrics",
3925            "ssl",
3926            "training",
3927        ]
3928        parameters = {}
3929        for key in keys:
3930            parameters[key] = self._open_yaml(os.path.join(config_path, f"{key}.yaml"))
3931        features = parameters["general"]["feature_extraction"]
3932        parameters["features"] = self._open_yaml(
3933            os.path.join(config_path, "features", f"{features}.yaml")
3934        )
3935        transformer = options.extractor_to_transformer[features]
3936        parameters["augmentations"] = self._open_yaml(
3937            os.path.join(config_path, "augmentations", f"{transformer}.yaml")
3938        )
3939        model = parameters["general"]["model_name"]
3940        parameters["model"] = self._open_yaml(
3941            os.path.join(config_path, "model", f"{model}.yaml")
3942        )
3943        # input = parameters["general"]["input"]
3944        # parameters["model"] = self._open_yaml(
3945        #     os.path.join(config_path, "model", f"{model}.yaml")
3946        # )
3947        if catch_blanks:
3948            blanks = self._get_blanks()
3949            if len(blanks) > 0:
3950                self.list_blanks()
3951                raise ValueError(
3952                    f"Please fill in all the blanks before running experiments"
3953                )
3954        return parameters
3955
3956    def set_main_parameters(self, model_name: str = None, metric_names: List = None):
3957        """Select the model and the metrics.
3958
3959        Parameters
3960        ----------
3961        model_name : str, optional
3962            model name; run `project.help("model") to find out more
3963        metric_names : list, optional
3964            a list of metric function names; run `project.help("metrics") to find out more
3965
3966        """
3967        pars = {"general": {}}
3968        if model_name is not None:
3969            assert model_name in options.models
3970            pars["general"]["model_name"] = model_name
3971        if metric_names is not None:
3972            for metric in metric_names:
3973                assert metric in options.metrics
3974            pars["general"]["metric_functions"] = metric_names
3975        self.update_parameters(pars)
3976
3977    def help(self, keyword: str = None):
3978        """Get information on available options.
3979
3980        Parameters
3981        ----------
3982        keyword : str, optional
3983            the keyword for options (run without arguments to see which keywords are available)
3984
3985        """
3986        if keyword is None:
3987            print("AVAILABLE HELP FUNCTIONS:")
3988            print("- Try running `project.help(keyword)` with the following keywords:")
3989            print("    - model: to get more information on available models,")
3990            print(
3991                "    - features: to get more information on available feature extraction modes,"
3992            )
3993            print(
3994                "    - partition_method: to get more information on available train/test/val partitioning methods,"
3995            )
3996            print("    - metrics: to see a list of available metric functions.")
3997            print("    - data: to see help for expected data structure")
3998            print(
3999                "- To start working with this project, first run `project.list_blanks()` to check which parameters need to be filled in."
4000            )
4001            print(
4002                "- After a model and metrics are set, run `project.list_basic_parameters()` to see a list of the most important parameters that you might want to modify"
4003            )
4004            print(
4005                f"- If you want to dig deeper, get the full dictionary with project._read_parameters() (it is a `ruamel.yaml.comments.CommentedMap` instance)."
4006            )
4007        elif keyword == "model":
4008            print("MODELS:")
4009            for key, model in options.models.items():
4010                print(f"{key}:")
4011                print(model.__doc__)
4012        elif keyword == "features":
4013            print("FEATURE EXTRACTORS:")
4014            for key, extractor in options.feature_extractors.items():
4015                print(f"{key}:")
4016                print(extractor.__doc__)
4017        elif keyword == "partition_method":
4018            print("PARTITION METHODS:")
4019            print(
4020                BehaviorDataset.partition_train_test_val.__doc__.split(
4021                    "The partitioning method:"
4022                )[1].split("val_frac :")[0]
4023            )
4024        elif keyword == "metrics":
4025            print("METRICS:")
4026            for key, metric in options.metrics.items():
4027                print(f"{key}:")
4028                print(metric.__doc__)
4029        elif keyword == "data":
4030            print("DATA:")
4031            print(f"Video data: {self.data_type}")
4032            print(options.input_stores[self.data_type].__doc__)
4033            print(f"Annotation data: {self.annotation_type}")
4034            print(options.annotation_stores[self.annotation_type].__doc__)
4035            print(
4036                "Annotation path and data path don't have to be separate, you can keep everything in one folder."
4037            )
4038        else:
4039            raise ValueError(f"The {keyword} keyword is not recognized")
4040        print("\n")
4041
4042    def _process_value(self, value):
4043        """Process a configuration value for display.
4044
4045        Parameters
4046        ----------
4047        value : any
4048            the value to process
4049
4050        Returns
4051        -------
4052        processed_value : any
4053            the processed value
4054
4055        """
4056        if isinstance(value, str):
4057            value = f'"{value}"'
4058        elif isinstance(value, CommentedSet):
4059            value = {x for x in value}
4060        return value
4061
4062    def _get_blanks(self):
4063        """Get a list of blank (unset) parameters in the configuration.
4064
4065        Returns
4066        -------
4067        caught : list
4068            a list of parameter keys that have blank values
4069
4070        """
4071        caught = []
4072        parameters = self._read_parameters(catch_blanks=False)
4073        for big_key, big_value in parameters.items():
4074            for key, value in big_value.items():
4075                if value == "???":
4076                    caught.append(
4077                        (big_key, key, self._get_comment(big_value.ca.items.get(key)))
4078                    )
4079        return caught
4080
4081    def list_blanks(self, blanks=None):
4082        """List parameters that need to be filled in.
4083
4084        Parameters
4085        ----------
4086        blanks : list, optional
4087            a list of the parameters to list, if already known
4088
4089        """
4090        if blanks is None:
4091            blanks = self._get_blanks()
4092        if len(blanks) > 0:
4093            to_update = defaultdict(lambda: [])
4094            for b, k, c in blanks:
4095                to_update[b].append((k, c))
4096            print("Before running experiments, please update all the blanks.")
4097            print("To do that, you can run this.")
4098            print("--------------------------------------------------------")
4099            print(f"project.update_parameters(")
4100            print(f"    {{")
4101            for big_key, keys in to_update.items():
4102                print(f'        "{big_key}": {{')
4103                for key, comment in keys:
4104                    print(f'            "{key}": ..., {comment}')
4105                print(f"        }}")
4106            print(f"    }}")
4107            print(")")
4108            print("--------------------------------------------------------")
4109            print("Replace ... with relevant values.")
4110        else:
4111            print("There is no blanks left!")
4112
4113    def list_basic_parameters(
4114        self,
4115    ):
4116        """Get a list of most relevant parameters and code to modify them."""
4117        parameters = self._read_parameters()
4118        print("BASIC PARAMETERS:")
4119        model_name = parameters["general"]["model_name"]
4120        metric_names = parameters["general"]["metric_functions"]
4121        loss_name = parameters["general"]["loss_function"]
4122        feature_extraction = parameters["general"]["feature_extraction"]
4123        print("Here is a list of current parameters.")
4124        print(
4125            "You can copy this code, change the parameters you want to set and run it to update the project config."
4126        )
4127        print("--------------------------------------------------------")
4128        print("project.update_parameters(")
4129        print("    {")
4130        for group in ["general", "data", "training"]:
4131            print(f'        "{group}": {{')
4132            for key in options.basic_parameters[group]:
4133                if key in parameters[group]:
4134                    print(
4135                        f'            "{key}": {self._process_value(parameters[group][key])}, {self._get_comment(parameters[group].ca.items.get(key))}'
4136                    )
4137            print("        },")
4138        print('        "losses": {')
4139        print(f'            "{loss_name}": {{')
4140        for key in options.basic_parameters["losses"][loss_name]:
4141            if key in parameters["losses"][loss_name]:
4142                print(
4143                    f'                "{key}": {self._process_value(parameters["losses"][loss_name][key])}, {self._get_comment(parameters["losses"][loss_name].ca.items.get(key))}'
4144                )
4145        print("            },")
4146        print("        },")
4147        print('        "metrics": {')
4148        for metric in metric_names:
4149            print(f'            "{metric}": {{')
4150            for key in parameters["metrics"][metric]:
4151                print(
4152                    f'                "{key}": {self._process_value(parameters["metrics"][metric][key])}, {self._get_comment(parameters["metrics"][metric].ca.items.get(key))}'
4153                )
4154            print("            },")
4155        print("        },")
4156        print('        "model": {')
4157        for key in options.basic_parameters["model"][model_name]:
4158            if key in parameters["model"]:
4159                print(
4160                    f'            "{key}": {self._process_value(parameters["model"][key])}, {self._get_comment(parameters["model"].ca.items.get(key))}'
4161                )
4162
4163        print("        },")
4164        print('        "features": {')
4165        for key in options.basic_parameters["features"][feature_extraction]:
4166            if key in parameters["features"]:
4167                print(
4168                    f'            "{key}": {self._process_value(parameters["features"][key])}, {self._get_comment(parameters["features"].ca.items.get(key))}'
4169                )
4170
4171        print("        },")
4172        print('        "augmentations": {')
4173        for key in options.basic_parameters["augmentations"][feature_extraction]:
4174            if key in parameters["augmentations"]:
4175                print(
4176                    f'            "{key}": {self._process_value(parameters["augmentations"][key])}, {self._get_comment(parameters["augmentations"].ca.items.get(key))}'
4177                )
4178        print("        },")
4179        print("    },")
4180        print(")")
4181        print("--------------------------------------------------------")
4182        print("\n")
4183
4184    def _create_record(
4185        self,
4186        episode_name: str,
4187        behaviors_dict: Dict,
4188        load_episode: str = None,
4189        parameters_update: Dict = None,
4190        task: TaskDispatcher = None,
4191        load_epoch: int = None,
4192        load_search: str = None,
4193        load_parameters: list = None,
4194        round_to_binary: list = None,
4195        load_strict: bool = True,
4196        n_seeds: int = 1,
4197    ) -> TaskDispatcher:
4198        """Create a meta data episode record."""
4199        if episode_name in self._episodes().data.index:
4200            return
4201        if type(n_seeds) is not int or n_seeds < 1:
4202            raise ValueError(
4203                f"The n_seeds parameter has to be an integer larger than 0; got {n_seeds}"
4204            )
4205        if parameters_update is None:
4206            parameters_update = {}
4207        parameters = self._read_parameters()
4208        parameters = self._update(parameters, parameters_update)
4209        if load_search is not None:
4210            parameters = self._update_with_search(
4211                parameters, load_search, load_parameters, round_to_binary
4212            )
4213        parameters = self._fill(
4214            parameters,
4215            episode_name,
4216            load_episode,
4217            load_epoch=load_epoch,
4218            only_load_model=True,
4219            load_strict=load_strict,
4220            continuing=True,
4221        )
4222        self._save_episode(episode_name, parameters, behaviors_dict)
4223        return task
4224
4225    def _save_thresholds(
4226        self,
4227        episode_names: List,
4228        metric_name: str,
4229        parameters: Dict,
4230        thresholds: List,
4231        load_epochs: List,
4232    ):
4233        """Save optimal decision thresholds in the meta records."""
4234        metric_parameters = parameters["metrics"][metric_name]
4235        self._thresholds().save_thresholds(
4236            episode_names, load_epochs, metric_name, metric_parameters, thresholds
4237        )
4238
4239    def _save_episode(
4240        self,
4241        episode_name: str,
4242        parameters: Dict,
4243        behaviors_dict: Dict,
4244        suppress_validation: bool = False,
4245        training_time: str = None,
4246        norm_stats: Dict = None,
4247    ) -> None:
4248        """Save an episode in the meta files."""
4249        try:
4250            split_info = self._split_info_from_filename(
4251                parameters["training"]["split_path"]
4252            )
4253            parameters["training"]["partition_method"] = split_info["partition_method"]
4254        except:
4255            pass
4256        if norm_stats is not None:
4257            norm_stats = dict(norm_stats)
4258        parameters["training"]["stats"] = norm_stats
4259        self._episodes().save_episode(
4260            episode_name,
4261            parameters,
4262            behaviors_dict,
4263            suppress_validation=suppress_validation,
4264            training_time=training_time,
4265        )
4266
4267    def _save_suggestions(
4268        self, suggestions_name: str, parameters: Dict, meta_parameters: Dict
4269    ) -> None:
4270        """Save a suggestion in the meta files."""
4271        self._suggestions().save_suggestion(
4272            suggestions_name, parameters, meta_parameters
4273        )
4274
4275    def _update_episode_results(
4276        self,
4277        episode_name: str,
4278        logs: Tuple,
4279        training_time: str = None,
4280    ) -> None:
4281        """Save the results of a run in the meta files."""
4282        self._episodes().update_episode_results(episode_name, logs, training_time)
4283
4284    def _save_prediction(
4285        self,
4286        prediction_name: str,
4287        predicted: Dict[str, Dict],
4288        parameters: Dict,
4289        task: TaskDispatcher,
4290        mode: str = "test",
4291        embedding: bool = False,
4292        inference_time: str = None,
4293        behavior_dict: List[Dict[str, Any]] = None,
4294    ) -> None:
4295        """Save a prediction in the meta files."""
4296
4297        folder = self.prediction_path(prediction_name)
4298        os.mkdir(folder)
4299        for video_id, prediction in predicted.items():
4300            with open(
4301                os.path.join(
4302                    folder, video_id + f"_{prediction_name}_prediction.pickle"
4303                ),
4304                "wb",
4305            ) as f:
4306                prediction["min_frames"], prediction["max_frames"] = task.dataset(
4307                    mode
4308                ).get_min_max_frames(video_id)
4309                prediction["classes"] = behavior_dict
4310                pickle.dump(prediction, f)
4311
4312        parameters = self._update(
4313            parameters,
4314            {"meta": {"embedding": embedding, "inference_time": inference_time}},
4315        )
4316        self._predictions().save_episode(
4317            prediction_name, parameters, task.behaviors_dict()
4318        )
4319
4320    def _save_search(
4321        self,
4322        search_name: str,
4323        parameters: Dict,
4324        n_trials: int,
4325        best_params: Dict,
4326        best_value: float,
4327        metric: str,
4328        search_space: Dict,
4329    ) -> None:
4330        """Save a hyperparameter search in the meta files."""
4331        self._searches().save_search(
4332            search_name,
4333            parameters,
4334            n_trials,
4335            best_params,
4336            best_value,
4337            metric,
4338            search_space,
4339        )
4340
4341    def _save_stores(self, parameters: Dict) -> None:
4342        """Save a pickled dataset in the meta files."""
4343        name = os.path.basename(parameters["data"]["feature_save_path"])
4344        self._saved_datasets().save_store(name, self._get_data_pars(parameters))
4345        self.create_metadata_backup()
4346
4347    def _remove_stores(self, parameters: Dict, remove_active: bool = False) -> None:
4348        """Remove the pre-computed features folder."""
4349        name = os.path.basename(parameters["data"]["feature_save_path"])
4350        if remove_active or name not in self._episodes().get_active_datasets():
4351            self.remove_saved_features([name])
4352
4353    def _check_episode_validity(
4354        self, episode_name: str, allow_doublecolon: bool = False, force: bool = False
4355    ) -> None:
4356        """Check whether the episode name is valid."""
4357        if episode_name.startswith("_"):
4358            raise ValueError(
4359                "Names starting with an underscore are reserved by dlc2action and cannot be used!"
4360            )
4361        elif "." in episode_name:
4362            raise ValueError("Names containing '.' cannot be used!")
4363        if not allow_doublecolon and "#" in episode_name:
4364            raise ValueError(
4365                "Names containing '#' are reserved by dlc2action and cannot be used!"
4366            )
4367        if "::" in episode_name:
4368            raise ValueError(
4369                "Names containing '::' are reserved by dlc2action and cannot be used!"
4370            )
4371        if force:
4372            self.remove_episode(episode_name)
4373        elif not self._episodes().check_name_validity(episode_name):
4374            raise ValueError(
4375                f"The {episode_name} name is already taken! Set force=True to overwrite."
4376            )
4377
4378    def _check_search_validity(self, search_name: str, force: bool = False) -> None:
4379        """Check whether the search name is valid."""
4380        if search_name.startswith("_"):
4381            raise ValueError(
4382                "Names starting with an underscore are reserved by dlc2action and cannot be used!"
4383            )
4384        elif "." in search_name:
4385            raise ValueError("Names containing '.' cannot be used!")
4386        if force:
4387            self.remove_search(search_name)
4388        elif not self._searches().check_name_validity(search_name):
4389            raise ValueError(f"The {search_name} name is already taken!")
4390
4391    def _check_prediction_validity(
4392        self, prediction_name: str, force: bool = False
4393    ) -> None:
4394        """Check whether the prediction name is valid."""
4395        if prediction_name.startswith("_"):
4396            raise ValueError(
4397                "Names starting with an underscore are reserved by dlc2action and cannot be used!"
4398            )
4399        elif "." in prediction_name:
4400            raise ValueError("Names containing '.' cannot be used!")
4401        if force:
4402            self.remove_prediction(prediction_name)
4403        elif not self._predictions().check_name_validity(prediction_name):
4404            raise ValueError(f"The {prediction_name} name is already taken!")
4405
4406    def _check_suggestions_validity(
4407        self, suggestions_name: str, force: bool = False
4408    ) -> None:
4409        """Check whether the suggestions name is valid."""
4410        if suggestions_name.startswith("_"):
4411            raise ValueError(
4412                "Names starting with an underscore are reserved by dlc2action and cannot be used!"
4413            )
4414        elif "." in suggestions_name:
4415            raise ValueError("Names containing '.' cannot be used!")
4416        if force:
4417            self.remove_suggestion(suggestions_name)
4418        elif not self._suggestions().check_name_validity(suggestions_name):
4419            raise ValueError(f"The {suggestions_name} name is already taken!")
4420
4421    def _training_time(self, episode_name: str) -> int:
4422        """Get the training time of an episode in seconds."""
4423        return self._episode(episode_name).training_time()
4424
4425    def _mask_path(self) -> str:
4426        """Get the path to the masks folder.
4427
4428        Returns
4429        -------
4430        path : str
4431            the path to the masks folder
4432
4433        """
4434        return os.path.join(self.project_path, "results", "masks")
4435
4436    def _thresholds_path(self) -> str:
4437        """Get the path to the thresholds meta file.
4438
4439        Returns
4440        -------
4441        path : str
4442            the path to the thresholds meta file
4443
4444        """
4445        return os.path.join(self.project_path, "meta", "thresholds.pickle")
4446
4447    def _episodes_path(self) -> str:
4448        """Get the path to the episodes meta file.
4449
4450        Returns
4451        -------
4452        path : str
4453            the path to the episodes meta file
4454
4455        """
4456        return os.path.join(self.project_path, "meta", "episodes.pickle")
4457
4458    def _suggestions_path(self) -> str:
4459        """Get the path to the suggestions meta file.
4460
4461        Returns
4462        -------
4463        path : str
4464            the path to the suggestions meta file
4465
4466        """
4467        return os.path.join(self.project_path, "meta", "suggestions.pickle")
4468
4469    def _saved_datasets_path(self) -> str:
4470        """Get the path to the datasets meta file.
4471
4472        Returns
4473        -------
4474        path : str
4475            the path to the datasets meta file
4476
4477        """
4478        return os.path.join(self.project_path, "meta", "saved_datasets.pickle")
4479
4480    def _predictions_path(self) -> str:
4481        """Get the path to the predictions meta file.
4482
4483        Returns
4484        -------
4485        path : str
4486            the path to the predictions meta file
4487
4488        """
4489        return os.path.join(self.project_path, "meta", "predictions.pickle")
4490
4491    def _dataset_store_path(self, name: str) -> str:
4492        """Get the path to a specific pickled dataset.
4493
4494        Parameters
4495        ----------
4496        name : str
4497            the name of the dataset
4498
4499        Returns
4500        -------
4501        path : str
4502            the path to the dataset file
4503
4504        """
4505        return os.path.join(self.project_path, "saved_datasets", f"{name}.pickle")
4506
4507    def _al_points_path(self, suggestions_name: str) -> str:
4508        """Get the path to an active learning intervals file.
4509
4510        Parameters
4511        ----------
4512        suggestions_name : str
4513            the name of the suggestions
4514
4515        Returns
4516        -------
4517        path : str
4518            the path to the active learning points file
4519
4520        """
4521        path = os.path.join(
4522            self.project_path,
4523            "results",
4524            "suggestions",
4525            suggestions_name,
4526            f"{suggestions_name}_al_points.pickle",
4527        )
4528        return path
4529
4530    def _suggestion_path(self, v_id: str, suggestions_name: str) -> str:
4531        """Get the path to a suggestion file.
4532
4533        Parameters
4534        ----------
4535        v_id : str
4536            the video ID
4537        suggestions_name : str
4538            the name of the suggestions
4539
4540        Returns
4541        -------
4542        path : str
4543            the path to the suggestion file
4544
4545        """
4546        path = os.path.join(
4547            self.project_path,
4548            "results",
4549            "suggestions",
4550            suggestions_name,
4551            f"{v_id}_suggestion.pickle",
4552        )
4553        return path
4554
4555    def _searches_path(self) -> str:
4556        """Get the path to the hyperparameter search meta file.
4557
4558        Returns
4559        -------
4560        path : str
4561            the path to the searches meta file
4562
4563        """
4564        return os.path.join(self.project_path, "meta", "searches.pickle")
4565
4566    def _search_path(self, name: str) -> str:
4567        """Get the default path to the graph folder for a specific hyperparameter search.
4568
4569        Parameters
4570        ----------
4571        name : str
4572            the name of the search
4573
4574        Returns
4575        -------
4576        path : str
4577            the path to the search folder
4578
4579        """
4580        return os.path.join(self.project_path, "results", "searches", name)
4581
4582    def _version_path(self) -> str:
4583        """Get the path to the version file.
4584
4585        Returns
4586        -------
4587        path : str
4588            the path to the version file
4589
4590        """
4591        return os.path.join(self.project_path, "meta", "version.txt")
4592
4593    def _default_split_file(self, split_info: Dict) -> Optional[str]:
4594        """Generate a path to a split file from split parameters.
4595
4596        Parameters
4597        ----------
4598        split_info : dict
4599            the split parameters dictionary
4600
4601        Returns
4602        -------
4603        split_file_path : str or None
4604            the path to the split file, or None if not applicable
4605
4606        """
4607        if split_info["partition_method"].startswith("time"):
4608            return None
4609        val_frac = split_info["val_frac"]
4610        test_frac = split_info["test_frac"]
4611        split_name = f'{split_info["partition_method"]}_val{val_frac * 100}%_test{test_frac * 100}%_len{split_info["len_segment"]}_overlap{split_info["overlap"]}'
4612        if not split_info["only_load_annotated"]:
4613            split_name += "_all"
4614        split_name += ".txt"
4615        return os.path.join(self.project_path, "results", "splits", split_name)
4616
4617    def _split_info_from_filename(self, split_name: str) -> Dict:
4618        """Get split parameters from default path to a split file.
4619
4620        Parameters
4621        ----------
4622        split_name : str
4623            the name/path of the split file
4624
4625        Returns
4626        -------
4627        split_info : dict
4628            the split parameters dictionary
4629
4630        """
4631        if split_name is None:
4632            return {}
4633        try:
4634            name = os.path.basename(split_name)[:-4]
4635            split = name.split("_")
4636            if len(split) == 6:
4637                only_load_annotated = False
4638            else:
4639                only_load_annotated = True
4640            len_segment = int(split[3][3:])
4641            overlap = float(split[4][7:])
4642            if overlap > 1:
4643                overlap = int(overlap)
4644            method, val, test = split[:3]
4645            val = float(val[3:-1]) / 100
4646            test = float(test[4:-1]) / 100
4647            return {
4648                "partition_method": method,
4649                "val_frac": val,
4650                "test_frac": test,
4651                "only_load_annotated": only_load_annotated,
4652                "len_segment": len_segment,
4653                "overlap": overlap,
4654            }
4655        except:
4656            return {"partition_method": "file"}
4657
4658    def _fill(
4659        self,
4660        parameters: Dict,
4661        episode_name: str,
4662        load_experiment: str = None,
4663        load_epoch: int = None,
4664        load_strict: bool = True,
4665        only_load_model: bool = False,
4666        continuing: bool = False,
4667        enforce_split_parameters: bool = False,
4668    ) -> Dict:
4669        """Update the parameters from the config files with project specific information.
4670
4671        Fill in the constant file path parameters and generate a unique log file and a model folder.
4672        Fill in the split file if the same split has been run before in the project and change partition method to
4673        from_file.
4674        Fill in saved data path if a dataset with the same data parameters already exists in the project.
4675        If load_experiment is not None, fill in the checkpoint path as well.
4676        The only_load_model training parameter is defined by the corresponding argument.
4677        If continuing is True, new files are not created and all information is loaded from load_experiment.
4678        If prediction is True, log and model files are not created.
4679        The enforce_split_parameters parameter is used to resolve conflicts
4680        between split file path and split parameters when they arise.
4681
4682        Parameters
4683        ----------
4684        parameters : dict
4685            the parameters dictionary to update
4686        episode_name : str
4687            the name of the episode
4688        load_experiment : str, optional
4689            the name of the experiment to load from
4690        load_epoch : int, optional
4691            the epoch to load (by default the last one)
4692        load_strict : bool, default True
4693            if `True`, strict loading is enforced
4694        only_load_model : bool, default False
4695            if `True`, only the model is loaded
4696        continuing : bool, default False
4697            if `True`, continues from existing files
4698        enforce_split_parameters : bool, default False
4699            if `True`, split parameters are enforced
4700
4701        Returns
4702        -------
4703        parameters : dict
4704            the updated parameters dictionary
4705
4706        """
4707        pars = deepcopy(parameters)
4708        if episode_name == "_":
4709            self.remove_episode("_")
4710        log = os.path.join(self.project_path, "results", "logs", f"{episode_name}.txt")
4711        model_save_path = os.path.join(
4712            self.project_path, "results", "model", episode_name
4713        )
4714        if not continuing and (os.path.exists(log) or os.path.exists(model_save_path)):
4715            raise ValueError(
4716                f"The {episode_name} episode name is already in use! Set force=True to overwrite."
4717            )
4718        keys = ["val_frac", "test_frac", "partition_method"]
4719        if "len_segment" not in pars["general"] and "len_segment" in pars["data"]:
4720            pars["general"]["len_segment"] = pars["data"]["len_segment"]
4721        if "overlap" not in pars["general"] and "overlap" in pars["data"]:
4722            pars["general"]["overlap"] = pars["data"]["overlap"]
4723        if "len_segment" in pars["data"]:
4724            pars["data"].pop("len_segment")
4725        if "overlap" in pars["data"]:
4726            pars["data"].pop("overlap")
4727        split_info = {k: pars["training"][k] for k in keys}
4728        split_info["only_load_annotated"] = pars["general"]["only_load_annotated"]
4729        split_info["len_segment"] = pars["general"]["len_segment"]
4730        split_info["overlap"] = pars["general"]["overlap"]
4731        pars["training"]["log_file"] = log
4732        if not os.path.exists(model_save_path):
4733            os.mkdir(model_save_path)
4734        pars["training"]["model_save_path"] = model_save_path
4735        if load_experiment is not None:
4736            if load_experiment not in self._episodes().data.index:
4737                raise ValueError(f"The {load_experiment} episode does not exist!")
4738            old_episode = self._episode(load_experiment)
4739            old_file = old_episode.split_file()
4740            old_info = self._split_info_from_filename(old_file)
4741            if len(old_info) == 0:
4742                old_info = old_episode.split_info()
4743            if enforce_split_parameters:
4744                if split_info["partition_method"] != "file":
4745                    pars["training"]["split_path"] = self._default_split_file(
4746                        split_info
4747                    )
4748            else:
4749                equal = True
4750                if old_info["partition_method"] != split_info["partition_method"]:
4751                    equal = False
4752                if old_info["partition_method"] != "file":
4753                    if (
4754                        old_info["val_frac"] != split_info["val_frac"]
4755                        or old_info["test_frac"] != split_info["test_frac"]
4756                    ):
4757                        equal = False
4758                if not continuing and not equal:
4759                    warnings.warn(
4760                        f"The partitioning parameters in the loaded experiment ({old_info}) "
4761                        f"are not equal to the current partitioning parameters ({split_info}). "
4762                        f"The current parameters are replaced."
4763                    )
4764                pars["training"]["split_path"] = old_file
4765                for k, v in old_info.items():
4766                    pars["training"][k] = v
4767            pars["training"]["checkpoint_path"] = old_episode.model_file(load_epoch)
4768            pars["training"]["load_strict"] = load_strict
4769        else:
4770            pars["training"]["checkpoint_path"] = None
4771            if pars["training"]["partition_method"] == "file":
4772                if (
4773                    "split_path" not in pars["training"]
4774                    or pars["training"]["split_path"] is None
4775                ):
4776                    raise ValueError(
4777                        "The partition_method parameter is set to file but the "
4778                        "split_path parameter is not set!"
4779                    )
4780                elif not os.path.exists(pars["training"]["split_path"]):
4781                    raise ValueError(
4782                        f'The {pars["training"]["split_path"]} split file does not exist'
4783                    )
4784            else:
4785                pars["training"]["split_path"] = self._default_split_file(split_info)
4786        pars["training"]["only_load_model"] = only_load_model
4787        pars["data"]["saved_data_path"] = None
4788        pars["data"]["feature_save_path"] = None
4789        pars_data_copy = self._get_data_pars(pars)
4790        saved_data_name = self._saved_datasets().find_name(pars_data_copy)
4791        if saved_data_name is not None:
4792            pars["data"]["saved_data_path"] = self._dataset_store_path(saved_data_name)
4793            pars["data"]["feature_save_path"] = self._dataset_store_path(
4794                saved_data_name
4795            ).split(".")[0]
4796        else:
4797            dataset_path = self._dataset_store_path(episode_name)
4798            if os.path.exists(dataset_path):
4799                name, ext = dataset_path.split(".")
4800                i = 0
4801                while os.path.exists(f"{name}_{i}.{ext}"):
4802                    i += 1
4803                dataset_path = f"{name}_{i}.{ext}"
4804            pars["data"]["saved_data_path"] = dataset_path
4805            pars["data"]["feature_save_path"] = dataset_path.split(".")[0]
4806        split_split = pars["training"]["partition_method"].split(":")
4807        random = True
4808        for partition_method in options.partition_methods["fixed"]:
4809            method_split = partition_method.split(":")
4810            if len(split_split) != len(method_split):
4811                continue
4812            equal = True
4813            for x, y in zip(split_split, method_split):
4814                if y.startswith("{"):
4815                    continue
4816                if x != y:
4817                    equal = False
4818                    break
4819            if equal:
4820                random = False
4821                break
4822        if random and os.path.exists(pars["training"]["split_path"]):
4823            pars["training"]["partition_method"] = "file"
4824        pars["general"]["save_dataset"] = True
4825        # Check len_segment for c2f models
4826        if pars["general"]["model_name"].startswith("c2f"):
4827            if int(pars["general"]["len_segment"]) < 512:
4828                raise ValueError(
4829                    "The segment length should be higher than 512 when using one of the C2F models"
4830                )
4831        return pars
4832
4833    def _get_data_pars(self, pars: Dict) -> Dict:
4834        """Get a complete description of the data from a general parameters dictionary.
4835
4836        Parameters
4837        ----------
4838        pars : dict
4839            the general parameters dictionary
4840
4841        Returns
4842        -------
4843        pars_data : dict
4844            the complete data parameters dictionary
4845
4846        """
4847        pars_data_copy = deepcopy(pars["data"])
4848        for par in [
4849            "only_load_annotated",
4850            "exclusive",
4851            "feature_extraction",
4852            "ignored_clips",
4853            "len_segment",
4854            "overlap",
4855        ]:
4856            pars_data_copy[par] = pars["general"].get(par, None)
4857        pars_data_copy.update(pars["features"])
4858        return pars_data_copy
4859
4860    def _make_al_points_from_suggestions(
4861        self,
4862        suggestions_name: str,
4863        task: TaskDispatcher,
4864        predicted_classes: Dict,
4865        background_threshold: Optional[float],
4866        visibility_min_score: float,
4867        visibility_min_frac: float,
4868        num_behaviors: int,
4869    ):
4870        valleys = []
4871        if background_threshold is not None:
4872            for i in range(num_behaviors):
4873                print(f"generating background for behavior {i}...")
4874                valleys.append(
4875                    task.dataset("train").find_valleys(
4876                        predicted_classes,
4877                        threshold=background_threshold,
4878                        visibility_min_score=visibility_min_score,
4879                        visibility_min_frac=visibility_min_frac,
4880                        main_class=i,
4881                        low=True,
4882                        cut_annotated=True,
4883                        min_frames=1,
4884                    )
4885                )
4886        valleys = task.dataset("train").valleys_intersection(valleys)
4887        folder = os.path.join(
4888            self.project_path, "results", "suggestions", suggestions_name
4889        )
4890        os.makedirs(os.path.dirname(folder), exist_ok=True)
4891        res = {}
4892        for file in os.listdir(folder):
4893            video_id = file.split("_suggestion.p")[0]
4894            res[video_id] = []
4895            with open(os.path.join(folder, file), "rb") as f:
4896                data = pickle.load(f)
4897            for clip_id, ind_list in zip(data[2], data[3]):
4898                max_len = max(
4899                    [
4900                        max([x[1] for x in cat_list]) if len(cat_list) > 0 else 0
4901                        for cat_list in ind_list
4902                    ]
4903                )
4904                if max_len == 0:
4905                    continue
4906                arr = torch.zeros(max_len)
4907                for cat_list in ind_list:
4908                    for start, end, amb in cat_list:
4909                        arr[start:end] = 1
4910                if video_id in valleys:
4911                    for start, end, clip in valleys[video_id]:
4912                        if clip == clip_id:
4913                            arr[start:end] = 1
4914                output, indices, counts = torch.unique_consecutive(
4915                    arr > 0, return_inverse=True, return_counts=True
4916                )
4917                long_indices = torch.where(output)[0]
4918                res[video_id] += [
4919                    (
4920                        (indices == i).nonzero(as_tuple=True)[0][0].item(),
4921                        (indices == i).nonzero(as_tuple=True)[0][-1].item(),
4922                        clip_id,
4923                    )
4924                    for i in long_indices
4925                ]
4926        return res
4927
4928    def _make_al_points(
4929        self,
4930        task: TaskDispatcher,
4931        predicted_error: torch.Tensor,
4932        predicted_classes: torch.Tensor,
4933        exclude_classes: List,
4934        exclude_threshold: List,
4935        exclude_threshold_diff: List,
4936        exclude_hysteresis: List,
4937        include_classes: List,
4938        include_threshold: List,
4939        include_threshold_diff: List,
4940        include_hysteresis: List,
4941        error_episode: str = None,
4942        error_class: str = None,
4943        suggestion_episodes: List = None,
4944        error_threshold: float = 0.5,
4945        error_threshold_diff: float = 0.1,
4946        error_hysteresis: bool = False,
4947        min_frames_al: int = 30,
4948        visibility_min_score: float = 5,
4949        visibility_min_frac: float = 0.7,
4950    ) -> Dict:
4951        """Generate an active learning file."""
4952        if len(exclude_classes) > 0 or len(include_classes) > 0:
4953            valleys = []
4954            included = None
4955            excluded = None
4956            for class_name, thr, thr_diff, hysteresis in zip(
4957                exclude_classes,
4958                exclude_threshold,
4959                exclude_threshold_diff,
4960                exclude_hysteresis,
4961            ):
4962                episode = self._episodes().get_runs(suggestion_episodes[0])[0]
4963                class_index = self._episode(episode).get_class_ind(class_name)
4964                valleys.append(
4965                    task.dataset("train").find_valleys(
4966                        predicted_classes,
4967                        predicted_error=predicted_error,
4968                        min_frames=min_frames_al,
4969                        threshold=thr,
4970                        visibility_min_score=visibility_min_score,
4971                        visibility_min_frac=visibility_min_frac,
4972                        error_threshold=error_threshold,
4973                        main_class=class_index,
4974                        low=True,
4975                        threshold_diff=thr_diff,
4976                        min_frames_error=min_frames_al,
4977                        hysteresis=hysteresis,
4978                    )
4979                )
4980            if len(valleys) > 0:
4981                included = task.dataset("train").valleys_union(valleys)
4982            valleys = []
4983            for class_name, thr, thr_diff, hysteresis in zip(
4984                include_classes,
4985                include_threshold,
4986                include_threshold_diff,
4987                include_hysteresis,
4988            ):
4989                episode = self._episodes().get_runs(suggestion_episodes[0])[0]
4990                class_index = self._episode(episode).get_class_ind(class_name)
4991                valleys.append(
4992                    task.dataset("train").find_valleys(
4993                        predicted_classes,
4994                        predicted_error=predicted_error,
4995                        min_frames=min_frames_al,
4996                        threshold=thr,
4997                        visibility_min_score=visibility_min_score,
4998                        visibility_min_frac=visibility_min_frac,
4999                        error_threshold=error_threshold,
5000                        main_class=class_index,
5001                        low=False,
5002                        threshold_diff=thr_diff,
5003                        min_frames_error=min_frames_al,
5004                        hysteresis=hysteresis,
5005                    )
5006                )
5007            if len(valleys) > 0:
5008                excluded = task.dataset("train").valleys_union(valleys)
5009            al_points = task.dataset("train").valleys_intersection([included, excluded])
5010        else:
5011            class_index = self._episode(error_episode).get_class_ind(error_class)
5012            print("generating active learning intervals...")
5013            al_points = task.dataset("train").find_valleys(
5014                predicted_error,
5015                min_frames=min_frames_al,
5016                threshold=error_threshold,
5017                visibility_min_score=visibility_min_score,
5018                visibility_min_frac=visibility_min_frac,
5019                main_class=class_index,
5020                low=True,
5021                threshold_diff=error_threshold_diff,
5022                min_frames_error=min_frames_al,
5023                hysteresis=error_hysteresis,
5024            )
5025        for v_id in al_points:
5026            clip_dict = defaultdict(lambda: [])
5027            res = []
5028            for x in al_points[v_id]:
5029                clip_dict[x[-1]].append(x)
5030            for clip_id in clip_dict:
5031                clip_dict[clip_id] = sorted(clip_dict[clip_id])
5032                i = 0
5033                j = 1
5034                while j < len(clip_dict[clip_id]):
5035                    end = clip_dict[clip_id][i][1]
5036                    start = clip_dict[clip_id][j][0]
5037                    if start - end < 30:
5038                        clip_dict[clip_id][i][1] = clip_dict[clip_id][j][1]
5039                    else:
5040                        res.append(clip_dict[clip_id][i])
5041                        i = j
5042                    j += 1
5043                res.append(clip_dict[clip_id][i])
5044            al_points[v_id] = sorted(res)
5045        return al_points
5046
5047    def _make_suggestions(
5048        self,
5049        task: TaskDispatcher,
5050        predicted_error: torch.Tensor,
5051        predicted_classes: torch.Tensor,
5052        suggestion_threshold: List,
5053        suggestion_threshold_diff: List,
5054        suggestion_hysteresis: List,
5055        suggestion_episodes: List = None,
5056        suggestion_classes: List = None,
5057        error_threshold: float = 0.5,
5058        min_frames_suggestion: int = 3,
5059        min_frames_al: int = 30,
5060        visibility_min_score: float = 0,
5061        visibility_min_frac: float = 0.7,
5062        cut_annotated: bool = False,
5063    ) -> Dict:
5064        """Make a suggestions dictionary."""
5065        suggestions = defaultdict(lambda: {})
5066        for class_name, thr, thr_diff, hysteresis in zip(
5067            suggestion_classes,
5068            suggestion_threshold,
5069            suggestion_threshold_diff,
5070            suggestion_hysteresis,
5071        ):
5072            episode = self._episodes().get_runs(suggestion_episodes[0])[0]
5073            class_index = self._episode(episode).get_class_ind(class_name)
5074            print(f"generating suggestions for {class_name}...")
5075            found = task.dataset("train").find_valleys(
5076                predicted_classes,
5077                smooth_interval=2,
5078                predicted_error=predicted_error,
5079                min_frames=min_frames_suggestion,
5080                threshold=thr,
5081                visibility_min_score=visibility_min_score,
5082                visibility_min_frac=visibility_min_frac,
5083                error_threshold=error_threshold,
5084                main_class=class_index,
5085                low=False,
5086                threshold_diff=thr_diff,
5087                min_frames_error=min_frames_al,
5088                hysteresis=hysteresis,
5089                cut_annotated=cut_annotated,
5090            )
5091            for v_id in found:
5092                suggestions[v_id][class_name] = found[v_id]
5093        suggestions = dict(suggestions)
5094        return suggestions
5095
5096    def count_classes(
5097        self,
5098        load_episode: str = None,
5099        parameters_update: Dict = None,
5100        remove_saved_features: bool = False,
5101        bouts: bool = True,
5102    ) -> Dict:
5103        """Get a dictionary of class counts in different modes.
5104
5105        Parameters
5106        ----------
5107        load_episode : str, optional
5108            the episode settings to load
5109        parameters_update : dict, optional
5110            a dictionary of parameter updates (only for "data" and "general" categories)
5111        remove_saved_features : bool, default False
5112            if `True`, the dataset that is used for computation is then deleted
5113        bouts : bool, default False
5114            if `True`, instead of frame counts segment counts are returned
5115
5116        Returns
5117        -------
5118        class_counts : dict
5119            a dictionary where first-level keys are "train", "val" and "test", second-level keys are
5120            class names and values are class counts (in frames)
5121
5122        """
5123        if load_episode is None:
5124            task, parameters = self._make_task_training(
5125                episode_name="_", parameters_update=parameters_update, throwaway=True
5126            )
5127        else:
5128            task, parameters, _ = self._make_task_prediction(
5129                "_",
5130                load_episode=load_episode,
5131                parameters_update=parameters_update,
5132            )
5133        class_counts = task.count_classes(bouts=bouts)
5134        behaviors = task.behaviors_dict()
5135        class_counts = {
5136            kk: {behaviors.get(k, "unknown"): v for k, v in vv.items()}
5137            for kk, vv in class_counts.items()
5138        }
5139        if remove_saved_features:
5140            self._remove_stores(parameters)
5141        return class_counts
5142
5143    def plot_class_distribution(
5144        self,
5145        parameters_update: Dict = None,
5146        frame_cutoff: int = 1,
5147        bout_cutoff: int = 1,
5148        print_full: bool = False,
5149        remove_saved_features: bool = False,
5150        save: str = None,
5151    ) -> None:
5152        """Make a class distribution plot.
5153
5154        You can either specify the parameters, choose an existing dataset or do neither (in that case a dataset
5155        is created or loaded for the computation with the default parameters).
5156
5157        Parameters
5158        ----------
5159        parameters_update : dict, optional
5160            a dictionary of parameter updates (only for "data" and "general" categories)
5161        frame_cutoff : int, default 1
5162            the minimum number of frames for a segment to be considered
5163        bout_cutoff : int, default 1
5164            the minimum number of bouts for a class to be considered
5165        print_full : bool, default False
5166            if `True`, the full class distribution is printed
5167        remove_saved_features : bool, default False
5168            if `True`, the dataset that is used for computation is then deleted
5169
5170        """
5171        task, parameters = self._make_task_training(
5172            episode_name="_", parameters_update=parameters_update, throwaway=True
5173        )
5174        cutoff = {True: bout_cutoff, False: frame_cutoff}
5175        for bouts in [True, False]:
5176            class_counts = task.count_classes(bouts=bouts)
5177            if print_full:
5178                print("Bouts:" if bouts else "Frames:")
5179                for k, v in class_counts.items():
5180                    if sum(v.values()) != 0:
5181                        print(f"  {k}:")
5182                        values, keys = zip(
5183                            *[
5184                                x
5185                                for x in sorted(zip(v.values(), v.keys()), reverse=True)
5186                                if x[-1] != -100
5187                            ]
5188                        )
5189                        for kk, vv in zip(keys, values):
5190                            print(f"    {task.behaviors_dict()[kk]}: {vv}")
5191            class_counts = {
5192                kk: {k: v for k, v in vv.items() if v >= cutoff[bouts]}
5193                for kk, vv in class_counts.items()
5194            }
5195            for key, d in class_counts.items():
5196                if sum(d.values()) != 0:
5197                    values, keys = zip(
5198                        *[x for x in sorted(zip(d.values(), d.keys())) if x[-1] != -100]
5199                    )
5200                    keys = [task.behaviors_dict()[x] for x in keys]
5201                    plt.bar(keys, values)
5202                    plt.title(key)
5203                    plt.xticks(rotation=45, ha="right")
5204                    if bouts:
5205                        plt.ylabel("bouts")
5206                    else:
5207                        plt.ylabel("frames")
5208                    plt.tight_layout()
5209
5210                    if save is None:
5211                        plt.savefig(save)
5212                        plt.close()
5213                    else:
5214                        plt.show()
5215        if remove_saved_features:
5216            self._remove_stores(parameters)
5217
5218    def _generate_mask(
5219        self,
5220        mask_name: str,
5221        perc_annotated: float = 0.1,
5222        parameters_update: Dict = None,
5223        remove_saved_features: bool = False,
5224    ) -> None:
5225        """Generate a real_lens for active learning simulation.
5226
5227        Parameters
5228        ----------
5229        mask_name : str
5230            the name of the real_lens
5231        perc_annotated : float, default 0.1
5232            a
5233
5234        """
5235        print(f"GENERATING {mask_name}")
5236        task, parameters = self._make_task_training(
5237            f"_{mask_name}", parameters_update=parameters_update, throwaway=True
5238        )
5239        val_intervals, val_ids = task.dataset("val").get_intervals()  # 1
5240        unannotated_intervals = task.dataset("train").get_unannotated_intervals()  # 2
5241        unannotated_intervals = task.dataset("val").get_unannotated_intervals(
5242            first_intervals=unannotated_intervals
5243        )
5244        ids = task.dataset("train").get_ids()
5245        mask = {video_id: {} for video_id in ids}
5246        total_all = 0
5247        total_masked = 0
5248        for video_id, clip_ids in ids.items():
5249            for clip_id in clip_ids:
5250                frames = np.ones(task.dataset("train").get_len(video_id, clip_id))
5251                if clip_id in val_intervals[video_id]:
5252                    for start, end in val_intervals[video_id][clip_id]:
5253                        frames[start:end] = 0
5254                if clip_id in unannotated_intervals[video_id]:
5255                    for start, end in unannotated_intervals[video_id][clip_id]:
5256                        frames[start:end] = 0
5257                annotated = np.where(frames)[0]
5258                total_all += len(annotated)
5259                masked = annotated[-int(len(annotated) * (1 - perc_annotated)) :]
5260                total_masked += len(masked)
5261                mask[video_id][clip_id] = self._get_intervals(masked)
5262        file = {
5263            "masked": mask,
5264            "val_intervals": val_intervals,
5265            "val_ids": val_ids,
5266            "unannotated": unannotated_intervals,
5267        }
5268        self._save_mask(file, mask_name)
5269        if remove_saved_features:
5270            self._remove_stores(parameters)
5271        print("\n")
5272        # print(f'Unmasked: {sum([(vv == 0).sum() for v in real_lens.values() for vv in v.values()])} frames')
5273
5274    def _get_intervals(self, frame_indices: np.ndarray):
5275        """Get a list of intervals from a list of frame indices.
5276
5277        Example: `[0, 1, 2, 5, 6, 8] -> [[0, 3], [5, 7], [8, 9]]`.
5278
5279        Parameters
5280        ----------
5281        frame_indices : np.ndarray
5282            a list of frame indices
5283
5284        Returns
5285        -------
5286        intervals : list
5287            a list of interval boundaries
5288
5289        """
5290        masked_intervals = []
5291        if len(frame_indices) > 0:
5292            breaks = np.where(np.diff(frame_indices) != 1)[0]
5293            start = frame_indices[0]
5294            for k in breaks:
5295                masked_intervals.append([start, frame_indices[k] + 1])
5296                start = frame_indices[k + 1]
5297            masked_intervals.append([start, frame_indices[-1] + 1])
5298        return masked_intervals
5299
5300    def _update_mask_with_uncertainty(
5301        self,
5302        mask_name: str,
5303        episode_name: Union[str, None],
5304        classes: List,
5305        load_epoch: int = None,
5306        n_frames: int = 10000,
5307        method: str = "least_confidence",
5308        min_length: int = 30,
5309        augment_n: int = 0,
5310        parameters_update: Dict = None,
5311    ):
5312        """Update real_lens with frame-wise uncertainty scores for active learning.
5313
5314        Parameters
5315        ----------
5316        mask_name : str
5317            the name of the real_lens
5318        episode_name : str
5319            the name of the episode to load
5320        classes : list
5321            a list of class names or indices; their uncertainty scores will be computed separately and stacked
5322        load_epoch : int, optional
5323            the epoch to load (by default last; if this epoch is not saved the closest checkpoint is chosen)
5324        n_frames : int, default 10000
5325            the number of frames to "annotate"
5326        method : {"least_confidence", "entropy"}
5327            the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if
5328            `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`)
5329        min_length : int
5330            the minimum length (in frames) of the annotated intervals
5331        augment_n : int, default 0
5332            the number of augmentations to average over
5333        parameters_update : dict, optional
5334            the dictionary used to update the parameters from the config
5335
5336        Returns
5337        -------
5338        score_dicts : dict
5339            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
5340            are score tensors
5341
5342        """
5343        print(f"UPDATING {mask_name}")
5344        task, parameters, _ = self._make_task_prediction(
5345            prediction_name=mask_name,
5346            load_episode=episode_name,
5347            parameters_update=parameters_update,
5348            load_epoch=load_epoch,
5349            mode="train",
5350        )
5351        score_tensors = task.generate_uncertainty_score(classes, augment_n, method)
5352        self._update_mask(task, mask_name, score_tensors, n_frames, min_length)
5353        print("\n")
5354
5355    def _update_mask_with_BALD(
5356        self,
5357        mask_name: str,
5358        episode_name: str,
5359        classes: List,
5360        load_epoch: int = None,
5361        augment_n: int = 0,
5362        n_frames: int = 10000,
5363        num_models: int = 10,
5364        kernel_size: int = 11,
5365        min_length: int = 30,
5366        parameters_update: Dict = None,
5367    ):
5368        """Update real_lens with frame-wise Bayesian Active Learning by Disagreement scores for active learning.
5369
5370        Parameters
5371        ----------
5372        mask_name : str
5373            the name of the real_lens
5374        episode_name : str
5375            the name of the episode to load
5376        classes : list
5377            a list of class names or indices; their uncertainty scores will be computed separately and stacked
5378        load_epoch : int, optional
5379            the epoch to load (by default last)
5380        augment_n : int, default 0
5381            the number of augmentations to average over
5382        n_frames : int, default 10000
5383            the number of frames to "annotate"
5384        num_models : int, default 10
5385            the number of dropout masks to apply
5386        kernel_size : int, default 11
5387            the size of the smoothing gaussian kernel
5388        min_length : int
5389            the minimum length (in frames) of the annotated intervals
5390        parameters_update : dict, optional
5391            the dictionary used to update the parameters from the config
5392
5393        Returns
5394        -------
5395        score_dicts : dict
5396            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
5397            are score tensors
5398
5399        """
5400        print(f"UPDATING {mask_name}")
5401        task, parameters, mode = self._make_task_prediction(
5402            mask_name,
5403            load_episode=episode_name,
5404            parameters_update=parameters_update,
5405            load_epoch=load_epoch,
5406        )
5407        score_tensors = task.generate_bald_score(
5408            classes, augment_n, num_models, kernel_size
5409        )
5410        self._update_mask(task, mask_name, score_tensors, n_frames, min_length)
5411        print("\n")
5412
5413    def _suggest_intervals(
5414        self,
5415        dataset: BehaviorDataset,
5416        score_tensors: Dict,
5417        n_frames: int,
5418        min_length: int,
5419    ) -> Dict:
5420        """Suggest intervals with highest score of total length `n_frames`.
5421
5422        Parameters
5423        ----------
5424        dataset : BehaviorDataset
5425            the dataset
5426        score_tensors : dict
5427            a dictionary where keys are clip ids and values are framewise score tensors
5428        n_frames : int
5429            the number of frames to "annotate"
5430        min_length : int
5431            minimum length of suggested intervals
5432
5433        Returns
5434        -------
5435        active_learning_intervals : Dict
5436            active learning dictionary with suggested intervals
5437
5438        """
5439        video_intervals, _ = dataset.get_intervals()
5440        taken = {
5441            video_id: defaultdict(lambda: {}) for video_id in video_intervals.keys()
5442        }
5443        annotated = dataset.get_annotated_intervals()
5444        for video_id in video_intervals:
5445            for clip_id in video_intervals[video_id]:
5446                taken[video_id][clip_id] = torch.zeros(
5447                    dataset.get_len(video_id, clip_id)
5448                )
5449                if video_id in annotated and clip_id in annotated[video_id]:
5450                    for start, end in annotated[video_id][clip_id]:
5451                        score_tensors[video_id][clip_id][:, start:end] = -10
5452                        taken[video_id][clip_id][int(start) : int(end)] = 1
5453        n_frames = (
5454            sum([(vv == 1).sum() for v in taken.values() for vv in v.values()])
5455            + n_frames
5456        )
5457        factor = 1
5458        threshold_start = float(
5459            torch.mean(
5460                torch.tensor(
5461                    [
5462                        torch.mean(
5463                            torch.tensor([torch.mean(y[y > 0]) for y in x.values()])
5464                        )
5465                        for x in score_tensors.values()
5466                    ]
5467                )
5468            )
5469        )
5470        while (
5471            sum([(vv == 1).sum() for v in taken.values() for vv in v.values()])
5472            < n_frames
5473        ):
5474            threshold = threshold_start * factor
5475            intervals = []
5476            interval_scores = []
5477            key1 = list(score_tensors.keys())[0]
5478            key2 = list(score_tensors[key1].keys())[0]
5479            num_scores = score_tensors[key1][key2].shape[0]
5480            for i in range(num_scores):
5481                v_dict = dataset.find_valleys(
5482                    predicted=score_tensors,
5483                    threshold=threshold,
5484                    min_frames=min_length,
5485                    main_class=i,
5486                    low=False,
5487                )
5488                for v_id, interval_list in v_dict.items():
5489                    intervals += [x + [v_id] for x in interval_list]
5490                    interval_scores += [
5491                        float(torch.mean(score_tensors[v_id][clip_id][i, start:end]))
5492                        for start, end, clip_id in interval_list
5493                    ]
5494            intervals = np.array(intervals)[np.argsort(interval_scores)[::-1]]
5495            i = 0
5496            while sum(
5497                [(vv == 1).sum() for v in taken.values() for vv in v.values()]
5498            ) < n_frames and i < len(intervals):
5499                start, end, clip_id, video_id = intervals[i]
5500                i += 1
5501                taken[video_id][clip_id][int(start) : int(end)] = 1
5502            factor *= 0.9
5503            if factor < 0.05:
5504                warnings.warn(f"Could not find enough frames!")
5505                break
5506        active_learning_intervals = {video_id: [] for video_id in video_intervals}
5507        for video_id in taken:
5508            for clip_id in taken[video_id]:
5509                if video_id in annotated and clip_id in annotated[video_id]:
5510                    for start, end in annotated[video_id][clip_id]:
5511                        taken[video_id][clip_id][int(start) : int(end)] = 0
5512                if (taken[video_id][clip_id] == 1).sum() == 0:
5513                    continue
5514                indices = np.where(taken[video_id][clip_id].numpy())[0]
5515                boundaries = self._get_intervals(indices)
5516                active_learning_intervals[video_id] += [
5517                    [start, end, clip_id] for start, end in boundaries
5518                ]
5519        return active_learning_intervals
5520
5521    def _update_mask(
5522        self,
5523        task: TaskDispatcher,
5524        mask_name: str,
5525        score_tensors: Dict,
5526        n_frames: int,
5527        min_length: int,
5528    ) -> None:
5529        """Update the real_lens with intervals with the highest score of total length `n_frames`.
5530
5531        Parameters
5532        ----------
5533        task : TaskDispatcher
5534            the task dispatcher object
5535        mask_name : str
5536            the name of the real_lens
5537        score_tensors : dict
5538            a dictionary where keys are clip ids and values are framewise score tensors
5539        n_frames : int
5540            the number of frames to "annotate"
5541        min_length : int
5542            the minimum length of the annotated intervals
5543
5544        """
5545        mask = self._load_mask(mask_name)
5546        video_intervals, _ = task.dataset("train").get_intervals()
5547        masked = {
5548            video_id: defaultdict(lambda: {}) for video_id in video_intervals.keys()
5549        }
5550        total_masked = 0
5551        total_all = 0
5552        for video_id in video_intervals:
5553            for clip_id in video_intervals[video_id]:
5554                masked[video_id][clip_id] = torch.zeros(
5555                    task.dataset("train").get_len(video_id, clip_id)
5556                )
5557                if (
5558                    video_id in mask["unannotated"]
5559                    and clip_id in mask["unannotated"][video_id]
5560                ):
5561                    for start, end in mask["unannotated"][video_id][clip_id]:
5562                        score_tensors[video_id][clip_id][:, start:end] = -10
5563                        masked[video_id][clip_id][int(start) : int(end)] = 1
5564                if (
5565                    video_id in mask["val_intervals"]
5566                    and clip_id in mask["val_intervals"][video_id]
5567                ):
5568                    for start, end in mask["val_intervals"][video_id][clip_id]:
5569                        score_tensors[video_id][clip_id][:, start:end] = -10
5570                        masked[video_id][clip_id][int(start) : int(end)] = 1
5571                total_all += torch.sum(masked[video_id][clip_id] == 0)
5572                if video_id in mask["masked"] and clip_id in mask["masked"][video_id]:
5573                    # print(f'{real_lens["masked"][video_id][clip_id]=}')
5574                    for start, end in mask["masked"][video_id][clip_id]:
5575                        masked[video_id][clip_id][int(start) : int(end)] = 1
5576                        total_masked += end - start
5577        old_n_frames = sum(
5578            [(vv == 0).sum() for v in masked.values() for vv in v.values()]
5579        )
5580        n_frames = old_n_frames + n_frames
5581        factor = 1
5582        while (
5583            sum([(vv == 0).sum() for v in masked.values() for vv in v.values()])
5584            < n_frames
5585        ):
5586            threshold = float(
5587                torch.mean(
5588                    torch.tensor(
5589                        [
5590                            torch.mean(
5591                                torch.tensor([torch.mean(y[y > 0]) for y in x.values()])
5592                            )
5593                            for x in score_tensors.values()
5594                        ]
5595                    )
5596                )
5597            )
5598            threshold = threshold * factor
5599            intervals = []
5600            interval_scores = []
5601            key1 = list(score_tensors.keys())[0]
5602            key2 = list(score_tensors[key1].keys())[0]
5603            num_scores = score_tensors[key1][key2].shape[0]
5604            for i in range(num_scores):
5605                v_dict = task.dataset("train").find_valleys(
5606                    predicted=score_tensors,
5607                    threshold=threshold,
5608                    min_frames=min_length,
5609                    main_class=i,
5610                    low=False,
5611                )
5612                for v_id, interval_list in v_dict.items():
5613                    intervals += [x + [v_id] for x in interval_list]
5614                    interval_scores += [
5615                        float(torch.mean(score_tensors[v_id][clip_id][i, start:end]))
5616                        for start, end, clip_id in interval_list
5617                    ]
5618            intervals = np.array(intervals)[np.argsort(interval_scores)[::-1]]
5619            i = 0
5620            while sum(
5621                [(vv == 0).sum() for v in masked.values() for vv in v.values()]
5622            ) < n_frames and i < len(intervals):
5623                start, end, clip_id, video_id = intervals[i]
5624                i += 1
5625                masked[video_id][clip_id][int(start) : int(end)] = 0
5626            factor *= 0.9
5627            if factor < 0.05:
5628                warnings.warn(f"Could not find enough frames!")
5629                break
5630        mask["masked"] = {video_id: {} for video_id in video_intervals}
5631        total_masked_new = 0
5632        for video_id in masked:
5633            for clip_id in masked[video_id]:
5634                if (
5635                    video_id in mask["unannotated"]
5636                    and clip_id in mask["unannotated"][video_id]
5637                ):
5638                    for start, end in mask["unannotated"][video_id][clip_id]:
5639                        masked[video_id][clip_id][int(start) : int(end)] = 0
5640                if (
5641                    video_id in mask["val_intervals"]
5642                    and clip_id in mask["val_intervals"][video_id]
5643                ):
5644                    for start, end in mask["val_intervals"][video_id][clip_id]:
5645                        masked[video_id][clip_id][int(start) : int(end)] = 0
5646                indices = np.where(masked[video_id][clip_id].numpy())[0]
5647                mask["masked"][video_id][clip_id] = self._get_intervals(indices)
5648        for video_id in mask["masked"]:
5649            for clip_id in mask["masked"][video_id]:
5650                for start, end in mask["masked"][video_id][clip_id]:
5651                    total_masked_new += end - start
5652        self._save_mask(mask, mask_name)
5653        with open(
5654            os.path.join(
5655                self.project_path, "results", f"{mask_name}.txt", encoding="utf-8"
5656            ),
5657            "a",
5658        ) as f:
5659            f.write(f"from {total_masked} to {total_masked_new} / {total_all}" + "\n")
5660        print(f"Unmasked from {total_masked} to {total_masked_new} / {total_all}")
5661
5662    def _visualize_results_label(
5663        self,
5664        episode_name: str,
5665        label: str,
5666        load_epoch: int = None,
5667        parameters_update: Dict = None,
5668        add_legend: bool = True,
5669        ground_truth: bool = True,
5670        hide_axes: bool = False,
5671        width: float = 10,
5672        whole_video: bool = False,
5673        transparent: bool = False,
5674        num_plots: int = 1,
5675        smooth_interval: int = 0,
5676    ):
5677        other_path = os.path.join(self.project_path, "results", "other")
5678        if not os.path.exists(other_path):
5679            os.mkdir(other_path)
5680        if parameters_update is None:
5681            parameters_update = {}
5682        if "model" in parameters_update.keys():
5683            raise ValueError("Cannot change model parameters after training!")
5684        task, parameters, _ = self._make_task_prediction(
5685            "_",
5686            load_episode=episode_name,
5687            parameters_update=parameters_update,
5688            load_epoch=load_epoch,
5689            mode="val",
5690        )
5691        for i in range(num_plots):
5692            print(i)
5693            task._visualize_results_label(
5694                smooth_interval=smooth_interval,
5695                label=label,
5696                save_path=os.path.join(
5697                    other_path, f"{episode_name}_prediction_{i}.jpg"
5698                ),
5699                add_legend=add_legend,
5700                ground_truth=ground_truth,
5701                hide_axes=hide_axes,
5702                whole_video=whole_video,
5703                transparent=transparent,
5704                dataset="val",
5705                width=width,
5706                title=str(i),
5707            )
5708
5709    def plot_confusion_matrix(
5710        self,
5711        episode_name: str,
5712        load_epoch: int = None,
5713        parameters_update: Dict = None,
5714        metric: str = "recall",
5715        mode: str = "val",
5716        remove_saved_features: bool = False,
5717        save_path: str = None,
5718        cmap: str = "viridis",
5719    ) -> Tuple[ndarray, Iterable]:
5720        """Make a confusion matrix plot and return the data.
5721
5722        If the annotation is non-exclusive, only false positive labels are considered.
5723
5724        Parameters
5725        ----------
5726        episode_name : str
5727            the name of the episode to load
5728        load_epoch : int, optional
5729            the index of the epoch to load (by default the last one is loaded)
5730        parameters_update : dict, optional
5731            a dictionary of parameter updates (only for "data" and "general" categories)
5732        metric : {"recall", "precision"}
5733            for datasets with non-exclusive annotation, if `type` is `"recall"`, only false positives are taken
5734            into account, and if `type` is `"precision"`, only false negatives
5735        mode : {'val', 'all', 'test', 'train'}
5736            the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
5737        remove_saved_features : bool, default False
5738            if `True`, the dataset that is used for computation is then deleted
5739
5740        Returns
5741        -------
5742        confusion_matrix : np.ndarray
5743            a confusion matrix of shape `(#classes, #classes)` where `A[i, j] = F_ij/N_i`, `F_ij` is the number of
5744            frames that have the i-th label in the ground truth and a false positive j-th label in the prediction,
5745            `N_i` is the number of frames that have the i-th label in the ground truth
5746        classes : list
5747            a list of labels
5748
5749        """
5750        task, parameters, mode = self._make_task_prediction(
5751            "_",
5752            load_episode=episode_name,
5753            load_epoch=load_epoch,
5754            parameters_update=parameters_update,
5755            mode=mode,
5756        )
5757        dataset = task.dataset(mode)
5758        prediction = task.predict(dataset, raw_output=True)
5759        confusion_matrix, classes, type = dataset.get_confusion_matrix(prediction, type)
5760        if remove_saved_features:
5761            self._remove_stores(parameters)
5762        fig, ax = plt.subplots(figsize=(len(classes), len(classes)))
5763        ax.imshow(confusion_matrix, cmap=cmap)
5764        # Show all ticks and label them with the respective list entries
5765        ax.set_xticks(np.arange(len(classes)))
5766        ax.set_xticklabels(classes)
5767        ax.set_yticks(np.arange(len(classes)))
5768        ax.set_yticklabels(classes)
5769        # Rotate the tick labels and set their alignment.
5770        plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
5771        # Loop over data dimensions and create text annotations.
5772        for i in range(len(classes)):
5773            for j in range(len(classes)):
5774                ax.text(
5775                    j,
5776                    i,
5777                    np.round(confusion_matrix[i, j], 2),
5778                    ha="center",
5779                    va="center",
5780                    color="w",
5781                )
5782        if metric is not None:
5783            ax.set_title(f"{metric} {episode_name}")
5784        else:
5785            ax.set_title(episode_name)
5786        fig.tight_layout()
5787        if save_path is None:
5788            plt.show()
5789        else:
5790            plt.savefig(save_path)
5791            plt.close()
5792        return confusion_matrix, classes
5793
5794    def _plot_ethograms_gt_pred(
5795        self,
5796        data_gt: dict,
5797        data_pred: dict,
5798        labels_gt: list,
5799        labels_pred: list,
5800        start: int = 0,
5801        end: int = -1,
5802        cmap_pred: str = "binary",
5803        cmap_gt: str = "binary",
5804        save: str = None,
5805        fontsize=22,
5806        time_mode="frames",
5807        fps: int = None,
5808    ) -> None:
5809        """Plot ethograms from start to end time (in frames), mode can be prediction or ground truth depending on the data format."""
5810        # print(data.keys())
5811        best_pred = (
5812            data_pred[list(data_pred.keys())[0]].numpy() > 0.5
5813        )  # Threshold the predictions
5814        data_gt = binarize_data(data_gt, max_frame=end)
5815
5816        # Crop data to min length
5817        if end < 0:
5818            end = min(data_gt.shape[1], best_pred.shape[1])
5819        data_gt = data_gt[:, :end]
5820        best_pred = best_pred[:, :end]
5821
5822        # Reorder behaviors
5823        ind_gt = []
5824        ind_pred = []
5825        labels_pred = [labels_pred[i] for i in range(len(labels_pred))]
5826        labels_pred = np.roll(
5827            labels_pred, 1
5828        ).tolist()  
5829        check_gt = np.where(np.sum(data_gt, axis=1) > 0)[0]
5830        check_pred = np.where(np.sum(best_pred, axis=1) > 0)[0]
5831        for k, gt_beh in enumerate(labels_gt):
5832            if gt_beh in labels_pred:
5833                j = labels_pred.index(gt_beh)
5834                if not k in check_gt and not j in check_pred:
5835                    continue
5836                ind_gt.append(labels_gt.index(gt_beh))
5837                ind_pred.append(j)
5838        # Create label list
5839        labels = np.array(labels_gt)[ind_gt]
5840        assert (labels == np.array(labels_pred)[ind_pred]).all()
5841
5842        # # Create image
5843        image_pred = best_pred[ind_pred].astype(float)
5844        image_gt = data_gt[ind_gt]
5845
5846        f, axs = plt.subplots(
5847            len(labels), 1, figsize=(5 * len(labels), 15), sharex=True
5848        )
5849        end = image_gt.shape[1] if end < 0 else end
5850        for i, (ax, label) in enumerate(zip(axs, labels)):
5851
5852            im1 = np.array([image_gt[i], np.ones_like(image_gt[i]) * (-1)])
5853            im1 = np.ma.masked_array(im1, im1 < 0)
5854
5855            im2 = np.array([np.ones_like(image_pred[i]) * (-1), image_pred[i]])
5856            im2 = np.ma.masked_array(im2, im2 < 0)
5857
5858            ax.imshow(im1, aspect="auto", cmap=cmap_gt, interpolation="nearest")
5859            ax.imshow(im2, aspect="auto", cmap=cmap_pred, interpolation="nearest")
5860
5861            ax.set_yticks(np.arange(2), ["GT", "Pred"], fontsize=fontsize)
5862            ax.tick_params(axis="x", labelsize=fontsize)
5863            ax.set_ylabel(label, fontsize=fontsize)
5864            if time_mode == "frames":
5865                ax.set_xlabel("Num Frames", fontsize=fontsize)
5866            elif time_mode == "seconds":
5867                assert not fps is None, "Please provide fps"
5868                ax.set_xlabel("Time (s)", fontsize=fontsize)
5869                ax.set_xticks(
5870                    np.linspace(0, end, 10),
5871                    np.linspace(0, end / fps, 10).astype(np.int32),
5872                )
5873
5874            ax.set_xlim(start, end)
5875
5876        if save is None:
5877            plt.show()
5878        else:
5879            plt.savefig(save)
5880            plt.close()
5881
5882    def plot_ethograms(
5883        self,
5884        episode_name: str,
5885        prediction_name: str,
5886        start: int = 0,
5887        end: int = -1,
5888        save_path: str = None,
5889        cmap_pred: str = "binary",
5890        cmap_gt: str = "binary",
5891        fontsize: int = 22,
5892        time_mode: str = "frames",
5893        fps: int = None,
5894    ):
5895        """Plot ethograms from start to end time (in frames) for ground truth and prediction"""
5896        params = self._read_parameters(catch_blanks=False)
5897        parameters = self._get_data_pars(
5898            params,
5899        )
5900        if not save_path is None:
5901            os.makedirs(save_path, exist_ok=True)
5902        gt_files = [
5903            f for f in self.data_path if f.endswith(parameters["annotation_suffix"])
5904        ]
5905        pred_path = os.path.join(
5906            self.project_path, "results", "predictions", prediction_name
5907        )
5908        pred_paths = [os.path.join(pred_path, f) for f in os.listdir(pred_path)]
5909        for pred_path in pred_paths:
5910            predictions = load_pickle(pred_path)
5911            behaviors = self.get_behavior_dictionary(episode_name)
5912            gt_filename = os.path.basename(pred_path).replace(
5913                "_".join(["_" + prediction_name, "prediction.pickle"]),
5914                parameters["annotation_suffix"],
5915            )
5916            if os.path.exists(os.path.join(self.data_path, gt_filename)):
5917                gt_data = load_pickle(os.path.join(self.data_path, gt_filename))
5918
5919                self._plot_ethograms_gt_pred(
5920                    gt_data,
5921                    predictions,
5922                    gt_data[1],
5923                    behaviors,
5924                    start=start,
5925                    end=end,
5926                    save=os.path.join(
5927                        save_path,
5928                        os.path.splitext(os.path.basename(pred_path))[0] + "_gt_pred",
5929                    ),
5930                    cmap_pred=cmap_pred,
5931                    cmap_gt=cmap_gt,
5932                    fontsize=fontsize,
5933                    time_mode=time_mode,
5934                    fps=fps,
5935                )
5936            else:
5937                print("GT file not found")
5938
5939    def _create_side_panel(self, height, width, labels_pred, preds, labels_gt, gt=None):
5940        """Create a side panel for video annotation display.
5941
5942        Parameters
5943        ----------
5944        height : int
5945            the height of the panel
5946        width : int
5947            the width of the panel
5948        labels_pred : list
5949            the list of predicted behavior labels
5950        preds : array-like
5951            the prediction values for each behavior
5952        labels_gt : list
5953            the list of ground truth behavior labels
5954        gt : array-like, optional
5955            the ground truth values for each behavior
5956
5957        Returns
5958        -------
5959        side_panel : np.ndarray
5960            the created side panel as an image array
5961
5962        """
5963        side_panel = np.ones((height, int(width / 4), 3), dtype=np.uint8) * 255
5964
5965        beh_indices = np.where(preds)[0]
5966        for i, label in enumerate(labels_pred):
5967            color = (0, 0, 0)
5968            if i in beh_indices:
5969                color = (0, 255, 0)
5970            cv2.putText(
5971                side_panel,
5972                label,
5973                (10, 50 + 50 * i),
5974                cv2.FONT_HERSHEY_SIMPLEX,
5975                1,
5976                color,
5977                2,
5978                cv2.LINE_AA,
5979            )
5980        if gt is not None:
5981            beh_indices_gt = np.where(gt)[0]
5982            for i, label in enumerate(labels_gt):
5983                color = (0, 0, 0)
5984                if i in beh_indices_gt:
5985                    color = (0, 255, 0)
5986                cv2.putText(
5987                    side_panel,
5988                    label,
5989                    (10, 50 + 50 * i + 80 * len(labels_pred)),
5990                    cv2.FONT_HERSHEY_SIMPLEX,
5991                    1,
5992                    color,
5993                    2,
5994                    cv2.LINE_AA,
5995                )
5996        return side_panel
5997
5998    def create_annotated_video(
5999        self,
6000        prediction_file_paths: list,
6001        video_file_paths: list,
6002        episode_name: str,  # To get the list of behaviors
6003        ground_truth_file_paths: list = None,
6004        pred_thresh: float = 0.5,
6005        start: int = 0,
6006        end: int = -1,
6007    ):
6008        """Create a video with the predictions overlaid on the video"""
6009        for k, (pred_path, vid_path) in enumerate(
6010            zip(prediction_file_paths, video_file_paths)
6011        ):
6012            print("Generating video for :", os.path.basename(vid_path))
6013            predictions = load_pickle(pred_path)
6014            best_pred = predictions[list(predictions.keys())[0]].numpy() > pred_thresh
6015            behaviors = self.get_behavior_dictionary(episode_name)
6016            # Load video
6017            labels_pred = [behaviors[i] for i in range(len(behaviors))]
6018            labels_pred = np.roll(
6019                labels_pred, 1
6020            ).tolist() 
6021
6022            gt_data = None
6023            if ground_truth_file_paths is not None:
6024                gt_data = load_pickle(ground_truth_file_paths[k])
6025                labels_gt = gt_data[1]
6026                gt_data = binarize_data(gt_data, max_frame=best_pred.shape[1])
6027
6028            cap = cv2.VideoCapture(vid_path)
6029            cap.set(cv2.CAP_PROP_POS_FRAMES, start)
6030            end = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if end < 0 else end
6031            fps = cap.get(cv2.CAP_PROP_FPS)
6032            width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
6033            height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
6034            fourcc = cv2.VideoWriter_fourcc(*"mp4v")
6035            out = cv2.VideoWriter(
6036                os.path.join(
6037                    os.path.dirname(vid_path),
6038                    os.path.splitext(os.path.basename(vid_path))[0] + "_annotated.mp4",
6039                ),
6040                fourcc,
6041                fps,
6042                # (width + int(width/4) , height),
6043                (600, 300),
6044            )
6045            count = 0
6046            bar = tqdm(total=end - start)
6047            while cap.isOpened():
6048                ret, frame = cap.read()
6049                if not ret:
6050                    break
6051
6052                side_panel = self._create_side_panel(
6053                    height,
6054                    width,
6055                    labels_pred,
6056                    best_pred[:, count],
6057                    labels_gt,
6058                    gt_data[:, count],
6059                )
6060                frame = np.concatenate((frame, side_panel), axis=1)
6061                frame = cv2.resize(frame, (0, 0), fx=0.25, fy=0.25)
6062                out.write(frame)
6063                count += 1
6064                bar.update(1)
6065
6066                if count > end:
6067                    break
6068
6069            cap.release()
6070            out.release()
6071            cv2.destroyAllWindows()
6072
6073    def plot_predictions(
6074        self,
6075        episode_name: str,
6076        load_epoch: int = None,
6077        parameters_update: Dict = None,
6078        add_legend: bool = True,
6079        ground_truth: bool = True,
6080        colormap: str = "dlc2action",
6081        hide_axes: bool = False,
6082        min_classes: int = 1,
6083        width: float = 10,
6084        whole_video: bool = False,
6085        transparent: bool = False,
6086        drop_classes: Set = None,
6087        search_classes: Set = None,
6088        num_plots: int = 1,
6089        remove_saved_features: bool = False,
6090        smooth_interval_prediction: int = 0,
6091        data_path: str = None,
6092        file_paths: Set = None,
6093        mode: str = "val",
6094        font_size: float = None,
6095        window_size: int = 400,
6096    ) -> None:
6097        """Visualize random predictions.
6098
6099        Parameters
6100        ----------
6101        episode_name : str
6102            the name of the episode to load
6103        load_epoch : int, optional
6104            the epoch to load (by default last)
6105        parameters_update : dict, optional
6106            parameter update dictionary
6107        add_legend : bool, default True
6108            if True, legend will be added to the plot
6109        ground_truth : bool, default True
6110            if True, ground truth will be added to the plot
6111        colormap : str, default 'Accent'
6112            the `matplotlib` colormap to use
6113        hide_axes : bool, default True
6114            if `True`, the axes will be hidden on the plot
6115        min_classes : int, default 1
6116            the minimum number of classes in a displayed interval
6117        width : float, default 10
6118            the width of the plot
6119        whole_video : bool, default False
6120            if `True`, whole videos are plotted instead of segments
6121        transparent : bool, default False
6122            if `True`, the background on the plot is transparent
6123        drop_classes : set, optional
6124            a set of class names to not be displayed
6125        search_classes : set, optional
6126            if given, only intervals where at least one of the classes is in ground truth will be shown
6127        num_plots : int, default 1
6128            the number of plots to make
6129        remove_saved_features : bool, default False
6130            if `True`, the dataset will be deleted after computation
6131        smooth_interval_prediction : int, default 0
6132            if >0, predictions shorter than this number of frames are removed (filled with prediction for the previous frame)
6133        data_path : str, optional
6134            the data path to run the prediction for
6135        file_paths : set, optional
6136            a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction
6137            for
6138        mode : {'all', 'test', 'val', 'train'}
6139            the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
6140
6141        """
6142        plot_path = os.path.join(self.project_path, "results", "plots")
6143        task, parameters, mode = self._make_task_prediction(
6144            "_",
6145            load_episode=episode_name,
6146            parameters_update=parameters_update,
6147            load_epoch=load_epoch,
6148            data_path=data_path,
6149            file_paths=file_paths,
6150            mode=mode,
6151        )
6152        os.makedirs(plot_path, exist_ok=True)
6153        task.visualize_results(
6154            save_path=os.path.join(plot_path, f"{episode_name}_prediction.svg"),
6155            add_legend=add_legend,
6156            ground_truth=ground_truth,
6157            colormap=colormap,
6158            hide_axes=hide_axes,
6159            min_classes=min_classes,
6160            whole_video=whole_video,
6161            transparent=transparent,
6162            dataset=mode,
6163            drop_classes=drop_classes,
6164            search_classes=search_classes,
6165            width=width,
6166            smooth_interval_prediction=smooth_interval_prediction,
6167            font_size=font_size,
6168            num_plots=num_plots,
6169            window_size=window_size,
6170        )
6171        if remove_saved_features:
6172            self._remove_stores(parameters)
6173
6174    def create_video_from_labels(
6175        self,
6176        video_dir_path: str,
6177        mode="ground_truth",
6178        prediction_name: str = None,
6179        save_path: str = None,
6180    ):
6181        if save_path is None:
6182            save_path = os.path.join(
6183                self.project_path, "results", f"annotated_videos_from_{mode}"
6184            )
6185        os.makedirs(save_path, exist_ok=True)
6186
6187        params = self._read_parameters(catch_blanks=False)
6188
6189        if mode == "ground_truth":
6190            source_dir = self.annotation_path
6191            annotation_suffix = params["data"]["annotation_suffix"]
6192        elif mode == "prediction":
6193            assert (
6194                not prediction_name is None
6195            ), "Please provide a prediction name with mode 'prediction'"
6196            source_dir = os.path.join(
6197                self.project_path, "results", "predictions", prediction_name
6198            )
6199            annotation_suffix = f"_{prediction_name}_prediction.pickle"
6200
6201        video_annotation_pairs = [
6202            (
6203                os.path.join(video_dir_path, f),
6204                os.path.join(
6205                    source_dir, f.replace(f.split(".")[-1], annotation_suffix)
6206                ),
6207            )
6208            for f in os.listdir(video_dir_path)
6209            if os.path.exists(
6210                os.path.join(source_dir, f.replace(f.split(".")[-1], annotation_suffix))
6211            )
6212        ]
6213
6214        for video_file, annotation_file in tqdm(video_annotation_pairs):
6215            if not os.path.exists(video_file):
6216                print(f"Video file {video_file} does not exist, skipping.")
6217                continue
6218            if not os.path.exists(annotation_file):
6219                print(f"Annotation file {annotation_file} does not exist, skipping.")
6220                continue
6221
6222            if annotation_file.endswith(".pickle"):
6223                annotations = load_pickle(annotation_file)
6224            elif annotation_file.endswith(".csv"):
6225                annotations = pd.read_csv(annotation_file)
6226
6227            if mode == "ground_truth":
6228                behaviors = annotations[1]
6229                annot_data = annotations[3]
6230            elif mode == "predictions":
6231                behaviors = list(annotations["classes"].values())
6232                annot_data = [
6233                    annotations[key]
6234                    for key in annotations.keys()
6235                    if key not in ["classes", "min_frame", "max_frame"]
6236                ]
6237                if params["general"]["exclusive"]:
6238                    annot_data = [np.argmax(annot, axis=1) for annot in annot_data]
6239                    seqs = [
6240                        [
6241                            self._bin_array_to_sequences(annot, target_value=k)
6242                            for k in range(len(behaviors))
6243                        ]
6244                        for annot in annot_data
6245                    ]
6246                else:
6247                    annot_data = [np.where(annot > 0.5)[0] for annot in annot_data]
6248                    seqs = [
6249                        self._bin_array_to_sequences(annot, target_value=1)
6250                        for annot in annot_data
6251                    ]
6252                annotations = ["", "", seqs]
6253
6254            for individual in annotations[3]:
6255                for behavior in annotations[3][individual]:
6256                    intervals = annotations[3][individual][behavior]
6257                    self._extract_videos(
6258                        video_file,
6259                        intervals,
6260                        behavior,
6261                        individual,
6262                        save_path,
6263                        resolution=(640, 480),
6264                        fps=30,
6265                    )
6266
6267    def _bin_array_to_sequences(
6268        self, annot_data: List[np.ndarray], target_value: int
6269    ) -> List[List[Tuple[int, int]]]:
6270        is_target = annot_data == target_value
6271        changes = np.diff(np.concatenate(([False], is_target, [False])))
6272        indices = np.where(changes)[0].reshape(-1, 2)
6273        subsequences = [list(range(start, end)) for start, end in indices]
6274        return subsequences
6275
6276    def _extract_videos(
6277        self,
6278        video_file: str,
6279        intervals: np.ndarray,
6280        behavior: str,
6281        individual: str,
6282        video_dir: str,
6283        resolution: Tuple[int, int] = (640, 480),
6284        fps: int = 30,
6285    ) -> None:
6286        """Extract frames from a video file from frames in between intervals in behavior folder for a given individual"""
6287        cap = cv2.VideoCapture(video_file)
6288        print("Extracting frames from", video_file)
6289
6290        for start, end, confusion in tqdm(intervals):
6291
6292            frame_count = start
6293            assert start < end, "Start frame should be less than end frame"
6294            if confusion > 0.5:
6295                continue
6296            cap.set(cv2.CAP_PROP_POS_FRAMES, start)
6297            output_file = os.path.join(
6298                video_dir,
6299                individual,
6300                behavior,
6301                os.path.splitext(os.path.basename(video_file))[0]
6302                + f"vid_{individual}_{behavior}_{start:05d}_{end:05d}.mp4",
6303            )
6304            fourcc = cv2.VideoWriter_fourcc(*"mp4v")  # Codec, e.g., 'XVID', 'MJPG'
6305            out = cv2.VideoWriter(
6306                output_file, fourcc, fps, (resolution[0], resolution[1])
6307            )
6308            while cap.isOpened():
6309                ret, frame = cap.read()
6310                if not ret:
6311                    break
6312
6313                # Resize large frames
6314                frame = cv2.resize(frame, (640, 480))
6315                out.write(frame)
6316
6317                frame_count += 1
6318                # Break if end frame is reached or max frames per behavior is reached
6319                if frame_count == end:
6320                    break
6321            if frame_count <= 2:
6322                os.remove(output_file)
6323            # cap.release()
6324            out.release()
6325
6326    def create_metadata_backup(self) -> None:
6327        """Create a copy of the meta files."""
6328        meta_copy_path = os.path.join(self.project_path, "meta", "backup")
6329        meta_path = os.path.join(self.project_path, "meta")
6330        if os.path.exists(meta_copy_path):
6331            shutil.rmtree(meta_copy_path)
6332        os.mkdir(meta_copy_path)
6333        for file in os.listdir(meta_path):
6334            if file == "backup":
6335                continue
6336            if os.path.isdir(os.path.join(meta_path, file)):
6337                continue
6338            shutil.copy(
6339                os.path.join(meta_path, file), os.path.join(meta_copy_path, file)
6340            )
6341
6342    def load_metadata_backup(self) -> None:
6343        """Load from previously created meta data backup (in case of corruption)."""
6344        meta_copy_path = os.path.join(self.project_path, "meta", "backup")
6345        meta_path = os.path.join(self.project_path, "meta")
6346        for file in os.listdir(meta_copy_path):
6347            shutil.copy(
6348                os.path.join(meta_copy_path, file), os.path.join(meta_path, file)
6349            )
6350
6351    def get_behavior_dictionary(self, episode_name: str) -> Dict:
6352        """Get the behavior dictionary for an episode.
6353
6354        Parameters
6355        ----------
6356        episode_name : str
6357            the name of the episode
6358
6359        Returns
6360        -------
6361        behaviors_dictionary : dict
6362            a dictionary where keys are label indices and values are label names
6363
6364        """
6365        return self._episode(episode_name).get_behaviors_dict()
6366
6367    def import_episodes(
6368        self,
6369        episodes_directory: str,
6370        name_map: Dict = None,
6371        repeat_policy: str = "error",
6372    ) -> None:
6373        """Import episodes exported with `Project.export_episodes`.
6374
6375        Parameters
6376        ----------
6377        episodes_directory : str
6378            the path to the exported episodes directory
6379        name_map : dict, optional
6380            a name change dictionary for the episodes: keys are old names, values are new names
6381        repeat_policy : {'error', 'skip', 'force'}, default 'error'
6382            the policy for repeated episode names: 'error' raises an error, 'skip' skips duplicates,
6383            'force' overwrites existing episodes
6384
6385        """
6386        if name_map is None:
6387            name_map = {}
6388        episodes = pd.read_pickle(os.path.join(episodes_directory, "episodes.pickle"))
6389        to_remove = []
6390        import_string = "Imported episodes: "
6391        for episode_name in episodes.index:
6392            if episode_name in name_map:
6393                import_string += f"{episode_name} "
6394                episode_name = name_map[episode_name]
6395                import_string += f"({episode_name}), "
6396            else:
6397                import_string += f"{episode_name}, "
6398            try:
6399                self._check_episode_validity(episode_name, allow_doublecolon=True)
6400            except ValueError as e:
6401                if str(e).endswith("is already taken!"):
6402                    if repeat_policy == "skip":
6403                        to_remove.append(episode_name)
6404                    elif repeat_policy == "force":
6405                        self.remove_episode(episode_name)
6406                    elif repeat_policy == "error":
6407                        raise ValueError(
6408                            f"The {episode_name} episode name is already taken; please use the name_map parameter to rename it"
6409                        )
6410                    else:
6411                        raise ValueError(
6412                            f"The {repeat_policy} repeat policy is not recognized; please choose from ['skip', 'force' and 'error']"
6413                        )
6414        episodes = episodes.drop(index=to_remove)
6415        self._episodes().update(
6416            episodes,
6417            name_map=name_map,
6418            force=(repeat_policy == "force"),
6419            data_path=self.data_path,
6420            annotation_path=self.annotation_path,
6421        )
6422        for episode_name in episodes.index:
6423            if episode_name in name_map:
6424                new_episode_name = name_map[episode_name]
6425            else:
6426                new_episode_name = episode_name
6427            model_dir = os.path.join(
6428                self.project_path, "results", "model", new_episode_name
6429            )
6430            old_model_dir = os.path.join(episodes_directory, "model", episode_name)
6431            if os.path.exists(model_dir):
6432                shutil.rmtree(model_dir)
6433            os.mkdir(model_dir)
6434            for file in os.listdir(old_model_dir):
6435                shutil.copyfile(
6436                    os.path.join(old_model_dir, file), os.path.join(model_dir, file)
6437                )
6438            log_file = os.path.join(
6439                self.project_path, "results", "logs", f"{new_episode_name}.txt"
6440            )
6441            old_log_file = os.path.join(
6442                episodes_directory, "logs", f"{episode_name}.txt"
6443            )
6444            shutil.copyfile(old_log_file, log_file)
6445        print(import_string)
6446        print("\n")
6447
6448    def export_episodes(
6449        self, episode_names: List, output_directory: str, name: str = None
6450    ) -> None:
6451        """Save selected episodes as a file that can be imported into another project with `Project.import_episodes`.
6452
6453        Parameters
6454        ----------
6455        episode_names : list
6456            a list of string episode names
6457        output_directory : str
6458            the path to the directory where the episodes will be saved
6459        name : str, optional
6460            the name of the episodes directory (by default `exported_episodes`)
6461
6462        """
6463        if name is None:
6464            name = "exported_episodes"
6465        if os.path.exists(
6466            os.path.join(output_directory, name + ".zip")
6467        ) or os.path.exists(os.path.join(output_directory, name)):
6468            i = 1
6469            while os.path.exists(
6470                os.path.join(output_directory, name + f"_{i}.zip")
6471            ) or os.path.exists(os.path.join(output_directory, name + f"_{i}")):
6472                i += 1
6473            name = name + f"_{i}"
6474        dest_dir = os.path.join(output_directory, name)
6475        os.mkdir(dest_dir)
6476        os.mkdir(os.path.join(dest_dir, "model"))
6477        os.mkdir(os.path.join(dest_dir, "logs"))
6478        runs = []
6479        for episode in episode_names:
6480            runs += self._episodes().get_runs(episode)
6481        for run in runs:
6482            shutil.copytree(
6483                os.path.join(self.project_path, "results", "model", run),
6484                os.path.join(dest_dir, "model", run),
6485            )
6486            shutil.copyfile(
6487                os.path.join(self.project_path, "results", "logs", f"{run}.txt"),
6488                os.path.join(dest_dir, "logs", f"{run}.txt"),
6489            )
6490        data = self._episodes().get_subset(runs)
6491        data.to_pickle(os.path.join(dest_dir, "episodes.pickle"))
6492
6493    def get_results_table(
6494        self,
6495        episode_names: List,
6496        metrics: List = None,
6497        mode: str = "mean",  # Choose between ["mean", "statistics", "detail"]
6498        print_results: bool = True,
6499        classes: List = None,
6500    ):
6501        """Generate a `pandas` dataframe with a summary of episode results.
6502
6503        Parameters
6504        ----------
6505        episode_names : list
6506            a list of names of episodes to include
6507        metrics : list, optional
6508            a list of metric names to include
6509        mode : bool, optional
6510            the mode of the results table, choose between ["mean", "statistics", "detail"], by default "mean"
6511        print_results : bool, optional
6512            if True, the results will be printed to the console, by default True
6513        classes : list, optional
6514            a list of names of classes to include (by default all are included)
6515
6516        Returns
6517        -------
6518        results : pd.DataFrame
6519            a table with the results
6520
6521        """
6522        run_names = []
6523        for episode in episode_names:
6524            run_names += self._episodes().get_runs(episode)
6525        episodes = self.list_episodes(run_names, print_results=False)
6526        metric_columns = [x for x in episodes.columns if x[0] == "results"]
6527        results_df = pd.DataFrame()
6528        if metrics is not None:
6529            metric_columns = [
6530                x for x in metric_columns if x[1].split("_")[0] in metrics
6531            ]
6532        for episode in episode_names:
6533            results = []
6534            metric_set = set()
6535            for run in self._episodes().get_runs(episode):
6536                beh_dict = self.get_behavior_dictionary(run)
6537                res_dict = defaultdict(lambda: {})
6538                for column in metric_columns:
6539                    if np.isnan(episodes.loc[run, column]):
6540                        continue
6541                    split = column[1].split("_")
6542                    if split[-1].isnumeric():
6543                        beh_ind = int(split[-1])
6544                        metric_name = "_".join(split[:-1])
6545                        beh = beh_dict[beh_ind]
6546                    else:
6547                        beh = "average"
6548                        metric_name = column[1]
6549                    res_dict[beh][metric_name] = episodes.loc[run, column]
6550                    metric_set.add(metric_name)
6551                if "average" not in res_dict:
6552                    res_dict["average"] = {}
6553                for metric in metric_set:
6554                    if metric not in res_dict["average"]:
6555                        arr = [
6556                            res_dict[beh][metric]
6557                            for beh in res_dict
6558                            if metric in res_dict[beh]
6559                        ]
6560                        res_dict["average"][metric] = np.mean(arr)
6561                results.append(res_dict)
6562            episode_results = {}
6563            for metric in metric_set:
6564                for beh in results[0].keys():
6565                    if classes is not None and beh not in classes:
6566                        continue
6567                    arr = []
6568                    for res_dict in results:
6569                        if metric in res_dict[beh]:
6570                            arr.append(res_dict[beh][metric])
6571                    if len(arr) > 0:
6572                        if mode == "statistics":
6573                            episode_results[(beh, f"{episode} {metric} mean")] = (
6574                                np.mean(arr)
6575                            )
6576                            episode_results[(beh, f"{episode} {metric} std")] = np.std(
6577                                arr
6578                            )
6579                        elif mode == "mean":
6580                            episode_results[(beh, f"{episode} {metric}")] = np.mean(arr)
6581                        elif mode == "detail":
6582                            for i, val in enumerate(arr):
6583                                episode_results[(beh, f"{episode}::{i} {metric}")] = val
6584            for key, value in episode_results.items():
6585                results_df.loc[key[0], key[1]] = value
6586        if print_results:
6587            print(f"RESULTS:")
6588            print(results_df)
6589            print("\n")
6590        return results_df
6591
6592    def episode_exists(self, episode_name: str) -> bool:
6593        """Check if an episode already exists.
6594
6595        Parameters
6596        ----------
6597        episode_name : str
6598            the episode name
6599
6600        Returns
6601        -------
6602        exists : bool
6603            `True` if the episode exists
6604
6605        """
6606        return self._episodes().check_name_validity(episode_name)
6607
6608    def search_exists(self, search_name: str) -> bool:
6609        """Check if a search already exists.
6610
6611        Parameters
6612        ----------
6613        search_name : str
6614            the search name
6615
6616        Returns
6617        -------
6618        exists : bool
6619            `True` if the search exists
6620
6621        """
6622        return self._searches().check_name_validity(search_name)
6623
6624    def prediction_exists(self, prediction_name: str) -> bool:
6625        """Check if a prediction already exists.
6626
6627        Parameters
6628        ----------
6629        prediction_name : str
6630            the prediction name
6631
6632        Returns
6633        -------
6634        exists : bool
6635            `True` if the prediction exists
6636
6637        """
6638        return self._predictions().check_name_validity(prediction_name)
6639
6640    @staticmethod
6641    def project_name_available(projects_path: str, project_name: str):
6642        """Check if a project name is available.
6643
6644        Parameters
6645        ----------
6646        projects_path : str
6647            the path to the projects directory
6648        project_name : str
6649            the name of the project to check
6650
6651        Returns
6652        -------
6653        available : bool
6654            `True` if the project name is available
6655
6656        """
6657        if projects_path is None:
6658            projects_path = os.path.join(str(Path.home()), "DLC2Action")
6659        return not os.path.exists(os.path.join(projects_path, project_name))
6660
6661    def _update_episode_metrics(self, episode_name: str, metrics: Dict):
6662        """Update meta data with evaluation results.
6663
6664        Parameters
6665        ----------
6666        episode_name : str
6667            the name of the episode
6668        metrics : dict
6669            the metrics dictionary to update with
6670
6671        """
6672        self._episodes().update_episode_metrics(episode_name, metrics)
6673
6674    def rename_episode(self, episode_name: str, new_episode_name: str):
6675        """Rename an episode.
6676
6677        Parameters
6678        ----------
6679        episode_name : str
6680            the current episode name
6681        new_episode_name : str
6682            the new episode name
6683
6684        """
6685        shutil.move(
6686            os.path.join(self.project_path, "results", "model", episode_name),
6687            os.path.join(self.project_path, "results", "model", new_episode_name),
6688        )
6689        shutil.move(
6690            os.path.join(self.project_path, "results", "logs", f"{episode_name}.txt"),
6691            os.path.join(
6692                self.project_path, "results", "logs", f"{new_episode_name}.txt"
6693            ),
6694        )
6695        self._episodes().rename_episode(episode_name, new_episode_name)
6696
6697
6698class _Runner:
6699    """A helper class for running hyperparameter searches."""
6700
6701    def __init__(
6702        self,
6703        search_name: str,
6704        search_space: Dict,
6705        load_episode: str,
6706        load_epoch: int,
6707        metric: str,
6708        average: int,
6709        task: Union[TaskDispatcher, None],
6710        remove_saved_features: bool,
6711        project: Project,
6712    ):
6713        """Initialize the class.
6714
6715        Parameters
6716        ----------
6717        task : TaskDispatcher
6718            the task dispatcher object
6719        search_name : str
6720            the name the search should be saved under
6721        search_space : dict
6722            a dictionary representing the search space; of this general structure:
6723            {'group/param_name': ('float/int/float_log/int_log', start, end),
6724            'group/param_name': ('categorical', [choices])}, e.g.
6725            {'data/overlap': ('int', 5, 100), 'training/lr': ('float_log', 1e-4, 1e-2),
6726            'data/feature_extraction': ('categorical', ['kinematic', 'bones'])}
6727        load_episode : str
6728            the name of the episode to load the model from
6729        load_epoch : int
6730            the epoch to load the model from (if not provided, the last checkpoint is used)
6731        metric : str
6732            the metric to maximize/minimize (see direction)
6733        average : int
6734            the number of epochs to average the metric; if 0, the last value is taken
6735        remove_saved_features : bool
6736            if `True`, the old datasets will be deleted when data parameters change
6737        project : Project
6738            the parent `Project` instance
6739
6740        """
6741        self.search_space = search_space
6742        self.load_episode = load_episode
6743        self.load_epoch = load_epoch
6744        self.metric = metric
6745        self.average = average
6746        self.feature_save_path = None
6747        self.remove_saved_featuress = remove_saved_features
6748        self.save_stores = project._save_stores
6749        self.remove_datasets = project.remove_saved_features
6750        self.task = task
6751        self.search_name = search_name
6752        self.update = project._update
6753        self.remove_episode = project.remove_episode
6754        self.fill = project._fill
6755
6756    def clean(self):
6757        """Remove datasets if needed.
6758
6759        This method removes saved feature datasets when the remove_saved_features flag is set.
6760
6761        """
6762        if self.remove_saved_featuress:
6763            self.remove_datasets([os.path.basename(self.feature_save_path)])
6764
6765    def run(self, trial, parameters):
6766        """Make a trial run.
6767
6768        Parameters
6769        ----------
6770        trial : optuna.trial.Trial
6771            the Optuna trial object
6772        parameters : dict
6773            the base parameters dictionary
6774
6775        Returns
6776        -------
6777        value : float
6778            the metric value for this trial
6779
6780        """
6781        params = deepcopy(parameters)
6782        param_update = defaultdict(
6783            lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: {})))
6784        )
6785        for full_name, space in self.search_space.items():
6786            group, param_name = (
6787                full_name.split("/")[0],
6788                "/".join(full_name.split("/")[1:]),
6789            )
6790            log = space[0][-3:] == "log"
6791            if space[0].startswith("int"):
6792                value = trial.suggest_int(full_name, space[1], space[2], log=log)
6793            elif space[0].startswith("float"):
6794                value = trial.suggest_float(full_name, space[1], space[2], log=log)
6795            elif space[0] == "categorical":
6796                value = trial.suggest_categorical(full_name, space[1])
6797            else:
6798                raise ValueError(
6799                    "The search space has to be formatted as either "
6800                    '("float"/"int"/"float_log"/"int_log", start, end) '
6801                    f'or ("categorical", [choices]); got {space} for {group}/{param_name}'
6802                )
6803            if len(param_name.split("/")) == 1:
6804                param_update[group][param_name] = value
6805            else:
6806                pars = param_name.split("/")
6807                pars = [int(x) if x.isnumeric() else x for x in pars]
6808                if len(pars) == 2:
6809                    param_update[group][pars[0]][pars[1]] = value
6810                elif len(pars) == 3:
6811                    param_update[group][pars[0]][pars[1]][pars[2]] = value
6812                elif len(pars) == 4:
6813                    param_update[group][pars[0]][pars[1]][pars[2]][pars[3]] = value
6814        param_update = {k: dict(v) for k, v in param_update.items()}
6815        params = self.update(params, param_update)
6816        self.remove_episode(f"_{self.search_name}")
6817        params = self.fill(
6818            params,
6819            f"_{self.search_name}",
6820            self.load_episode,
6821            load_epoch=self.load_epoch,
6822            only_load_model=True,
6823        )
6824        if self.feature_save_path != params["data"]["feature_save_path"]:
6825            if self.feature_save_path is not None:
6826                self.clean()
6827            self.feature_save_path = params["data"]["feature_save_path"]
6828        self.save_stores(params)
6829        if self.task is None:
6830            self.task = TaskDispatcher(deepcopy(params))
6831        else:
6832            self.task.update_task(params)
6833
6834        _, metrics_log = self.task.train(trial, self.metric)
6835        if self.metric in metrics_log["val"].keys():
6836            metric_values = metrics_log["val"][self.metric]
6837            if self.average > 0:
6838                value = np.mean(sorted(metric_values)[-self.average :])
6839            else:
6840                value = metric_values[-1]
6841            return value
6842        else:  # ['accuracy', 'precision', 'f1', 'recall', 'count', 'segmental_precision', 'segmental_recall', 'segmental_f1', 'edit_distance', 'f_beta', 'segmental_f_beta', 'semisegmental_precision', 'semisegmental_recall', 'semisegmental_f1', 'pr-auc', 'semisegmental_pr-auc', 'mAP']
6843            if self.metric in [
6844                "f1",
6845                "precision",
6846                "recall",
6847                "accuracy",
6848                "count",
6849                "segmental_precision",
6850                "segmental_recall",
6851                "segmental_f1",
6852                "f_beta",
6853                "segmental_f_beta",
6854                "semisegmental_precision",
6855                "semisegmental_recall",
6856                "semisegmental_f1",
6857                "pr-auc",
6858                "semisegmental_pr-auc",
6859                "mAP",
6860            ]:
6861                return 0
6862            elif self.metric in ["loss", "mse", "mae", "edit_distance"]:
6863                return float("inf")
class Project:
  55class Project:
  56    """A class to create and maintain the project files + keep track of experiments."""
  57
  58    def __init__(
  59        self,
  60        name: str,
  61        data_type: str = None,
  62        annotation_type: str = "none",
  63        projects_path: str = None,
  64        data_path: Union[str, List] = None,
  65        annotation_path: Union[str, List] = None,
  66        copy: bool = False,
  67    ) -> None:
  68        """Initialize the class.
  69
  70        Parameters
  71        ----------
  72        name : str
  73            name of the project
  74        data_type : str, optional
  75            data type (run Project.data_types() to see available options; has to be provided if the project is being
  76            created)
  77        annotation_type : str, default 'none'
  78            annotation type (run Project.annotation_types() to see available options)
  79        projects_path : str, optional
  80            path to the projects folder (is filled with ~/DLC2Action by default)
  81        data_path : str, optional
  82            path to the folder containing input files for the project (has to be provided if the project is being
  83            created)
  84        annotation_path : str, optional
  85            path to the folder containing annotation files for the project
  86        copy : bool, default False
  87            if True, the files from annotation_path and data_path will be copied to the projects folder;
  88            otherwise they will be moved
  89
  90        """
  91        if projects_path is None:
  92            projects_path = os.path.join(str(Path.home()), "DLC2Action")
  93        if not os.path.exists(projects_path):
  94            os.mkdir(projects_path)
  95        self.project_path = os.path.join(projects_path, name)
  96        self.name = name
  97        self.data_type = data_type
  98        self.annotation_type = annotation_type
  99        self.data_path = data_path
 100        self.annotation_path = annotation_path
 101        if not os.path.exists(self.project_path):
 102            if data_type is None:
 103                raise ValueError(
 104                    "The data_type parameter is necessary when creating a new project!"
 105                )
 106            self._initialize_project(
 107                data_type, annotation_type, data_path, annotation_path, copy
 108            )
 109        else:
 110            self.annotation_type, self.data_type = self._read_types()
 111            if data_type != self.data_type and data_type is not None:
 112                raise ValueError(
 113                    f"The project has already been initialized with data_type={self.data_type}!"
 114                )
 115            if annotation_type != self.annotation_type and annotation_type != "none":
 116                raise ValueError(
 117                    f"The project has already been initialized with annotation_type={self.annotation_type}!"
 118                )
 119            self.annotation_path, data_path = self._read_paths()
 120            if self.data_path is None:
 121                self.data_path = data_path
 122            # if data_path != self.data_path and data_path is not None:
 123            #     raise ValueError(
 124            #         f"The project has already been initialized with data_path={self.data_path}!"
 125            #     )
 126            if annotation_path != self.annotation_path and annotation_path is not None:
 127                raise ValueError(
 128                    f"The project has already been initialized with annotation_path={self.annotation_path}!"
 129                )
 130        self._update_configs()
 131
 132    def _make_prediction(
 133        self,
 134        prediction_name: str,
 135        episode_names: List,
 136        load_epochs: Union[List[int], int] = None,
 137        parameters_update: Dict = None,
 138        data_path: str = None,
 139        file_paths: Set = None,
 140        mode: str = "all",
 141        augment_n: int = 0,
 142        evaluate: bool = False,
 143        task: TaskDispatcher = None,
 144        embedding: bool = False,
 145        annotation_type: str = "none",
 146    ) -> Tuple[TaskDispatcher, Dict, str, torch.Tensor]:
 147        """Generate a prediction.
 148        Parameters
 149        ----------
 150        prediction_name : str
 151            name of the prediction
 152            episode_names : List
 153            names of the episodes to use for the prediction
 154            load_epochs : Union[List[int],int], optional
 155            epochs to load for each episode; if a single integer is provided, it will be used for all episodes;
 156            if None, the last epochs will be used
 157            parameters_update : Dict, optional
 158            dictionary with parameters to update the task parameters
 159            data_path : str, optional
 160            path to the data folder; if None, the data_path from the project will be used
 161            file_paths : Set, optional
 162            set of file paths to use for the prediction; if None, the data_path will be used
 163            mode : str, default "all
 164            mode of the prediction; can be "train", "val", "test" or "all
 165            augment_n : int, default 0
 166            number of augmentations to apply to the data; if 0, no augmentations are applied
 167            evaluate : bool, default False
 168            if True, the prediction will be evaluated and the results will be saved to the episode meta file
 169            task : TaskDispatcher, optional
 170            task object to use for the prediction; if None, a new task object will be created
 171            embedding : bool, default False
 172            if True, the prediction will be returned as an embedding
 173            annotation_type : str, default "none
 174            type of the annotation to use for the prediction; if "none", the annotation will not be used
 175        Returns
 176        -------
 177        task : TaskDispatcher
 178            task object used for the prediction
 179        parameters : Dict
 180            parameters used for the prediction
 181        mode : str
 182            mode of the prediction
 183        prediction : torch.Tensor
 184            prediction tensor of shape (num_videos, num_behaviors, num_frames)
 185        inference_time : str
 186            time taken for the prediction in the format "HH:MM:SS"
 187        behavior_dict : Dict
 188            dictionary with behavior names and their indices
 189        """
 190
 191        names = []
 192        for episode_name in episode_names:
 193            names += self._episodes().get_runs(episode_name)
 194        if len(names) == 0:
 195            warnings.warn(f"None of the episodes {episode_names} exist!")
 196            names = [None]
 197        if load_epochs is None:
 198            load_epochs = [None for _ in names]
 199        elif isinstance(load_epochs, int):
 200            load_epochs = [load_epochs for _ in names]
 201        assert len(load_epochs) == len(
 202            names
 203        ), f"Length of load_epochs ({len(load_epochs)}) must match the number of episodes ({len(names)})!"
 204        prediction = None
 205        decision_thresholds = None
 206        time_total = 0
 207        behavior_dicts = [
 208            self.get_behavior_dictionary(episode_name) for episode_name in names
 209        ]
 210
 211        if not all(
 212            [
 213                set(d.values()) == set(behavior_dicts[0].values())
 214                for d in behavior_dicts[1:]
 215            ]
 216        ):
 217            raise ValueError(
 218                f"Episodes {episode_names} have different sets of behaviors!"
 219            )
 220        behaviors = list(behavior_dicts[0].values())
 221
 222        for episode_name, load_epoch, behavior_dict in zip(
 223            names, load_epochs, behavior_dicts
 224        ):
 225            print(f"episode {episode_name}")
 226            task, parameters, data_mode = self._make_task_prediction(
 227                prediction_name=prediction_name,
 228                load_episode=episode_name,
 229                parameters_update=parameters_update,
 230                load_epoch=load_epoch,
 231                data_path=data_path,
 232                mode=mode,
 233                file_paths=file_paths,
 234                task=task,
 235                decision_thresholds=decision_thresholds,
 236                annotation_type=annotation_type,
 237            )
 238            # data_mode = "train" if mode == "all" else mode
 239            time_start = time.time()
 240            new_pred = task.predict(
 241                data_mode,
 242                raw_output=True,
 243                apply_primary_function=True,
 244                augment_n=augment_n,
 245                embedding=embedding,
 246            )
 247            indices = [
 248                behaviors.index(behavior_dict[i]) for i in range(new_pred.shape[1])
 249            ]
 250            new_pred = new_pred[:, indices, :]
 251            time_end = time.time()
 252            time_total += time_end - time_start
 253            if evaluate:
 254                _, metrics = task.evaluate_prediction(
 255                    new_pred, data=data_mode, indices=indices
 256                )
 257                if mode == "val":
 258                    self._update_episode_metrics(episode_name, metrics)
 259            if prediction is None:
 260                prediction = new_pred
 261            else:
 262                prediction += new_pred
 263            print("\n")
 264        hours = int(time_total // 3600)
 265        time_total -= hours * 3600
 266        minutes = int(time_total // 60)
 267        time_total -= minutes * 60
 268        seconds = int(time_total)
 269        inference_time = f"{hours}:{minutes:02}:{seconds:02}"
 270        prediction /= len(names)
 271        return (
 272            task,
 273            parameters,
 274            data_mode,
 275            prediction,
 276            inference_time,
 277            behavior_dicts[0],
 278        )
 279
 280    def _make_task_prediction(
 281        self,
 282        prediction_name: str,
 283        load_episode: str = None,
 284        parameters_update: Dict = None,
 285        load_epoch: int = None,
 286        data_path: str = None,
 287        annotation_path: str = None,
 288        mode: str = "val",
 289        file_paths: Set = None,
 290        decision_thresholds: List = None,
 291        task: TaskDispatcher = None,
 292        annotation_type: str = "none",
 293    ) -> Tuple[TaskDispatcher, Dict, str]:
 294        """Make a `TaskDispatcher` object that will be used to generate a prediction."""
 295        if parameters_update is None:
 296            parameters_update = {}
 297        parameters_update_second = {}
 298        if mode == "all" or data_path is not None or file_paths is not None:
 299            parameters_update_second["training"] = {
 300                "val_frac": 0,
 301                "test_frac": 0,
 302                "partition_method": "random",
 303                "save_split": False,
 304                "split_path": None,
 305            }
 306            mode = "train"
 307        if decision_thresholds is not None:
 308            if (
 309                len(decision_thresholds)
 310                == self._episode(load_episode).get_num_classes()
 311            ):
 312                parameters_update_second["general"] = {
 313                    "threshold_value": decision_thresholds
 314                }
 315            else:
 316                raise ValueError(
 317                    f"The length of the decision thresholds {decision_thresholds} "
 318                    f"must be equal to the length of the behaviors dictionary "
 319                    f"{self._episode(load_episode).get_behaviors_dict()}"
 320                )
 321        data_param_update = {}
 322        if data_path is not None:
 323            data_param_update = {"data_path": data_path}
 324            if annotation_path is None:
 325                data_param_update["annotation_path"] = data_path
 326        if annotation_path is not None:
 327            data_param_update["annotation_path"] = annotation_path
 328        if file_paths is not None:
 329            data_param_update = {"data_path": None, "file_paths": file_paths}
 330        parameters_update = self._update(parameters_update, {"data": data_param_update})
 331        if data_path is not None or file_paths is not None:
 332            general_update = {
 333                "annotation_type": annotation_type,
 334                "only_load_annotated": False,
 335            }
 336        else:
 337            general_update = {}
 338        parameters_update = self._update(parameters_update, {"general": general_update})
 339        task, parameters = self._make_task(
 340            episode_name=prediction_name,
 341            load_episode=load_episode,
 342            parameters_update=parameters_update,
 343            parameters_update_second=parameters_update_second,
 344            load_epoch=load_epoch,
 345            purpose="prediction",
 346            task=task,
 347        )
 348        return task, parameters, mode
 349
 350    def _make_task_training(
 351        self,
 352        episode_name: str,
 353        load_episode: str = None,
 354        parameters_update: Dict = None,
 355        load_epoch: int = None,
 356        load_search: str = None,
 357        load_parameters: list = None,
 358        round_to_binary: list = None,
 359        load_strict: bool = True,
 360        continuing: bool = False,
 361        task: TaskDispatcher = None,
 362        mask_name: str = None,
 363        throwaway: bool = False,
 364    ) -> Tuple[TaskDispatcher, Dict, str]:
 365        """Make a `TaskDispatcher` object that will be used to generate a prediction."""
 366        if parameters_update is None:
 367            parameters_update = {}
 368        if continuing:
 369            purpose = "continuing"
 370        else:
 371            purpose = "training"
 372        if mask_name is not None:
 373            mask_name = os.path.join(self._mask_path(), f"{mask_name}.pickle")
 374        parameters_update_second = {"data": {"real_lens": mask_name}}
 375        if throwaway:
 376            parameters_update = self._update(
 377                parameters_update, {"training": {"normalize": False, "device": "cpu"}}
 378            )
 379        return self._make_task(
 380            episode_name,
 381            load_episode,
 382            parameters_update,
 383            parameters_update_second,
 384            load_epoch,
 385            load_search,
 386            load_parameters,
 387            round_to_binary,
 388            purpose,
 389            task,
 390            load_strict=load_strict,
 391        )
 392
 393    def _make_parameters(
 394        self,
 395        episode_name: str,
 396        load_episode: str = None,
 397        parameters_update: Dict = None,
 398        parameters_update_second: Dict = None,
 399        load_epoch: int = None,
 400        load_search: str = None,
 401        load_parameters: list = None,
 402        round_to_binary: list = None,
 403        purpose: str = "train",
 404        load_strict: bool = True,
 405    ):
 406        """Construct a parameters dictionary."""
 407        if parameters_update is None:
 408            parameters_update = {}
 409        pars_update = deepcopy(parameters_update)
 410        if parameters_update_second is None:
 411            parameters_update_second = {}
 412        if (
 413            purpose == "prediction"
 414            and "model" in pars_update.keys()
 415            and pars_update["general"]["model_name"] != "motionbert"
 416        ):
 417            raise ValueError("Cannot change model parameters after training!")
 418        if purpose in ["continuing", "prediction"] and load_episode is not None:
 419            read_parameters = self._read_parameters()
 420            parameters = self._episodes().load_parameters(load_episode)
 421            parameters["metrics"] = self._update(
 422                read_parameters["metrics"], parameters["metrics"]
 423            )
 424            parameters["ssl"] = self._update(
 425                read_parameters["ssl"], parameters.get("ssl", {})
 426            )
 427        else:
 428            parameters = self._read_parameters()
 429        if "model" in pars_update:
 430            model_params = pars_update.pop("model")
 431        else:
 432            model_params = None
 433        if "features" in pars_update:
 434            feat_params = pars_update.pop("features")
 435        else:
 436            feat_params = None
 437        if "augmentations" in pars_update:
 438            aug_params = pars_update.pop("augmentations")
 439        else:
 440            aug_params = None
 441        parameters = self._update(parameters, pars_update)
 442        if pars_update.get("general", {}).get("model_name") is not None:
 443            model_name = parameters["general"]["model_name"]
 444            parameters["model"] = self._open_yaml(
 445                os.path.join(self.project_path, "config", "model", f"{model_name}.yaml")
 446            )
 447        if pars_update.get("general", {}).get("feature_extraction") is not None:
 448            feat_name = parameters["general"]["feature_extraction"]
 449            parameters["features"] = self._open_yaml(
 450                os.path.join(
 451                    self.project_path, "config", "features", f"{feat_name}.yaml"
 452                )
 453            )
 454            aug_name = options.extractor_to_transformer[
 455                parameters["general"]["feature_extraction"]
 456            ]
 457            parameters["augmentations"] = self._open_yaml(
 458                os.path.join(
 459                    self.project_path, "config", "augmentations", f"{aug_name}.yaml"
 460                )
 461            )
 462        if model_params is not None:
 463            parameters["model"] = self._update(parameters["model"], model_params)
 464        if feat_params is not None:
 465            parameters["features"] = self._update(parameters["features"], feat_params)
 466        if aug_params is not None:
 467            parameters["augmentations"] = self._update(
 468                parameters["augmentations"], aug_params
 469            )
 470        if load_search is not None:
 471            parameters = self._update_with_search(
 472                parameters, load_search, load_parameters, round_to_binary
 473            )
 474        parameters = self._fill(
 475            parameters,
 476            episode_name,
 477            load_episode,
 478            load_epoch=load_epoch,
 479            load_strict=load_strict,
 480            only_load_model=(purpose != "continuing"),
 481            continuing=(purpose in ["prediction", "continuing"]),
 482            enforce_split_parameters=(purpose == "prediction"),
 483        )
 484        parameters = self._update(parameters, parameters_update_second)
 485        return parameters
 486
 487    def _make_task(
 488        self,
 489        episode_name: str,
 490        load_episode: str = None,
 491        parameters_update: Dict = None,
 492        parameters_update_second: Dict = None,
 493        load_epoch: int = None,
 494        load_search: str = None,
 495        load_parameters: list = None,
 496        round_to_binary: list = None,
 497        purpose: str = "train",
 498        task: TaskDispatcher = None,
 499        load_strict: bool = True,
 500    ) -> Tuple[TaskDispatcher, Union[CommentedMap, dict]]:
 501        """Make a `TaskDispatcher` object.
 502
 503        The task parameters are read from the config files and then updated with the
 504        parameters_update dictionary. The model can be either initialized from scratch or loaded from one of the
 505        previous experiments. All parameters and results are saved in the meta files and can be accessed with the
 506        list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the
 507        same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same
 508        data parameters are used.
 509
 510        Parameters
 511        ----------
 512        episode_name : str
 513            the name of the episode
 514        load_episode : str, optional
 515            the (previously run) episode name to load the model from
 516        parameters_update : dict, optional
 517            the dictionary used to update the parameters from the config
 518        parameters_update_second : dict, optional
 519            the dictionary used to update the parameters after the automatic fill-out
 520        load_epoch : int, optional
 521            the epoch to load (if load_episodes is not None); if not provided, the last epoch is used
 522        load_search : str, optional
 523            the hyperparameter search result to load
 524        load_parameters : list, optional
 525            a list of string names of the parameters to load from load_search (if not provided, all parameters
 526            are loaded)
 527        round_to_binary : list, optional
 528            a list of string names of the loaded parameters that should be rounded to the nearest power of two
 529        purpose : {"train", "continuing", "prediction"}
 530            the purpose of the task object (`"train"` for training from scratch, `"continuing"` for continuing
 531            the training of an interrupted episode, `"prediction"` for generating a prediction)
 532        task : TaskDispatcher, optional
 533            a pre-existing task; if provided, the method will update the task instead of creating a new one
 534            (this might save time, mainly on dataset loading)
 535
 536        Returns
 537        -------
 538        task : TaskDispatcher
 539            the `TaskDispatcher` instance
 540        parameters : dict
 541            the parameters dictionary that describes the task
 542
 543        """
 544        parameters = self._make_parameters(
 545            episode_name,
 546            load_episode,
 547            parameters_update,
 548            parameters_update_second,
 549            load_epoch,
 550            load_search,
 551            load_parameters,
 552            round_to_binary,
 553            purpose,
 554            load_strict=load_strict,
 555        )
 556        if task is None:
 557            task = TaskDispatcher(parameters)
 558        else:
 559            task.update_task(parameters)
 560        self._save_stores(parameters)
 561        return task, parameters
 562
 563    def get_decision_thresholds(
 564        self,
 565        episode_names: List,
 566        metric_name: str = "f1",
 567        parameters_update: Dict = None,
 568        load_epochs: List = None,
 569        remove_saved_features: bool = False,
 570    ) -> Tuple[List, List, TaskDispatcher]:
 571        """Compute optimal decision thresholds or load them if they have been computed before.
 572
 573        Parameters
 574        ----------
 575        episode_names : List
 576            a list of episode names
 577        metric_name : {"f1", "segmental_f1", "semisegmental_f1", "f_beta", "segmental_f_beta"}
 578            the metric to optimize
 579        parameters_update : dict, optional
 580            the parameter update dictionary
 581        load_epochs : list, optional
 582            a list of epochs to load (by default last are loaded)
 583        remove_saved_features : bool, default False
 584            if `True`, the dataset will be deleted after the computation
 585
 586        Returns
 587        -------
 588        thresholds : list
 589            a list of float decision threshold values
 590        classes : list
 591            the label names corresponding to the values
 592        task : TaskDispatcher | None
 593            the task used in computation
 594
 595        """
 596        parameters = self._make_parameters(
 597            "_",
 598            episode_names[0],
 599            parameters_update,
 600            {},
 601            load_epochs[0],
 602            purpose="prediction",
 603        )
 604        thresholds = self._thresholds().find_thresholds(
 605            episode_names,
 606            load_epochs,
 607            metric_name,
 608            metric_parameters=parameters["metrics"][metric_name],
 609        )
 610        task = None
 611        behaviors = list(self._episode(episode_names[0]).get_behaviors_dict().values())
 612        return thresholds, behaviors, task
 613
 614    def run_episode(
 615        self,
 616        episode_name: str,
 617        load_episode: str = None,
 618        parameters_update: Dict = None,
 619        task: TaskDispatcher = None,
 620        load_epoch: int = None,
 621        load_search: str = None,
 622        load_parameters: list = None,
 623        round_to_binary: list = None,
 624        load_strict: bool = True,
 625        n_seeds: int = 1,
 626        force: bool = False,
 627        suppress_name_check: bool = False,
 628        remove_saved_features: bool = False,
 629        mask_name: str = None,
 630        autostop_metric: str = None,
 631        autostop_interval: int = 50,
 632        autostop_threshold: float = 0.001,
 633        loading_bar: bool = False,
 634        trial: Tuple = None,
 635    ) -> TaskDispatcher:
 636        """Run an episode.
 637
 638        The task parameters are read from the config files and then updated with the
 639        parameters_update dictionary. The model can be either initialized from scratch or loaded from one of the
 640        previous experiments. All parameters and results are saved in the meta files and can be accessed with the
 641        list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the
 642        same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same
 643        data parameters are used.
 644
 645        You can use the autostop parameters to finish training when the parameters are not improving. It will be
 646        stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than
 647        the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the
 648        current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared.
 649
 650        Parameters
 651        ----------
 652        episode_name : str
 653            the episode name
 654        load_episode : str, optional
 655            the (previously run) episode name to load the model from; if the episode has multiple runs,
 656            the new episode will have the same number of runs, each starting with one of the pre-trained models
 657        parameters_update : dict, optional
 658            the dictionary used to update the parameters from the config files
 659        task : TaskDispatcher, optional
 660            a pre-existing `TaskDispatcher` object (if provided, the method will update it instead of creating
 661            a new instance)
 662        load_epoch : int, optional
 663            the epoch to load (if load_episodes is not None); if not provided, the last epoch is used
 664        load_search : str, optional
 665            the hyperparameter search result to load
 666        load_parameters : list, optional
 667            a list of string names of the parameters to load from load_search (if not provided, all parameters
 668            are loaded)
 669        round_to_binary : list, optional
 670            a list of string names of the loaded parameters that should be rounded to the nearest power of two
 671        load_strict : bool, default True
 672            if `False`, matching weights will be loaded from `load_episode` and differences in parameter name lists and
 673            weight shapes will be ignored; otherwise mismatches will prompt a `RuntimeError`
 674        n_seeds : int, default 1
 675            the number of runs to perform; if `n_seeds > 1`, the episodes will be named `episode_name#run_index`, e.g.
 676            `test_episode#0` and `test_episode#1`
 677        force : bool, default False
 678            if `True` and an episode with name `episode_name` already exists, it will be overwritten (use with caution!)
 679        suppress_name_check : bool, default False
 680            if `True`, episode names with a double colon are allowed (please don't use this option unless you understand
 681            why they are usually forbidden)
 682        remove_saved_features : bool, default False
 683            if `True`, the dataset will be deleted after training
 684        mask_name : str, optional
 685            the name of the real_lens to apply
 686        autostop_metric : str, optional
 687            the autostop metric (can be any one of the tracked metrics of `'loss'`)
 688        autostop_interval : int, default 50
 689            the number of epochs to average the autostop metric over
 690        autostop_threshold : float, default 0.001
 691            the autostop difference threshold
 692        loading_bar : bool, default False
 693            if `True`, a loading bar will be displayed
 694        trial : tuple, optional
 695            a tuple of (trial, metric) for hyperparameter search
 696
 697        Returns
 698        -------
 699        TaskDispatcher
 700            the `TaskDispatcher` object
 701
 702        """
 703
 704        import gc
 705
 706        gc.collect()
 707        if torch.cuda.is_available():
 708            torch.cuda.empty_cache()
 709
 710        if type(n_seeds) is not int or n_seeds < 1:
 711            raise ValueError(
 712                f"The n_seeds parameter has to be an integer larger than 0; got {n_seeds}"
 713            )
 714        if n_seeds > 1 and mask_name is not None:
 715            raise ValueError("Cannot apply a real_lens with n_seeds > 1")
 716        self._check_episode_validity(
 717            episode_name, allow_doublecolon=suppress_name_check, force=force
 718        )
 719        load_runs = self._episodes().get_runs(load_episode)
 720        if len(load_runs) > 1:
 721            task = self.run_episodes(
 722                episode_names=[
 723                    f'{episode_name}#{run.split("#")[-1]}' for run in load_runs
 724                ],
 725                load_episodes=load_runs,
 726                parameters_updates=[parameters_update for _ in load_runs],
 727                load_epochs=[load_epoch for _ in load_runs],
 728                load_searches=[load_search for _ in load_runs],
 729                load_parameters=[load_parameters for _ in load_runs],
 730                round_to_binary=[round_to_binary for _ in load_runs],
 731                load_strict=[load_strict for _ in load_runs],
 732                suppress_name_check=True,
 733                force=force,
 734                remove_saved_features=False,
 735            )
 736            if remove_saved_features:
 737                self._remove_stores(
 738                    {
 739                        "general": task.general_parameters,
 740                        "data": task.data_parameters,
 741                        "features": task.feature_parameters,
 742                    }
 743                )
 744            if n_seeds > 1:
 745                warnings.warn(
 746                    f"The n_seeds parameter is disregarded since load_episode={load_episode} has multiple runs"
 747                )
 748        elif n_seeds > 1:
 749
 750            self.run_episodes(
 751                episode_names=[f"{episode_name}#{i}" for i in range(n_seeds)],
 752                load_episodes=[load_episode for _ in range(n_seeds)],
 753                parameters_updates=[parameters_update for _ in range(n_seeds)],
 754                load_epochs=[load_epoch for _ in range(n_seeds)],
 755                load_searches=[load_search for _ in range(n_seeds)],
 756                load_parameters=[load_parameters for _ in range(n_seeds)],
 757                round_to_binary=[round_to_binary for _ in range(n_seeds)],
 758                load_strict=[load_strict for _ in range(n_seeds)],
 759                suppress_name_check=True,
 760                force=force,
 761                remove_saved_features=remove_saved_features,
 762            )
 763        else:
 764            print(f"TRAINING {episode_name}")
 765            try:
 766                task, parameters = self._make_task_training(
 767                    episode_name,
 768                    load_episode,
 769                    parameters_update,
 770                    load_epoch,
 771                    load_search,
 772                    load_parameters,
 773                    round_to_binary,
 774                    continuing=False,
 775                    task=task,
 776                    mask_name=mask_name,
 777                    load_strict=load_strict,
 778                )
 779                self._save_episode(
 780                    episode_name,
 781                    parameters,
 782                    task.behaviors_dict(),
 783                    norm_stats=task.get_normalization_stats(),
 784                )
 785                time_start = time.time()
 786                if trial is not None:
 787                    trial, metric = trial
 788                else:
 789                    trial, metric = None, None
 790                logs = task.train(
 791                    autostop_metric=autostop_metric,
 792                    autostop_interval=autostop_interval,
 793                    autostop_threshold=autostop_threshold,
 794                    loading_bar=loading_bar,
 795                    trial=trial,
 796                    optimized_metric=metric,
 797                )
 798                time_end = time.time()
 799                time_total = time_end - time_start
 800                hours = int(time_total // 3600)
 801                time_total -= hours * 3600
 802                minutes = int(time_total // 60)
 803                time_total -= minutes * 60
 804                seconds = int(time_total)
 805                training_time = f"{hours}:{minutes:02}:{seconds:02}"
 806                self._update_episode_results(episode_name, logs, training_time)
 807                if remove_saved_features:
 808                    self._remove_stores(parameters)
 809                print("\n")
 810                return task
 811
 812            except Exception as e:
 813                if isinstance(e, optuna.exceptions.TrialPruned):
 814                    raise e
 815                else:
 816                    # if str(e) != f"The {episode_name} episode name is already in use!":
 817                    #     self.remove_episode(episode_name)
 818                    raise RuntimeError(f"Episode {episode_name} could not run")
 819
 820    def run_episodes(
 821        self,
 822        episode_names: List,
 823        load_episodes: List = None,
 824        parameters_updates: List = None,
 825        load_epochs: List = None,
 826        load_searches: List = None,
 827        load_parameters: List = None,
 828        round_to_binary: List = None,
 829        load_strict: List = None,
 830        force: bool = False,
 831        suppress_name_check: bool = False,
 832        remove_saved_features: bool = False,
 833    ) -> TaskDispatcher:
 834        """Run multiple episodes in sequence (and re-use previously loaded information).
 835
 836        For each episode, the task parameters are read from the config files and then updated with the
 837        parameter_update dictionary. The model can be either initialized from scratch or loaded from one of the
 838        previous experiments. All parameters and results are saved in the meta files and can be accessed with the
 839        list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the
 840        same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same
 841        data parameters are used.
 842
 843        Parameters
 844        ----------
 845        episode_names : list
 846            a list of strings of episode names
 847        load_episodes : list, optional
 848            a list of strings of (previously run) episode names to load the model from; if the episode has multiple runs,
 849            the new episode will have the same number of runs, each starting with one of the pre-trained models
 850        parameters_updates : list, optional
 851            a list of dictionaries used to update the parameters from the config
 852        load_epochs : list, optional
 853            a list of integers used to specify the epoch to load (if load_episodes is not None)
 854        load_searches : list, optional
 855            a list of strings of hyperparameter search results to load
 856        load_parameters : list, optional
 857            a list of lists of string names of the parameters to load from the searches
 858        round_to_binary : list, optional
 859            a list of string names of the loaded parameters that should be rounded to the nearest power of two
 860        load_strict : list, optional
 861            a list of boolean values specifying weight loading policy: if `False`, matching weights will be loaded from
 862            the corresponding episode and differences in parameter name lists and
 863            weight shapes will be ignored; otherwise mismatches will prompt a `RuntimeError` (by default `True` for
 864            every episode)
 865        force : bool, default False
 866            if `True` and an episode name is already taken, it will be overwritten (use with caution!)
 867        suppress_name_check : bool, default False
 868            if `True`, episode names with a double colon are allowed (please don't use this option unless you understand
 869            why they are usually forbidden)
 870        remove_saved_features : bool, default False
 871            if `True`, the dataset will be deleted after training
 872
 873        Returns
 874        -------
 875        TaskDispatcher
 876            the task dispatcher object
 877
 878        """
 879        task = None
 880        if load_searches is None:
 881            load_searches = [None for _ in episode_names]
 882        if load_episodes is None:
 883            load_episodes = [None for _ in episode_names]
 884        if parameters_updates is None:
 885            parameters_updates = [None for _ in episode_names]
 886        if load_parameters is None:
 887            load_parameters = [None for _ in episode_names]
 888        if load_epochs is None:
 889            load_epochs = [None for _ in episode_names]
 890        if load_strict is None:
 891            load_strict = [True for _ in episode_names]
 892        for (
 893            parameters_update,
 894            episode_name,
 895            load_episode,
 896            load_epoch,
 897            load_search,
 898            load_parameters_list,
 899            load_strict_value,
 900        ) in zip(
 901            parameters_updates,
 902            episode_names,
 903            load_episodes,
 904            load_epochs,
 905            load_searches,
 906            load_parameters,
 907            load_strict,
 908        ):
 909            task = self.run_episode(
 910                episode_name,
 911                load_episode,
 912                parameters_update,
 913                task,
 914                load_epoch,
 915                load_search,
 916                load_parameters_list,
 917                round_to_binary,
 918                load_strict_value,
 919                suppress_name_check=suppress_name_check,
 920                force=force,
 921                remove_saved_features=remove_saved_features,
 922            )
 923        return task
 924
 925    def continue_episode(
 926        self,
 927        episode_name: str,
 928        num_epochs: int = None,
 929        task: TaskDispatcher = None,
 930        n_seeds: int = 1,
 931        remove_saved_features: bool = False,
 932        device: str = "cuda",
 933        num_cpus: int = None,
 934    ) -> TaskDispatcher:
 935        """Load an older episode and continue running from the latest checkpoint.
 936
 937        All parameters as well as the model and optimizer state dictionaries are loaded from the episode.
 938
 939        Parameters
 940        ----------
 941        episode_name : str
 942            the name of the episode to continue
 943        num_epochs : int, optional
 944            the new number of epochs
 945        task : TaskDispatcher, optional
 946            a pre-existing task; if provided, the method will update the task instead of creating a new one
 947            (this might save time, mainly on dataset loading)
 948        n_seeds : int, default 1
 949            the number of runs to perform; if `n_seeds > 1`, the episodes will be named `episode_name#run_index`, e.g.
 950            `test_episode#0` and `test_episode#1`
 951        remove_saved_features : bool, default False
 952            if `True`, pre-computed features will be deleted after the run
 953        device : str, default "cuda"
 954            the torch device to use
 955        num_cpus : int, optional
 956            the number of CPUs to use for data loading; if `None`, the number of available CPUs will be used
 957
 958        Returns
 959        -------
 960        TaskDispatcher
 961            the task dispatcher
 962
 963        """
 964        runs = self._episodes().get_runs(episode_name)
 965        for run in runs:
 966            print(f"TRAINING {run}")
 967            if num_epochs is None and not self._episode(run).unfinished():
 968                continue
 969            parameters_update = {
 970                "training": {
 971                    "num_epochs": num_epochs,
 972                    "device": device,
 973                },
 974                "general": {"num_cpus": num_cpus},
 975            }
 976            task, parameters = self._make_task_training(
 977                run,
 978                load_episode=run,
 979                parameters_update=parameters_update,
 980                continuing=True,
 981                task=task,
 982            )
 983            time_start = time.time()
 984            logs = task.train()
 985            time_end = time.time()
 986            old_time = self._training_time(run)
 987            if not np.isnan(old_time):
 988                time_end += old_time
 989                time_total = time_end - time_start
 990                hours = int(time_total // 3600)
 991                time_total -= hours * 3600
 992                minutes = int(time_total // 60)
 993                time_total -= minutes * 60
 994                seconds = int(time_total)
 995                training_time = f"{hours}:{minutes:02}:{seconds:02}"
 996            else:
 997                training_time = np.nan
 998            self._save_episode(
 999                run,
1000                parameters,
1001                task.behaviors_dict(),
1002                suppress_validation=True,
1003                training_time=training_time,
1004                norm_stats=task.get_normalization_stats(),
1005            )
1006            self._update_episode_results(run, logs)
1007            print("\n")
1008        if len(runs) < n_seeds:
1009            for i in range(len(runs), n_seeds):
1010                self.run_episode(
1011                    f"{episode_name}#{i}",
1012                    parameters_update=self._episodes().load_parameters(runs[0]),
1013                    task=task,
1014                    suppress_name_check=True,
1015                )
1016        if remove_saved_features:
1017            self._remove_stores(parameters)
1018        return task
1019
1020    def run_default_hyperparameter_search(
1021        self,
1022        search_name: str,
1023        model_name: str,
1024        metric: str = "f1",
1025        best_n: int = 3,
1026        direction: str = "maximize",
1027        load_episode: str = None,
1028        load_epoch: int = None,
1029        load_strict: bool = True,
1030        prune: bool = True,
1031        force: bool = False,
1032        remove_saved_features: bool = False,
1033        overlap: float = 0,
1034        num_epochs: int = 50,
1035        test_frac: float = None,
1036        n_trials=150,
1037        batch_size=32,
1038    ):
1039        """Run an optuna hyperparameter search with default parameters for a model.
1040
1041        For the vast majority of cases, optimizing the default parameters should be enough.
1042        Check out `dlc2action.options.model_hyperparameters` for the lists of parameters.
1043        There are also options to set overlap, test fraction and number of epochs parameters for the search without
1044        modifying the project config files. However, if you want something more complex, look into
1045        `Project.run_hyperparameter_search`.
1046
1047        The task parameters are read from the config files and updated with the parameters_update dictionary.
1048        The model can be either initialized from scratch or loaded from a previously run episode.
1049        For each trial, the objective metric is averaged over a few best epochs.
1050
1051        Parameters
1052        ----------
1053        search_name : str
1054            the name of the search to store it in the meta files and load in run_episode
1055        model_name : str
1056            the name
1057        metric : str
1058            the metric to maximize/minimize (see direction); if the metric has an `"average"` parameter and it is set to
1059            `"none"` in the config files, it will be reset to `"macro"` for the search
1060        best_n : int, default 1
1061            the number of epochs to average the metric; if 0, the last value is taken
1062        direction : {'maximize', 'minimize'}
1063            optimization direction
1064        load_episode : str, optional
1065            the name of the episode to load the model from
1066        load_epoch : int, optional
1067            the epoch to load the model from (if not provided, the last checkpoint is used)
1068        load_strict : bool, default True
1069            if `True`, the model will be loaded only if the parameters match exactly
1070        prune : bool, default False
1071            if `True`, experiments where the optimized metric is improving too slowly will be terminated
1072            (with optuna HyperBand pruner)
1073        force : bool, default False
1074            if `True`, existing searches with the same name will be overwritten
1075        remove_saved_features : bool, default False
1076            if `True`, pre-computed features will be deleted after each run (if the data parameters change)
1077        overlap : float, default 0
1078            the overlap to use for the search
1079        num_epochs : int, default 50
1080            the number of epochs to use for the search
1081        test_frac : float, optional
1082            the test fraction to use for the search
1083        n_trials : int, default 150
1084            the number of trials to run
1085        batch_size : int, default 32
1086            the batch size to use for the search
1087
1088        Returns
1089        -------
1090        best_parameters : dict
1091            a dictionary of best parameters
1092
1093        """
1094        if model_name not in options.model_hyperparameters:
1095            raise ValueError(
1096                f"There is no default search space for {model_name}! Please choose from {options.model_hyperparameters.keys()} or try project.run_hyperparameter_search()"
1097            )
1098        pars = {
1099            "general": {"overlap": overlap, "model_name": model_name},
1100            "training": {"num_epochs": num_epochs, "batch_size": batch_size},
1101        }
1102        if test_frac is not None:
1103            pars["training"]["test_frac"] = test_frac
1104        if not metric.split("_")[-1].isnumeric():
1105            project_pars = self._read_parameters()
1106            if project_pars["metrics"][metric].get("average") == "none":
1107                pars["metrics"] = {metric: {"average": "macro"}}
1108        return self.run_hyperparameter_search(
1109            search_name=search_name,
1110            search_space=options.model_hyperparameters[model_name],
1111            metric=metric,
1112            n_trials=n_trials,
1113            best_n=best_n,
1114            parameters_update=pars,
1115            direction=direction,
1116            load_episode=load_episode,
1117            load_epoch=load_epoch,
1118            load_strict=load_strict,
1119            prune=prune,
1120            force=force,
1121            remove_saved_features=remove_saved_features,
1122        )
1123
1124    def run_hyperparameter_search(
1125        self,
1126        search_name: str,
1127        search_space: Dict,
1128        metric: str = "f1",
1129        n_trials: int = 20,
1130        best_n: int = 1,
1131        parameters_update: Dict = None,
1132        direction: str = "maximize",
1133        load_episode: str = None,
1134        load_epoch: int = None,
1135        load_strict: bool = True,
1136        prune: bool = False,
1137        force: bool = False,
1138        remove_saved_features: bool = False,
1139        make_plots: bool = True,
1140    ) -> Dict:
1141        """Run an optuna hyperparameter search.
1142
1143        For a simpler function that fits most use cases, check out `Project.run_default_hyperparameter_search()`.
1144
1145        To use a default search space with this method, import `dlc2action.options.model_hyperparameters`. It is
1146        a dictionary where keys are model names and values are default search spaces.
1147
1148        The task parameters are read from the config files and updated with the parameters_update dictionary.
1149        The model can be either initialized from scratch or loaded from a previously run episode.
1150        For each trial, the objective metric is averaged over a few best epochs.
1151
1152        Parameters
1153        ----------
1154        search_name : str
1155            the name of the search to store it in the meta files and load in run_episode
1156        search_space : dict
1157            a dictionary representing the search space; of this general structure:
1158            {'group/param_name': ('float/int/float_log/int_log', start, end),
1159            'group/param_name': ('categorical', [choices])}, e.g.
1160            {'data/overlap': ('int', 5, 100), 'training/lr': ('float_log', 1e-4, 1e-2),
1161            'data/feature_extraction': ('categorical', ['kinematic', 'bones'])};
1162        metric : str, default f1
1163            the metric to maximize/minimize (see direction)
1164        n_trials : int, default 20
1165            the number of optimization trials to run
1166        best_n : int, default 1
1167            the number of epochs to average the metric; if 0, the last value is taken
1168        parameters_update : dict, optional
1169            the parameters update dictionary
1170        direction : {'maximize', 'minimize'}
1171            optimization direction
1172        load_episode : str, optional
1173            the name of the episode to load the model from
1174        load_epoch : int, optional
1175            the epoch to load the model from (if not provided, the last checkpoint is used)
1176        load_strict : bool, default True
1177            if `True`, the model will be loaded only if the parameters match exactly
1178        prune : bool, default False
1179            if `True`, experiments where the optimized metric is improving too slowly will be terminated
1180            (with optuna HyperBand pruner)
1181        force : bool, default False
1182            if `True`, existing searches with the same name will be overwritten
1183        remove_saved_features : bool, default False
1184            if `True`, pre-computed features will be deleted after each run (if the data parameters change)
1185
1186        Returns
1187        -------
1188        dict
1189            a dictionary of best parameters
1190
1191        """
1192        self._check_search_validity(search_name, force=force)
1193        print(f"SEARCH {search_name}")
1194        self.remove_episode(f"_{search_name}")
1195        if parameters_update is None:
1196            parameters_update = {}
1197        parameters_update = self._update(
1198            parameters_update, {"general": {"metric_functions": {metric}}}
1199        )
1200        parameters = self._make_parameters(
1201            f"_{search_name}",
1202            load_episode,
1203            parameters_update,
1204            parameters_update_second={"training": {"model_save_path": None}},
1205            load_epoch=load_epoch,
1206            load_strict=load_strict,
1207        )
1208        task = None
1209
1210        if prune:
1211            pruner = optuna.pruners.HyperbandPruner()
1212        else:
1213            pruner = optuna.pruners.NopPruner()
1214        study = optuna.create_study(direction=direction, pruner=pruner)
1215        runner = _Runner(
1216            search_space=search_space,
1217            load_episode=load_episode,
1218            load_epoch=load_epoch,
1219            metric=metric,
1220            average=best_n,
1221            task=task,
1222            remove_saved_features=remove_saved_features,
1223            project=self,
1224            search_name=search_name,
1225        )
1226        study.optimize(lambda trial: runner.run(trial, parameters), n_trials=n_trials)
1227        if make_plots:
1228            search_path = self._search_path(search_name)
1229            os.mkdir(search_path)
1230            fig = optuna.visualization.plot_contour(study)
1231            plotly.offline.plot(
1232                fig, filename=os.path.join(search_path, f"{search_name}_contour.html")
1233            )
1234            fig = optuna.visualization.plot_param_importances(study)
1235            plotly.offline.plot(
1236                fig,
1237                filename=os.path.join(search_path, f"{search_name}_importances.html"),
1238            )
1239        best_params = study.best_params
1240        best_value = study.best_value
1241        if best_value == 0 or best_value == float("inf"):
1242            raise ValueError(
1243                f"Best metric value is {best_value}, check your partition method and make sure that all behaviors are present in the validation set!"
1244            )
1245        self._save_search(
1246            search_name,
1247            parameters,
1248            n_trials,
1249            best_params,
1250            best_value,
1251            metric,
1252            search_space,
1253        )
1254        self.remove_episode(f"_{search_name}")
1255        runner.clean()
1256        print(f"best parameters: {best_params}")
1257        print("\n")
1258        return best_params
1259
1260    def run_prediction(
1261        self,
1262        prediction_name: str,
1263        episode_names: List,
1264        load_epochs: List = None,
1265        parameters_update: Dict = None,
1266        augment_n: int = 10,
1267        data_path: str = None,
1268        mode: str = "all",
1269        file_paths: Set = None,
1270        remove_saved_features: bool = False,
1271        frame_number_map_file: str = None,
1272        force: bool = False,
1273        embedding: bool = False,
1274    ) -> None:
1275        """Load models from previously run episodes to generate a prediction.
1276
1277        The probabilities predicted by the models are averaged.
1278        Unless `submission` is `True`, the prediction results are saved as a pickled dictionary in the project_name/results/predictions folder
1279        under the {episode_name}_{load_epoch}.pickle name. The file is a nested dictionary where the first-level
1280        keys are the video ids, the second-level keys are the clip ids (like individual names) and the values
1281        are the prediction arrays.
1282
1283        Parameters
1284        ----------
1285        prediction_name : str
1286            the name of the prediction
1287        episode_names : list
1288            a list of string episode names to load the models from
1289        load_epochs : list or int, optional
1290            a list of integer epoch indices to load the model from; if None, the last ones are used, if int the same epoch is used for all episodes
1291        parameters_update : dict, optional
1292            a dictionary of parameter updates
1293        augment_n : int, default 10
1294            the number of augmentations to average over
1295        data_path : str, optional
1296            the data path to run the prediction for
1297        mode : {'all', 'test', 'val', 'train'}
1298            the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
1299        file_paths : set, optional
1300            a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction
1301            for
1302        remove_saved_features : bool, default False
1303            if `True`, pre-computed features will be deleted
1304        submission : bool, default False
1305            if `True`, a MABe-22 style submission file is generated
1306        frame_number_map_file : str, optional
1307            path to the frame number map file
1308        force : bool, default False
1309            if `True`, existing prediction with this name will be overwritten
1310        embedding : bool, default False
1311            if `True`, the prediction is made for the embedding task
1312
1313        """
1314        self._check_prediction_validity(prediction_name, force=force)
1315        print(f"PREDICTION {prediction_name}")
1316        task, parameters, mode, prediction, inference_time, behavior_dict = (
1317            self._make_prediction(
1318                prediction_name,
1319                episode_names,
1320                load_epochs,
1321                parameters_update,
1322                data_path,
1323                file_paths,
1324                mode,
1325                augment_n,
1326                evaluate=False,
1327                embedding=embedding,
1328            )
1329        )
1330        predicted = task.dataset(mode).generate_full_length_prediction(prediction)
1331
1332        if remove_saved_features:
1333            self._remove_stores(parameters)
1334
1335        self._save_prediction(
1336            prediction_name,
1337            predicted,
1338            parameters,
1339            task,
1340            mode,
1341            embedding,
1342            inference_time,
1343            behavior_dict,
1344        )
1345        print("\n")
1346
1347    def evaluate_prediction(
1348        self,
1349        prediction_name: str,
1350        parameters_update: Dict = None,
1351        data_path: str = None,
1352        annotation_path: str = None,
1353        file_paths: Set = None,
1354        mode: str = None,
1355        remove_saved_features: bool = False,
1356        annotation_type: str = "none",
1357        num_classes: int = None,  # Set when using data_path
1358    ) -> Tuple[float, dict]:
1359        """Make predictions and evaluate them
1360        inputs:
1361            prediction_name (str): the name of the prediction
1362            parameters_update (dict): a dictionary of parameter updates
1363            data_path (str): the data path to run the prediction for
1364            annotation_path (str): the annotation path to run the prediction for
1365            file_paths (set): a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction for
1366            mode (str): the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
1367            remove_saved_features (bool): if `True`, pre-computed features will be deleted
1368            annotation_type (str): the type of annotation to use for evaluation
1369            num_classes (int): the number of classes in the dataset, must be set with data_path
1370        outputs:
1371            results (dict): a dictionary of average values of metric functions
1372        """
1373
1374        prediction_path = os.path.join(
1375            self.project_path, "results", "predictions", f"{prediction_name}"
1376        )
1377        prediction_dict = {}
1378        for prediction_file_path in [
1379            os.path.join(prediction_path, i) for i in os.listdir(prediction_path)
1380        ]:
1381            with open(os.path.join(prediction_file_path), "rb") as f:
1382                prediction = pickle.load(f)
1383            video_id = os.path.basename(prediction_file_path).split(
1384                "_" + prediction_name
1385            )[0]
1386            prediction_dict[video_id] = prediction
1387        if parameters_update is None:
1388            parameters_update = {}
1389        parameters_update = self._update(
1390            self._predictions().load_parameters(prediction_name), parameters_update
1391        )
1392        parameters_update.pop("model")
1393        if not data_path is None:
1394            assert (
1395                not num_classes is None
1396            ), "num_classes must be provided if data_path is provided"
1397            parameters_update["general"]["num_classes"] = num_classes + int(
1398                parameters_update["general"]["exclusive"]
1399            )
1400        task, parameters, mode = self._make_task_prediction(
1401            "_",
1402            load_episode=None,
1403            parameters_update=parameters_update,
1404            data_path=data_path,
1405            annotation_path=annotation_path,
1406            file_paths=file_paths,
1407            mode=mode,
1408            annotation_type=annotation_type,
1409        )
1410        results = task.evaluate_prediction(prediction_dict, data=mode)
1411        if remove_saved_features:
1412            self._remove_stores(parameters)
1413        results = Project._reformat_results(
1414            results[1],
1415            task.behaviors_dict(),
1416            exclusive=task.general_parameters["exclusive"],
1417        )
1418        return results
1419
1420    def evaluate(
1421        self,
1422        episode_names: List,
1423        load_epochs: List = None,
1424        augment_n: int = 0,
1425        data_path: str = None,
1426        file_paths: Set = None,
1427        mode: str = None,
1428        parameters_update: Dict = None,
1429        multiple_episode_policy: str = "average",
1430        remove_saved_features: bool = False,
1431        skip_updating_meta: bool = True,
1432        annotation_type: str = "none",
1433    ) -> Dict:
1434        """Load one or several models from previously run episodes to make an evaluation.
1435
1436        By default it will run on the test (or validation, if there is no test) subset of the project dataset.
1437
1438        Parameters
1439        ----------
1440        episode_names : list
1441            a list of string episode names to load the models from
1442        load_epochs : list, optional
1443            a list of integer epoch indices to load the model from; if None, the last ones are used
1444        augment_n : int, default 0
1445            the number of augmentations to average over
1446        data_path : str, optional
1447            the data path to run the prediction for
1448        file_paths : set, optional
1449            a set of files to run the prediction for
1450        mode : {'test', 'val', 'train', 'all'}
1451            the subset of the data to make the prediction for (forced to 'all' if data_path is not None;
1452            by default 'test' if test subset is not empty and 'val' otherwise)
1453        parameters_update : dict, optional
1454            a dictionary with parameter updates (cannot change model parameters)
1455        multiple_episode_policy : {'average', 'statistics'}
1456            the policy to use when multiple episodes are provided
1457        remove_saved_features : bool, default False
1458            if `True`, the dataset will be deleted
1459        skip_updating_meta : bool, default True
1460            if `True`, the meta file will not be updated with the computed metrics
1461
1462        Returns
1463        -------
1464        metric : dict
1465            a dictionary of average values of metric functions
1466
1467        """
1468        names = []
1469        for episode_name in episode_names:
1470            names += self._episodes().get_runs(episode_name)
1471        if len(set(episode_names)) == 1:
1472            print(f"EVALUATION {episode_names[0]}")
1473        else:
1474            print(f"EVALUATION {episode_names}")
1475        if len(names) > 1:
1476            evaluate = True
1477        else:
1478            evaluate = False
1479        if multiple_episode_policy == "average":
1480            task, parameters, mode, prediction, inference_time, behavior_dict = (
1481                self._make_prediction(
1482                    "_",
1483                    episode_names,
1484                    load_epochs,
1485                    parameters_update,
1486                    mode=mode,
1487                    data_path=data_path,
1488                    file_paths=file_paths,
1489                    augment_n=augment_n,
1490                    evaluate=evaluate,
1491                    annotation_type=annotation_type,
1492                )
1493            )
1494            print("EVALUATE PREDICTION:")
1495            indices = [
1496                list(behavior_dict.keys()).index(i) for i in range(len(behavior_dict))
1497            ]
1498            _, results = task.evaluate_prediction(
1499                prediction, data=mode, indices=indices
1500            )
1501            if len(names) == 1 and mode == "val" and not skip_updating_meta:
1502                self._update_episode_metrics(names[0], results)
1503            results = Project._reformat_results(
1504                results,
1505                behavior_dict,
1506                exclusive=task.general_parameters["exclusive"],
1507            )
1508
1509        elif multiple_episode_policy == "statistics":
1510            values = defaultdict(lambda: [])
1511            task = None
1512            for name in names:
1513                (
1514                    task,
1515                    parameters,
1516                    mode,
1517                    prediction,
1518                    inference_time,
1519                    behavior_dict,
1520                ) = self._make_prediction(
1521                    "_",
1522                    [name],
1523                    load_epochs,
1524                    parameters_update,
1525                    mode=mode,
1526                    data_path=data_path,
1527                    file_paths=file_paths,
1528                    augment_n=augment_n,
1529                    evaluate=evaluate,
1530                    task=task,
1531                )
1532                _, metrics = task.evaluate_prediction(
1533                    prediction, data=mode, indices=list(behavior_dict.keys())
1534                )
1535                for name, value in metrics.items():
1536                    values[name].append(value)
1537                if mode == "val" and not skip_updating_meta:
1538                    self._update_episode_metrics(name, metrics)
1539            results = defaultdict(lambda: {})
1540            mean_string = ""
1541            std_string = ""
1542            for key, value_list in values.items():
1543                results[key]["mean"] = np.mean(value_list)
1544                results[key]["std"] = np.std(value_list)
1545                results[key]["all"] = value_list
1546                mean_string += f"{key} {np.mean(value_list):.3f}, "
1547                std_string += f"{key} {np.std(value_list):.3f}, "
1548            print("MEAN:")
1549            print(mean_string)
1550            print("STD:")
1551            print(std_string)
1552        else:
1553            raise ValueError(
1554                f"The {multiple_episode_policy} multiple episode policy is not recognized; please choose "
1555                f"from ['average', 'statistics']"
1556            )
1557        if len(names) > 0 and remove_saved_features:
1558            self._remove_stores(parameters)
1559        print(f"Inference time: {inference_time}")
1560        print("\n")
1561        return results
1562
1563    def run_suggestion(
1564        self,
1565        suggestions_name: str,
1566        error_episode: str = None,
1567        error_load_epoch: int = None,
1568        error_class: str = None,
1569        suggestions_prediction: str = None,
1570        suggestion_episodes: List = [None],
1571        suggestion_load_epoch: int = None,
1572        suggestion_classes: List = None,
1573        error_threshold: float = 0.5,
1574        error_threshold_diff: float = 0.1,
1575        error_hysteresis: bool = False,
1576        suggestion_threshold: Union[float, List] = 0.5,
1577        suggestion_threshold_diff: Union[float, List] = 0.1,
1578        suggestion_hysteresis: Union[bool, List] = True,
1579        min_frames_suggestion: int = 10,
1580        min_frames_al: int = 30,
1581        visibility_min_score: float = 0,
1582        visibility_min_frac: float = 0.7,
1583        augment_n: int = 0,
1584        exclude_classes: List = None,
1585        exclude_threshold: Union[float, List] = 0.6,
1586        exclude_threshold_diff: Union[float, List] = 0.1,
1587        exclude_hysteresis: Union[bool, List] = False,
1588        include_classes: List = None,
1589        include_threshold: Union[float, List] = 0.4,
1590        include_threshold_diff: Union[float, List] = 0.1,
1591        include_hysteresis: Union[bool, List] = False,
1592        data_path: str = None,
1593        file_paths: Set = None,
1594        parameters_update: Dict = None,
1595        mode: str = "all",
1596        force: bool = False,
1597        remove_saved_features: bool = False,
1598        cut_annotated: bool = False,
1599        background_threshold: float = None,
1600    ) -> None:
1601        """Create active learning and suggestion files.
1602
1603        Generate predictions with the error and suggestion model and use them to create
1604        suggestion files for the labeling interface. Those files will render as suggested labels
1605        at intervals with high pose estimation quality. Quality here is defined by probability of error
1606        (predicted by the error model) and visibility parameters.
1607
1608        If `error_episode` or `exclude_classes` is not `None`,
1609        an active learning file will be created as well (with frames with high predicted probability of classes
1610        from `exclude_classes` and/or errors excluded from the active learning intervals).
1611
1612        In all three steps (predicting errors, suggesting labels and excluding them from active learning intervals)
1613        you can apply one of three methods.
1614
1615        - **Simple threshold**
1616
1617            Set the `hysteresis` parameter (e.g. `error_hysteresis`) to `False` and the `threshold`
1618            parameter to $\alpha$.
1619            In this case if the probability of a label is predicted to be higher than $\alpha$ the frame will
1620            be considered labeled.
1621
1622        - **Hysteresis threshold**
1623
1624            Set the `hysteresis` parameter (e.g. `error_hysteresis`) to `True`, the `threshold`
1625            parameter to $\alpha$ and the `threshold_diff` parameter to $\beta$.
1626            Now intervals will be marked with a label if the probability of that label for all frames is higher
1627            than $\alpha - \beta$ and at least for one frame in that interval it is higher than $\alpha$.
1628
1629        - **Max hysteresis threshold**
1630
1631            Set the `hysteresis` parameter (e.g. `error_hysteresis`) to `True`, the `threshold`
1632            parameter to $\alpha$ and the `threshold_diff` parameter to `None`.
1633            With this combination intervals are marked with a label if that label is more likely than any other
1634            for all frames in this interval and at for at least one of those frames its probability is higher than
1635            $\alpha$.
1636
1637        Parameters
1638        ----------
1639        suggestions_name : str
1640            the name of the suggestions
1641        error_episode : str, optional
1642            the name of the episode where the error model should be loaded from
1643        error_load_epoch : int, optional
1644            the epoch the error model should be loaded from
1645        error_class : str, optional
1646            the name of the error class (in `error_episode`)
1647        suggestions_prediction : str, optional
1648            the name of the predictions that should be used for the suggestion model
1649        suggestion_episodes : list, optional
1650            the names of the episodes where the suggestion models should be loaded from
1651        suggestion_load_epoch : int, optional
1652            the epoch the suggestion model should be loaded from
1653        suggestion_classes : list, optional
1654            a list of string names of the classes that should be suggested (in `suggestion_episode`)
1655        error_threshold : float, default 0.5
1656            the hard threshold for error prediction
1657        error_threshold_diff : float, default 0.1
1658            the difference between soft and hard thresholds for error prediction (in case hysteresis is used)
1659        error_hysteresis : bool, default False
1660            if True, hysteresis is used for error prediction
1661        suggestion_threshold : float | list, default 0.5
1662            the hard threshold for class prediction (use a list to set different rules for different classes)
1663        suggestion_threshold_diff : float | list, default 0.1
1664            the difference between soft and hard thresholds for class prediction (in case hysteresis is used;
1665            use a list to set different rules for different classes)
1666        suggestion_hysteresis : bool | list, default True
1667            if True, hysteresis is used for class prediction (use a list to set different rules for different classes)
1668        min_frames_suggestion : int, default 10
1669            only actions longer than this number of frames will be suggested
1670        min_frames_al : int, default 30
1671            only active learning intervals longer than this number of frames will be suggested
1672        visibility_min_score : float, default 0
1673            the minimum visibility score for visibility filtering
1674        visibility_min_frac : float, default 0.7
1675            the minimum fraction of visible frames for visibility filtering
1676        augment_n : int, default 10
1677            the number of augmentations to average the predictions over
1678        exclude_classes : list, optional
1679            a list of string names of classes that should be excluded from the active learning intervals
1680        exclude_threshold : float | list, default 0.6
1681            the hard threshold for excluded class prediction (use a list to set different rules for different classes)
1682        exclude_threshold_diff : float | list, default 0.1
1683            the difference between soft and hard thresholds for excluded class prediction (in case hysteresis is used)
1684        exclude_hysteresis : bool | list, default False
1685            if True, hysteresis is used for excluded class prediction (use a list to set different rules for different classes)
1686        include_classes : list, optional
1687            a list of string names of classes that should be included into the active learning intervals
1688        include_threshold : float | list, default 0.6
1689            the hard threshold for included class prediction (use a list to set different rules for different classes)
1690        include_threshold_diff : float | list, default 0.1
1691            the difference between soft and hard thresholds for included class prediction (in case hysteresis is used)
1692        include_hysteresis : bool | list, default False
1693            if True, hysteresis is used for included class prediction (use a list to set different rules for different classes)
1694        data_path : str, optional
1695            the data path to run the prediction for
1696        file_paths : set, optional
1697            a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction
1698            for
1699        parameters_update : dict, optional
1700            the parameters update dictionary
1701        mode : {'all', 'test', 'val', 'train'}
1702            the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
1703        force : bool, default False
1704            if `True` and an episode with name `episode_name` already exists, it will be overwritten (use with caution!)
1705        remove_saved_features : bool, default False
1706            if `True`, the dataset will be deleted.
1707        cut_annotated : bool, default False
1708            if `True`, annotated frames will be cut from the suggestions
1709        background_threshold : float, default 0.5
1710            the threshold for background prediction
1711
1712        """
1713        self._check_suggestions_validity(suggestions_name, force=force)
1714        if any([x is None for x in suggestion_episodes]):
1715            suggestion_episodes = None
1716        if error_episode is None and (
1717            suggestion_episodes is None and suggestions_prediction is None
1718        ):
1719            raise ValueError(
1720                "Both error_episode and suggestion_episode parameters cannot be None at the same time"
1721            )
1722        print(f"SUGGESTION {suggestions_name}")
1723        task = None
1724        if suggestion_classes is None:
1725            suggestion_classes = []
1726        if exclude_classes is None:
1727            exclude_classes = []
1728        if include_classes is None:
1729            include_classes = []
1730        if isinstance(suggestion_threshold, list):
1731            if len(suggestion_threshold) != len(suggestion_classes):
1732                raise ValueError(
1733                    "The suggestion_threshold parameter has to be either a float value or a list of "
1734                    f"float values of the same length as suggestion_classes (got a list of length "
1735                    f"{len(suggestion_threshold)} for {len(suggestion_classes)} classes)"
1736                )
1737        else:
1738            suggestion_threshold = [suggestion_threshold for _ in suggestion_classes]
1739        if isinstance(suggestion_threshold_diff, list):
1740            if len(suggestion_threshold_diff) != len(suggestion_classes):
1741                raise ValueError(
1742                    "The suggestion_threshold_diff parameter has to be either a float value or a list of "
1743                    f"float values of the same length as suggestion_classes (got a list of length "
1744                    f"{len(suggestion_threshold)} for {len(suggestion_classes)} classes)"
1745                )
1746        else:
1747            suggestion_threshold_diff = [
1748                suggestion_threshold_diff for _ in suggestion_classes
1749            ]
1750        if isinstance(suggestion_hysteresis, list):
1751            if len(suggestion_hysteresis) != len(suggestion_classes):
1752                raise ValueError(
1753                    "The suggestion_threshold_diff parameter has to be either a float value or a list of "
1754                    f"float values of the same length as suggestion_classes (got a list of length "
1755                    f"{len(suggestion_hysteresis)} for {len(suggestion_classes)} classes)"
1756                )
1757        else:
1758            suggestion_hysteresis = [suggestion_hysteresis for _ in suggestion_classes]
1759        if isinstance(exclude_threshold, list):
1760            if len(exclude_threshold) != len(exclude_classes):
1761                raise ValueError(
1762                    "The exclude_threshold parameter has to be either a float value or a list of "
1763                    f"float values of the same length as exclude_classes (got a list of length "
1764                    f"{len(exclude_threshold)} for {len(exclude_classes)} classes)"
1765                )
1766        else:
1767            exclude_threshold = [exclude_threshold for _ in exclude_classes]
1768        if isinstance(exclude_threshold_diff, list):
1769            if len(exclude_threshold_diff) != len(exclude_classes):
1770                raise ValueError(
1771                    "The exclude_threshold_diff parameter has to be either a float value or a list of "
1772                    f"float values of the same length as exclude_classes (got a list of length "
1773                    f"{len(exclude_threshold_diff)} for {len(exclude_classes)} classes)"
1774                )
1775        else:
1776            exclude_threshold_diff = [exclude_threshold_diff for _ in exclude_classes]
1777        if isinstance(exclude_hysteresis, list):
1778            if len(exclude_hysteresis) != len(exclude_classes):
1779                raise ValueError(
1780                    "The suggestion_threshold_diff parameter has to be either a float value or a list of "
1781                    f"float values of the same length as suggestion_classes (got a list of length "
1782                    f"{len(exclude_hysteresis)} for {len(exclude_classes)} classes)"
1783                )
1784        else:
1785            exclude_hysteresis = [exclude_hysteresis for _ in exclude_classes]
1786        if isinstance(include_threshold, list):
1787            if len(include_threshold) != len(include_classes):
1788                raise ValueError(
1789                    "The exclude_threshold parameter has to be either a float value or a list of "
1790                    f"float values of the same length as exclude_classes (got a list of length "
1791                    f"{len(include_threshold)} for {len(include_classes)} classes)"
1792                )
1793        else:
1794            include_threshold = [include_threshold for _ in include_classes]
1795        if isinstance(include_threshold_diff, list):
1796            if len(include_threshold_diff) != len(include_classes):
1797                raise ValueError(
1798                    "The exclude_threshold_diff parameter has to be either a float value or a list of "
1799                    f"float values of the same length as exclude_classes (got a list of length "
1800                    f"{len(include_threshold_diff)} for {len(include_classes)} classes)"
1801                )
1802        else:
1803            include_threshold_diff = [include_threshold_diff for _ in include_classes]
1804        if isinstance(include_hysteresis, list):
1805            if len(include_hysteresis) != len(include_classes):
1806                raise ValueError(
1807                    "The suggestion_threshold_diff parameter has to be either a float value or a list of "
1808                    f"float values of the same length as suggestion_classes (got a list of length "
1809                    f"{len(include_hysteresis)} for {len(include_classes)} classes)"
1810                )
1811        else:
1812            include_hysteresis = [include_hysteresis for _ in include_classes]
1813        if (suggestion_episodes is None and suggestions_prediction is None) and len(
1814            exclude_classes
1815        ) > 0:
1816            raise ValueError(
1817                "In order to exclude classes from the active learning intervals you need to set the "
1818                "suggestion_episode parameter"
1819            )
1820
1821        task = None
1822        if error_episode is not None:
1823            task, parameters, mode = self._make_task_prediction(
1824                prediction_name=suggestions_name,
1825                load_episode=error_episode,
1826                parameters_update=parameters_update,
1827                load_epoch=error_load_epoch,
1828                data_path=data_path,
1829                mode=mode,
1830                file_paths=file_paths,
1831                task=task,
1832            )
1833            predicted_error = task.predict(
1834                data=mode,
1835                raw_output=True,
1836                apply_primary_function=True,
1837                augment_n=augment_n,
1838            )
1839        else:
1840            predicted_error = None
1841
1842        if suggestion_episodes is not None:
1843            (
1844                task,
1845                parameters,
1846                mode,
1847                predicted_classes,
1848                inference_time,
1849                behavior_dict,
1850            ) = self._make_prediction(
1851                prediction_name=suggestions_name,
1852                episode_names=suggestion_episodes,
1853                load_epochs=suggestion_load_epoch,
1854                parameters_update=parameters_update,
1855                data_path=data_path,
1856                file_paths=file_paths,
1857                mode=mode,
1858                task=task,
1859            )
1860        elif suggestions_prediction is not None:
1861            with open(
1862                os.path.join(
1863                    self.project_path,
1864                    "results",
1865                    "predictions",
1866                    f"{suggestions_prediction}.pickle",
1867                ),
1868                "rb",
1869            ) as f:
1870                predicted_classes = pickle.load(f)
1871            if parameters_update is None:
1872                parameters_update = {}
1873            parameters_update = self._update(
1874                self._predictions().load_parameters(suggestions_prediction),
1875                parameters_update,
1876            )
1877            parameters_update.pop("model")
1878            if suggestion_episodes is None:
1879                suggestion_episodes = [
1880                    os.path.basename(
1881                        os.path.dirname(
1882                            parameters_update["training"]["checkpoint_path"]
1883                        )
1884                    )
1885                ]
1886            task, parameters, mode = self._make_task_prediction(
1887                "_",
1888                load_episode=None,
1889                parameters_update=parameters_update,
1890                data_path=data_path,
1891                file_paths=file_paths,
1892                mode=mode,
1893            )
1894        else:
1895            predicted_classes = None
1896
1897        if len(suggestion_classes) > 0 and predicted_classes is not None:
1898            suggestions = self._make_suggestions(
1899                task,
1900                predicted_error,
1901                predicted_classes,
1902                suggestion_threshold,
1903                suggestion_threshold_diff,
1904                suggestion_hysteresis,
1905                suggestion_episodes,
1906                suggestion_classes,
1907                error_threshold,
1908                min_frames_suggestion,
1909                min_frames_al,
1910                visibility_min_score,
1911                visibility_min_frac,
1912                cut_annotated=cut_annotated,
1913            )
1914            videos = list(suggestions.keys())
1915            for v_id in videos:
1916                times_dict = defaultdict(lambda: defaultdict(lambda: []))
1917                clips = set()
1918                for c in suggestions[v_id]:
1919                    for start, end, ind in suggestions[v_id][c]:
1920                        times_dict[ind][c].append([start, end, 2])
1921                        clips.add(ind)
1922                clips = list(clips)
1923                times_dict = dict(times_dict)
1924                times = [
1925                    [times_dict[ind][c] for c in suggestion_classes] for ind in clips
1926                ]
1927                save_path = self._suggestion_path(v_id, suggestions_name)
1928                Path(os.path.dirname(save_path)).mkdir(parents=True, exist_ok=True)
1929                with open(save_path, "wb") as f:
1930                    pickle.dump((None, suggestion_classes, clips, times), f)
1931
1932        if (
1933            error_episode is not None
1934            or len(exclude_classes) > 0
1935            or len(include_classes) > 0
1936        ):
1937            al_points = self._make_al_points(
1938                task,
1939                predicted_error,
1940                predicted_classes,
1941                exclude_classes,
1942                exclude_threshold,
1943                exclude_threshold_diff,
1944                exclude_hysteresis,
1945                include_classes,
1946                include_threshold,
1947                include_threshold_diff,
1948                include_hysteresis,
1949                error_episode,
1950                error_class,
1951                suggestion_episodes,
1952                error_threshold,
1953                error_threshold_diff,
1954                error_hysteresis,
1955                min_frames_al,
1956                visibility_min_score,
1957                visibility_min_frac,
1958            )
1959        else:
1960            al_points = self._make_al_points_from_suggestions(
1961                suggestions_name,
1962                task,
1963                predicted_classes,
1964                background_threshold,
1965                visibility_min_score,
1966                visibility_min_frac,
1967                num_behaviors=len(task.behaviors_dict()),
1968            )
1969        save_path = self._al_points_path(suggestions_name)
1970        Path(os.path.dirname(save_path)).mkdir(parents=True, exist_ok=True)
1971        with open(save_path, "wb") as f:
1972            pickle.dump(al_points, f)
1973
1974        meta_parameters = {
1975            "error_episode": error_episode,
1976            "error_load_epoch": error_load_epoch,
1977            "error_class": error_class,
1978            "suggestion_episode": suggestion_episodes,
1979            "suggestion_load_epoch": suggestion_load_epoch,
1980            "suggestion_classes": suggestion_classes,
1981            "error_threshold": error_threshold,
1982            "error_threshold_diff": error_threshold_diff,
1983            "error_hysteresis": error_hysteresis,
1984            "suggestion_threshold": suggestion_threshold,
1985            "suggestion_threshold_diff": suggestion_threshold_diff,
1986            "suggestion_hysteresis": suggestion_hysteresis,
1987            "min_frames_suggestion": min_frames_suggestion,
1988            "min_frames_al": min_frames_al,
1989            "visibility_min_score": visibility_min_score,
1990            "visibility_min_frac": visibility_min_frac,
1991            "augment_n": augment_n,
1992            "exclude_classes": exclude_classes,
1993            "exclude_threshold": exclude_threshold,
1994            "exclude_threshold_diff": exclude_threshold_diff,
1995            "exclude_hysteresis": exclude_hysteresis,
1996        }
1997        self._save_suggestions(suggestions_name, {}, meta_parameters)
1998        if data_path is not None or file_paths is not None or remove_saved_features:
1999            self._remove_stores(parameters)
2000        print(f"\n")
2001
2002    def _generate_similarity_score(
2003        self,
2004        prediction_name: str,
2005        target_video_id: str,
2006        target_clip: str,
2007        target_start: int,
2008        target_end: int,
2009    ) -> Dict:
2010        with open(
2011            os.path.join(
2012                self.project_path,
2013                "results",
2014                "predictions",
2015                f"{prediction_name}.pickle",
2016            ),
2017            "rb",
2018        ) as f:
2019            prediction = pickle.load(f)
2020        target = prediction[target_video_id][target_clip][:, target_start:target_end]
2021        score_dict = defaultdict(lambda: {})
2022        for video_id in prediction:
2023            for clip_id in prediction[video_id]:
2024                score_dict[video_id][clip_id] = torch.cdist(
2025                    target.T, prediction[video_id][score_dict].T
2026                ).min(0)
2027        return score_dict
2028
2029    def _suggest_intervals_from_dict(self, score_dict, min_length, n_intervals) -> Dict:
2030        """Suggest intervals from a score dictionary.
2031
2032        Parameters
2033        ----------
2034        score_dict : dict
2035            a dictionary containing scores for intervals
2036        min_length : int
2037            minimum length of intervals to suggest
2038        n_intervals : int
2039            number of intervals to suggest
2040
2041        Returns
2042        -------
2043        intervals : dict
2044            a dictionary of suggested intervals
2045
2046        """
2047        interval_address = {}
2048        interval_value = {}
2049        s = 0
2050        n = 0
2051        for video_id, video_dict in score_dict.items():
2052            for clip_id, value in video_dict.items():
2053                s += value.mean()
2054                n += 1
2055        mean_value = s / n
2056        alpha = 1.75
2057        for it in range(10):
2058            id = 0
2059            interval_address = {}
2060            interval_value = {}
2061            for video_id, video_dict in score_dict.items():
2062                for clip_id, value in video_dict.items():
2063                    res_indices_start, res_indices_end = apply_threshold(
2064                        value,
2065                        threshold=(2 - alpha * (0.9**it)) * mean_value,
2066                        low=True,
2067                        error_mask=None,
2068                        min_frames=min_length,
2069                        smooth_interval=0,
2070                    )
2071                    for start, end in zip(res_indices_start, res_indices_end):
2072                        interval_address[id] = [video_id, clip_id, start, end]
2073                        interval_value[id] = score_dict[video_id][clip_id][
2074                            start:end
2075                        ].mean()
2076                        id += 1
2077            if len(interval_address) >= n_intervals:
2078                break
2079        if len(interval_address) < n_intervals:
2080            warnings.warn(
2081                f"Could not get {n_intervals} intervals from the data, saving the result with {len(interval_address)} intervals"
2082            )
2083        sorted_intervals = sorted(
2084            interval_value.items(), key=lambda x: x[1], reverse=True
2085        )
2086        output_intervals = [
2087            interval_address[x[0]]
2088            for x in sorted_intervals[: min(len(sorted_intervals), n_intervals)]
2089        ]
2090        output = defaultdict(lambda: [])
2091        for video_id, clip_id, start, end in output_intervals:
2092            output[video_id].append([start, end, clip_id])
2093        return output
2094
2095    def suggest_intervals_with_similarity(
2096        self,
2097        suggestions_name: str,
2098        prediction_name: str,
2099        target_video_id: str,
2100        target_clip: str,
2101        target_start: int,
2102        target_end: int,
2103        min_length: int = 60,
2104        n_intervals: int = 5,
2105        force: bool = False,
2106    ):
2107        """
2108        Suggest intervals based on similarity to a target interval.
2109
2110        Parameters
2111        ----------
2112        suggestions_name : str
2113            Name of the suggestion.
2114        prediction_name : str
2115            Name of the prediction to use.
2116        target_video_id : str
2117            Video id of the target interval.
2118        target_clip : str
2119            Clip id of the target interval.
2120        target_start : int
2121            Start frame of the target interval.
2122        target_end : int
2123            End frame of the target interval.
2124        min_length : int, default 60
2125            Minimum length of the suggested intervals.
2126        n_intervals : int, default 5
2127            Number of suggested intervals.
2128        force : bool, default False
2129            If True, the suggestion is overwritten if it already exists.
2130
2131        """
2132        self._check_suggestions_validity(suggestions_name, force=force)
2133        print(f"SUGGESTION {suggestions_name}")
2134        score_dict = self._generate_similarity_score(
2135            prediction_name, target_video_id, target_clip, target_start, target_end
2136        )
2137        intervals = self._suggest_intervals_from_dict(
2138            score_dict, min_length, n_intervals
2139        )
2140        suggestions_path = os.path.join(
2141            self.project_path,
2142            "results",
2143            "suggestions",
2144            suggestions_name,
2145        )
2146        if not os.path.exists(suggestions_path):
2147            os.mkdir(suggestions_path)
2148        with open(
2149            os.path.join(suggestions_path, f"{suggestions_name}_al_points.pickle"), "wb"
2150        ) as f:
2151            pickle.dump(intervals, f)
2152        meta_parameters = {
2153            "prediction_name": prediction_name,
2154            "min_frames_suggestion": min_length,
2155            "n_intervals": n_intervals,
2156            "target_clip": target_clip,
2157            "target_start": target_start,
2158            "target_end": target_end,
2159        }
2160        self._save_suggestions(suggestions_name, {}, meta_parameters)
2161        print("\n")
2162
2163    def suggest_intervals_with_uncertainty(
2164        self,
2165        suggestions_name: str,
2166        episode_names: List,
2167        load_epochs: List = None,
2168        classes: List = None,
2169        n_frames: int = 10000,
2170        method: str = "least_confidence",
2171        min_length: int = 60,
2172        augment_n: int = 0,
2173        data_path: str = None,
2174        file_paths: Set = None,
2175        parameters_update: Dict = None,
2176        mode: str = "all",
2177        force: bool = False,
2178        remove_saved_features: bool = False,
2179    ) -> None:
2180        """Generate an active learning file based on model uncertainty.
2181
2182        If you provide several episode names, the predicted probabilities will be averaged.
2183
2184        Parameters
2185        ----------
2186        suggestions_name : str
2187            the name of the suggestion
2188        episode_names : list
2189            a list of string episode names to load the models from
2190        load_epochs : list, optional
2191            a list of epoch indices to load the models from (if `None`, the last ones will be used)
2192        classes : list, optional
2193            a list of classes to look at (by default all)
2194        n_frames : int, default 10000
2195            the threshold total number of frames in the suggested intervals (in the end result it will most likely
2196            be slightly larger; it will only be smaller if the algorithm fails to find enough intervals
2197            with the set parameters)
2198        method : {"least_confidence", "entropy"}
2199            the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if
2200            `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`)
2201        min_length : int, default 60
2202            the minimum number of frames in one interval
2203        augment_n : int, default 0
2204            the number of augmentations to average the predictions over
2205        data_path : str, optional
2206            the path to a data folder (by default, the project data is used)
2207        file_paths : set, optional
2208            a list of file paths (by default, the project data is used)
2209        parameters_update : dict, optional
2210            a dictionary of parameter updates
2211        mode : {"test", "val", "train", "all"}
2212            the subset of the data to make the prediction for (forced to 'all' if `data_path` is not `None`;
2213            by default set to `'test'` if the test subset if not empty, or to `'val'` otherwise)
2214        force : bool, default False
2215            if `True`, existing suggestions with the same name will be overwritten
2216        remove_saved_features : bool, default False
2217            if `True`, the dataset will be deleted after the computation
2218
2219        """
2220        self._check_suggestions_validity(suggestions_name, force=force)
2221        print(f"SUGGESTION {suggestions_name}")
2222        task, parameters, mode, predicted, inference_time, behavior_dict = (
2223            self._make_prediction(
2224                suggestions_name,
2225                episode_names,
2226                load_epochs,
2227                parameters_update,
2228                data_path=data_path,
2229                file_paths=file_paths,
2230                mode=mode,
2231                augment_n=augment_n,
2232                evaluate=False,
2233            )
2234        )
2235        if classes is None:
2236            classes = behavior_dict.values()
2237        episode = self._episodes().get_runs(episode_names[0])[0]
2238        score_tensors = task.generate_uncertainty_score(
2239            classes,
2240            augment_n,
2241            method,
2242            predicted,
2243            self._episode(episode).get_behaviors_dict(),
2244        )
2245        intervals = self._suggest_intervals(
2246            task.dataset(mode), score_tensors, n_frames, min_length
2247        )
2248        for k, v in intervals.items():
2249            l = sum([x[1] - x[0] for x in v])
2250            print(f"{k}: {len(v)} ({l})")
2251        if remove_saved_features:
2252            self._remove_stores(parameters)
2253        suggestions_path = os.path.join(
2254            self.project_path,
2255            "results",
2256            "suggestions",
2257            suggestions_name,
2258        )
2259        if not os.path.exists(suggestions_path):
2260            os.mkdir(suggestions_path)
2261        with open(
2262            os.path.join(suggestions_path, f"{suggestions_name}_al_points.pickle"), "wb"
2263        ) as f:
2264            pickle.dump(intervals, f)
2265        meta_parameters = {
2266            "suggestion_episode": episode_names,
2267            "suggestion_load_epoch": load_epochs,
2268            "suggestion_classes": classes,
2269            "min_frames_suggestion": min_length,
2270            "augment_n": augment_n,
2271            "method": method,
2272            "num_frames": n_frames,
2273        }
2274        self._save_suggestions(suggestions_name, {}, meta_parameters)
2275        print("\n")
2276
2277    def suggest_intervals_with_bald(
2278        self,
2279        suggestions_name: str,
2280        episode_name: str,
2281        load_epoch: int = None,
2282        classes: List = None,
2283        n_frames: int = 10000,
2284        num_models: int = 10,
2285        kernel_size: int = 11,
2286        min_length: int = 60,
2287        augment_n: int = 0,
2288        data_path: str = None,
2289        file_paths: Set = None,
2290        parameters_update: Dict = None,
2291        mode: str = "all",
2292        force: bool = False,
2293        remove_saved_features: bool = False,
2294    ):
2295        """Generate an active learning file based on Bayesian Active Learning by Disagreement.
2296
2297        Parameters
2298        ----------
2299        suggestions_name : str
2300            the name of the suggestion
2301        episode_name : str
2302            the name of the episode to load the model from
2303        load_epoch : int, optional
2304            the index of the epoch to load the model from (if `None`, the last one will be used)
2305        classes : list, optional
2306            a list of classes to look at (by default all)
2307        n_frames : int, default 10000
2308            the threshold total number of frames in the suggested intervals (in the end result it will most likely
2309            be slightly larger; it will only be smaller if the algorithm fails to find enough intervals
2310            with the set parameters)
2311        num_models : int, default 10
2312            the number of dropout masks to apply
2313        kernel_size : int, default 11
2314            the size of the smoothing kernel applied to the discrete results
2315        min_length : int, default 60
2316            the minimum number of frames in one interval
2317        augment_n : int, default 0
2318            the number of augmentations to average the predictions over
2319        data_path : str, optional
2320            the path to a data folder (by default, the project data is used)
2321        file_paths : set, optional
2322            a list of file paths (by default, the project data is used)
2323        parameters_update : dict, optional
2324            a dictionary of parameter updates
2325        mode : {"test", "val", "train", "all"}
2326            the subset of the data to make the prediction for (forced to 'all' if `data_path` is not `None`;
2327            by default set to `'test'` if the test subset if not empty, or to `'val'` otherwise)
2328        force : bool, default False
2329            if `True`, existing suggestions with the same name will be overwritten
2330        remove_saved_features : bool, default False
2331            if `True`, the dataset will be deleted after the computation
2332
2333        """
2334        self._check_suggestions_validity(suggestions_name, force=force)
2335        print(f"SUGGESTION {suggestions_name}")
2336        task, parameters, mode = self._make_task_prediction(
2337            suggestions_name,
2338            episode_name,
2339            parameters_update,
2340            load_epoch,
2341            data_path=data_path,
2342            file_paths=file_paths,
2343            mode=mode,
2344        )
2345        if classes is None:
2346            classes = list(task.behaviors_dict().values())
2347        score_tensors = task.generate_bald_score(
2348            classes, augment_n, num_models, kernel_size
2349        )
2350        intervals = self._suggest_intervals(
2351            task.dataset(mode), score_tensors, n_frames, min_length
2352        )
2353        if remove_saved_features:
2354            self._remove_stores(parameters)
2355        suggestions_path = os.path.join(
2356            self.project_path,
2357            "results",
2358            "suggestions",
2359            suggestions_name,
2360        )
2361        if not os.path.exists(suggestions_path):
2362            os.mkdir(suggestions_path)
2363        with open(
2364            os.path.join(suggestions_path, f"{suggestions_name}_al_points.pickle"), "wb"
2365        ) as f:
2366            pickle.dump(intervals, f)
2367        meta_parameters = {
2368            "suggestion_episode": episode_name,
2369            "suggestion_load_epoch": load_epoch,
2370            "suggestion_classes": classes,
2371            "min_frames_suggestion": min_length,
2372            "augment_n": augment_n,
2373            "method": f"BALD:{num_models}",
2374            "num_frames": n_frames,
2375        }
2376        self._save_suggestions(suggestions_name, {}, meta_parameters)
2377        print("\n")
2378
2379    def list_episodes(
2380        self,
2381        episode_names: List = None,
2382        value_filter: str = "",
2383        display_parameters: List = None,
2384        print_results: bool = True,
2385    ) -> pd.DataFrame:
2386        """Get a filtered pandas dataframe with episode metadata.
2387
2388        Parameters
2389        ----------
2390        episode_names : list
2391            a list of strings of episode names
2392        value_filter : str
2393            a string of filters to apply; of this general structure:
2394            'group_name1/par_name1::(</>/<=/>=/=)value1,group_name2/par_name2::(</>/<=/>=/=)value2', e.g.
2395            'data/overlap::=50,results/recall::>0.5,data/feature_extraction::=kinematic,meta/training_time::>=00:00:10'
2396        display_parameters : list
2397            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
2398        print_results : bool, default True
2399            if True, the result will be printed to standard output
2400
2401        Returns
2402        -------
2403        pd.DataFrame
2404            the filtered dataframe
2405
2406        """
2407        episodes = self._episodes().list_episodes(
2408            episode_names, value_filter, display_parameters
2409        )
2410        if print_results:
2411            print("TRAINING EPISODES")
2412            print(episodes)
2413            print("\n")
2414        return episodes
2415
2416    def list_predictions(
2417        self,
2418        episode_names: List = None,
2419        value_filter: str = "",
2420        display_parameters: List = None,
2421        print_results: bool = True,
2422    ) -> pd.DataFrame:
2423        """Get a filtered pandas dataframe with prediction metadata.
2424
2425        Parameters
2426        ----------
2427        episode_names : list
2428            a list of strings of episode names
2429        value_filter : str
2430            a string of filters to apply; of this general structure:
2431            'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g.
2432            'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic'
2433        display_parameters : list
2434            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
2435        print_results : bool, default True
2436            if True, the result will be printed to standard output
2437
2438        Returns
2439        -------
2440        pd.DataFrame
2441            the filtered dataframe
2442
2443        """
2444        predictions = self._predictions().list_episodes(
2445            episode_names, value_filter, display_parameters
2446        )
2447        if print_results:
2448            print("PREDICTIONS")
2449            print(predictions)
2450            print("\n")
2451        return predictions
2452
2453    def list_suggestions(
2454        self,
2455        suggestions_names: List = None,
2456        value_filter: str = "",
2457        display_parameters: List = None,
2458        print_results: bool = True,
2459    ) -> pd.DataFrame:
2460        """Get a filtered pandas dataframe with prediction metadata.
2461
2462        Parameters
2463        ----------
2464        suggestions_names : list
2465            a list of strings of suggestion names
2466        value_filter : str
2467            a string of filters to apply; of this general structure:
2468            'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g.
2469            'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic'
2470        display_parameters : list
2471            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
2472        print_results : bool, default True
2473            if True, the result will be printed to standard output
2474
2475        Returns
2476        -------
2477        pd.DataFrame
2478            the filtered dataframe
2479
2480        """
2481        suggestions = self._suggestions().list_episodes(
2482            suggestions_names, value_filter, display_parameters
2483        )
2484        if print_results:
2485            print("SUGGESTIONS")
2486            print(suggestions)
2487            print("\n")
2488        return suggestions
2489
2490    def list_searches(
2491        self,
2492        search_names: List = None,
2493        value_filter: str = "",
2494        display_parameters: List = None,
2495        print_results: bool = True,
2496    ) -> pd.DataFrame:
2497        """Get a filtered pandas dataframe with hyperparameter search metadata.
2498
2499        Parameters
2500        ----------
2501        search_names : list
2502            a list of strings of search names
2503        value_filter : str
2504            a string of filters to apply; of this general structure:
2505            'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g.
2506            'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic'
2507        display_parameters : list
2508            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
2509        print_results : bool, default True
2510            if True, the result will be printed to standard output
2511
2512        Returns
2513        -------
2514        pd.DataFrame
2515            the filtered dataframe
2516
2517        """
2518        searches = self._searches().list_episodes(
2519            search_names, value_filter, display_parameters
2520        )
2521        if print_results:
2522            print("SEARCHES")
2523            print(searches)
2524            print("\n")
2525        return searches
2526
2527    def get_best_parameters(
2528        self,
2529        search_name: str,
2530        round_to_binary: List = None,
2531    ):
2532        """Get the best parameters found by a search.
2533
2534        Parameters
2535        ----------
2536        search_name : str
2537            the name of the search
2538        round_to_binary : list, default None
2539            a list of parameters to round to binary values
2540
2541        Returns
2542        -------
2543        best_params : dict
2544            a dictionary of the best parameters where the keys are in '{group}/{name}' format
2545
2546        """
2547        params, model = self._searches().get_best_params(
2548            search_name, round_to_binary=round_to_binary
2549        )
2550        params = self._update(params, {"general": {"model_name": model}})
2551        return params
2552
2553    def list_best_parameters(
2554        self, search_name: str, print_results: bool = True
2555    ) -> Dict:
2556        """Get the raw dictionary of best parameters found by a search.
2557
2558        Parameters
2559        ----------
2560        search_name : str
2561            the name of the search
2562        print_results : bool, default True
2563            if True, the result will be printed to standard output
2564
2565        Returns
2566        -------
2567        best_params : dict
2568            a dictionary of the best parameters where the keys are in '{group}/{name}' format
2569
2570        """
2571        params = self._searches().get_best_params_raw(search_name)
2572        if print_results:
2573            print(f"SEARCH RESULTS {search_name}")
2574            for k, v in params.items():
2575                print(f"{k}: {v}")
2576            print("\n")
2577        return params
2578
2579    def plot_episodes(
2580        self,
2581        episode_names: List,
2582        metrics: List | str,
2583        modes: List | str = None,
2584        title: str = None,
2585        episode_labels: List = None,
2586        save_path: str = None,
2587        add_hlines: List = None,
2588        epoch_limits: List = None,
2589        colors: List = None,
2590        add_highpoint_hlines: bool = False,
2591        remove_box: bool = False,
2592        font_size: float = None,
2593        linewidth: float = None,
2594        return_ax: bool = False,
2595    ) -> None:
2596        """Plot episode training curves.
2597
2598        Parameters
2599        ----------
2600        episode_names : list
2601            a list of episode names to plot; to plot to episodes in one line combine them in a list
2602            (e.g. ['episode1', ['episode2', 'episode3']] to plot episode2 and episode3 as one experiment)
2603        metrics : list
2604            a list of metric to plot
2605        modes : list, optional
2606            a list of modes to plot ('train' and/or 'val'; `['val']` by default)
2607        title : str, optional
2608            title for the plot
2609        episode_labels : list, optional
2610            a list of strings used to label the curves (has to be the same length as episode_names)
2611        save_path : str, optional
2612            the path to save the resulting plot
2613        add_hlines : list, optional
2614            a list of float values (or (value, label) tuples) to mark with horizontal lines
2615        epoch_limits : list, optional
2616            a list of (min, max) tuples to set the x-axis limits for each episode
2617        colors: list, optional
2618            a list of matplotlib colors
2619        add_highpoint_hlines : bool, default False
2620            if `True`, horizontal lines will be added at the highest value of each episode
2621        """
2622
2623        if isinstance(metrics, str):
2624            metrics = [metrics]
2625        if isinstance(modes, str):
2626            modes = [modes]
2627
2628        if font_size is not None:
2629            font = {"size": font_size}
2630            rc("font", **font)
2631        if modes is None:
2632            modes = ["val"]
2633        if add_hlines is None:
2634            add_hlines = []
2635        logs = []
2636        epochs = []
2637        labels = []
2638        if episode_labels is not None:
2639            assert len(episode_labels) == len(episode_names)
2640        for name_i, name in enumerate(episode_names):
2641            log_params = product(metrics, modes)
2642            for metric, mode in log_params:
2643                if episode_labels is not None:
2644                    label = episode_labels[name_i]
2645                else:
2646                    label = deepcopy(name)
2647                if len(modes) != 1:
2648                    label += f"_{mode}"
2649                if len(metrics) != 1:
2650                    label += f"_{metric}"
2651                labels.append(label)
2652                if isinstance(name, Iterable) and not isinstance(name, str):
2653                    epoch_list = defaultdict(lambda: [])
2654                    multi_logs = defaultdict(lambda: [])
2655                    for i, n in enumerate(name):
2656                        runs = self._episodes().get_runs(n)
2657                        if len(runs) > 1:
2658                            for run in runs:
2659                                if "::" in run:
2660                                    index = run.split("::")[-1]
2661                                else:
2662                                    index = run.split("#")[-1]
2663                                if multi_logs[index] == []:
2664                                    if multi_logs["null"] is None:
2665                                        raise RuntimeError(
2666                                            "The run indices are not consistent across episodes!"
2667                                        )
2668                                    else:
2669                                        multi_logs[index] += multi_logs["null"]
2670                                multi_logs[index] += list(
2671                                    self._episode(run).get_metric_log(mode, metric)
2672                                )
2673                                start = (
2674                                    0
2675                                    if len(epoch_list[index]) == 0
2676                                    else epoch_list[index][-1]
2677                                )
2678                                epoch_list[index] += [
2679                                    x + start
2680                                    for x in self._episode(run).get_epoch_list(mode)
2681                                ]
2682                            multi_logs["null"] = None
2683                        else:
2684                            if len(multi_logs.keys()) > 1:
2685                                raise RuntimeError(
2686                                    "Cannot plot a single-run episode after a multi-run episode!"
2687                                )
2688                            multi_logs["null"] += list(
2689                                self._episode(n).get_metric_log(mode, metric)
2690                            )
2691                            start = (
2692                                0
2693                                if len(epoch_list["null"]) == 0
2694                                else epoch_list["null"][-1]
2695                            )
2696                            epoch_list["null"] += [
2697                                x + start for x in self._episode(n).get_epoch_list(mode)
2698                            ]
2699                    if len(multi_logs.keys()) == 1:
2700                        log = multi_logs["null"]
2701                        epochs.append(epoch_list["null"])
2702                    else:
2703                        log = tuple([v for k, v in multi_logs.items() if k != "null"])
2704                        epochs.append(
2705                            tuple([v for k, v in epoch_list.items() if k != "null"])
2706                        )
2707                else:
2708                    runs = self._episodes().get_runs(name)
2709                    if len(runs) > 1:
2710                        log = []
2711                        for run in runs:
2712                            tracked_metrics = self._episode(run).get_metrics()
2713                            if metric in tracked_metrics:
2714                                log.append(
2715                                    list(
2716                                        self._episode(run).get_metric_log(mode, metric)
2717                                    )
2718                                )
2719                            else:
2720                                relevant = []
2721                                for m in tracked_metrics:
2722                                    m_split = m.split("_")
2723                                    if (
2724                                        "_".join(m_split[:-1]) == metric
2725                                        and m_split[-1].isnumeric()
2726                                    ):
2727                                        relevant.append(m)
2728                                if len(relevant) == 0:
2729                                    raise ValueError(
2730                                        f"The {metric} metric was not tracked at {run}"
2731                                    )
2732                                arr = 0
2733                                for m in relevant:
2734                                    arr += self._episode(run).get_metric_log(mode, m)
2735                                arr /= len(relevant)
2736                                log.append(list(arr))
2737                        log = tuple(log)
2738                        epochs.append(
2739                            tuple(
2740                                [
2741                                    self._episode(run).get_epoch_list(mode)
2742                                    for run in runs
2743                                ]
2744                            )
2745                        )
2746                    else:
2747                        tracked_metrics = self._episode(name).get_metrics()
2748                        if metric in tracked_metrics:
2749                            log = list(self._episode(name).get_metric_log(mode, metric))
2750                        else:
2751                            relevant = []
2752                            for m in tracked_metrics:
2753                                m_split = m.split("_")
2754                                if (
2755                                    "_".join(m_split[:-1]) == metric
2756                                    and m_split[-1].isnumeric()
2757                                ):
2758                                    relevant.append(m)
2759                            if len(relevant) == 0:
2760                                raise ValueError(
2761                                    f"The {metric} metric was not tracked at {name}"
2762                                )
2763                            arr = 0
2764                            for m in relevant:
2765                                arr += self._episode(name).get_metric_log(mode, m)
2766                            arr /= len(relevant)
2767                            log = list(arr)
2768                        epochs.append(self._episode(name).get_epoch_list(mode))
2769                logs.append(log)
2770        # if episode_labels is not None:
2771        #     print(f'{len(episode_labels)=}, {len(logs)=}')
2772        #     if len(episode_labels) != len(logs):
2773
2774        #         raise ValueError(
2775        #             f"The length of episode_labels ({len(episode_labels)}) has to be equal to the length of "
2776        #             f"curves ({len(logs)})!"
2777        #         )
2778        #     else:
2779        #         labels = episode_labels
2780        if colors is None:
2781            colors = cm.rainbow(np.linspace(0, 1, len(logs)))
2782        if len(colors) != len(logs):
2783            raise ValueError(
2784                "The length of colors has to be equal to the length of curves (metrics * modes * episode_names)!"
2785            )
2786        f, ax = plt.subplots()
2787        length = 0
2788        for log, label, color, epoch_list in zip(logs, labels, colors, epochs):
2789            if type(log) is list:
2790                if len(log) > length:
2791                    length = len(log)
2792                ax.plot(
2793                    epoch_list,
2794                    log,
2795                    label=label,
2796                    color=color,
2797                )
2798                if add_highpoint_hlines:
2799                    ax.axhline(np.max(log), linestyle="dashed", color=color)
2800            else:
2801                for l, xx in zip(log, epoch_list):
2802                    if len(l) > length:
2803                        length = len(l)
2804                    ax.plot(
2805                        xx,
2806                        l,
2807                        color=color,
2808                        alpha=0.2,
2809                    )
2810                if not all([len(x) == len(log[0]) for x in log]):
2811                    warnings.warn(
2812                        f"Got logs with unequal lengths in parallel runs for {label}"
2813                    )
2814                    log = list(log)
2815                    epoch_list = list(epoch_list)
2816                    for i, x in enumerate(epoch_list):
2817                        to_remove = []
2818                        for j, y in enumerate(x[1:]):
2819                            if y <= x[j - 1]:
2820                                y_ind = x.index(y)
2821                                to_remove += list(range(y_ind, j))
2822                        epoch_list[i] = [
2823                            y for j, y in enumerate(x) if j not in to_remove
2824                        ]
2825                        log[i] = [y for j, y in enumerate(log[i]) if j not in to_remove]
2826                    length = min([len(x) for x in log])
2827                    for i in range(len(log)):
2828                        log[i] = log[i][:length]
2829                        epoch_list[i] = epoch_list[i][:length]
2830                    if not all([x == epoch_list[0] for x in epoch_list]):
2831                        raise RuntimeError(
2832                            f"Got different epoch indices in parallel runs for {label}"
2833                        )
2834                mean = np.array(log).mean(0)
2835                ax.plot(
2836                    epoch_list[0],
2837                    mean,
2838                    label=label,
2839                    color=color,
2840                    linewidth=linewidth,
2841                )
2842                if add_highpoint_hlines:
2843                    ax.axhline(np.max(mean), linestyle="dashed", color=color)
2844        for x in add_hlines:
2845            label = None
2846            if isinstance(x, Iterable):
2847                x, label = x
2848            ax.axhline(x, label=label)
2849            ax.set_xlim((0, length))
2850
2851        ax.legend()
2852        ax.set_xlabel("epochs")
2853        if len(metrics) == 1:
2854            ax.set_ylabel(metrics[0])
2855        else:
2856            ax.set_ylabel("value")
2857        if title is None:
2858            if len(episode_names) == 1:
2859                title = episode_names[0]
2860            elif len(metrics) == 1:
2861                title = metrics[0]
2862        if epoch_limits is not None:
2863            ax.set_xlim(epoch_limits)
2864        if title is not None:
2865            ax.set_title(title)
2866        if remove_box:
2867            ax.box(False)
2868        if return_ax:
2869            return ax
2870        if save_path is not None:
2871            plt.savefig(save_path)
2872        plt.show()
2873
2874    def update_parameters(
2875        self,
2876        parameters_update: Dict = None,
2877        load_search: str = None,
2878        load_parameters: List = None,
2879        round_to_binary: List = None,
2880    ) -> None:
2881        """Update the parameters in the project config files.
2882
2883        Parameters
2884        ----------
2885        parameters_update : dict, optional
2886            a dictionary of parameter updates
2887        load_search : str, optional
2888            the name of hyperparameter search results to load to config
2889        load_parameters : list, optional
2890            a list of lists of string names of the parameters to load from the searches
2891        round_to_binary : list, optional
2892            a list of string names of the loaded parameters that should be rounded to the nearest power of two
2893
2894        """
2895        keys = [
2896            "general",
2897            "losses",
2898            "metrics",
2899            "ssl",
2900            "training",
2901            "data",
2902        ]
2903        parameters = self._read_parameters(catch_blanks=False)
2904        if parameters_update is not None:
2905            model_params = (
2906                parameters_update.pop("model") if "model" in parameters_update else None
2907            )
2908            feat_params = (
2909                parameters_update.pop("features")
2910                if "features" in parameters_update
2911                else None
2912            )
2913            aug_params = (
2914                parameters_update.pop("augmentations")
2915                if "augmentations" in parameters_update
2916                else None
2917            )
2918
2919            parameters = self._update(parameters, parameters_update)
2920            model_name = parameters["general"]["model_name"]
2921            parameters["model"] = self._open_yaml(
2922                os.path.join(self.project_path, "config", "model", f"{model_name}.yaml")
2923            )
2924            if model_params is not None:
2925                parameters["model"] = self._update(parameters["model"], model_params)
2926            feat_name = parameters["general"]["feature_extraction"]
2927            parameters["features"] = self._open_yaml(
2928                os.path.join(
2929                    self.project_path, "config", "features", f"{feat_name}.yaml"
2930                )
2931            )
2932            if feat_params is not None:
2933                parameters["features"] = self._update(
2934                    parameters["features"], feat_params
2935                )
2936            aug_name = options.extractor_to_transformer[
2937                parameters["general"]["feature_extraction"]
2938            ]
2939            parameters["augmentations"] = self._open_yaml(
2940                os.path.join(
2941                    self.project_path, "config", "augmentations", f"{aug_name}.yaml"
2942                )
2943            )
2944            if aug_params is not None:
2945                parameters["augmentations"] = self._update(
2946                    parameters["augmentations"], aug_params
2947                )
2948        if load_search is not None:
2949            parameters_update, model_name = self._searches().get_best_params(
2950                load_search, load_parameters, round_to_binary
2951            )
2952            parameters["general"]["model_name"] = model_name
2953            parameters["model"] = self._open_yaml(
2954                os.path.join(self.project_path, "config", "model", f"{model_name}.yaml")
2955            )
2956            parameters = self._update(parameters, parameters_update)
2957        for key in keys:
2958            with open(
2959                os.path.join(self.project_path, "config", f"{key}.yaml"),
2960                "w",
2961                encoding="utf-8",
2962            ) as f:
2963                YAML().dump(parameters[key], f)
2964        model_name = parameters["general"]["model_name"]
2965        model_path = os.path.join(
2966            self.project_path, "config", "model", f"{model_name}.yaml"
2967        )
2968        with open(model_path, "w", encoding="utf-8") as f:
2969            YAML().dump(parameters["model"], f)
2970        features_name = parameters["general"]["feature_extraction"]
2971        features_path = os.path.join(
2972            self.project_path, "config", "features", f"{features_name}.yaml"
2973        )
2974        with open(features_path, "w", encoding="utf-8") as f:
2975            YAML().dump(parameters["features"], f)
2976        aug_name = options.extractor_to_transformer[features_name]
2977        aug_path = os.path.join(
2978            self.project_path, "config", "augmentations", f"{aug_name}.yaml"
2979        )
2980        with open(aug_path, "w", encoding="utf-8") as f:
2981            YAML().dump(parameters["augmentations"], f)
2982
2983    def get_summary(
2984        self,
2985        episode_names: list,
2986        method: str = "last",
2987        average: int = 1,
2988        metrics: List = None,
2989        return_values: bool = False,
2990    ) -> Dict:
2991        """Get a summary of episode statistics.
2992
2993        If an episode has multiple runs, the statistics will be aggregated over all of them.
2994
2995        Parameters
2996        ----------
2997        episode_names : str
2998            the names of the episodes
2999        method : ["best", "last"]
3000            the method for choosing the epochs
3001        average : int, default 1
3002            the number of epochs to average over (for each run)
3003        metrics : list, optional
3004            a list of metrics
3005
3006        Returns
3007        -------
3008        statistics : dict
3009            a nested dictionary where first-level keys are metric names and second-level keys are 'mean' for the mean
3010            and 'std' for the standard deviation
3011
3012        """
3013        runs = []
3014        for episode_name in episode_names:
3015            runs_ep = self._episodes().get_runs(episode_name)
3016            if len(runs_ep) == 0:
3017                raise RuntimeError(
3018                    f"There is no {episode_name} episode in the project memory"
3019                )
3020            runs += runs_ep
3021        if metrics is None:
3022            metrics = self._episode(runs[0]).get_metrics()
3023
3024        values = {m: [] for m in metrics}
3025        for run in runs:
3026            for m in metrics:
3027                log = self._episode(run).get_metric_log(mode="val", metric_name=m)
3028                if method == "best":
3029                    log = sorted(log)
3030                    values[m] += list(log[-average:])
3031                elif method == "last":
3032                    if len(log) == 0:
3033                        episodes = self._episodes().data
3034                        if average == 1 and ("results", m) in episodes.columns:
3035                            values[m] += [episodes.loc[run, ("results", m)]]
3036                        else:
3037                            raise RuntimeError(f"Did not find {m} metric for {run} run")
3038                    values[m] += list(log[-average:])
3039                elif method.startswith("epoch"):
3040                    epoch = int(method[5:]) - 1
3041                    pars = self._episodes().load_parameters(run)
3042                    step = int(pars["training"]["validation_interval"])
3043                    values[m] += [log[epoch // step]]
3044                else:
3045                    raise ValueError(
3046                        f"The {method} method is not recognized! Please choose from ['last', 'best', 'epoch...']"
3047                    )
3048        statistics = defaultdict(lambda: {})
3049        for m, v in values.items():
3050            statistics[m]["mean"] = np.mean(v)
3051            statistics[m]["std"] = np.std(v)
3052        print(f"SUMMARY {episode_names}")
3053        for m, v in statistics.items():
3054            print(f'{m}: mean {v["mean"]:.3f}, std {v["std"]:.3f}')
3055        print("\n")
3056
3057        return (dict(statistics), values) if return_values else dict(statistics)
3058
3059    @staticmethod
3060    def remove_project(name: str, projects_path: str = None) -> None:
3061        """Remove all project files and experiment records and results.
3062
3063        Parameters
3064        ----------
3065        name : str
3066            the name of the project to remove
3067        projects_path : str, optional
3068            the path to the projects directory (by default the home DLC2Action directory)
3069
3070        """
3071        if projects_path is None:
3072            projects_path = os.path.join(str(Path.home()), "DLC2Action")
3073        project_path = os.path.join(projects_path, name)
3074        if os.path.exists(project_path):
3075            shutil.rmtree(project_path)
3076
3077    def remove_saved_features(
3078        self,
3079        dataset_names: List = None,
3080        exceptions: List = None,
3081        remove_active: bool = False,
3082    ) -> None:
3083        """Remove saved pre-computed dataset feature files.
3084
3085        By default, all features will be deleted.
3086        No essential information can get lost, storing them only saves time. Be careful with deleting datasets
3087        while training or inference is happening though.
3088
3089        Parameters
3090        ----------
3091        dataset_names : list, optional
3092            a list of dataset names to delete (by default all names are added)
3093        exceptions : list, optional
3094            a list of dataset names to not be deleted
3095        remove_active : bool, default False
3096            if `False`, datasets used by unfinished episodes will not be deleted
3097
3098        """
3099        print("Removing datasets...")
3100        if dataset_names is None:
3101            dataset_names = []
3102        if exceptions is None:
3103            exceptions = []
3104        if not remove_active:
3105            exceptions += self._episodes().get_active_datasets()
3106        dataset_path = os.path.join(self.project_path, "saved_datasets")
3107        if os.path.exists(dataset_path):
3108            if dataset_names == []:
3109                dataset_names = set([f.split(".")[0] for f in os.listdir(dataset_path)])
3110
3111            to_remove = [
3112                x
3113                for x in dataset_names
3114                if os.path.exists(os.path.join(dataset_path, x)) and x not in exceptions
3115            ]
3116            if len(to_remove) > 2:
3117                to_remove = tqdm(to_remove)
3118            for dataset in to_remove:
3119                shutil.rmtree(os.path.join(dataset_path, dataset))
3120            to_remove = [
3121                f"{x}.pickle"
3122                for x in dataset_names
3123                if os.path.exists(os.path.join(dataset_path, f"{x}.pickle"))
3124                and x not in exceptions
3125            ]
3126            for dataset in to_remove:
3127                os.remove(os.path.join(dataset_path, dataset))
3128            names = self._saved_datasets().dataset_names()
3129            self._saved_datasets().remove(names)
3130        print("\n")
3131
3132    def remove_extra_checkpoints(
3133        self, episode_names: List = None, exceptions: List = None
3134    ) -> None:
3135        """Remove intermediate model checkpoint files (only leave the files for the last epoch).
3136
3137        By default, all intermediate checkpoints will be deleted.
3138        Files in the model folder that are not associated with any record in the meta files are also deleted.
3139
3140        Parameters
3141        ----------
3142        episode_names : list, optional
3143            a list of episode names to clean (by default all names are added)
3144        exceptions : list, optional
3145            a list of episode names to not clean
3146
3147        """
3148        model_path = os.path.join(self.project_path, "results", "model")
3149        try:
3150            all_names = self._episodes().data.index
3151        except:
3152            all_names = os.listdir(model_path)
3153        if episode_names is None:
3154            episode_names = all_names
3155        if exceptions is None:
3156            exceptions = []
3157        to_remove = [x for x in episode_names if x not in exceptions]
3158        folders = os.listdir(model_path)
3159        for folder in folders:
3160            if folder not in all_names:
3161                shutil.rmtree(os.path.join(model_path, folder))
3162            elif folder in to_remove:
3163                files = os.listdir(os.path.join(model_path, folder))
3164                for file in sorted(files)[:-1]:
3165                    os.remove(os.path.join(model_path, folder, file))
3166
3167    def remove_search(self, search_name: str) -> None:
3168        """Remove a hyperparameter search record.
3169
3170        Parameters
3171        ----------
3172        search_name : str
3173            the name of the search to remove
3174
3175        """
3176        self._searches().remove_episode(search_name)
3177        graph_path = os.path.join(self.project_path, "results", "searches", search_name)
3178        if os.path.exists(graph_path):
3179            shutil.rmtree(graph_path)
3180
3181    def remove_suggestion(self, suggestion_name: str) -> None:
3182        """Remove a suggestion record.
3183
3184        Parameters
3185        ----------
3186        suggestion_name : str
3187            the name of the suggestion to remove
3188
3189        """
3190        self._suggestions().remove_episode(suggestion_name)
3191        suggestion_path = os.path.join(
3192            self.project_path, "results", "suggestions", suggestion_name
3193        )
3194        if os.path.exists(suggestion_path):
3195            shutil.rmtree(suggestion_path)
3196
3197    def remove_prediction(self, prediction_name: str) -> None:
3198        """Remove a prediction record.
3199
3200        Parameters
3201        ----------
3202        prediction_name : str
3203            the name of the prediction to remove
3204
3205        """
3206        self._predictions().remove_episode(prediction_name)
3207        prediction_path = self.prediction_path(prediction_name)
3208        if os.path.exists(prediction_path):
3209            shutil.rmtree(prediction_path)
3210
3211    def check_prediction_exists(self, prediction_name: str) -> str | None:
3212        """Check if a prediction exists.
3213
3214        Parameters
3215        ----------
3216        prediction_name : str
3217            the name of the prediction to check
3218
3219        Returns
3220        -------
3221        str | None
3222            the path to the prediction if it exists, `None` otherwise
3223
3224        """
3225        prediction_path = self.prediction_path(prediction_name)
3226        if os.path.exists(prediction_path):
3227            return prediction_path
3228        return None
3229
3230    def remove_episode(self, episode_name: str) -> None:
3231        """Remove all model, logs and metafile records related to an episode.
3232
3233        Parameters
3234        ----------
3235        episode_name : str
3236            the name of the episode to remove
3237
3238        """
3239        runs = self._episodes().get_runs(episode_name)
3240        runs.append(episode_name)
3241        for run in runs:
3242            self._episodes().remove_episode(run)
3243            model_path = os.path.join(self.project_path, "results", "model", run)
3244            if os.path.exists(model_path):
3245                shutil.rmtree(model_path)
3246            log_path = os.path.join(self.project_path, "results", "logs", f"{run}.txt")
3247            if os.path.exists(log_path):
3248                os.remove(log_path)
3249
3250    @abstractmethod
3251    def _reformat_results(res: dict, classes: dict, exclusive=False):
3252        """Add classes to micro metrics in results from evaluation"""
3253        results = deepcopy(res)
3254        for key in results.keys():
3255            if isinstance(results[key], list):
3256                if exclusive and len(classes) == len(results[key]) + 1:
3257                    other_ind = list(classes.keys())[
3258                        list(classes.values()).index("other")
3259                    ]
3260                    classes = {
3261                        (i if i < other_ind else i - 1): c
3262                        for i, c in classes.items()
3263                        if i != other_ind
3264                    }
3265                assert len(results[key]) == len(
3266                    classes
3267                ), f"Results for {key} have {len(results[key])} values, but {len(classes)} classes were provided!"
3268                results[key] = {
3269                    classes[i]: float(v) for i, v in enumerate(results[key])
3270                }
3271        return results
3272
3273    def prune_unfinished(self, exceptions: List = None) -> List:
3274        """Remove all interrupted episodes.
3275
3276        Remove all episodes that either don't have a log file or have less epochs in the log file than in
3277        the training parameters or have a model folder but not a record. Note that it can remove episodes that are
3278        currently running!
3279
3280        Parameters
3281        ----------
3282        exceptions : list
3283            the episodes to keep even if they are interrupted
3284
3285        Returns
3286        -------
3287        pruned : list
3288            a list of the episode names that were pruned
3289
3290        """
3291        if exceptions is None:
3292            exceptions = []
3293        unfinished = self._episodes().unfinished_episodes()
3294        unfinished = [x for x in unfinished if x not in exceptions]
3295        model_folders = os.listdir(os.path.join(self.project_path, "results", "model"))
3296        unfinished += [
3297            x for x in model_folders if x not in self._episodes().list_episodes().index
3298        ]
3299        print(f"PRUNING {unfinished}")
3300        for episode_name in unfinished:
3301            self.remove_episode(episode_name)
3302        print(f"\n")
3303        return unfinished
3304
3305    def prediction_path(self, prediction_name: str) -> str:
3306        """Get the path where prediction files are saved.
3307
3308        Parameters
3309        ----------
3310        prediction_name : str
3311            name of the prediction
3312
3313        Returns
3314        -------
3315        prediction_path : str
3316            the file path
3317
3318        """
3319        return os.path.join(
3320            self.project_path, "results", "predictions", f"{prediction_name}"
3321        )
3322
3323    def suggestion_path(self, suggestion_name: str) -> str:
3324        """Get the path where suggestion files are saved.
3325
3326        Parameters
3327        ----------
3328        suggestion_name : str
3329            name of the prediction
3330
3331        Returns
3332        -------
3333        suggestion_path : str
3334            the file path
3335
3336        """
3337        return os.path.join(
3338            self.project_path, "results", "suggestions", f"{suggestion_name}"
3339        )
3340
3341    @classmethod
3342    def print_data_types(cls):
3343        """Print available data types."""
3344        print("DATA TYPES:")
3345        for key, value in cls.data_types().items():
3346            print(f"{key}:")
3347            print(value.__doc__)
3348
3349    @classmethod
3350    def print_annotation_types(cls):
3351        """Print available annotation types."""
3352        print("ANNOTATION TYPES:")
3353        for key, value in cls.annotation_types():
3354            print(f"{key}:")
3355            print(value.__doc__)
3356
3357    @staticmethod
3358    def data_types() -> List:
3359        """Get available data types.
3360
3361        Returns
3362        -------
3363        data_types : list
3364            available data types
3365
3366        """
3367        return options.input_stores
3368
3369    @staticmethod
3370    def annotation_types() -> List:
3371        """Get available annotation types.
3372
3373        Returns
3374        -------
3375        list
3376            available annotation types
3377
3378        """
3379        return options.annotation_stores
3380
3381    def _save_mask(self, file: Dict, mask_name: str):
3382        """Save a mask file.
3383
3384        Parameters
3385        ----------
3386        file : dict
3387            the mask file data to save
3388        mask_name : str
3389            the name of the mask file
3390
3391        """
3392        if not os.path.exists(self._mask_path()):
3393            os.mkdir(self._mask_path())
3394        with open(os.path.join(self._mask_path(), mask_name + ".pickle"), "wb") as f:
3395            pickle.dump(file, f)
3396
3397    def _load_mask(self, mask_name: str) -> Dict:
3398        """Load a mask file.
3399
3400        Parameters
3401        ----------
3402        mask_name : str
3403            the name of the mask file to load
3404
3405        Returns
3406        -------
3407        mask : dict
3408            the loaded mask data
3409
3410        """
3411        with open(os.path.join(self._mask_path(), mask_name + ".pickle"), "rb") as f:
3412            data = pickle.load(f)
3413        return data
3414
3415    def _thresholds(self) -> DecisionThresholds:
3416        """Get the decision thresholds meta object.
3417
3418        Returns
3419        -------
3420        thresholds : DecisionThresholds
3421            the decision thresholds meta object
3422
3423        """
3424        return DecisionThresholds(self._thresholds_path())
3425
3426    def _episodes(self) -> SavedRuns:
3427        """Get the episodes meta object.
3428
3429        Returns
3430        -------
3431        episodes : SavedRuns
3432            the episodes meta object
3433
3434        """
3435        try:
3436            return SavedRuns(self._episodes_path(), self.project_path)
3437        except:
3438            self.load_metadata_backup()
3439            return SavedRuns(self._episodes_path(), self.project_path)
3440
3441    def _suggestions(self) -> Suggestions:
3442        """Get the suggestions meta object.
3443
3444        Returns
3445        -------
3446        suggestions : Suggestions
3447            the suggestions meta object
3448
3449        """
3450        try:
3451            return Suggestions(self._suggestions_path(), self.project_path)
3452        except:
3453            self.load_metadata_backup()
3454            return Suggestions(self._suggestions_path(), self.project_path)
3455
3456    def _predictions(self) -> SavedRuns:
3457        """Get the predictions meta object.
3458
3459        Returns
3460        -------
3461        predictions : SavedRuns
3462            the predictions meta object
3463
3464        """
3465        try:
3466            return SavedRuns(self._predictions_path(), self.project_path)
3467        except:
3468            self.load_metadata_backup()
3469            return SavedRuns(self._predictions_path(), self.project_path)
3470
3471    def _saved_datasets(self) -> SavedStores:
3472        """Get the datasets meta object.
3473
3474        Returns
3475        -------
3476        datasets : SavedStores
3477            the datasets meta object
3478
3479        """
3480        try:
3481            return SavedStores(self._saved_datasets_path())
3482        except:
3483            self.load_metadata_backup()
3484            return SavedStores(self._saved_datasets_path())
3485
3486    def _prediction(self, name: str) -> Run:
3487        """Get a prediction meta object.
3488
3489        Parameters
3490        ----------
3491        name : str
3492            episode name
3493
3494        Returns
3495        -------
3496        prediction : Run
3497            the prediction meta object
3498
3499        """
3500        try:
3501            return Run(name, self.project_path, meta_path=self._predictions_path())
3502        except:
3503            self.load_metadata_backup()
3504            return Run(name, self.project_path, meta_path=self._predictions_path())
3505
3506    def _episode(self, name: str) -> Run:
3507        """Get an episode meta object.
3508
3509        Parameters
3510        ----------
3511        name : str
3512            episode name
3513
3514        Returns
3515        -------
3516        episode : Run
3517            the episode meta object
3518
3519        """
3520        try:
3521            return Run(name, self.project_path, meta_path=self._episodes_path())
3522        except:
3523            self.load_metadata_backup()
3524            return Run(name, self.project_path, meta_path=self._episodes_path())
3525
3526    def _searches(self) -> Searches:
3527        """Get the hyperparameter search meta object.
3528
3529        Returns
3530        -------
3531        searches : Searches
3532            the searches meta object
3533
3534        """
3535        try:
3536            return Searches(self._searches_path(), self.project_path)
3537        except:
3538            self.load_metadata_backup()
3539            return Searches(self._searches_path(), self.project_path)
3540
3541    def _update_configs(self) -> None:
3542        """Update the project config files with newly added files and parameters.
3543
3544        This method updates the project configuration with the data path and copies
3545        any new configuration files from the original package to the project.
3546
3547        """
3548        self.update_parameters({"data": {"data_path": self.data_path}})
3549        folders = ["augmentations", "features", "model"]
3550        original_path = os.path.join(
3551            os.path.dirname(os.path.dirname(__file__)), "config"
3552        )
3553        project_path = os.path.join(self.project_path, "config")
3554        filenames = [x for x in os.listdir(original_path) if x.endswith("yaml")]
3555        for folder in folders:
3556            filenames += [
3557                os.path.join(folder, x)
3558                for x in os.listdir(os.path.join(original_path, folder))
3559            ]
3560        filenames.append(os.path.join("data", f"{self.data_type}.yaml"))
3561        if self.annotation_type != "none":
3562            filenames.append(os.path.join("annotation", f"{self.annotation_type}.yaml"))
3563        for file in filenames:
3564            filepath_original = os.path.join(original_path, file)
3565            if file.startswith("data") or file.startswith("annotation"):
3566                file = os.path.basename(file)
3567            filepath_project = os.path.join(project_path, file)
3568            if not os.path.exists(filepath_project):
3569                shutil.copy(filepath_original, filepath_project)
3570            else:
3571                original_pars = self._open_yaml(filepath_original)
3572                project_pars = self._open_yaml(filepath_project)
3573                to_remove = []
3574                for key, value in project_pars.items():
3575                    if key not in original_pars:
3576                        if key not in ["data_type", "annotation_type"]:
3577                            to_remove.append(key)
3578                for key in to_remove:
3579                    project_pars.pop(key)
3580                to_remove = []
3581                for key, value in original_pars.items():
3582                    if key in project_pars:
3583                        to_remove.append(key)
3584                for key in to_remove:
3585                    original_pars.pop(key)
3586                project_pars = self._update(project_pars, original_pars)
3587                with open(filepath_project, "w", encoding="utf-8") as f:
3588                    YAML().dump(project_pars, f)
3589
3590    def _update_project(self) -> None:
3591        """Update project files with the current version."""
3592        version_file = self._version_path()
3593        ok = True
3594        if not os.path.exists(version_file):
3595            ok = False
3596        else:
3597            with open(version_file, encoding="utf-8") as f:
3598                project_version = f.read()
3599            if project_version < __version__:
3600                ok = False
3601            elif project_version > __version__:
3602                warnings.warn(
3603                    f"The project expects a higher dlc2action version ({project_version}), please update!"
3604                )
3605        if not ok:
3606            project_config_path = os.path.join(self.project_path, "config")
3607            config_path = os.path.join(
3608                os.path.dirname(os.path.dirname(__path__)), "config"
3609            )
3610            episodes = self._episodes()
3611            folders = ["annotation", "augmentations", "data", "features", "model"]
3612
3613            project_annotation_configs = os.listdir(
3614                os.path.join(project_config_path, "annotation")
3615            )
3616            annotation_configs = os.listdir(os.path.join(config_path, "annotation"))
3617            for ann_config in annotation_configs:
3618                if ann_config not in project_annotation_configs:
3619                    shutil.copytree(
3620                        os.path.join(config_path, "annotation", ann_config),
3621                        os.path.join(project_config_path, "annotation", ann_config),
3622                        dirs_exist_ok=True,
3623                    )
3624                else:
3625                    project_pars = self._open_yaml(
3626                        os.path.join(project_config_path, "annotation", ann_config)
3627                    )
3628                    pars = self._open_yaml(
3629                        os.path.join(config_path, "annotation", ann_config)
3630                    )
3631                    new_keys = set(pars.keys()) - set(project_pars.keys())
3632                    for key in new_keys:
3633                        project_pars[key] = pars[key]
3634                        c = self._get_comment(pars.ca.items.get(key))
3635                        project_pars.yaml_add_eol_comment(c, key=key)
3636                        episodes.update(
3637                            condition=f"general/annotation_type::={ann_config}",
3638                            update={f"data/{key}": pars[key]},
3639                        )
3640
3641    def _initialize_project(
3642        self,
3643        data_type: str,
3644        annotation_type: str = None,
3645        data_path: str = None,
3646        annotation_path: str = None,
3647        copy: bool = True,
3648    ) -> None:
3649        """Initialize a new project."""
3650        if data_type not in self.data_types():
3651            raise ValueError(
3652                f"The {data_type} data type is not available. "
3653                f"Please choose from {self.data_types()}"
3654            )
3655        if annotation_type not in self.annotation_types():
3656            raise ValueError(
3657                f"The {annotation_type} annotation type is not available. "
3658                f"Please choose from {self.annotation_types()}"
3659            )
3660        os.mkdir(self.project_path)
3661        folders = ["results", "saved_datasets", "meta", "config"]
3662        for f in folders:
3663            os.mkdir(os.path.join(self.project_path, f))
3664        results_subfolders = [
3665            "model",
3666            "logs",
3667            "predictions",
3668            "splits",
3669            "searches",
3670            "suggestions",
3671        ]
3672        for sf in results_subfolders:
3673            os.mkdir(os.path.join(self.project_path, "results", sf))
3674        if data_path is not None:
3675            if copy:
3676                os.mkdir(os.path.join(self.project_path, "data"))
3677                shutil.copytree(
3678                    data_path,
3679                    os.path.join(self.project_path, "data"),
3680                    dirs_exist_ok=True,
3681                )
3682                data_path = os.path.join(self.project_path, "data")
3683        if annotation_path is not None:
3684            if copy:
3685                os.mkdir(os.path.join(self.project_path, "annotation"))
3686                shutil.copytree(
3687                    annotation_path,
3688                    os.path.join(self.project_path, "annotation"),
3689                    dirs_exist_ok=True,
3690                )
3691                annotation_path = os.path.join(self.project_path, "annotation")
3692        self._generate_config(
3693            data_type,
3694            annotation_type,
3695            data_path=data_path,
3696            annotation_path=annotation_path,
3697        )
3698        self._generate_meta()
3699
3700    def _read_types(self) -> Tuple[str, str]:
3701        """Get data type and annotation type from existing project files."""
3702        config_path = os.path.join(self.project_path, "config", "general.yaml")
3703        with open(config_path, encoding="utf-8") as f:
3704            pars = YAML().load(f)
3705        data_type = pars["data_type"]
3706        annotation_type = pars["annotation_type"]
3707        return annotation_type, data_type
3708
3709    def _read_paths(self) -> Tuple[str, str]:
3710        """Get data type and annotation type from existing project files."""
3711        config_path = os.path.join(self.project_path, "config", "data.yaml")
3712        with open(config_path, encoding="utf-8") as f:
3713            pars = YAML().load(f)
3714        data_path = pars["data_path"]
3715        annotation_path = pars["annotation_path"]
3716        return annotation_path, data_path
3717
3718    def _generate_config(
3719        self, data_type: str, annotation_type: str, data_path: str, annotation_path: str
3720    ) -> None:
3721        """Initialize the config files."""
3722        default_path = os.path.join(
3723            os.path.dirname(os.path.dirname(__file__)), "config"
3724        )
3725        config_path = os.path.join(self.project_path, "config")
3726        files = ["losses", "metrics", "ssl", "training"]
3727        for f in files:
3728            shutil.copy(os.path.join(default_path, f"{f}.yaml"), config_path)
3729        shutil.copytree(
3730            os.path.join(default_path, "model"), os.path.join(config_path, "model")
3731        )
3732        shutil.copytree(
3733            os.path.join(default_path, "features"),
3734            os.path.join(config_path, "features"),
3735        )
3736        shutil.copytree(
3737            os.path.join(default_path, "augmentations"),
3738            os.path.join(config_path, "augmentations"),
3739        )
3740        yaml = YAML()
3741        data_param_path = os.path.join(default_path, "data", f"{data_type}.yaml")
3742        if os.path.exists(data_param_path):
3743            with open(data_param_path, encoding="utf-8") as f:
3744                data_params = yaml.load(f)
3745        if data_params is None:
3746            data_params = {}
3747        if annotation_type is None:
3748            ann_params = {}
3749        else:
3750            ann_param_path = os.path.join(
3751                default_path, "annotation", f"{annotation_type}.yaml"
3752            )
3753            if os.path.exists(ann_param_path):
3754                ann_params = self._open_yaml(ann_param_path)
3755            elif annotation_type == "none":
3756                ann_params = {}
3757            else:
3758                raise ValueError(
3759                    f"The {annotation_type} data type is not available. "
3760                    f"Please choose from {BehaviorDataset.annotation_types()}"
3761                )
3762        if ann_params is None:
3763            ann_params = {}
3764        data_params = self._update(data_params, ann_params)
3765        data_params["data_path"] = data_path
3766        data_params["annotation_path"] = annotation_path
3767        with open(os.path.join(config_path, "data.yaml"), "w", encoding="utf-8") as f:
3768            yaml.dump(data_params, f)
3769        with open(os.path.join(default_path, "general.yaml"), encoding="utf-8") as f:
3770            general_params = yaml.load(f)
3771        general_params["data_type"] = data_type
3772        general_params["annotation_type"] = annotation_type
3773        with open(
3774            os.path.join(config_path, "general.yaml"), "w", encoding="utf-8"
3775        ) as f:
3776            yaml.dump(general_params, f)
3777
3778    def _generate_meta(self) -> None:
3779        """Initialize the meta files."""
3780        config_file = os.path.join(self.project_path, "config")
3781        meta_fields = ["time"]
3782        columns = [("meta", field) for field in meta_fields]
3783        episodes = pd.DataFrame(columns=pd.MultiIndex.from_tuples(columns))
3784        episodes.to_pickle(self._episodes_path())
3785        meta_fields = ["time", "objective"]
3786        result_fields = ["best_params", "best_value"]
3787        columns = [("meta", field) for field in meta_fields] + [
3788            ("results", field) for field in result_fields
3789        ]
3790        searches = pd.DataFrame(columns=pd.MultiIndex.from_tuples(columns))
3791        searches.to_pickle(self._searches_path())
3792        meta_fields = ["time"]
3793        columns = [("meta", field) for field in meta_fields]
3794        predictions = pd.DataFrame(columns=pd.MultiIndex.from_tuples(columns))
3795        predictions.to_pickle(self._predictions_path())
3796        suggestions = pd.DataFrame(columns=pd.MultiIndex.from_tuples(columns))
3797        suggestions.to_pickle(self._suggestions_path())
3798        with open(os.path.join(config_file, "data.yaml"), encoding="utf-8") as f:
3799            data_keys = list(YAML().load(f).keys())
3800        saved_data = pd.DataFrame(columns=data_keys)
3801        saved_data.to_pickle(self._saved_datasets_path())
3802        pd.DataFrame().to_pickle(self._thresholds_path())
3803        # with open(self._version_path()) as f:
3804        #     f.write(__version__)
3805
3806    def _open_yaml(self, path: str) -> CommentedMap:
3807        """Load a parameter dictionary from a .yaml file."""
3808        with open(path, encoding="utf-8") as f:
3809            data = YAML().load(f)
3810        if data is None:
3811            data = {}
3812        return data
3813
3814    def _compare(self, d: Dict, u: Dict, allow_diff: float = 1e-7):
3815        """Compare nested dictionaries with 'almost equal' condition."""
3816        ok = True
3817        if u.keys() != d.keys():
3818            ok = False
3819        else:
3820            for k, v in u.items():
3821                if isinstance(v, Mapping):
3822                    ok = self._compare(d[k], v, allow_diff=allow_diff)
3823                else:
3824                    if isinstance(v, float) or isinstance(d[k], float):
3825                        if not isinstance(d[k], float) and not isinstance(d[k], int):
3826                            ok = False
3827                        elif not isinstance(v, float) and not isinstance(v, int):
3828                            ok = False
3829                        elif np.abs(v - d[k]) > allow_diff:
3830                            ok = False
3831                    elif v != d[k]:
3832                        ok = False
3833        return ok
3834
3835    def _check_comment(self, comment_sequence: List) -> bool:
3836        """Check if a comment already exists in a ruamel.yaml comment sequence."""
3837        if comment_sequence is None:
3838            return False
3839        c = self._get_comment(comment_sequence)
3840        if c != "":
3841            return True
3842        else:
3843            return False
3844
3845    def _get_comment(self, comment_sequence: List, strip=True) -> str:
3846        """Get the comment string from a ruamel.yaml comment sequence."""
3847        if comment_sequence is None:
3848            return ""
3849        c = ""
3850        for cm in comment_sequence:
3851            if cm is not None:
3852                if isinstance(cm, Iterable):
3853                    for c in cm:
3854                        if c is not None:
3855                            c = c.value
3856                            break
3857                    break
3858                else:
3859                    c = cm.value
3860                    break
3861        if strip:
3862            c = c.strip()
3863        return c
3864
3865    def _update(self, d: Union[CommentedMap, Dict], u: Union[CommentedMap, Dict]):
3866        """Update a nested dictionary."""
3867        if "general" in u and "model_name" in u["general"] and "model" in d:
3868            model_name = u["general"]["model_name"]
3869            if d["general"]["model_name"] != model_name:
3870                d["model"] = self._open_yaml(
3871                    os.path.join(
3872                        self.project_path, "config", "model", f"{model_name}.yaml"
3873                    )
3874                )
3875        d_copied = deepcopy(d)
3876        for k, v in u.items():
3877            if (
3878                k in d_copied
3879                and isinstance(d_copied[k], list)
3880                and isinstance(v, Mapping)
3881                and all([isinstance(x, int) for x in v.keys()])
3882            ):
3883                for kk, vv in v.items():
3884                    d_copied[k][kk] = vv
3885            elif (
3886                isinstance(v, Mapping)
3887                and k in d_copied
3888                and isinstance(d_copied[k], Mapping)
3889            ):
3890                if d_copied[k] is None:
3891                    d_k = CommentedMap()
3892                else:
3893                    d_k = d_copied[k]
3894                d_copied[k] = self._update(d_k, v)
3895            else:
3896                d_copied[k] = v
3897                if isinstance(u, CommentedMap) and u.ca.items.get(k) is not None:
3898                    c = self._get_comment(u.ca.items.get(k), strip=False)
3899                    if isinstance(d_copied, CommentedMap) and not self._check_comment(
3900                        d_copied.ca.items.get(k)
3901                    ):
3902                        d_copied.yaml_add_eol_comment(c, key=k)
3903        return d_copied
3904
3905    def _update_with_search(
3906        self,
3907        d: Dict,
3908        search_name: str,
3909        load_parameters: list = None,
3910        round_to_binary: list = None,
3911    ):
3912        """Update a dictionary with best parameters from a hyperparameter search."""
3913        u, _ = self._searches().get_best_params(
3914            search_name, load_parameters, round_to_binary
3915        )
3916        return self._update(d, u)
3917
3918    def _read_parameters(self, catch_blanks=True) -> Dict:
3919        """Compose a parameter dictionary to create a task from the config files."""
3920        config_path = os.path.join(self.project_path, "config")
3921        keys = [
3922            "data",
3923            "general",
3924            "losses",
3925            "metrics",
3926            "ssl",
3927            "training",
3928        ]
3929        parameters = {}
3930        for key in keys:
3931            parameters[key] = self._open_yaml(os.path.join(config_path, f"{key}.yaml"))
3932        features = parameters["general"]["feature_extraction"]
3933        parameters["features"] = self._open_yaml(
3934            os.path.join(config_path, "features", f"{features}.yaml")
3935        )
3936        transformer = options.extractor_to_transformer[features]
3937        parameters["augmentations"] = self._open_yaml(
3938            os.path.join(config_path, "augmentations", f"{transformer}.yaml")
3939        )
3940        model = parameters["general"]["model_name"]
3941        parameters["model"] = self._open_yaml(
3942            os.path.join(config_path, "model", f"{model}.yaml")
3943        )
3944        # input = parameters["general"]["input"]
3945        # parameters["model"] = self._open_yaml(
3946        #     os.path.join(config_path, "model", f"{model}.yaml")
3947        # )
3948        if catch_blanks:
3949            blanks = self._get_blanks()
3950            if len(blanks) > 0:
3951                self.list_blanks()
3952                raise ValueError(
3953                    f"Please fill in all the blanks before running experiments"
3954                )
3955        return parameters
3956
3957    def set_main_parameters(self, model_name: str = None, metric_names: List = None):
3958        """Select the model and the metrics.
3959
3960        Parameters
3961        ----------
3962        model_name : str, optional
3963            model name; run `project.help("model") to find out more
3964        metric_names : list, optional
3965            a list of metric function names; run `project.help("metrics") to find out more
3966
3967        """
3968        pars = {"general": {}}
3969        if model_name is not None:
3970            assert model_name in options.models
3971            pars["general"]["model_name"] = model_name
3972        if metric_names is not None:
3973            for metric in metric_names:
3974                assert metric in options.metrics
3975            pars["general"]["metric_functions"] = metric_names
3976        self.update_parameters(pars)
3977
3978    def help(self, keyword: str = None):
3979        """Get information on available options.
3980
3981        Parameters
3982        ----------
3983        keyword : str, optional
3984            the keyword for options (run without arguments to see which keywords are available)
3985
3986        """
3987        if keyword is None:
3988            print("AVAILABLE HELP FUNCTIONS:")
3989            print("- Try running `project.help(keyword)` with the following keywords:")
3990            print("    - model: to get more information on available models,")
3991            print(
3992                "    - features: to get more information on available feature extraction modes,"
3993            )
3994            print(
3995                "    - partition_method: to get more information on available train/test/val partitioning methods,"
3996            )
3997            print("    - metrics: to see a list of available metric functions.")
3998            print("    - data: to see help for expected data structure")
3999            print(
4000                "- To start working with this project, first run `project.list_blanks()` to check which parameters need to be filled in."
4001            )
4002            print(
4003                "- After a model and metrics are set, run `project.list_basic_parameters()` to see a list of the most important parameters that you might want to modify"
4004            )
4005            print(
4006                f"- If you want to dig deeper, get the full dictionary with project._read_parameters() (it is a `ruamel.yaml.comments.CommentedMap` instance)."
4007            )
4008        elif keyword == "model":
4009            print("MODELS:")
4010            for key, model in options.models.items():
4011                print(f"{key}:")
4012                print(model.__doc__)
4013        elif keyword == "features":
4014            print("FEATURE EXTRACTORS:")
4015            for key, extractor in options.feature_extractors.items():
4016                print(f"{key}:")
4017                print(extractor.__doc__)
4018        elif keyword == "partition_method":
4019            print("PARTITION METHODS:")
4020            print(
4021                BehaviorDataset.partition_train_test_val.__doc__.split(
4022                    "The partitioning method:"
4023                )[1].split("val_frac :")[0]
4024            )
4025        elif keyword == "metrics":
4026            print("METRICS:")
4027            for key, metric in options.metrics.items():
4028                print(f"{key}:")
4029                print(metric.__doc__)
4030        elif keyword == "data":
4031            print("DATA:")
4032            print(f"Video data: {self.data_type}")
4033            print(options.input_stores[self.data_type].__doc__)
4034            print(f"Annotation data: {self.annotation_type}")
4035            print(options.annotation_stores[self.annotation_type].__doc__)
4036            print(
4037                "Annotation path and data path don't have to be separate, you can keep everything in one folder."
4038            )
4039        else:
4040            raise ValueError(f"The {keyword} keyword is not recognized")
4041        print("\n")
4042
4043    def _process_value(self, value):
4044        """Process a configuration value for display.
4045
4046        Parameters
4047        ----------
4048        value : any
4049            the value to process
4050
4051        Returns
4052        -------
4053        processed_value : any
4054            the processed value
4055
4056        """
4057        if isinstance(value, str):
4058            value = f'"{value}"'
4059        elif isinstance(value, CommentedSet):
4060            value = {x for x in value}
4061        return value
4062
4063    def _get_blanks(self):
4064        """Get a list of blank (unset) parameters in the configuration.
4065
4066        Returns
4067        -------
4068        caught : list
4069            a list of parameter keys that have blank values
4070
4071        """
4072        caught = []
4073        parameters = self._read_parameters(catch_blanks=False)
4074        for big_key, big_value in parameters.items():
4075            for key, value in big_value.items():
4076                if value == "???":
4077                    caught.append(
4078                        (big_key, key, self._get_comment(big_value.ca.items.get(key)))
4079                    )
4080        return caught
4081
4082    def list_blanks(self, blanks=None):
4083        """List parameters that need to be filled in.
4084
4085        Parameters
4086        ----------
4087        blanks : list, optional
4088            a list of the parameters to list, if already known
4089
4090        """
4091        if blanks is None:
4092            blanks = self._get_blanks()
4093        if len(blanks) > 0:
4094            to_update = defaultdict(lambda: [])
4095            for b, k, c in blanks:
4096                to_update[b].append((k, c))
4097            print("Before running experiments, please update all the blanks.")
4098            print("To do that, you can run this.")
4099            print("--------------------------------------------------------")
4100            print(f"project.update_parameters(")
4101            print(f"    {{")
4102            for big_key, keys in to_update.items():
4103                print(f'        "{big_key}": {{')
4104                for key, comment in keys:
4105                    print(f'            "{key}": ..., {comment}')
4106                print(f"        }}")
4107            print(f"    }}")
4108            print(")")
4109            print("--------------------------------------------------------")
4110            print("Replace ... with relevant values.")
4111        else:
4112            print("There is no blanks left!")
4113
4114    def list_basic_parameters(
4115        self,
4116    ):
4117        """Get a list of most relevant parameters and code to modify them."""
4118        parameters = self._read_parameters()
4119        print("BASIC PARAMETERS:")
4120        model_name = parameters["general"]["model_name"]
4121        metric_names = parameters["general"]["metric_functions"]
4122        loss_name = parameters["general"]["loss_function"]
4123        feature_extraction = parameters["general"]["feature_extraction"]
4124        print("Here is a list of current parameters.")
4125        print(
4126            "You can copy this code, change the parameters you want to set and run it to update the project config."
4127        )
4128        print("--------------------------------------------------------")
4129        print("project.update_parameters(")
4130        print("    {")
4131        for group in ["general", "data", "training"]:
4132            print(f'        "{group}": {{')
4133            for key in options.basic_parameters[group]:
4134                if key in parameters[group]:
4135                    print(
4136                        f'            "{key}": {self._process_value(parameters[group][key])}, {self._get_comment(parameters[group].ca.items.get(key))}'
4137                    )
4138            print("        },")
4139        print('        "losses": {')
4140        print(f'            "{loss_name}": {{')
4141        for key in options.basic_parameters["losses"][loss_name]:
4142            if key in parameters["losses"][loss_name]:
4143                print(
4144                    f'                "{key}": {self._process_value(parameters["losses"][loss_name][key])}, {self._get_comment(parameters["losses"][loss_name].ca.items.get(key))}'
4145                )
4146        print("            },")
4147        print("        },")
4148        print('        "metrics": {')
4149        for metric in metric_names:
4150            print(f'            "{metric}": {{')
4151            for key in parameters["metrics"][metric]:
4152                print(
4153                    f'                "{key}": {self._process_value(parameters["metrics"][metric][key])}, {self._get_comment(parameters["metrics"][metric].ca.items.get(key))}'
4154                )
4155            print("            },")
4156        print("        },")
4157        print('        "model": {')
4158        for key in options.basic_parameters["model"][model_name]:
4159            if key in parameters["model"]:
4160                print(
4161                    f'            "{key}": {self._process_value(parameters["model"][key])}, {self._get_comment(parameters["model"].ca.items.get(key))}'
4162                )
4163
4164        print("        },")
4165        print('        "features": {')
4166        for key in options.basic_parameters["features"][feature_extraction]:
4167            if key in parameters["features"]:
4168                print(
4169                    f'            "{key}": {self._process_value(parameters["features"][key])}, {self._get_comment(parameters["features"].ca.items.get(key))}'
4170                )
4171
4172        print("        },")
4173        print('        "augmentations": {')
4174        for key in options.basic_parameters["augmentations"][feature_extraction]:
4175            if key in parameters["augmentations"]:
4176                print(
4177                    f'            "{key}": {self._process_value(parameters["augmentations"][key])}, {self._get_comment(parameters["augmentations"].ca.items.get(key))}'
4178                )
4179        print("        },")
4180        print("    },")
4181        print(")")
4182        print("--------------------------------------------------------")
4183        print("\n")
4184
4185    def _create_record(
4186        self,
4187        episode_name: str,
4188        behaviors_dict: Dict,
4189        load_episode: str = None,
4190        parameters_update: Dict = None,
4191        task: TaskDispatcher = None,
4192        load_epoch: int = None,
4193        load_search: str = None,
4194        load_parameters: list = None,
4195        round_to_binary: list = None,
4196        load_strict: bool = True,
4197        n_seeds: int = 1,
4198    ) -> TaskDispatcher:
4199        """Create a meta data episode record."""
4200        if episode_name in self._episodes().data.index:
4201            return
4202        if type(n_seeds) is not int or n_seeds < 1:
4203            raise ValueError(
4204                f"The n_seeds parameter has to be an integer larger than 0; got {n_seeds}"
4205            )
4206        if parameters_update is None:
4207            parameters_update = {}
4208        parameters = self._read_parameters()
4209        parameters = self._update(parameters, parameters_update)
4210        if load_search is not None:
4211            parameters = self._update_with_search(
4212                parameters, load_search, load_parameters, round_to_binary
4213            )
4214        parameters = self._fill(
4215            parameters,
4216            episode_name,
4217            load_episode,
4218            load_epoch=load_epoch,
4219            only_load_model=True,
4220            load_strict=load_strict,
4221            continuing=True,
4222        )
4223        self._save_episode(episode_name, parameters, behaviors_dict)
4224        return task
4225
4226    def _save_thresholds(
4227        self,
4228        episode_names: List,
4229        metric_name: str,
4230        parameters: Dict,
4231        thresholds: List,
4232        load_epochs: List,
4233    ):
4234        """Save optimal decision thresholds in the meta records."""
4235        metric_parameters = parameters["metrics"][metric_name]
4236        self._thresholds().save_thresholds(
4237            episode_names, load_epochs, metric_name, metric_parameters, thresholds
4238        )
4239
4240    def _save_episode(
4241        self,
4242        episode_name: str,
4243        parameters: Dict,
4244        behaviors_dict: Dict,
4245        suppress_validation: bool = False,
4246        training_time: str = None,
4247        norm_stats: Dict = None,
4248    ) -> None:
4249        """Save an episode in the meta files."""
4250        try:
4251            split_info = self._split_info_from_filename(
4252                parameters["training"]["split_path"]
4253            )
4254            parameters["training"]["partition_method"] = split_info["partition_method"]
4255        except:
4256            pass
4257        if norm_stats is not None:
4258            norm_stats = dict(norm_stats)
4259        parameters["training"]["stats"] = norm_stats
4260        self._episodes().save_episode(
4261            episode_name,
4262            parameters,
4263            behaviors_dict,
4264            suppress_validation=suppress_validation,
4265            training_time=training_time,
4266        )
4267
4268    def _save_suggestions(
4269        self, suggestions_name: str, parameters: Dict, meta_parameters: Dict
4270    ) -> None:
4271        """Save a suggestion in the meta files."""
4272        self._suggestions().save_suggestion(
4273            suggestions_name, parameters, meta_parameters
4274        )
4275
4276    def _update_episode_results(
4277        self,
4278        episode_name: str,
4279        logs: Tuple,
4280        training_time: str = None,
4281    ) -> None:
4282        """Save the results of a run in the meta files."""
4283        self._episodes().update_episode_results(episode_name, logs, training_time)
4284
4285    def _save_prediction(
4286        self,
4287        prediction_name: str,
4288        predicted: Dict[str, Dict],
4289        parameters: Dict,
4290        task: TaskDispatcher,
4291        mode: str = "test",
4292        embedding: bool = False,
4293        inference_time: str = None,
4294        behavior_dict: List[Dict[str, Any]] = None,
4295    ) -> None:
4296        """Save a prediction in the meta files."""
4297
4298        folder = self.prediction_path(prediction_name)
4299        os.mkdir(folder)
4300        for video_id, prediction in predicted.items():
4301            with open(
4302                os.path.join(
4303                    folder, video_id + f"_{prediction_name}_prediction.pickle"
4304                ),
4305                "wb",
4306            ) as f:
4307                prediction["min_frames"], prediction["max_frames"] = task.dataset(
4308                    mode
4309                ).get_min_max_frames(video_id)
4310                prediction["classes"] = behavior_dict
4311                pickle.dump(prediction, f)
4312
4313        parameters = self._update(
4314            parameters,
4315            {"meta": {"embedding": embedding, "inference_time": inference_time}},
4316        )
4317        self._predictions().save_episode(
4318            prediction_name, parameters, task.behaviors_dict()
4319        )
4320
4321    def _save_search(
4322        self,
4323        search_name: str,
4324        parameters: Dict,
4325        n_trials: int,
4326        best_params: Dict,
4327        best_value: float,
4328        metric: str,
4329        search_space: Dict,
4330    ) -> None:
4331        """Save a hyperparameter search in the meta files."""
4332        self._searches().save_search(
4333            search_name,
4334            parameters,
4335            n_trials,
4336            best_params,
4337            best_value,
4338            metric,
4339            search_space,
4340        )
4341
4342    def _save_stores(self, parameters: Dict) -> None:
4343        """Save a pickled dataset in the meta files."""
4344        name = os.path.basename(parameters["data"]["feature_save_path"])
4345        self._saved_datasets().save_store(name, self._get_data_pars(parameters))
4346        self.create_metadata_backup()
4347
4348    def _remove_stores(self, parameters: Dict, remove_active: bool = False) -> None:
4349        """Remove the pre-computed features folder."""
4350        name = os.path.basename(parameters["data"]["feature_save_path"])
4351        if remove_active or name not in self._episodes().get_active_datasets():
4352            self.remove_saved_features([name])
4353
4354    def _check_episode_validity(
4355        self, episode_name: str, allow_doublecolon: bool = False, force: bool = False
4356    ) -> None:
4357        """Check whether the episode name is valid."""
4358        if episode_name.startswith("_"):
4359            raise ValueError(
4360                "Names starting with an underscore are reserved by dlc2action and cannot be used!"
4361            )
4362        elif "." in episode_name:
4363            raise ValueError("Names containing '.' cannot be used!")
4364        if not allow_doublecolon and "#" in episode_name:
4365            raise ValueError(
4366                "Names containing '#' are reserved by dlc2action and cannot be used!"
4367            )
4368        if "::" in episode_name:
4369            raise ValueError(
4370                "Names containing '::' are reserved by dlc2action and cannot be used!"
4371            )
4372        if force:
4373            self.remove_episode(episode_name)
4374        elif not self._episodes().check_name_validity(episode_name):
4375            raise ValueError(
4376                f"The {episode_name} name is already taken! Set force=True to overwrite."
4377            )
4378
4379    def _check_search_validity(self, search_name: str, force: bool = False) -> None:
4380        """Check whether the search name is valid."""
4381        if search_name.startswith("_"):
4382            raise ValueError(
4383                "Names starting with an underscore are reserved by dlc2action and cannot be used!"
4384            )
4385        elif "." in search_name:
4386            raise ValueError("Names containing '.' cannot be used!")
4387        if force:
4388            self.remove_search(search_name)
4389        elif not self._searches().check_name_validity(search_name):
4390            raise ValueError(f"The {search_name} name is already taken!")
4391
4392    def _check_prediction_validity(
4393        self, prediction_name: str, force: bool = False
4394    ) -> None:
4395        """Check whether the prediction name is valid."""
4396        if prediction_name.startswith("_"):
4397            raise ValueError(
4398                "Names starting with an underscore are reserved by dlc2action and cannot be used!"
4399            )
4400        elif "." in prediction_name:
4401            raise ValueError("Names containing '.' cannot be used!")
4402        if force:
4403            self.remove_prediction(prediction_name)
4404        elif not self._predictions().check_name_validity(prediction_name):
4405            raise ValueError(f"The {prediction_name} name is already taken!")
4406
4407    def _check_suggestions_validity(
4408        self, suggestions_name: str, force: bool = False
4409    ) -> None:
4410        """Check whether the suggestions name is valid."""
4411        if suggestions_name.startswith("_"):
4412            raise ValueError(
4413                "Names starting with an underscore are reserved by dlc2action and cannot be used!"
4414            )
4415        elif "." in suggestions_name:
4416            raise ValueError("Names containing '.' cannot be used!")
4417        if force:
4418            self.remove_suggestion(suggestions_name)
4419        elif not self._suggestions().check_name_validity(suggestions_name):
4420            raise ValueError(f"The {suggestions_name} name is already taken!")
4421
4422    def _training_time(self, episode_name: str) -> int:
4423        """Get the training time of an episode in seconds."""
4424        return self._episode(episode_name).training_time()
4425
4426    def _mask_path(self) -> str:
4427        """Get the path to the masks folder.
4428
4429        Returns
4430        -------
4431        path : str
4432            the path to the masks folder
4433
4434        """
4435        return os.path.join(self.project_path, "results", "masks")
4436
4437    def _thresholds_path(self) -> str:
4438        """Get the path to the thresholds meta file.
4439
4440        Returns
4441        -------
4442        path : str
4443            the path to the thresholds meta file
4444
4445        """
4446        return os.path.join(self.project_path, "meta", "thresholds.pickle")
4447
4448    def _episodes_path(self) -> str:
4449        """Get the path to the episodes meta file.
4450
4451        Returns
4452        -------
4453        path : str
4454            the path to the episodes meta file
4455
4456        """
4457        return os.path.join(self.project_path, "meta", "episodes.pickle")
4458
4459    def _suggestions_path(self) -> str:
4460        """Get the path to the suggestions meta file.
4461
4462        Returns
4463        -------
4464        path : str
4465            the path to the suggestions meta file
4466
4467        """
4468        return os.path.join(self.project_path, "meta", "suggestions.pickle")
4469
4470    def _saved_datasets_path(self) -> str:
4471        """Get the path to the datasets meta file.
4472
4473        Returns
4474        -------
4475        path : str
4476            the path to the datasets meta file
4477
4478        """
4479        return os.path.join(self.project_path, "meta", "saved_datasets.pickle")
4480
4481    def _predictions_path(self) -> str:
4482        """Get the path to the predictions meta file.
4483
4484        Returns
4485        -------
4486        path : str
4487            the path to the predictions meta file
4488
4489        """
4490        return os.path.join(self.project_path, "meta", "predictions.pickle")
4491
4492    def _dataset_store_path(self, name: str) -> str:
4493        """Get the path to a specific pickled dataset.
4494
4495        Parameters
4496        ----------
4497        name : str
4498            the name of the dataset
4499
4500        Returns
4501        -------
4502        path : str
4503            the path to the dataset file
4504
4505        """
4506        return os.path.join(self.project_path, "saved_datasets", f"{name}.pickle")
4507
4508    def _al_points_path(self, suggestions_name: str) -> str:
4509        """Get the path to an active learning intervals file.
4510
4511        Parameters
4512        ----------
4513        suggestions_name : str
4514            the name of the suggestions
4515
4516        Returns
4517        -------
4518        path : str
4519            the path to the active learning points file
4520
4521        """
4522        path = os.path.join(
4523            self.project_path,
4524            "results",
4525            "suggestions",
4526            suggestions_name,
4527            f"{suggestions_name}_al_points.pickle",
4528        )
4529        return path
4530
4531    def _suggestion_path(self, v_id: str, suggestions_name: str) -> str:
4532        """Get the path to a suggestion file.
4533
4534        Parameters
4535        ----------
4536        v_id : str
4537            the video ID
4538        suggestions_name : str
4539            the name of the suggestions
4540
4541        Returns
4542        -------
4543        path : str
4544            the path to the suggestion file
4545
4546        """
4547        path = os.path.join(
4548            self.project_path,
4549            "results",
4550            "suggestions",
4551            suggestions_name,
4552            f"{v_id}_suggestion.pickle",
4553        )
4554        return path
4555
4556    def _searches_path(self) -> str:
4557        """Get the path to the hyperparameter search meta file.
4558
4559        Returns
4560        -------
4561        path : str
4562            the path to the searches meta file
4563
4564        """
4565        return os.path.join(self.project_path, "meta", "searches.pickle")
4566
4567    def _search_path(self, name: str) -> str:
4568        """Get the default path to the graph folder for a specific hyperparameter search.
4569
4570        Parameters
4571        ----------
4572        name : str
4573            the name of the search
4574
4575        Returns
4576        -------
4577        path : str
4578            the path to the search folder
4579
4580        """
4581        return os.path.join(self.project_path, "results", "searches", name)
4582
4583    def _version_path(self) -> str:
4584        """Get the path to the version file.
4585
4586        Returns
4587        -------
4588        path : str
4589            the path to the version file
4590
4591        """
4592        return os.path.join(self.project_path, "meta", "version.txt")
4593
4594    def _default_split_file(self, split_info: Dict) -> Optional[str]:
4595        """Generate a path to a split file from split parameters.
4596
4597        Parameters
4598        ----------
4599        split_info : dict
4600            the split parameters dictionary
4601
4602        Returns
4603        -------
4604        split_file_path : str or None
4605            the path to the split file, or None if not applicable
4606
4607        """
4608        if split_info["partition_method"].startswith("time"):
4609            return None
4610        val_frac = split_info["val_frac"]
4611        test_frac = split_info["test_frac"]
4612        split_name = f'{split_info["partition_method"]}_val{val_frac * 100}%_test{test_frac * 100}%_len{split_info["len_segment"]}_overlap{split_info["overlap"]}'
4613        if not split_info["only_load_annotated"]:
4614            split_name += "_all"
4615        split_name += ".txt"
4616        return os.path.join(self.project_path, "results", "splits", split_name)
4617
4618    def _split_info_from_filename(self, split_name: str) -> Dict:
4619        """Get split parameters from default path to a split file.
4620
4621        Parameters
4622        ----------
4623        split_name : str
4624            the name/path of the split file
4625
4626        Returns
4627        -------
4628        split_info : dict
4629            the split parameters dictionary
4630
4631        """
4632        if split_name is None:
4633            return {}
4634        try:
4635            name = os.path.basename(split_name)[:-4]
4636            split = name.split("_")
4637            if len(split) == 6:
4638                only_load_annotated = False
4639            else:
4640                only_load_annotated = True
4641            len_segment = int(split[3][3:])
4642            overlap = float(split[4][7:])
4643            if overlap > 1:
4644                overlap = int(overlap)
4645            method, val, test = split[:3]
4646            val = float(val[3:-1]) / 100
4647            test = float(test[4:-1]) / 100
4648            return {
4649                "partition_method": method,
4650                "val_frac": val,
4651                "test_frac": test,
4652                "only_load_annotated": only_load_annotated,
4653                "len_segment": len_segment,
4654                "overlap": overlap,
4655            }
4656        except:
4657            return {"partition_method": "file"}
4658
4659    def _fill(
4660        self,
4661        parameters: Dict,
4662        episode_name: str,
4663        load_experiment: str = None,
4664        load_epoch: int = None,
4665        load_strict: bool = True,
4666        only_load_model: bool = False,
4667        continuing: bool = False,
4668        enforce_split_parameters: bool = False,
4669    ) -> Dict:
4670        """Update the parameters from the config files with project specific information.
4671
4672        Fill in the constant file path parameters and generate a unique log file and a model folder.
4673        Fill in the split file if the same split has been run before in the project and change partition method to
4674        from_file.
4675        Fill in saved data path if a dataset with the same data parameters already exists in the project.
4676        If load_experiment is not None, fill in the checkpoint path as well.
4677        The only_load_model training parameter is defined by the corresponding argument.
4678        If continuing is True, new files are not created and all information is loaded from load_experiment.
4679        If prediction is True, log and model files are not created.
4680        The enforce_split_parameters parameter is used to resolve conflicts
4681        between split file path and split parameters when they arise.
4682
4683        Parameters
4684        ----------
4685        parameters : dict
4686            the parameters dictionary to update
4687        episode_name : str
4688            the name of the episode
4689        load_experiment : str, optional
4690            the name of the experiment to load from
4691        load_epoch : int, optional
4692            the epoch to load (by default the last one)
4693        load_strict : bool, default True
4694            if `True`, strict loading is enforced
4695        only_load_model : bool, default False
4696            if `True`, only the model is loaded
4697        continuing : bool, default False
4698            if `True`, continues from existing files
4699        enforce_split_parameters : bool, default False
4700            if `True`, split parameters are enforced
4701
4702        Returns
4703        -------
4704        parameters : dict
4705            the updated parameters dictionary
4706
4707        """
4708        pars = deepcopy(parameters)
4709        if episode_name == "_":
4710            self.remove_episode("_")
4711        log = os.path.join(self.project_path, "results", "logs", f"{episode_name}.txt")
4712        model_save_path = os.path.join(
4713            self.project_path, "results", "model", episode_name
4714        )
4715        if not continuing and (os.path.exists(log) or os.path.exists(model_save_path)):
4716            raise ValueError(
4717                f"The {episode_name} episode name is already in use! Set force=True to overwrite."
4718            )
4719        keys = ["val_frac", "test_frac", "partition_method"]
4720        if "len_segment" not in pars["general"] and "len_segment" in pars["data"]:
4721            pars["general"]["len_segment"] = pars["data"]["len_segment"]
4722        if "overlap" not in pars["general"] and "overlap" in pars["data"]:
4723            pars["general"]["overlap"] = pars["data"]["overlap"]
4724        if "len_segment" in pars["data"]:
4725            pars["data"].pop("len_segment")
4726        if "overlap" in pars["data"]:
4727            pars["data"].pop("overlap")
4728        split_info = {k: pars["training"][k] for k in keys}
4729        split_info["only_load_annotated"] = pars["general"]["only_load_annotated"]
4730        split_info["len_segment"] = pars["general"]["len_segment"]
4731        split_info["overlap"] = pars["general"]["overlap"]
4732        pars["training"]["log_file"] = log
4733        if not os.path.exists(model_save_path):
4734            os.mkdir(model_save_path)
4735        pars["training"]["model_save_path"] = model_save_path
4736        if load_experiment is not None:
4737            if load_experiment not in self._episodes().data.index:
4738                raise ValueError(f"The {load_experiment} episode does not exist!")
4739            old_episode = self._episode(load_experiment)
4740            old_file = old_episode.split_file()
4741            old_info = self._split_info_from_filename(old_file)
4742            if len(old_info) == 0:
4743                old_info = old_episode.split_info()
4744            if enforce_split_parameters:
4745                if split_info["partition_method"] != "file":
4746                    pars["training"]["split_path"] = self._default_split_file(
4747                        split_info
4748                    )
4749            else:
4750                equal = True
4751                if old_info["partition_method"] != split_info["partition_method"]:
4752                    equal = False
4753                if old_info["partition_method"] != "file":
4754                    if (
4755                        old_info["val_frac"] != split_info["val_frac"]
4756                        or old_info["test_frac"] != split_info["test_frac"]
4757                    ):
4758                        equal = False
4759                if not continuing and not equal:
4760                    warnings.warn(
4761                        f"The partitioning parameters in the loaded experiment ({old_info}) "
4762                        f"are not equal to the current partitioning parameters ({split_info}). "
4763                        f"The current parameters are replaced."
4764                    )
4765                pars["training"]["split_path"] = old_file
4766                for k, v in old_info.items():
4767                    pars["training"][k] = v
4768            pars["training"]["checkpoint_path"] = old_episode.model_file(load_epoch)
4769            pars["training"]["load_strict"] = load_strict
4770        else:
4771            pars["training"]["checkpoint_path"] = None
4772            if pars["training"]["partition_method"] == "file":
4773                if (
4774                    "split_path" not in pars["training"]
4775                    or pars["training"]["split_path"] is None
4776                ):
4777                    raise ValueError(
4778                        "The partition_method parameter is set to file but the "
4779                        "split_path parameter is not set!"
4780                    )
4781                elif not os.path.exists(pars["training"]["split_path"]):
4782                    raise ValueError(
4783                        f'The {pars["training"]["split_path"]} split file does not exist'
4784                    )
4785            else:
4786                pars["training"]["split_path"] = self._default_split_file(split_info)
4787        pars["training"]["only_load_model"] = only_load_model
4788        pars["data"]["saved_data_path"] = None
4789        pars["data"]["feature_save_path"] = None
4790        pars_data_copy = self._get_data_pars(pars)
4791        saved_data_name = self._saved_datasets().find_name(pars_data_copy)
4792        if saved_data_name is not None:
4793            pars["data"]["saved_data_path"] = self._dataset_store_path(saved_data_name)
4794            pars["data"]["feature_save_path"] = self._dataset_store_path(
4795                saved_data_name
4796            ).split(".")[0]
4797        else:
4798            dataset_path = self._dataset_store_path(episode_name)
4799            if os.path.exists(dataset_path):
4800                name, ext = dataset_path.split(".")
4801                i = 0
4802                while os.path.exists(f"{name}_{i}.{ext}"):
4803                    i += 1
4804                dataset_path = f"{name}_{i}.{ext}"
4805            pars["data"]["saved_data_path"] = dataset_path
4806            pars["data"]["feature_save_path"] = dataset_path.split(".")[0]
4807        split_split = pars["training"]["partition_method"].split(":")
4808        random = True
4809        for partition_method in options.partition_methods["fixed"]:
4810            method_split = partition_method.split(":")
4811            if len(split_split) != len(method_split):
4812                continue
4813            equal = True
4814            for x, y in zip(split_split, method_split):
4815                if y.startswith("{"):
4816                    continue
4817                if x != y:
4818                    equal = False
4819                    break
4820            if equal:
4821                random = False
4822                break
4823        if random and os.path.exists(pars["training"]["split_path"]):
4824            pars["training"]["partition_method"] = "file"
4825        pars["general"]["save_dataset"] = True
4826        # Check len_segment for c2f models
4827        if pars["general"]["model_name"].startswith("c2f"):
4828            if int(pars["general"]["len_segment"]) < 512:
4829                raise ValueError(
4830                    "The segment length should be higher than 512 when using one of the C2F models"
4831                )
4832        return pars
4833
4834    def _get_data_pars(self, pars: Dict) -> Dict:
4835        """Get a complete description of the data from a general parameters dictionary.
4836
4837        Parameters
4838        ----------
4839        pars : dict
4840            the general parameters dictionary
4841
4842        Returns
4843        -------
4844        pars_data : dict
4845            the complete data parameters dictionary
4846
4847        """
4848        pars_data_copy = deepcopy(pars["data"])
4849        for par in [
4850            "only_load_annotated",
4851            "exclusive",
4852            "feature_extraction",
4853            "ignored_clips",
4854            "len_segment",
4855            "overlap",
4856        ]:
4857            pars_data_copy[par] = pars["general"].get(par, None)
4858        pars_data_copy.update(pars["features"])
4859        return pars_data_copy
4860
4861    def _make_al_points_from_suggestions(
4862        self,
4863        suggestions_name: str,
4864        task: TaskDispatcher,
4865        predicted_classes: Dict,
4866        background_threshold: Optional[float],
4867        visibility_min_score: float,
4868        visibility_min_frac: float,
4869        num_behaviors: int,
4870    ):
4871        valleys = []
4872        if background_threshold is not None:
4873            for i in range(num_behaviors):
4874                print(f"generating background for behavior {i}...")
4875                valleys.append(
4876                    task.dataset("train").find_valleys(
4877                        predicted_classes,
4878                        threshold=background_threshold,
4879                        visibility_min_score=visibility_min_score,
4880                        visibility_min_frac=visibility_min_frac,
4881                        main_class=i,
4882                        low=True,
4883                        cut_annotated=True,
4884                        min_frames=1,
4885                    )
4886                )
4887        valleys = task.dataset("train").valleys_intersection(valleys)
4888        folder = os.path.join(
4889            self.project_path, "results", "suggestions", suggestions_name
4890        )
4891        os.makedirs(os.path.dirname(folder), exist_ok=True)
4892        res = {}
4893        for file in os.listdir(folder):
4894            video_id = file.split("_suggestion.p")[0]
4895            res[video_id] = []
4896            with open(os.path.join(folder, file), "rb") as f:
4897                data = pickle.load(f)
4898            for clip_id, ind_list in zip(data[2], data[3]):
4899                max_len = max(
4900                    [
4901                        max([x[1] for x in cat_list]) if len(cat_list) > 0 else 0
4902                        for cat_list in ind_list
4903                    ]
4904                )
4905                if max_len == 0:
4906                    continue
4907                arr = torch.zeros(max_len)
4908                for cat_list in ind_list:
4909                    for start, end, amb in cat_list:
4910                        arr[start:end] = 1
4911                if video_id in valleys:
4912                    for start, end, clip in valleys[video_id]:
4913                        if clip == clip_id:
4914                            arr[start:end] = 1
4915                output, indices, counts = torch.unique_consecutive(
4916                    arr > 0, return_inverse=True, return_counts=True
4917                )
4918                long_indices = torch.where(output)[0]
4919                res[video_id] += [
4920                    (
4921                        (indices == i).nonzero(as_tuple=True)[0][0].item(),
4922                        (indices == i).nonzero(as_tuple=True)[0][-1].item(),
4923                        clip_id,
4924                    )
4925                    for i in long_indices
4926                ]
4927        return res
4928
4929    def _make_al_points(
4930        self,
4931        task: TaskDispatcher,
4932        predicted_error: torch.Tensor,
4933        predicted_classes: torch.Tensor,
4934        exclude_classes: List,
4935        exclude_threshold: List,
4936        exclude_threshold_diff: List,
4937        exclude_hysteresis: List,
4938        include_classes: List,
4939        include_threshold: List,
4940        include_threshold_diff: List,
4941        include_hysteresis: List,
4942        error_episode: str = None,
4943        error_class: str = None,
4944        suggestion_episodes: List = None,
4945        error_threshold: float = 0.5,
4946        error_threshold_diff: float = 0.1,
4947        error_hysteresis: bool = False,
4948        min_frames_al: int = 30,
4949        visibility_min_score: float = 5,
4950        visibility_min_frac: float = 0.7,
4951    ) -> Dict:
4952        """Generate an active learning file."""
4953        if len(exclude_classes) > 0 or len(include_classes) > 0:
4954            valleys = []
4955            included = None
4956            excluded = None
4957            for class_name, thr, thr_diff, hysteresis in zip(
4958                exclude_classes,
4959                exclude_threshold,
4960                exclude_threshold_diff,
4961                exclude_hysteresis,
4962            ):
4963                episode = self._episodes().get_runs(suggestion_episodes[0])[0]
4964                class_index = self._episode(episode).get_class_ind(class_name)
4965                valleys.append(
4966                    task.dataset("train").find_valleys(
4967                        predicted_classes,
4968                        predicted_error=predicted_error,
4969                        min_frames=min_frames_al,
4970                        threshold=thr,
4971                        visibility_min_score=visibility_min_score,
4972                        visibility_min_frac=visibility_min_frac,
4973                        error_threshold=error_threshold,
4974                        main_class=class_index,
4975                        low=True,
4976                        threshold_diff=thr_diff,
4977                        min_frames_error=min_frames_al,
4978                        hysteresis=hysteresis,
4979                    )
4980                )
4981            if len(valleys) > 0:
4982                included = task.dataset("train").valleys_union(valleys)
4983            valleys = []
4984            for class_name, thr, thr_diff, hysteresis in zip(
4985                include_classes,
4986                include_threshold,
4987                include_threshold_diff,
4988                include_hysteresis,
4989            ):
4990                episode = self._episodes().get_runs(suggestion_episodes[0])[0]
4991                class_index = self._episode(episode).get_class_ind(class_name)
4992                valleys.append(
4993                    task.dataset("train").find_valleys(
4994                        predicted_classes,
4995                        predicted_error=predicted_error,
4996                        min_frames=min_frames_al,
4997                        threshold=thr,
4998                        visibility_min_score=visibility_min_score,
4999                        visibility_min_frac=visibility_min_frac,
5000                        error_threshold=error_threshold,
5001                        main_class=class_index,
5002                        low=False,
5003                        threshold_diff=thr_diff,
5004                        min_frames_error=min_frames_al,
5005                        hysteresis=hysteresis,
5006                    )
5007                )
5008            if len(valleys) > 0:
5009                excluded = task.dataset("train").valleys_union(valleys)
5010            al_points = task.dataset("train").valleys_intersection([included, excluded])
5011        else:
5012            class_index = self._episode(error_episode).get_class_ind(error_class)
5013            print("generating active learning intervals...")
5014            al_points = task.dataset("train").find_valleys(
5015                predicted_error,
5016                min_frames=min_frames_al,
5017                threshold=error_threshold,
5018                visibility_min_score=visibility_min_score,
5019                visibility_min_frac=visibility_min_frac,
5020                main_class=class_index,
5021                low=True,
5022                threshold_diff=error_threshold_diff,
5023                min_frames_error=min_frames_al,
5024                hysteresis=error_hysteresis,
5025            )
5026        for v_id in al_points:
5027            clip_dict = defaultdict(lambda: [])
5028            res = []
5029            for x in al_points[v_id]:
5030                clip_dict[x[-1]].append(x)
5031            for clip_id in clip_dict:
5032                clip_dict[clip_id] = sorted(clip_dict[clip_id])
5033                i = 0
5034                j = 1
5035                while j < len(clip_dict[clip_id]):
5036                    end = clip_dict[clip_id][i][1]
5037                    start = clip_dict[clip_id][j][0]
5038                    if start - end < 30:
5039                        clip_dict[clip_id][i][1] = clip_dict[clip_id][j][1]
5040                    else:
5041                        res.append(clip_dict[clip_id][i])
5042                        i = j
5043                    j += 1
5044                res.append(clip_dict[clip_id][i])
5045            al_points[v_id] = sorted(res)
5046        return al_points
5047
5048    def _make_suggestions(
5049        self,
5050        task: TaskDispatcher,
5051        predicted_error: torch.Tensor,
5052        predicted_classes: torch.Tensor,
5053        suggestion_threshold: List,
5054        suggestion_threshold_diff: List,
5055        suggestion_hysteresis: List,
5056        suggestion_episodes: List = None,
5057        suggestion_classes: List = None,
5058        error_threshold: float = 0.5,
5059        min_frames_suggestion: int = 3,
5060        min_frames_al: int = 30,
5061        visibility_min_score: float = 0,
5062        visibility_min_frac: float = 0.7,
5063        cut_annotated: bool = False,
5064    ) -> Dict:
5065        """Make a suggestions dictionary."""
5066        suggestions = defaultdict(lambda: {})
5067        for class_name, thr, thr_diff, hysteresis in zip(
5068            suggestion_classes,
5069            suggestion_threshold,
5070            suggestion_threshold_diff,
5071            suggestion_hysteresis,
5072        ):
5073            episode = self._episodes().get_runs(suggestion_episodes[0])[0]
5074            class_index = self._episode(episode).get_class_ind(class_name)
5075            print(f"generating suggestions for {class_name}...")
5076            found = task.dataset("train").find_valleys(
5077                predicted_classes,
5078                smooth_interval=2,
5079                predicted_error=predicted_error,
5080                min_frames=min_frames_suggestion,
5081                threshold=thr,
5082                visibility_min_score=visibility_min_score,
5083                visibility_min_frac=visibility_min_frac,
5084                error_threshold=error_threshold,
5085                main_class=class_index,
5086                low=False,
5087                threshold_diff=thr_diff,
5088                min_frames_error=min_frames_al,
5089                hysteresis=hysteresis,
5090                cut_annotated=cut_annotated,
5091            )
5092            for v_id in found:
5093                suggestions[v_id][class_name] = found[v_id]
5094        suggestions = dict(suggestions)
5095        return suggestions
5096
5097    def count_classes(
5098        self,
5099        load_episode: str = None,
5100        parameters_update: Dict = None,
5101        remove_saved_features: bool = False,
5102        bouts: bool = True,
5103    ) -> Dict:
5104        """Get a dictionary of class counts in different modes.
5105
5106        Parameters
5107        ----------
5108        load_episode : str, optional
5109            the episode settings to load
5110        parameters_update : dict, optional
5111            a dictionary of parameter updates (only for "data" and "general" categories)
5112        remove_saved_features : bool, default False
5113            if `True`, the dataset that is used for computation is then deleted
5114        bouts : bool, default False
5115            if `True`, instead of frame counts segment counts are returned
5116
5117        Returns
5118        -------
5119        class_counts : dict
5120            a dictionary where first-level keys are "train", "val" and "test", second-level keys are
5121            class names and values are class counts (in frames)
5122
5123        """
5124        if load_episode is None:
5125            task, parameters = self._make_task_training(
5126                episode_name="_", parameters_update=parameters_update, throwaway=True
5127            )
5128        else:
5129            task, parameters, _ = self._make_task_prediction(
5130                "_",
5131                load_episode=load_episode,
5132                parameters_update=parameters_update,
5133            )
5134        class_counts = task.count_classes(bouts=bouts)
5135        behaviors = task.behaviors_dict()
5136        class_counts = {
5137            kk: {behaviors.get(k, "unknown"): v for k, v in vv.items()}
5138            for kk, vv in class_counts.items()
5139        }
5140        if remove_saved_features:
5141            self._remove_stores(parameters)
5142        return class_counts
5143
5144    def plot_class_distribution(
5145        self,
5146        parameters_update: Dict = None,
5147        frame_cutoff: int = 1,
5148        bout_cutoff: int = 1,
5149        print_full: bool = False,
5150        remove_saved_features: bool = False,
5151        save: str = None,
5152    ) -> None:
5153        """Make a class distribution plot.
5154
5155        You can either specify the parameters, choose an existing dataset or do neither (in that case a dataset
5156        is created or loaded for the computation with the default parameters).
5157
5158        Parameters
5159        ----------
5160        parameters_update : dict, optional
5161            a dictionary of parameter updates (only for "data" and "general" categories)
5162        frame_cutoff : int, default 1
5163            the minimum number of frames for a segment to be considered
5164        bout_cutoff : int, default 1
5165            the minimum number of bouts for a class to be considered
5166        print_full : bool, default False
5167            if `True`, the full class distribution is printed
5168        remove_saved_features : bool, default False
5169            if `True`, the dataset that is used for computation is then deleted
5170
5171        """
5172        task, parameters = self._make_task_training(
5173            episode_name="_", parameters_update=parameters_update, throwaway=True
5174        )
5175        cutoff = {True: bout_cutoff, False: frame_cutoff}
5176        for bouts in [True, False]:
5177            class_counts = task.count_classes(bouts=bouts)
5178            if print_full:
5179                print("Bouts:" if bouts else "Frames:")
5180                for k, v in class_counts.items():
5181                    if sum(v.values()) != 0:
5182                        print(f"  {k}:")
5183                        values, keys = zip(
5184                            *[
5185                                x
5186                                for x in sorted(zip(v.values(), v.keys()), reverse=True)
5187                                if x[-1] != -100
5188                            ]
5189                        )
5190                        for kk, vv in zip(keys, values):
5191                            print(f"    {task.behaviors_dict()[kk]}: {vv}")
5192            class_counts = {
5193                kk: {k: v for k, v in vv.items() if v >= cutoff[bouts]}
5194                for kk, vv in class_counts.items()
5195            }
5196            for key, d in class_counts.items():
5197                if sum(d.values()) != 0:
5198                    values, keys = zip(
5199                        *[x for x in sorted(zip(d.values(), d.keys())) if x[-1] != -100]
5200                    )
5201                    keys = [task.behaviors_dict()[x] for x in keys]
5202                    plt.bar(keys, values)
5203                    plt.title(key)
5204                    plt.xticks(rotation=45, ha="right")
5205                    if bouts:
5206                        plt.ylabel("bouts")
5207                    else:
5208                        plt.ylabel("frames")
5209                    plt.tight_layout()
5210
5211                    if save is None:
5212                        plt.savefig(save)
5213                        plt.close()
5214                    else:
5215                        plt.show()
5216        if remove_saved_features:
5217            self._remove_stores(parameters)
5218
5219    def _generate_mask(
5220        self,
5221        mask_name: str,
5222        perc_annotated: float = 0.1,
5223        parameters_update: Dict = None,
5224        remove_saved_features: bool = False,
5225    ) -> None:
5226        """Generate a real_lens for active learning simulation.
5227
5228        Parameters
5229        ----------
5230        mask_name : str
5231            the name of the real_lens
5232        perc_annotated : float, default 0.1
5233            a
5234
5235        """
5236        print(f"GENERATING {mask_name}")
5237        task, parameters = self._make_task_training(
5238            f"_{mask_name}", parameters_update=parameters_update, throwaway=True
5239        )
5240        val_intervals, val_ids = task.dataset("val").get_intervals()  # 1
5241        unannotated_intervals = task.dataset("train").get_unannotated_intervals()  # 2
5242        unannotated_intervals = task.dataset("val").get_unannotated_intervals(
5243            first_intervals=unannotated_intervals
5244        )
5245        ids = task.dataset("train").get_ids()
5246        mask = {video_id: {} for video_id in ids}
5247        total_all = 0
5248        total_masked = 0
5249        for video_id, clip_ids in ids.items():
5250            for clip_id in clip_ids:
5251                frames = np.ones(task.dataset("train").get_len(video_id, clip_id))
5252                if clip_id in val_intervals[video_id]:
5253                    for start, end in val_intervals[video_id][clip_id]:
5254                        frames[start:end] = 0
5255                if clip_id in unannotated_intervals[video_id]:
5256                    for start, end in unannotated_intervals[video_id][clip_id]:
5257                        frames[start:end] = 0
5258                annotated = np.where(frames)[0]
5259                total_all += len(annotated)
5260                masked = annotated[-int(len(annotated) * (1 - perc_annotated)) :]
5261                total_masked += len(masked)
5262                mask[video_id][clip_id] = self._get_intervals(masked)
5263        file = {
5264            "masked": mask,
5265            "val_intervals": val_intervals,
5266            "val_ids": val_ids,
5267            "unannotated": unannotated_intervals,
5268        }
5269        self._save_mask(file, mask_name)
5270        if remove_saved_features:
5271            self._remove_stores(parameters)
5272        print("\n")
5273        # print(f'Unmasked: {sum([(vv == 0).sum() for v in real_lens.values() for vv in v.values()])} frames')
5274
5275    def _get_intervals(self, frame_indices: np.ndarray):
5276        """Get a list of intervals from a list of frame indices.
5277
5278        Example: `[0, 1, 2, 5, 6, 8] -> [[0, 3], [5, 7], [8, 9]]`.
5279
5280        Parameters
5281        ----------
5282        frame_indices : np.ndarray
5283            a list of frame indices
5284
5285        Returns
5286        -------
5287        intervals : list
5288            a list of interval boundaries
5289
5290        """
5291        masked_intervals = []
5292        if len(frame_indices) > 0:
5293            breaks = np.where(np.diff(frame_indices) != 1)[0]
5294            start = frame_indices[0]
5295            for k in breaks:
5296                masked_intervals.append([start, frame_indices[k] + 1])
5297                start = frame_indices[k + 1]
5298            masked_intervals.append([start, frame_indices[-1] + 1])
5299        return masked_intervals
5300
5301    def _update_mask_with_uncertainty(
5302        self,
5303        mask_name: str,
5304        episode_name: Union[str, None],
5305        classes: List,
5306        load_epoch: int = None,
5307        n_frames: int = 10000,
5308        method: str = "least_confidence",
5309        min_length: int = 30,
5310        augment_n: int = 0,
5311        parameters_update: Dict = None,
5312    ):
5313        """Update real_lens with frame-wise uncertainty scores for active learning.
5314
5315        Parameters
5316        ----------
5317        mask_name : str
5318            the name of the real_lens
5319        episode_name : str
5320            the name of the episode to load
5321        classes : list
5322            a list of class names or indices; their uncertainty scores will be computed separately and stacked
5323        load_epoch : int, optional
5324            the epoch to load (by default last; if this epoch is not saved the closest checkpoint is chosen)
5325        n_frames : int, default 10000
5326            the number of frames to "annotate"
5327        method : {"least_confidence", "entropy"}
5328            the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if
5329            `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`)
5330        min_length : int
5331            the minimum length (in frames) of the annotated intervals
5332        augment_n : int, default 0
5333            the number of augmentations to average over
5334        parameters_update : dict, optional
5335            the dictionary used to update the parameters from the config
5336
5337        Returns
5338        -------
5339        score_dicts : dict
5340            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
5341            are score tensors
5342
5343        """
5344        print(f"UPDATING {mask_name}")
5345        task, parameters, _ = self._make_task_prediction(
5346            prediction_name=mask_name,
5347            load_episode=episode_name,
5348            parameters_update=parameters_update,
5349            load_epoch=load_epoch,
5350            mode="train",
5351        )
5352        score_tensors = task.generate_uncertainty_score(classes, augment_n, method)
5353        self._update_mask(task, mask_name, score_tensors, n_frames, min_length)
5354        print("\n")
5355
5356    def _update_mask_with_BALD(
5357        self,
5358        mask_name: str,
5359        episode_name: str,
5360        classes: List,
5361        load_epoch: int = None,
5362        augment_n: int = 0,
5363        n_frames: int = 10000,
5364        num_models: int = 10,
5365        kernel_size: int = 11,
5366        min_length: int = 30,
5367        parameters_update: Dict = None,
5368    ):
5369        """Update real_lens with frame-wise Bayesian Active Learning by Disagreement scores for active learning.
5370
5371        Parameters
5372        ----------
5373        mask_name : str
5374            the name of the real_lens
5375        episode_name : str
5376            the name of the episode to load
5377        classes : list
5378            a list of class names or indices; their uncertainty scores will be computed separately and stacked
5379        load_epoch : int, optional
5380            the epoch to load (by default last)
5381        augment_n : int, default 0
5382            the number of augmentations to average over
5383        n_frames : int, default 10000
5384            the number of frames to "annotate"
5385        num_models : int, default 10
5386            the number of dropout masks to apply
5387        kernel_size : int, default 11
5388            the size of the smoothing gaussian kernel
5389        min_length : int
5390            the minimum length (in frames) of the annotated intervals
5391        parameters_update : dict, optional
5392            the dictionary used to update the parameters from the config
5393
5394        Returns
5395        -------
5396        score_dicts : dict
5397            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
5398            are score tensors
5399
5400        """
5401        print(f"UPDATING {mask_name}")
5402        task, parameters, mode = self._make_task_prediction(
5403            mask_name,
5404            load_episode=episode_name,
5405            parameters_update=parameters_update,
5406            load_epoch=load_epoch,
5407        )
5408        score_tensors = task.generate_bald_score(
5409            classes, augment_n, num_models, kernel_size
5410        )
5411        self._update_mask(task, mask_name, score_tensors, n_frames, min_length)
5412        print("\n")
5413
5414    def _suggest_intervals(
5415        self,
5416        dataset: BehaviorDataset,
5417        score_tensors: Dict,
5418        n_frames: int,
5419        min_length: int,
5420    ) -> Dict:
5421        """Suggest intervals with highest score of total length `n_frames`.
5422
5423        Parameters
5424        ----------
5425        dataset : BehaviorDataset
5426            the dataset
5427        score_tensors : dict
5428            a dictionary where keys are clip ids and values are framewise score tensors
5429        n_frames : int
5430            the number of frames to "annotate"
5431        min_length : int
5432            minimum length of suggested intervals
5433
5434        Returns
5435        -------
5436        active_learning_intervals : Dict
5437            active learning dictionary with suggested intervals
5438
5439        """
5440        video_intervals, _ = dataset.get_intervals()
5441        taken = {
5442            video_id: defaultdict(lambda: {}) for video_id in video_intervals.keys()
5443        }
5444        annotated = dataset.get_annotated_intervals()
5445        for video_id in video_intervals:
5446            for clip_id in video_intervals[video_id]:
5447                taken[video_id][clip_id] = torch.zeros(
5448                    dataset.get_len(video_id, clip_id)
5449                )
5450                if video_id in annotated and clip_id in annotated[video_id]:
5451                    for start, end in annotated[video_id][clip_id]:
5452                        score_tensors[video_id][clip_id][:, start:end] = -10
5453                        taken[video_id][clip_id][int(start) : int(end)] = 1
5454        n_frames = (
5455            sum([(vv == 1).sum() for v in taken.values() for vv in v.values()])
5456            + n_frames
5457        )
5458        factor = 1
5459        threshold_start = float(
5460            torch.mean(
5461                torch.tensor(
5462                    [
5463                        torch.mean(
5464                            torch.tensor([torch.mean(y[y > 0]) for y in x.values()])
5465                        )
5466                        for x in score_tensors.values()
5467                    ]
5468                )
5469            )
5470        )
5471        while (
5472            sum([(vv == 1).sum() for v in taken.values() for vv in v.values()])
5473            < n_frames
5474        ):
5475            threshold = threshold_start * factor
5476            intervals = []
5477            interval_scores = []
5478            key1 = list(score_tensors.keys())[0]
5479            key2 = list(score_tensors[key1].keys())[0]
5480            num_scores = score_tensors[key1][key2].shape[0]
5481            for i in range(num_scores):
5482                v_dict = dataset.find_valleys(
5483                    predicted=score_tensors,
5484                    threshold=threshold,
5485                    min_frames=min_length,
5486                    main_class=i,
5487                    low=False,
5488                )
5489                for v_id, interval_list in v_dict.items():
5490                    intervals += [x + [v_id] for x in interval_list]
5491                    interval_scores += [
5492                        float(torch.mean(score_tensors[v_id][clip_id][i, start:end]))
5493                        for start, end, clip_id in interval_list
5494                    ]
5495            intervals = np.array(intervals)[np.argsort(interval_scores)[::-1]]
5496            i = 0
5497            while sum(
5498                [(vv == 1).sum() for v in taken.values() for vv in v.values()]
5499            ) < n_frames and i < len(intervals):
5500                start, end, clip_id, video_id = intervals[i]
5501                i += 1
5502                taken[video_id][clip_id][int(start) : int(end)] = 1
5503            factor *= 0.9
5504            if factor < 0.05:
5505                warnings.warn(f"Could not find enough frames!")
5506                break
5507        active_learning_intervals = {video_id: [] for video_id in video_intervals}
5508        for video_id in taken:
5509            for clip_id in taken[video_id]:
5510                if video_id in annotated and clip_id in annotated[video_id]:
5511                    for start, end in annotated[video_id][clip_id]:
5512                        taken[video_id][clip_id][int(start) : int(end)] = 0
5513                if (taken[video_id][clip_id] == 1).sum() == 0:
5514                    continue
5515                indices = np.where(taken[video_id][clip_id].numpy())[0]
5516                boundaries = self._get_intervals(indices)
5517                active_learning_intervals[video_id] += [
5518                    [start, end, clip_id] for start, end in boundaries
5519                ]
5520        return active_learning_intervals
5521
5522    def _update_mask(
5523        self,
5524        task: TaskDispatcher,
5525        mask_name: str,
5526        score_tensors: Dict,
5527        n_frames: int,
5528        min_length: int,
5529    ) -> None:
5530        """Update the real_lens with intervals with the highest score of total length `n_frames`.
5531
5532        Parameters
5533        ----------
5534        task : TaskDispatcher
5535            the task dispatcher object
5536        mask_name : str
5537            the name of the real_lens
5538        score_tensors : dict
5539            a dictionary where keys are clip ids and values are framewise score tensors
5540        n_frames : int
5541            the number of frames to "annotate"
5542        min_length : int
5543            the minimum length of the annotated intervals
5544
5545        """
5546        mask = self._load_mask(mask_name)
5547        video_intervals, _ = task.dataset("train").get_intervals()
5548        masked = {
5549            video_id: defaultdict(lambda: {}) for video_id in video_intervals.keys()
5550        }
5551        total_masked = 0
5552        total_all = 0
5553        for video_id in video_intervals:
5554            for clip_id in video_intervals[video_id]:
5555                masked[video_id][clip_id] = torch.zeros(
5556                    task.dataset("train").get_len(video_id, clip_id)
5557                )
5558                if (
5559                    video_id in mask["unannotated"]
5560                    and clip_id in mask["unannotated"][video_id]
5561                ):
5562                    for start, end in mask["unannotated"][video_id][clip_id]:
5563                        score_tensors[video_id][clip_id][:, start:end] = -10
5564                        masked[video_id][clip_id][int(start) : int(end)] = 1
5565                if (
5566                    video_id in mask["val_intervals"]
5567                    and clip_id in mask["val_intervals"][video_id]
5568                ):
5569                    for start, end in mask["val_intervals"][video_id][clip_id]:
5570                        score_tensors[video_id][clip_id][:, start:end] = -10
5571                        masked[video_id][clip_id][int(start) : int(end)] = 1
5572                total_all += torch.sum(masked[video_id][clip_id] == 0)
5573                if video_id in mask["masked"] and clip_id in mask["masked"][video_id]:
5574                    # print(f'{real_lens["masked"][video_id][clip_id]=}')
5575                    for start, end in mask["masked"][video_id][clip_id]:
5576                        masked[video_id][clip_id][int(start) : int(end)] = 1
5577                        total_masked += end - start
5578        old_n_frames = sum(
5579            [(vv == 0).sum() for v in masked.values() for vv in v.values()]
5580        )
5581        n_frames = old_n_frames + n_frames
5582        factor = 1
5583        while (
5584            sum([(vv == 0).sum() for v in masked.values() for vv in v.values()])
5585            < n_frames
5586        ):
5587            threshold = float(
5588                torch.mean(
5589                    torch.tensor(
5590                        [
5591                            torch.mean(
5592                                torch.tensor([torch.mean(y[y > 0]) for y in x.values()])
5593                            )
5594                            for x in score_tensors.values()
5595                        ]
5596                    )
5597                )
5598            )
5599            threshold = threshold * factor
5600            intervals = []
5601            interval_scores = []
5602            key1 = list(score_tensors.keys())[0]
5603            key2 = list(score_tensors[key1].keys())[0]
5604            num_scores = score_tensors[key1][key2].shape[0]
5605            for i in range(num_scores):
5606                v_dict = task.dataset("train").find_valleys(
5607                    predicted=score_tensors,
5608                    threshold=threshold,
5609                    min_frames=min_length,
5610                    main_class=i,
5611                    low=False,
5612                )
5613                for v_id, interval_list in v_dict.items():
5614                    intervals += [x + [v_id] for x in interval_list]
5615                    interval_scores += [
5616                        float(torch.mean(score_tensors[v_id][clip_id][i, start:end]))
5617                        for start, end, clip_id in interval_list
5618                    ]
5619            intervals = np.array(intervals)[np.argsort(interval_scores)[::-1]]
5620            i = 0
5621            while sum(
5622                [(vv == 0).sum() for v in masked.values() for vv in v.values()]
5623            ) < n_frames and i < len(intervals):
5624                start, end, clip_id, video_id = intervals[i]
5625                i += 1
5626                masked[video_id][clip_id][int(start) : int(end)] = 0
5627            factor *= 0.9
5628            if factor < 0.05:
5629                warnings.warn(f"Could not find enough frames!")
5630                break
5631        mask["masked"] = {video_id: {} for video_id in video_intervals}
5632        total_masked_new = 0
5633        for video_id in masked:
5634            for clip_id in masked[video_id]:
5635                if (
5636                    video_id in mask["unannotated"]
5637                    and clip_id in mask["unannotated"][video_id]
5638                ):
5639                    for start, end in mask["unannotated"][video_id][clip_id]:
5640                        masked[video_id][clip_id][int(start) : int(end)] = 0
5641                if (
5642                    video_id in mask["val_intervals"]
5643                    and clip_id in mask["val_intervals"][video_id]
5644                ):
5645                    for start, end in mask["val_intervals"][video_id][clip_id]:
5646                        masked[video_id][clip_id][int(start) : int(end)] = 0
5647                indices = np.where(masked[video_id][clip_id].numpy())[0]
5648                mask["masked"][video_id][clip_id] = self._get_intervals(indices)
5649        for video_id in mask["masked"]:
5650            for clip_id in mask["masked"][video_id]:
5651                for start, end in mask["masked"][video_id][clip_id]:
5652                    total_masked_new += end - start
5653        self._save_mask(mask, mask_name)
5654        with open(
5655            os.path.join(
5656                self.project_path, "results", f"{mask_name}.txt", encoding="utf-8"
5657            ),
5658            "a",
5659        ) as f:
5660            f.write(f"from {total_masked} to {total_masked_new} / {total_all}" + "\n")
5661        print(f"Unmasked from {total_masked} to {total_masked_new} / {total_all}")
5662
5663    def _visualize_results_label(
5664        self,
5665        episode_name: str,
5666        label: str,
5667        load_epoch: int = None,
5668        parameters_update: Dict = None,
5669        add_legend: bool = True,
5670        ground_truth: bool = True,
5671        hide_axes: bool = False,
5672        width: float = 10,
5673        whole_video: bool = False,
5674        transparent: bool = False,
5675        num_plots: int = 1,
5676        smooth_interval: int = 0,
5677    ):
5678        other_path = os.path.join(self.project_path, "results", "other")
5679        if not os.path.exists(other_path):
5680            os.mkdir(other_path)
5681        if parameters_update is None:
5682            parameters_update = {}
5683        if "model" in parameters_update.keys():
5684            raise ValueError("Cannot change model parameters after training!")
5685        task, parameters, _ = self._make_task_prediction(
5686            "_",
5687            load_episode=episode_name,
5688            parameters_update=parameters_update,
5689            load_epoch=load_epoch,
5690            mode="val",
5691        )
5692        for i in range(num_plots):
5693            print(i)
5694            task._visualize_results_label(
5695                smooth_interval=smooth_interval,
5696                label=label,
5697                save_path=os.path.join(
5698                    other_path, f"{episode_name}_prediction_{i}.jpg"
5699                ),
5700                add_legend=add_legend,
5701                ground_truth=ground_truth,
5702                hide_axes=hide_axes,
5703                whole_video=whole_video,
5704                transparent=transparent,
5705                dataset="val",
5706                width=width,
5707                title=str(i),
5708            )
5709
5710    def plot_confusion_matrix(
5711        self,
5712        episode_name: str,
5713        load_epoch: int = None,
5714        parameters_update: Dict = None,
5715        metric: str = "recall",
5716        mode: str = "val",
5717        remove_saved_features: bool = False,
5718        save_path: str = None,
5719        cmap: str = "viridis",
5720    ) -> Tuple[ndarray, Iterable]:
5721        """Make a confusion matrix plot and return the data.
5722
5723        If the annotation is non-exclusive, only false positive labels are considered.
5724
5725        Parameters
5726        ----------
5727        episode_name : str
5728            the name of the episode to load
5729        load_epoch : int, optional
5730            the index of the epoch to load (by default the last one is loaded)
5731        parameters_update : dict, optional
5732            a dictionary of parameter updates (only for "data" and "general" categories)
5733        metric : {"recall", "precision"}
5734            for datasets with non-exclusive annotation, if `type` is `"recall"`, only false positives are taken
5735            into account, and if `type` is `"precision"`, only false negatives
5736        mode : {'val', 'all', 'test', 'train'}
5737            the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
5738        remove_saved_features : bool, default False
5739            if `True`, the dataset that is used for computation is then deleted
5740
5741        Returns
5742        -------
5743        confusion_matrix : np.ndarray
5744            a confusion matrix of shape `(#classes, #classes)` where `A[i, j] = F_ij/N_i`, `F_ij` is the number of
5745            frames that have the i-th label in the ground truth and a false positive j-th label in the prediction,
5746            `N_i` is the number of frames that have the i-th label in the ground truth
5747        classes : list
5748            a list of labels
5749
5750        """
5751        task, parameters, mode = self._make_task_prediction(
5752            "_",
5753            load_episode=episode_name,
5754            load_epoch=load_epoch,
5755            parameters_update=parameters_update,
5756            mode=mode,
5757        )
5758        dataset = task.dataset(mode)
5759        prediction = task.predict(dataset, raw_output=True)
5760        confusion_matrix, classes, type = dataset.get_confusion_matrix(prediction, type)
5761        if remove_saved_features:
5762            self._remove_stores(parameters)
5763        fig, ax = plt.subplots(figsize=(len(classes), len(classes)))
5764        ax.imshow(confusion_matrix, cmap=cmap)
5765        # Show all ticks and label them with the respective list entries
5766        ax.set_xticks(np.arange(len(classes)))
5767        ax.set_xticklabels(classes)
5768        ax.set_yticks(np.arange(len(classes)))
5769        ax.set_yticklabels(classes)
5770        # Rotate the tick labels and set their alignment.
5771        plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
5772        # Loop over data dimensions and create text annotations.
5773        for i in range(len(classes)):
5774            for j in range(len(classes)):
5775                ax.text(
5776                    j,
5777                    i,
5778                    np.round(confusion_matrix[i, j], 2),
5779                    ha="center",
5780                    va="center",
5781                    color="w",
5782                )
5783        if metric is not None:
5784            ax.set_title(f"{metric} {episode_name}")
5785        else:
5786            ax.set_title(episode_name)
5787        fig.tight_layout()
5788        if save_path is None:
5789            plt.show()
5790        else:
5791            plt.savefig(save_path)
5792            plt.close()
5793        return confusion_matrix, classes
5794
5795    def _plot_ethograms_gt_pred(
5796        self,
5797        data_gt: dict,
5798        data_pred: dict,
5799        labels_gt: list,
5800        labels_pred: list,
5801        start: int = 0,
5802        end: int = -1,
5803        cmap_pred: str = "binary",
5804        cmap_gt: str = "binary",
5805        save: str = None,
5806        fontsize=22,
5807        time_mode="frames",
5808        fps: int = None,
5809    ) -> None:
5810        """Plot ethograms from start to end time (in frames), mode can be prediction or ground truth depending on the data format."""
5811        # print(data.keys())
5812        best_pred = (
5813            data_pred[list(data_pred.keys())[0]].numpy() > 0.5
5814        )  # Threshold the predictions
5815        data_gt = binarize_data(data_gt, max_frame=end)
5816
5817        # Crop data to min length
5818        if end < 0:
5819            end = min(data_gt.shape[1], best_pred.shape[1])
5820        data_gt = data_gt[:, :end]
5821        best_pred = best_pred[:, :end]
5822
5823        # Reorder behaviors
5824        ind_gt = []
5825        ind_pred = []
5826        labels_pred = [labels_pred[i] for i in range(len(labels_pred))]
5827        labels_pred = np.roll(
5828            labels_pred, 1
5829        ).tolist()  
5830        check_gt = np.where(np.sum(data_gt, axis=1) > 0)[0]
5831        check_pred = np.where(np.sum(best_pred, axis=1) > 0)[0]
5832        for k, gt_beh in enumerate(labels_gt):
5833            if gt_beh in labels_pred:
5834                j = labels_pred.index(gt_beh)
5835                if not k in check_gt and not j in check_pred:
5836                    continue
5837                ind_gt.append(labels_gt.index(gt_beh))
5838                ind_pred.append(j)
5839        # Create label list
5840        labels = np.array(labels_gt)[ind_gt]
5841        assert (labels == np.array(labels_pred)[ind_pred]).all()
5842
5843        # # Create image
5844        image_pred = best_pred[ind_pred].astype(float)
5845        image_gt = data_gt[ind_gt]
5846
5847        f, axs = plt.subplots(
5848            len(labels), 1, figsize=(5 * len(labels), 15), sharex=True
5849        )
5850        end = image_gt.shape[1] if end < 0 else end
5851        for i, (ax, label) in enumerate(zip(axs, labels)):
5852
5853            im1 = np.array([image_gt[i], np.ones_like(image_gt[i]) * (-1)])
5854            im1 = np.ma.masked_array(im1, im1 < 0)
5855
5856            im2 = np.array([np.ones_like(image_pred[i]) * (-1), image_pred[i]])
5857            im2 = np.ma.masked_array(im2, im2 < 0)
5858
5859            ax.imshow(im1, aspect="auto", cmap=cmap_gt, interpolation="nearest")
5860            ax.imshow(im2, aspect="auto", cmap=cmap_pred, interpolation="nearest")
5861
5862            ax.set_yticks(np.arange(2), ["GT", "Pred"], fontsize=fontsize)
5863            ax.tick_params(axis="x", labelsize=fontsize)
5864            ax.set_ylabel(label, fontsize=fontsize)
5865            if time_mode == "frames":
5866                ax.set_xlabel("Num Frames", fontsize=fontsize)
5867            elif time_mode == "seconds":
5868                assert not fps is None, "Please provide fps"
5869                ax.set_xlabel("Time (s)", fontsize=fontsize)
5870                ax.set_xticks(
5871                    np.linspace(0, end, 10),
5872                    np.linspace(0, end / fps, 10).astype(np.int32),
5873                )
5874
5875            ax.set_xlim(start, end)
5876
5877        if save is None:
5878            plt.show()
5879        else:
5880            plt.savefig(save)
5881            plt.close()
5882
5883    def plot_ethograms(
5884        self,
5885        episode_name: str,
5886        prediction_name: str,
5887        start: int = 0,
5888        end: int = -1,
5889        save_path: str = None,
5890        cmap_pred: str = "binary",
5891        cmap_gt: str = "binary",
5892        fontsize: int = 22,
5893        time_mode: str = "frames",
5894        fps: int = None,
5895    ):
5896        """Plot ethograms from start to end time (in frames) for ground truth and prediction"""
5897        params = self._read_parameters(catch_blanks=False)
5898        parameters = self._get_data_pars(
5899            params,
5900        )
5901        if not save_path is None:
5902            os.makedirs(save_path, exist_ok=True)
5903        gt_files = [
5904            f for f in self.data_path if f.endswith(parameters["annotation_suffix"])
5905        ]
5906        pred_path = os.path.join(
5907            self.project_path, "results", "predictions", prediction_name
5908        )
5909        pred_paths = [os.path.join(pred_path, f) for f in os.listdir(pred_path)]
5910        for pred_path in pred_paths:
5911            predictions = load_pickle(pred_path)
5912            behaviors = self.get_behavior_dictionary(episode_name)
5913            gt_filename = os.path.basename(pred_path).replace(
5914                "_".join(["_" + prediction_name, "prediction.pickle"]),
5915                parameters["annotation_suffix"],
5916            )
5917            if os.path.exists(os.path.join(self.data_path, gt_filename)):
5918                gt_data = load_pickle(os.path.join(self.data_path, gt_filename))
5919
5920                self._plot_ethograms_gt_pred(
5921                    gt_data,
5922                    predictions,
5923                    gt_data[1],
5924                    behaviors,
5925                    start=start,
5926                    end=end,
5927                    save=os.path.join(
5928                        save_path,
5929                        os.path.splitext(os.path.basename(pred_path))[0] + "_gt_pred",
5930                    ),
5931                    cmap_pred=cmap_pred,
5932                    cmap_gt=cmap_gt,
5933                    fontsize=fontsize,
5934                    time_mode=time_mode,
5935                    fps=fps,
5936                )
5937            else:
5938                print("GT file not found")
5939
5940    def _create_side_panel(self, height, width, labels_pred, preds, labels_gt, gt=None):
5941        """Create a side panel for video annotation display.
5942
5943        Parameters
5944        ----------
5945        height : int
5946            the height of the panel
5947        width : int
5948            the width of the panel
5949        labels_pred : list
5950            the list of predicted behavior labels
5951        preds : array-like
5952            the prediction values for each behavior
5953        labels_gt : list
5954            the list of ground truth behavior labels
5955        gt : array-like, optional
5956            the ground truth values for each behavior
5957
5958        Returns
5959        -------
5960        side_panel : np.ndarray
5961            the created side panel as an image array
5962
5963        """
5964        side_panel = np.ones((height, int(width / 4), 3), dtype=np.uint8) * 255
5965
5966        beh_indices = np.where(preds)[0]
5967        for i, label in enumerate(labels_pred):
5968            color = (0, 0, 0)
5969            if i in beh_indices:
5970                color = (0, 255, 0)
5971            cv2.putText(
5972                side_panel,
5973                label,
5974                (10, 50 + 50 * i),
5975                cv2.FONT_HERSHEY_SIMPLEX,
5976                1,
5977                color,
5978                2,
5979                cv2.LINE_AA,
5980            )
5981        if gt is not None:
5982            beh_indices_gt = np.where(gt)[0]
5983            for i, label in enumerate(labels_gt):
5984                color = (0, 0, 0)
5985                if i in beh_indices_gt:
5986                    color = (0, 255, 0)
5987                cv2.putText(
5988                    side_panel,
5989                    label,
5990                    (10, 50 + 50 * i + 80 * len(labels_pred)),
5991                    cv2.FONT_HERSHEY_SIMPLEX,
5992                    1,
5993                    color,
5994                    2,
5995                    cv2.LINE_AA,
5996                )
5997        return side_panel
5998
5999    def create_annotated_video(
6000        self,
6001        prediction_file_paths: list,
6002        video_file_paths: list,
6003        episode_name: str,  # To get the list of behaviors
6004        ground_truth_file_paths: list = None,
6005        pred_thresh: float = 0.5,
6006        start: int = 0,
6007        end: int = -1,
6008    ):
6009        """Create a video with the predictions overlaid on the video"""
6010        for k, (pred_path, vid_path) in enumerate(
6011            zip(prediction_file_paths, video_file_paths)
6012        ):
6013            print("Generating video for :", os.path.basename(vid_path))
6014            predictions = load_pickle(pred_path)
6015            best_pred = predictions[list(predictions.keys())[0]].numpy() > pred_thresh
6016            behaviors = self.get_behavior_dictionary(episode_name)
6017            # Load video
6018            labels_pred = [behaviors[i] for i in range(len(behaviors))]
6019            labels_pred = np.roll(
6020                labels_pred, 1
6021            ).tolist() 
6022
6023            gt_data = None
6024            if ground_truth_file_paths is not None:
6025                gt_data = load_pickle(ground_truth_file_paths[k])
6026                labels_gt = gt_data[1]
6027                gt_data = binarize_data(gt_data, max_frame=best_pred.shape[1])
6028
6029            cap = cv2.VideoCapture(vid_path)
6030            cap.set(cv2.CAP_PROP_POS_FRAMES, start)
6031            end = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if end < 0 else end
6032            fps = cap.get(cv2.CAP_PROP_FPS)
6033            width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
6034            height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
6035            fourcc = cv2.VideoWriter_fourcc(*"mp4v")
6036            out = cv2.VideoWriter(
6037                os.path.join(
6038                    os.path.dirname(vid_path),
6039                    os.path.splitext(os.path.basename(vid_path))[0] + "_annotated.mp4",
6040                ),
6041                fourcc,
6042                fps,
6043                # (width + int(width/4) , height),
6044                (600, 300),
6045            )
6046            count = 0
6047            bar = tqdm(total=end - start)
6048            while cap.isOpened():
6049                ret, frame = cap.read()
6050                if not ret:
6051                    break
6052
6053                side_panel = self._create_side_panel(
6054                    height,
6055                    width,
6056                    labels_pred,
6057                    best_pred[:, count],
6058                    labels_gt,
6059                    gt_data[:, count],
6060                )
6061                frame = np.concatenate((frame, side_panel), axis=1)
6062                frame = cv2.resize(frame, (0, 0), fx=0.25, fy=0.25)
6063                out.write(frame)
6064                count += 1
6065                bar.update(1)
6066
6067                if count > end:
6068                    break
6069
6070            cap.release()
6071            out.release()
6072            cv2.destroyAllWindows()
6073
6074    def plot_predictions(
6075        self,
6076        episode_name: str,
6077        load_epoch: int = None,
6078        parameters_update: Dict = None,
6079        add_legend: bool = True,
6080        ground_truth: bool = True,
6081        colormap: str = "dlc2action",
6082        hide_axes: bool = False,
6083        min_classes: int = 1,
6084        width: float = 10,
6085        whole_video: bool = False,
6086        transparent: bool = False,
6087        drop_classes: Set = None,
6088        search_classes: Set = None,
6089        num_plots: int = 1,
6090        remove_saved_features: bool = False,
6091        smooth_interval_prediction: int = 0,
6092        data_path: str = None,
6093        file_paths: Set = None,
6094        mode: str = "val",
6095        font_size: float = None,
6096        window_size: int = 400,
6097    ) -> None:
6098        """Visualize random predictions.
6099
6100        Parameters
6101        ----------
6102        episode_name : str
6103            the name of the episode to load
6104        load_epoch : int, optional
6105            the epoch to load (by default last)
6106        parameters_update : dict, optional
6107            parameter update dictionary
6108        add_legend : bool, default True
6109            if True, legend will be added to the plot
6110        ground_truth : bool, default True
6111            if True, ground truth will be added to the plot
6112        colormap : str, default 'Accent'
6113            the `matplotlib` colormap to use
6114        hide_axes : bool, default True
6115            if `True`, the axes will be hidden on the plot
6116        min_classes : int, default 1
6117            the minimum number of classes in a displayed interval
6118        width : float, default 10
6119            the width of the plot
6120        whole_video : bool, default False
6121            if `True`, whole videos are plotted instead of segments
6122        transparent : bool, default False
6123            if `True`, the background on the plot is transparent
6124        drop_classes : set, optional
6125            a set of class names to not be displayed
6126        search_classes : set, optional
6127            if given, only intervals where at least one of the classes is in ground truth will be shown
6128        num_plots : int, default 1
6129            the number of plots to make
6130        remove_saved_features : bool, default False
6131            if `True`, the dataset will be deleted after computation
6132        smooth_interval_prediction : int, default 0
6133            if >0, predictions shorter than this number of frames are removed (filled with prediction for the previous frame)
6134        data_path : str, optional
6135            the data path to run the prediction for
6136        file_paths : set, optional
6137            a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction
6138            for
6139        mode : {'all', 'test', 'val', 'train'}
6140            the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
6141
6142        """
6143        plot_path = os.path.join(self.project_path, "results", "plots")
6144        task, parameters, mode = self._make_task_prediction(
6145            "_",
6146            load_episode=episode_name,
6147            parameters_update=parameters_update,
6148            load_epoch=load_epoch,
6149            data_path=data_path,
6150            file_paths=file_paths,
6151            mode=mode,
6152        )
6153        os.makedirs(plot_path, exist_ok=True)
6154        task.visualize_results(
6155            save_path=os.path.join(plot_path, f"{episode_name}_prediction.svg"),
6156            add_legend=add_legend,
6157            ground_truth=ground_truth,
6158            colormap=colormap,
6159            hide_axes=hide_axes,
6160            min_classes=min_classes,
6161            whole_video=whole_video,
6162            transparent=transparent,
6163            dataset=mode,
6164            drop_classes=drop_classes,
6165            search_classes=search_classes,
6166            width=width,
6167            smooth_interval_prediction=smooth_interval_prediction,
6168            font_size=font_size,
6169            num_plots=num_plots,
6170            window_size=window_size,
6171        )
6172        if remove_saved_features:
6173            self._remove_stores(parameters)
6174
6175    def create_video_from_labels(
6176        self,
6177        video_dir_path: str,
6178        mode="ground_truth",
6179        prediction_name: str = None,
6180        save_path: str = None,
6181    ):
6182        if save_path is None:
6183            save_path = os.path.join(
6184                self.project_path, "results", f"annotated_videos_from_{mode}"
6185            )
6186        os.makedirs(save_path, exist_ok=True)
6187
6188        params = self._read_parameters(catch_blanks=False)
6189
6190        if mode == "ground_truth":
6191            source_dir = self.annotation_path
6192            annotation_suffix = params["data"]["annotation_suffix"]
6193        elif mode == "prediction":
6194            assert (
6195                not prediction_name is None
6196            ), "Please provide a prediction name with mode 'prediction'"
6197            source_dir = os.path.join(
6198                self.project_path, "results", "predictions", prediction_name
6199            )
6200            annotation_suffix = f"_{prediction_name}_prediction.pickle"
6201
6202        video_annotation_pairs = [
6203            (
6204                os.path.join(video_dir_path, f),
6205                os.path.join(
6206                    source_dir, f.replace(f.split(".")[-1], annotation_suffix)
6207                ),
6208            )
6209            for f in os.listdir(video_dir_path)
6210            if os.path.exists(
6211                os.path.join(source_dir, f.replace(f.split(".")[-1], annotation_suffix))
6212            )
6213        ]
6214
6215        for video_file, annotation_file in tqdm(video_annotation_pairs):
6216            if not os.path.exists(video_file):
6217                print(f"Video file {video_file} does not exist, skipping.")
6218                continue
6219            if not os.path.exists(annotation_file):
6220                print(f"Annotation file {annotation_file} does not exist, skipping.")
6221                continue
6222
6223            if annotation_file.endswith(".pickle"):
6224                annotations = load_pickle(annotation_file)
6225            elif annotation_file.endswith(".csv"):
6226                annotations = pd.read_csv(annotation_file)
6227
6228            if mode == "ground_truth":
6229                behaviors = annotations[1]
6230                annot_data = annotations[3]
6231            elif mode == "predictions":
6232                behaviors = list(annotations["classes"].values())
6233                annot_data = [
6234                    annotations[key]
6235                    for key in annotations.keys()
6236                    if key not in ["classes", "min_frame", "max_frame"]
6237                ]
6238                if params["general"]["exclusive"]:
6239                    annot_data = [np.argmax(annot, axis=1) for annot in annot_data]
6240                    seqs = [
6241                        [
6242                            self._bin_array_to_sequences(annot, target_value=k)
6243                            for k in range(len(behaviors))
6244                        ]
6245                        for annot in annot_data
6246                    ]
6247                else:
6248                    annot_data = [np.where(annot > 0.5)[0] for annot in annot_data]
6249                    seqs = [
6250                        self._bin_array_to_sequences(annot, target_value=1)
6251                        for annot in annot_data
6252                    ]
6253                annotations = ["", "", seqs]
6254
6255            for individual in annotations[3]:
6256                for behavior in annotations[3][individual]:
6257                    intervals = annotations[3][individual][behavior]
6258                    self._extract_videos(
6259                        video_file,
6260                        intervals,
6261                        behavior,
6262                        individual,
6263                        save_path,
6264                        resolution=(640, 480),
6265                        fps=30,
6266                    )
6267
6268    def _bin_array_to_sequences(
6269        self, annot_data: List[np.ndarray], target_value: int
6270    ) -> List[List[Tuple[int, int]]]:
6271        is_target = annot_data == target_value
6272        changes = np.diff(np.concatenate(([False], is_target, [False])))
6273        indices = np.where(changes)[0].reshape(-1, 2)
6274        subsequences = [list(range(start, end)) for start, end in indices]
6275        return subsequences
6276
6277    def _extract_videos(
6278        self,
6279        video_file: str,
6280        intervals: np.ndarray,
6281        behavior: str,
6282        individual: str,
6283        video_dir: str,
6284        resolution: Tuple[int, int] = (640, 480),
6285        fps: int = 30,
6286    ) -> None:
6287        """Extract frames from a video file from frames in between intervals in behavior folder for a given individual"""
6288        cap = cv2.VideoCapture(video_file)
6289        print("Extracting frames from", video_file)
6290
6291        for start, end, confusion in tqdm(intervals):
6292
6293            frame_count = start
6294            assert start < end, "Start frame should be less than end frame"
6295            if confusion > 0.5:
6296                continue
6297            cap.set(cv2.CAP_PROP_POS_FRAMES, start)
6298            output_file = os.path.join(
6299                video_dir,
6300                individual,
6301                behavior,
6302                os.path.splitext(os.path.basename(video_file))[0]
6303                + f"vid_{individual}_{behavior}_{start:05d}_{end:05d}.mp4",
6304            )
6305            fourcc = cv2.VideoWriter_fourcc(*"mp4v")  # Codec, e.g., 'XVID', 'MJPG'
6306            out = cv2.VideoWriter(
6307                output_file, fourcc, fps, (resolution[0], resolution[1])
6308            )
6309            while cap.isOpened():
6310                ret, frame = cap.read()
6311                if not ret:
6312                    break
6313
6314                # Resize large frames
6315                frame = cv2.resize(frame, (640, 480))
6316                out.write(frame)
6317
6318                frame_count += 1
6319                # Break if end frame is reached or max frames per behavior is reached
6320                if frame_count == end:
6321                    break
6322            if frame_count <= 2:
6323                os.remove(output_file)
6324            # cap.release()
6325            out.release()
6326
6327    def create_metadata_backup(self) -> None:
6328        """Create a copy of the meta files."""
6329        meta_copy_path = os.path.join(self.project_path, "meta", "backup")
6330        meta_path = os.path.join(self.project_path, "meta")
6331        if os.path.exists(meta_copy_path):
6332            shutil.rmtree(meta_copy_path)
6333        os.mkdir(meta_copy_path)
6334        for file in os.listdir(meta_path):
6335            if file == "backup":
6336                continue
6337            if os.path.isdir(os.path.join(meta_path, file)):
6338                continue
6339            shutil.copy(
6340                os.path.join(meta_path, file), os.path.join(meta_copy_path, file)
6341            )
6342
6343    def load_metadata_backup(self) -> None:
6344        """Load from previously created meta data backup (in case of corruption)."""
6345        meta_copy_path = os.path.join(self.project_path, "meta", "backup")
6346        meta_path = os.path.join(self.project_path, "meta")
6347        for file in os.listdir(meta_copy_path):
6348            shutil.copy(
6349                os.path.join(meta_copy_path, file), os.path.join(meta_path, file)
6350            )
6351
6352    def get_behavior_dictionary(self, episode_name: str) -> Dict:
6353        """Get the behavior dictionary for an episode.
6354
6355        Parameters
6356        ----------
6357        episode_name : str
6358            the name of the episode
6359
6360        Returns
6361        -------
6362        behaviors_dictionary : dict
6363            a dictionary where keys are label indices and values are label names
6364
6365        """
6366        return self._episode(episode_name).get_behaviors_dict()
6367
6368    def import_episodes(
6369        self,
6370        episodes_directory: str,
6371        name_map: Dict = None,
6372        repeat_policy: str = "error",
6373    ) -> None:
6374        """Import episodes exported with `Project.export_episodes`.
6375
6376        Parameters
6377        ----------
6378        episodes_directory : str
6379            the path to the exported episodes directory
6380        name_map : dict, optional
6381            a name change dictionary for the episodes: keys are old names, values are new names
6382        repeat_policy : {'error', 'skip', 'force'}, default 'error'
6383            the policy for repeated episode names: 'error' raises an error, 'skip' skips duplicates,
6384            'force' overwrites existing episodes
6385
6386        """
6387        if name_map is None:
6388            name_map = {}
6389        episodes = pd.read_pickle(os.path.join(episodes_directory, "episodes.pickle"))
6390        to_remove = []
6391        import_string = "Imported episodes: "
6392        for episode_name in episodes.index:
6393            if episode_name in name_map:
6394                import_string += f"{episode_name} "
6395                episode_name = name_map[episode_name]
6396                import_string += f"({episode_name}), "
6397            else:
6398                import_string += f"{episode_name}, "
6399            try:
6400                self._check_episode_validity(episode_name, allow_doublecolon=True)
6401            except ValueError as e:
6402                if str(e).endswith("is already taken!"):
6403                    if repeat_policy == "skip":
6404                        to_remove.append(episode_name)
6405                    elif repeat_policy == "force":
6406                        self.remove_episode(episode_name)
6407                    elif repeat_policy == "error":
6408                        raise ValueError(
6409                            f"The {episode_name} episode name is already taken; please use the name_map parameter to rename it"
6410                        )
6411                    else:
6412                        raise ValueError(
6413                            f"The {repeat_policy} repeat policy is not recognized; please choose from ['skip', 'force' and 'error']"
6414                        )
6415        episodes = episodes.drop(index=to_remove)
6416        self._episodes().update(
6417            episodes,
6418            name_map=name_map,
6419            force=(repeat_policy == "force"),
6420            data_path=self.data_path,
6421            annotation_path=self.annotation_path,
6422        )
6423        for episode_name in episodes.index:
6424            if episode_name in name_map:
6425                new_episode_name = name_map[episode_name]
6426            else:
6427                new_episode_name = episode_name
6428            model_dir = os.path.join(
6429                self.project_path, "results", "model", new_episode_name
6430            )
6431            old_model_dir = os.path.join(episodes_directory, "model", episode_name)
6432            if os.path.exists(model_dir):
6433                shutil.rmtree(model_dir)
6434            os.mkdir(model_dir)
6435            for file in os.listdir(old_model_dir):
6436                shutil.copyfile(
6437                    os.path.join(old_model_dir, file), os.path.join(model_dir, file)
6438                )
6439            log_file = os.path.join(
6440                self.project_path, "results", "logs", f"{new_episode_name}.txt"
6441            )
6442            old_log_file = os.path.join(
6443                episodes_directory, "logs", f"{episode_name}.txt"
6444            )
6445            shutil.copyfile(old_log_file, log_file)
6446        print(import_string)
6447        print("\n")
6448
6449    def export_episodes(
6450        self, episode_names: List, output_directory: str, name: str = None
6451    ) -> None:
6452        """Save selected episodes as a file that can be imported into another project with `Project.import_episodes`.
6453
6454        Parameters
6455        ----------
6456        episode_names : list
6457            a list of string episode names
6458        output_directory : str
6459            the path to the directory where the episodes will be saved
6460        name : str, optional
6461            the name of the episodes directory (by default `exported_episodes`)
6462
6463        """
6464        if name is None:
6465            name = "exported_episodes"
6466        if os.path.exists(
6467            os.path.join(output_directory, name + ".zip")
6468        ) or os.path.exists(os.path.join(output_directory, name)):
6469            i = 1
6470            while os.path.exists(
6471                os.path.join(output_directory, name + f"_{i}.zip")
6472            ) or os.path.exists(os.path.join(output_directory, name + f"_{i}")):
6473                i += 1
6474            name = name + f"_{i}"
6475        dest_dir = os.path.join(output_directory, name)
6476        os.mkdir(dest_dir)
6477        os.mkdir(os.path.join(dest_dir, "model"))
6478        os.mkdir(os.path.join(dest_dir, "logs"))
6479        runs = []
6480        for episode in episode_names:
6481            runs += self._episodes().get_runs(episode)
6482        for run in runs:
6483            shutil.copytree(
6484                os.path.join(self.project_path, "results", "model", run),
6485                os.path.join(dest_dir, "model", run),
6486            )
6487            shutil.copyfile(
6488                os.path.join(self.project_path, "results", "logs", f"{run}.txt"),
6489                os.path.join(dest_dir, "logs", f"{run}.txt"),
6490            )
6491        data = self._episodes().get_subset(runs)
6492        data.to_pickle(os.path.join(dest_dir, "episodes.pickle"))
6493
6494    def get_results_table(
6495        self,
6496        episode_names: List,
6497        metrics: List = None,
6498        mode: str = "mean",  # Choose between ["mean", "statistics", "detail"]
6499        print_results: bool = True,
6500        classes: List = None,
6501    ):
6502        """Generate a `pandas` dataframe with a summary of episode results.
6503
6504        Parameters
6505        ----------
6506        episode_names : list
6507            a list of names of episodes to include
6508        metrics : list, optional
6509            a list of metric names to include
6510        mode : bool, optional
6511            the mode of the results table, choose between ["mean", "statistics", "detail"], by default "mean"
6512        print_results : bool, optional
6513            if True, the results will be printed to the console, by default True
6514        classes : list, optional
6515            a list of names of classes to include (by default all are included)
6516
6517        Returns
6518        -------
6519        results : pd.DataFrame
6520            a table with the results
6521
6522        """
6523        run_names = []
6524        for episode in episode_names:
6525            run_names += self._episodes().get_runs(episode)
6526        episodes = self.list_episodes(run_names, print_results=False)
6527        metric_columns = [x for x in episodes.columns if x[0] == "results"]
6528        results_df = pd.DataFrame()
6529        if metrics is not None:
6530            metric_columns = [
6531                x for x in metric_columns if x[1].split("_")[0] in metrics
6532            ]
6533        for episode in episode_names:
6534            results = []
6535            metric_set = set()
6536            for run in self._episodes().get_runs(episode):
6537                beh_dict = self.get_behavior_dictionary(run)
6538                res_dict = defaultdict(lambda: {})
6539                for column in metric_columns:
6540                    if np.isnan(episodes.loc[run, column]):
6541                        continue
6542                    split = column[1].split("_")
6543                    if split[-1].isnumeric():
6544                        beh_ind = int(split[-1])
6545                        metric_name = "_".join(split[:-1])
6546                        beh = beh_dict[beh_ind]
6547                    else:
6548                        beh = "average"
6549                        metric_name = column[1]
6550                    res_dict[beh][metric_name] = episodes.loc[run, column]
6551                    metric_set.add(metric_name)
6552                if "average" not in res_dict:
6553                    res_dict["average"] = {}
6554                for metric in metric_set:
6555                    if metric not in res_dict["average"]:
6556                        arr = [
6557                            res_dict[beh][metric]
6558                            for beh in res_dict
6559                            if metric in res_dict[beh]
6560                        ]
6561                        res_dict["average"][metric] = np.mean(arr)
6562                results.append(res_dict)
6563            episode_results = {}
6564            for metric in metric_set:
6565                for beh in results[0].keys():
6566                    if classes is not None and beh not in classes:
6567                        continue
6568                    arr = []
6569                    for res_dict in results:
6570                        if metric in res_dict[beh]:
6571                            arr.append(res_dict[beh][metric])
6572                    if len(arr) > 0:
6573                        if mode == "statistics":
6574                            episode_results[(beh, f"{episode} {metric} mean")] = (
6575                                np.mean(arr)
6576                            )
6577                            episode_results[(beh, f"{episode} {metric} std")] = np.std(
6578                                arr
6579                            )
6580                        elif mode == "mean":
6581                            episode_results[(beh, f"{episode} {metric}")] = np.mean(arr)
6582                        elif mode == "detail":
6583                            for i, val in enumerate(arr):
6584                                episode_results[(beh, f"{episode}::{i} {metric}")] = val
6585            for key, value in episode_results.items():
6586                results_df.loc[key[0], key[1]] = value
6587        if print_results:
6588            print(f"RESULTS:")
6589            print(results_df)
6590            print("\n")
6591        return results_df
6592
6593    def episode_exists(self, episode_name: str) -> bool:
6594        """Check if an episode already exists.
6595
6596        Parameters
6597        ----------
6598        episode_name : str
6599            the episode name
6600
6601        Returns
6602        -------
6603        exists : bool
6604            `True` if the episode exists
6605
6606        """
6607        return self._episodes().check_name_validity(episode_name)
6608
6609    def search_exists(self, search_name: str) -> bool:
6610        """Check if a search already exists.
6611
6612        Parameters
6613        ----------
6614        search_name : str
6615            the search name
6616
6617        Returns
6618        -------
6619        exists : bool
6620            `True` if the search exists
6621
6622        """
6623        return self._searches().check_name_validity(search_name)
6624
6625    def prediction_exists(self, prediction_name: str) -> bool:
6626        """Check if a prediction already exists.
6627
6628        Parameters
6629        ----------
6630        prediction_name : str
6631            the prediction name
6632
6633        Returns
6634        -------
6635        exists : bool
6636            `True` if the prediction exists
6637
6638        """
6639        return self._predictions().check_name_validity(prediction_name)
6640
6641    @staticmethod
6642    def project_name_available(projects_path: str, project_name: str):
6643        """Check if a project name is available.
6644
6645        Parameters
6646        ----------
6647        projects_path : str
6648            the path to the projects directory
6649        project_name : str
6650            the name of the project to check
6651
6652        Returns
6653        -------
6654        available : bool
6655            `True` if the project name is available
6656
6657        """
6658        if projects_path is None:
6659            projects_path = os.path.join(str(Path.home()), "DLC2Action")
6660        return not os.path.exists(os.path.join(projects_path, project_name))
6661
6662    def _update_episode_metrics(self, episode_name: str, metrics: Dict):
6663        """Update meta data with evaluation results.
6664
6665        Parameters
6666        ----------
6667        episode_name : str
6668            the name of the episode
6669        metrics : dict
6670            the metrics dictionary to update with
6671
6672        """
6673        self._episodes().update_episode_metrics(episode_name, metrics)
6674
6675    def rename_episode(self, episode_name: str, new_episode_name: str):
6676        """Rename an episode.
6677
6678        Parameters
6679        ----------
6680        episode_name : str
6681            the current episode name
6682        new_episode_name : str
6683            the new episode name
6684
6685        """
6686        shutil.move(
6687            os.path.join(self.project_path, "results", "model", episode_name),
6688            os.path.join(self.project_path, "results", "model", new_episode_name),
6689        )
6690        shutil.move(
6691            os.path.join(self.project_path, "results", "logs", f"{episode_name}.txt"),
6692            os.path.join(
6693                self.project_path, "results", "logs", f"{new_episode_name}.txt"
6694            ),
6695        )
6696        self._episodes().rename_episode(episode_name, new_episode_name)

A class to create and maintain the project files + keep track of experiments.

Project( name: str, data_type: str = None, annotation_type: str = 'none', projects_path: str = None, data_path: Union[str, List] = None, annotation_path: Union[str, List] = None, copy: bool = False)
 58    def __init__(
 59        self,
 60        name: str,
 61        data_type: str = None,
 62        annotation_type: str = "none",
 63        projects_path: str = None,
 64        data_path: Union[str, List] = None,
 65        annotation_path: Union[str, List] = None,
 66        copy: bool = False,
 67    ) -> None:
 68        """Initialize the class.
 69
 70        Parameters
 71        ----------
 72        name : str
 73            name of the project
 74        data_type : str, optional
 75            data type (run Project.data_types() to see available options; has to be provided if the project is being
 76            created)
 77        annotation_type : str, default 'none'
 78            annotation type (run Project.annotation_types() to see available options)
 79        projects_path : str, optional
 80            path to the projects folder (is filled with ~/DLC2Action by default)
 81        data_path : str, optional
 82            path to the folder containing input files for the project (has to be provided if the project is being
 83            created)
 84        annotation_path : str, optional
 85            path to the folder containing annotation files for the project
 86        copy : bool, default False
 87            if True, the files from annotation_path and data_path will be copied to the projects folder;
 88            otherwise they will be moved
 89
 90        """
 91        if projects_path is None:
 92            projects_path = os.path.join(str(Path.home()), "DLC2Action")
 93        if not os.path.exists(projects_path):
 94            os.mkdir(projects_path)
 95        self.project_path = os.path.join(projects_path, name)
 96        self.name = name
 97        self.data_type = data_type
 98        self.annotation_type = annotation_type
 99        self.data_path = data_path
100        self.annotation_path = annotation_path
101        if not os.path.exists(self.project_path):
102            if data_type is None:
103                raise ValueError(
104                    "The data_type parameter is necessary when creating a new project!"
105                )
106            self._initialize_project(
107                data_type, annotation_type, data_path, annotation_path, copy
108            )
109        else:
110            self.annotation_type, self.data_type = self._read_types()
111            if data_type != self.data_type and data_type is not None:
112                raise ValueError(
113                    f"The project has already been initialized with data_type={self.data_type}!"
114                )
115            if annotation_type != self.annotation_type and annotation_type != "none":
116                raise ValueError(
117                    f"The project has already been initialized with annotation_type={self.annotation_type}!"
118                )
119            self.annotation_path, data_path = self._read_paths()
120            if self.data_path is None:
121                self.data_path = data_path
122            # if data_path != self.data_path and data_path is not None:
123            #     raise ValueError(
124            #         f"The project has already been initialized with data_path={self.data_path}!"
125            #     )
126            if annotation_path != self.annotation_path and annotation_path is not None:
127                raise ValueError(
128                    f"The project has already been initialized with annotation_path={self.annotation_path}!"
129                )
130        self._update_configs()

Initialize the class.

Parameters

name : str name of the project data_type : str, optional data type (run Project.data_types() to see available options; has to be provided if the project is being created) annotation_type : str, default 'none' annotation type (run Project.annotation_types() to see available options) projects_path : str, optional path to the projects folder (is filled with ~/DLC2Action by default) data_path : str, optional path to the folder containing input files for the project (has to be provided if the project is being created) annotation_path : str, optional path to the folder containing annotation files for the project copy : bool, default False if True, the files from annotation_path and data_path will be copied to the projects folder; otherwise they will be moved

project_path
name
data_type
annotation_type
data_path
annotation_path
def get_decision_thresholds( self, episode_names: List, metric_name: str = 'f1', parameters_update: Dict = None, load_epochs: List = None, remove_saved_features: bool = False) -> Tuple[List, List, dlc2action.task.task_dispatcher.TaskDispatcher]:
563    def get_decision_thresholds(
564        self,
565        episode_names: List,
566        metric_name: str = "f1",
567        parameters_update: Dict = None,
568        load_epochs: List = None,
569        remove_saved_features: bool = False,
570    ) -> Tuple[List, List, TaskDispatcher]:
571        """Compute optimal decision thresholds or load them if they have been computed before.
572
573        Parameters
574        ----------
575        episode_names : List
576            a list of episode names
577        metric_name : {"f1", "segmental_f1", "semisegmental_f1", "f_beta", "segmental_f_beta"}
578            the metric to optimize
579        parameters_update : dict, optional
580            the parameter update dictionary
581        load_epochs : list, optional
582            a list of epochs to load (by default last are loaded)
583        remove_saved_features : bool, default False
584            if `True`, the dataset will be deleted after the computation
585
586        Returns
587        -------
588        thresholds : list
589            a list of float decision threshold values
590        classes : list
591            the label names corresponding to the values
592        task : TaskDispatcher | None
593            the task used in computation
594
595        """
596        parameters = self._make_parameters(
597            "_",
598            episode_names[0],
599            parameters_update,
600            {},
601            load_epochs[0],
602            purpose="prediction",
603        )
604        thresholds = self._thresholds().find_thresholds(
605            episode_names,
606            load_epochs,
607            metric_name,
608            metric_parameters=parameters["metrics"][metric_name],
609        )
610        task = None
611        behaviors = list(self._episode(episode_names[0]).get_behaviors_dict().values())
612        return thresholds, behaviors, task

Compute optimal decision thresholds or load them if they have been computed before.

Parameters

episode_names : List a list of episode names metric_name : {"f1", "segmental_f1", "semisegmental_f1", "f_beta", "segmental_f_beta"} the metric to optimize parameters_update : dict, optional the parameter update dictionary load_epochs : list, optional a list of epochs to load (by default last are loaded) remove_saved_features : bool, default False if True, the dataset will be deleted after the computation

Returns

thresholds : list a list of float decision threshold values classes : list the label names corresponding to the values task : TaskDispatcher | None the task used in computation

def run_episode( self, episode_name: str, load_episode: str = None, parameters_update: Dict = None, task: dlc2action.task.task_dispatcher.TaskDispatcher = None, load_epoch: int = None, load_search: str = None, load_parameters: list = None, round_to_binary: list = None, load_strict: bool = True, n_seeds: int = 1, force: bool = False, suppress_name_check: bool = False, remove_saved_features: bool = False, mask_name: str = None, autostop_metric: str = None, autostop_interval: int = 50, autostop_threshold: float = 0.001, loading_bar: bool = False, trial: Tuple = None) -> dlc2action.task.task_dispatcher.TaskDispatcher:
614    def run_episode(
615        self,
616        episode_name: str,
617        load_episode: str = None,
618        parameters_update: Dict = None,
619        task: TaskDispatcher = None,
620        load_epoch: int = None,
621        load_search: str = None,
622        load_parameters: list = None,
623        round_to_binary: list = None,
624        load_strict: bool = True,
625        n_seeds: int = 1,
626        force: bool = False,
627        suppress_name_check: bool = False,
628        remove_saved_features: bool = False,
629        mask_name: str = None,
630        autostop_metric: str = None,
631        autostop_interval: int = 50,
632        autostop_threshold: float = 0.001,
633        loading_bar: bool = False,
634        trial: Tuple = None,
635    ) -> TaskDispatcher:
636        """Run an episode.
637
638        The task parameters are read from the config files and then updated with the
639        parameters_update dictionary. The model can be either initialized from scratch or loaded from one of the
640        previous experiments. All parameters and results are saved in the meta files and can be accessed with the
641        list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the
642        same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same
643        data parameters are used.
644
645        You can use the autostop parameters to finish training when the parameters are not improving. It will be
646        stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than
647        the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the
648        current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared.
649
650        Parameters
651        ----------
652        episode_name : str
653            the episode name
654        load_episode : str, optional
655            the (previously run) episode name to load the model from; if the episode has multiple runs,
656            the new episode will have the same number of runs, each starting with one of the pre-trained models
657        parameters_update : dict, optional
658            the dictionary used to update the parameters from the config files
659        task : TaskDispatcher, optional
660            a pre-existing `TaskDispatcher` object (if provided, the method will update it instead of creating
661            a new instance)
662        load_epoch : int, optional
663            the epoch to load (if load_episodes is not None); if not provided, the last epoch is used
664        load_search : str, optional
665            the hyperparameter search result to load
666        load_parameters : list, optional
667            a list of string names of the parameters to load from load_search (if not provided, all parameters
668            are loaded)
669        round_to_binary : list, optional
670            a list of string names of the loaded parameters that should be rounded to the nearest power of two
671        load_strict : bool, default True
672            if `False`, matching weights will be loaded from `load_episode` and differences in parameter name lists and
673            weight shapes will be ignored; otherwise mismatches will prompt a `RuntimeError`
674        n_seeds : int, default 1
675            the number of runs to perform; if `n_seeds > 1`, the episodes will be named `episode_name#run_index`, e.g.
676            `test_episode#0` and `test_episode#1`
677        force : bool, default False
678            if `True` and an episode with name `episode_name` already exists, it will be overwritten (use with caution!)
679        suppress_name_check : bool, default False
680            if `True`, episode names with a double colon are allowed (please don't use this option unless you understand
681            why they are usually forbidden)
682        remove_saved_features : bool, default False
683            if `True`, the dataset will be deleted after training
684        mask_name : str, optional
685            the name of the real_lens to apply
686        autostop_metric : str, optional
687            the autostop metric (can be any one of the tracked metrics of `'loss'`)
688        autostop_interval : int, default 50
689            the number of epochs to average the autostop metric over
690        autostop_threshold : float, default 0.001
691            the autostop difference threshold
692        loading_bar : bool, default False
693            if `True`, a loading bar will be displayed
694        trial : tuple, optional
695            a tuple of (trial, metric) for hyperparameter search
696
697        Returns
698        -------
699        TaskDispatcher
700            the `TaskDispatcher` object
701
702        """
703
704        import gc
705
706        gc.collect()
707        if torch.cuda.is_available():
708            torch.cuda.empty_cache()
709
710        if type(n_seeds) is not int or n_seeds < 1:
711            raise ValueError(
712                f"The n_seeds parameter has to be an integer larger than 0; got {n_seeds}"
713            )
714        if n_seeds > 1 and mask_name is not None:
715            raise ValueError("Cannot apply a real_lens with n_seeds > 1")
716        self._check_episode_validity(
717            episode_name, allow_doublecolon=suppress_name_check, force=force
718        )
719        load_runs = self._episodes().get_runs(load_episode)
720        if len(load_runs) > 1:
721            task = self.run_episodes(
722                episode_names=[
723                    f'{episode_name}#{run.split("#")[-1]}' for run in load_runs
724                ],
725                load_episodes=load_runs,
726                parameters_updates=[parameters_update for _ in load_runs],
727                load_epochs=[load_epoch for _ in load_runs],
728                load_searches=[load_search for _ in load_runs],
729                load_parameters=[load_parameters for _ in load_runs],
730                round_to_binary=[round_to_binary for _ in load_runs],
731                load_strict=[load_strict for _ in load_runs],
732                suppress_name_check=True,
733                force=force,
734                remove_saved_features=False,
735            )
736            if remove_saved_features:
737                self._remove_stores(
738                    {
739                        "general": task.general_parameters,
740                        "data": task.data_parameters,
741                        "features": task.feature_parameters,
742                    }
743                )
744            if n_seeds > 1:
745                warnings.warn(
746                    f"The n_seeds parameter is disregarded since load_episode={load_episode} has multiple runs"
747                )
748        elif n_seeds > 1:
749
750            self.run_episodes(
751                episode_names=[f"{episode_name}#{i}" for i in range(n_seeds)],
752                load_episodes=[load_episode for _ in range(n_seeds)],
753                parameters_updates=[parameters_update for _ in range(n_seeds)],
754                load_epochs=[load_epoch for _ in range(n_seeds)],
755                load_searches=[load_search for _ in range(n_seeds)],
756                load_parameters=[load_parameters for _ in range(n_seeds)],
757                round_to_binary=[round_to_binary for _ in range(n_seeds)],
758                load_strict=[load_strict for _ in range(n_seeds)],
759                suppress_name_check=True,
760                force=force,
761                remove_saved_features=remove_saved_features,
762            )
763        else:
764            print(f"TRAINING {episode_name}")
765            try:
766                task, parameters = self._make_task_training(
767                    episode_name,
768                    load_episode,
769                    parameters_update,
770                    load_epoch,
771                    load_search,
772                    load_parameters,
773                    round_to_binary,
774                    continuing=False,
775                    task=task,
776                    mask_name=mask_name,
777                    load_strict=load_strict,
778                )
779                self._save_episode(
780                    episode_name,
781                    parameters,
782                    task.behaviors_dict(),
783                    norm_stats=task.get_normalization_stats(),
784                )
785                time_start = time.time()
786                if trial is not None:
787                    trial, metric = trial
788                else:
789                    trial, metric = None, None
790                logs = task.train(
791                    autostop_metric=autostop_metric,
792                    autostop_interval=autostop_interval,
793                    autostop_threshold=autostop_threshold,
794                    loading_bar=loading_bar,
795                    trial=trial,
796                    optimized_metric=metric,
797                )
798                time_end = time.time()
799                time_total = time_end - time_start
800                hours = int(time_total // 3600)
801                time_total -= hours * 3600
802                minutes = int(time_total // 60)
803                time_total -= minutes * 60
804                seconds = int(time_total)
805                training_time = f"{hours}:{minutes:02}:{seconds:02}"
806                self._update_episode_results(episode_name, logs, training_time)
807                if remove_saved_features:
808                    self._remove_stores(parameters)
809                print("\n")
810                return task
811
812            except Exception as e:
813                if isinstance(e, optuna.exceptions.TrialPruned):
814                    raise e
815                else:
816                    # if str(e) != f"The {episode_name} episode name is already in use!":
817                    #     self.remove_episode(episode_name)
818                    raise RuntimeError(f"Episode {episode_name} could not run")

Run an episode.

The task parameters are read from the config files and then updated with the parameters_update dictionary. The model can be either initialized from scratch or loaded from one of the previous experiments. All parameters and results are saved in the meta files and can be accessed with the list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same data parameters are used.

You can use the autostop parameters to finish training when the parameters are not improving. It will be stopped if the average value of autostop_metric over the last autostop_interval epochs is smaller than the average over the previous autostop_interval epochs + autostop_threshold. For example, if the current epoch is 120 and autostop_interval is 50, the averages over epochs 70-120 and 20-70 will be compared.

Parameters

episode_name : str the episode name load_episode : str, optional the (previously run) episode name to load the model from; if the episode has multiple runs, the new episode will have the same number of runs, each starting with one of the pre-trained models parameters_update : dict, optional the dictionary used to update the parameters from the config files task : TaskDispatcher, optional a pre-existing TaskDispatcher object (if provided, the method will update it instead of creating a new instance) load_epoch : int, optional the epoch to load (if load_episodes is not None); if not provided, the last epoch is used load_search : str, optional the hyperparameter search result to load load_parameters : list, optional a list of string names of the parameters to load from load_search (if not provided, all parameters are loaded) round_to_binary : list, optional a list of string names of the loaded parameters that should be rounded to the nearest power of two load_strict : bool, default True if False, matching weights will be loaded from load_episode and differences in parameter name lists and weight shapes will be ignored; otherwise mismatches will prompt a RuntimeError n_seeds : int, default 1 the number of runs to perform; if n_seeds > 1, the episodes will be named episode_name#run_index, e.g. test_episode#0 and test_episode#1 force : bool, default False if True and an episode with name episode_name already exists, it will be overwritten (use with caution!) suppress_name_check : bool, default False if True, episode names with a double colon are allowed (please don't use this option unless you understand why they are usually forbidden) remove_saved_features : bool, default False if True, the dataset will be deleted after training mask_name : str, optional the name of the real_lens to apply autostop_metric : str, optional the autostop metric (can be any one of the tracked metrics of 'loss') autostop_interval : int, default 50 the number of epochs to average the autostop metric over autostop_threshold : float, default 0.001 the autostop difference threshold loading_bar : bool, default False if True, a loading bar will be displayed trial : tuple, optional a tuple of (trial, metric) for hyperparameter search

Returns

TaskDispatcher the TaskDispatcher object

def run_episodes( self, episode_names: List, load_episodes: List = None, parameters_updates: List = None, load_epochs: List = None, load_searches: List = None, load_parameters: List = None, round_to_binary: List = None, load_strict: List = None, force: bool = False, suppress_name_check: bool = False, remove_saved_features: bool = False) -> dlc2action.task.task_dispatcher.TaskDispatcher:
820    def run_episodes(
821        self,
822        episode_names: List,
823        load_episodes: List = None,
824        parameters_updates: List = None,
825        load_epochs: List = None,
826        load_searches: List = None,
827        load_parameters: List = None,
828        round_to_binary: List = None,
829        load_strict: List = None,
830        force: bool = False,
831        suppress_name_check: bool = False,
832        remove_saved_features: bool = False,
833    ) -> TaskDispatcher:
834        """Run multiple episodes in sequence (and re-use previously loaded information).
835
836        For each episode, the task parameters are read from the config files and then updated with the
837        parameter_update dictionary. The model can be either initialized from scratch or loaded from one of the
838        previous experiments. All parameters and results are saved in the meta files and can be accessed with the
839        list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the
840        same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same
841        data parameters are used.
842
843        Parameters
844        ----------
845        episode_names : list
846            a list of strings of episode names
847        load_episodes : list, optional
848            a list of strings of (previously run) episode names to load the model from; if the episode has multiple runs,
849            the new episode will have the same number of runs, each starting with one of the pre-trained models
850        parameters_updates : list, optional
851            a list of dictionaries used to update the parameters from the config
852        load_epochs : list, optional
853            a list of integers used to specify the epoch to load (if load_episodes is not None)
854        load_searches : list, optional
855            a list of strings of hyperparameter search results to load
856        load_parameters : list, optional
857            a list of lists of string names of the parameters to load from the searches
858        round_to_binary : list, optional
859            a list of string names of the loaded parameters that should be rounded to the nearest power of two
860        load_strict : list, optional
861            a list of boolean values specifying weight loading policy: if `False`, matching weights will be loaded from
862            the corresponding episode and differences in parameter name lists and
863            weight shapes will be ignored; otherwise mismatches will prompt a `RuntimeError` (by default `True` for
864            every episode)
865        force : bool, default False
866            if `True` and an episode name is already taken, it will be overwritten (use with caution!)
867        suppress_name_check : bool, default False
868            if `True`, episode names with a double colon are allowed (please don't use this option unless you understand
869            why they are usually forbidden)
870        remove_saved_features : bool, default False
871            if `True`, the dataset will be deleted after training
872
873        Returns
874        -------
875        TaskDispatcher
876            the task dispatcher object
877
878        """
879        task = None
880        if load_searches is None:
881            load_searches = [None for _ in episode_names]
882        if load_episodes is None:
883            load_episodes = [None for _ in episode_names]
884        if parameters_updates is None:
885            parameters_updates = [None for _ in episode_names]
886        if load_parameters is None:
887            load_parameters = [None for _ in episode_names]
888        if load_epochs is None:
889            load_epochs = [None for _ in episode_names]
890        if load_strict is None:
891            load_strict = [True for _ in episode_names]
892        for (
893            parameters_update,
894            episode_name,
895            load_episode,
896            load_epoch,
897            load_search,
898            load_parameters_list,
899            load_strict_value,
900        ) in zip(
901            parameters_updates,
902            episode_names,
903            load_episodes,
904            load_epochs,
905            load_searches,
906            load_parameters,
907            load_strict,
908        ):
909            task = self.run_episode(
910                episode_name,
911                load_episode,
912                parameters_update,
913                task,
914                load_epoch,
915                load_search,
916                load_parameters_list,
917                round_to_binary,
918                load_strict_value,
919                suppress_name_check=suppress_name_check,
920                force=force,
921                remove_saved_features=remove_saved_features,
922            )
923        return task

Run multiple episodes in sequence (and re-use previously loaded information).

For each episode, the task parameters are read from the config files and then updated with the parameter_update dictionary. The model can be either initialized from scratch or loaded from one of the previous experiments. All parameters and results are saved in the meta files and can be accessed with the list_episodes() function. The train/test/validation split is saved and loaded from a file whenever the same split parameters are used. The pre-computed datasets are also saved and loaded whenever the same data parameters are used.

Parameters

episode_names : list a list of strings of episode names load_episodes : list, optional a list of strings of (previously run) episode names to load the model from; if the episode has multiple runs, the new episode will have the same number of runs, each starting with one of the pre-trained models parameters_updates : list, optional a list of dictionaries used to update the parameters from the config load_epochs : list, optional a list of integers used to specify the epoch to load (if load_episodes is not None) load_searches : list, optional a list of strings of hyperparameter search results to load load_parameters : list, optional a list of lists of string names of the parameters to load from the searches round_to_binary : list, optional a list of string names of the loaded parameters that should be rounded to the nearest power of two load_strict : list, optional a list of boolean values specifying weight loading policy: if False, matching weights will be loaded from the corresponding episode and differences in parameter name lists and weight shapes will be ignored; otherwise mismatches will prompt a RuntimeError (by default True for every episode) force : bool, default False if True and an episode name is already taken, it will be overwritten (use with caution!) suppress_name_check : bool, default False if True, episode names with a double colon are allowed (please don't use this option unless you understand why they are usually forbidden) remove_saved_features : bool, default False if True, the dataset will be deleted after training

Returns

TaskDispatcher the task dispatcher object

def continue_episode( self, episode_name: str, num_epochs: int = None, task: dlc2action.task.task_dispatcher.TaskDispatcher = None, n_seeds: int = 1, remove_saved_features: bool = False, device: str = 'cuda', num_cpus: int = None) -> dlc2action.task.task_dispatcher.TaskDispatcher:
 925    def continue_episode(
 926        self,
 927        episode_name: str,
 928        num_epochs: int = None,
 929        task: TaskDispatcher = None,
 930        n_seeds: int = 1,
 931        remove_saved_features: bool = False,
 932        device: str = "cuda",
 933        num_cpus: int = None,
 934    ) -> TaskDispatcher:
 935        """Load an older episode and continue running from the latest checkpoint.
 936
 937        All parameters as well as the model and optimizer state dictionaries are loaded from the episode.
 938
 939        Parameters
 940        ----------
 941        episode_name : str
 942            the name of the episode to continue
 943        num_epochs : int, optional
 944            the new number of epochs
 945        task : TaskDispatcher, optional
 946            a pre-existing task; if provided, the method will update the task instead of creating a new one
 947            (this might save time, mainly on dataset loading)
 948        n_seeds : int, default 1
 949            the number of runs to perform; if `n_seeds > 1`, the episodes will be named `episode_name#run_index`, e.g.
 950            `test_episode#0` and `test_episode#1`
 951        remove_saved_features : bool, default False
 952            if `True`, pre-computed features will be deleted after the run
 953        device : str, default "cuda"
 954            the torch device to use
 955        num_cpus : int, optional
 956            the number of CPUs to use for data loading; if `None`, the number of available CPUs will be used
 957
 958        Returns
 959        -------
 960        TaskDispatcher
 961            the task dispatcher
 962
 963        """
 964        runs = self._episodes().get_runs(episode_name)
 965        for run in runs:
 966            print(f"TRAINING {run}")
 967            if num_epochs is None and not self._episode(run).unfinished():
 968                continue
 969            parameters_update = {
 970                "training": {
 971                    "num_epochs": num_epochs,
 972                    "device": device,
 973                },
 974                "general": {"num_cpus": num_cpus},
 975            }
 976            task, parameters = self._make_task_training(
 977                run,
 978                load_episode=run,
 979                parameters_update=parameters_update,
 980                continuing=True,
 981                task=task,
 982            )
 983            time_start = time.time()
 984            logs = task.train()
 985            time_end = time.time()
 986            old_time = self._training_time(run)
 987            if not np.isnan(old_time):
 988                time_end += old_time
 989                time_total = time_end - time_start
 990                hours = int(time_total // 3600)
 991                time_total -= hours * 3600
 992                minutes = int(time_total // 60)
 993                time_total -= minutes * 60
 994                seconds = int(time_total)
 995                training_time = f"{hours}:{minutes:02}:{seconds:02}"
 996            else:
 997                training_time = np.nan
 998            self._save_episode(
 999                run,
1000                parameters,
1001                task.behaviors_dict(),
1002                suppress_validation=True,
1003                training_time=training_time,
1004                norm_stats=task.get_normalization_stats(),
1005            )
1006            self._update_episode_results(run, logs)
1007            print("\n")
1008        if len(runs) < n_seeds:
1009            for i in range(len(runs), n_seeds):
1010                self.run_episode(
1011                    f"{episode_name}#{i}",
1012                    parameters_update=self._episodes().load_parameters(runs[0]),
1013                    task=task,
1014                    suppress_name_check=True,
1015                )
1016        if remove_saved_features:
1017            self._remove_stores(parameters)
1018        return task

Load an older episode and continue running from the latest checkpoint.

All parameters as well as the model and optimizer state dictionaries are loaded from the episode.

Parameters

episode_name : str the name of the episode to continue num_epochs : int, optional the new number of epochs task : TaskDispatcher, optional a pre-existing task; if provided, the method will update the task instead of creating a new one (this might save time, mainly on dataset loading) n_seeds : int, default 1 the number of runs to perform; if n_seeds > 1, the episodes will be named episode_name#run_index, e.g. test_episode#0 and test_episode#1 remove_saved_features : bool, default False if True, pre-computed features will be deleted after the run device : str, default "cuda" the torch device to use num_cpus : int, optional the number of CPUs to use for data loading; if None, the number of available CPUs will be used

Returns

TaskDispatcher the task dispatcher

def run_prediction( self, prediction_name: str, episode_names: List, load_epochs: List = None, parameters_update: Dict = None, augment_n: int = 10, data_path: str = None, mode: str = 'all', file_paths: Set = None, remove_saved_features: bool = False, frame_number_map_file: str = None, force: bool = False, embedding: bool = False) -> None:
1260    def run_prediction(
1261        self,
1262        prediction_name: str,
1263        episode_names: List,
1264        load_epochs: List = None,
1265        parameters_update: Dict = None,
1266        augment_n: int = 10,
1267        data_path: str = None,
1268        mode: str = "all",
1269        file_paths: Set = None,
1270        remove_saved_features: bool = False,
1271        frame_number_map_file: str = None,
1272        force: bool = False,
1273        embedding: bool = False,
1274    ) -> None:
1275        """Load models from previously run episodes to generate a prediction.
1276
1277        The probabilities predicted by the models are averaged.
1278        Unless `submission` is `True`, the prediction results are saved as a pickled dictionary in the project_name/results/predictions folder
1279        under the {episode_name}_{load_epoch}.pickle name. The file is a nested dictionary where the first-level
1280        keys are the video ids, the second-level keys are the clip ids (like individual names) and the values
1281        are the prediction arrays.
1282
1283        Parameters
1284        ----------
1285        prediction_name : str
1286            the name of the prediction
1287        episode_names : list
1288            a list of string episode names to load the models from
1289        load_epochs : list or int, optional
1290            a list of integer epoch indices to load the model from; if None, the last ones are used, if int the same epoch is used for all episodes
1291        parameters_update : dict, optional
1292            a dictionary of parameter updates
1293        augment_n : int, default 10
1294            the number of augmentations to average over
1295        data_path : str, optional
1296            the data path to run the prediction for
1297        mode : {'all', 'test', 'val', 'train'}
1298            the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
1299        file_paths : set, optional
1300            a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction
1301            for
1302        remove_saved_features : bool, default False
1303            if `True`, pre-computed features will be deleted
1304        submission : bool, default False
1305            if `True`, a MABe-22 style submission file is generated
1306        frame_number_map_file : str, optional
1307            path to the frame number map file
1308        force : bool, default False
1309            if `True`, existing prediction with this name will be overwritten
1310        embedding : bool, default False
1311            if `True`, the prediction is made for the embedding task
1312
1313        """
1314        self._check_prediction_validity(prediction_name, force=force)
1315        print(f"PREDICTION {prediction_name}")
1316        task, parameters, mode, prediction, inference_time, behavior_dict = (
1317            self._make_prediction(
1318                prediction_name,
1319                episode_names,
1320                load_epochs,
1321                parameters_update,
1322                data_path,
1323                file_paths,
1324                mode,
1325                augment_n,
1326                evaluate=False,
1327                embedding=embedding,
1328            )
1329        )
1330        predicted = task.dataset(mode).generate_full_length_prediction(prediction)
1331
1332        if remove_saved_features:
1333            self._remove_stores(parameters)
1334
1335        self._save_prediction(
1336            prediction_name,
1337            predicted,
1338            parameters,
1339            task,
1340            mode,
1341            embedding,
1342            inference_time,
1343            behavior_dict,
1344        )
1345        print("\n")

Load models from previously run episodes to generate a prediction.

The probabilities predicted by the models are averaged. Unless submission is True, the prediction results are saved as a pickled dictionary in the project_name/results/predictions folder under the {episode_name}_{load_epoch}.pickle name. The file is a nested dictionary where the first-level keys are the video ids, the second-level keys are the clip ids (like individual names) and the values are the prediction arrays.

Parameters

prediction_name : str the name of the prediction episode_names : list a list of string episode names to load the models from load_epochs : list or int, optional a list of integer epoch indices to load the model from; if None, the last ones are used, if int the same epoch is used for all episodes parameters_update : dict, optional a dictionary of parameter updates augment_n : int, default 10 the number of augmentations to average over data_path : str, optional the data path to run the prediction for mode : {'all', 'test', 'val', 'train'} the subset of the data to make the prediction for (forced to 'all' if data_path is not None) file_paths : set, optional a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction for remove_saved_features : bool, default False if True, pre-computed features will be deleted submission : bool, default False if True, a MABe-22 style submission file is generated frame_number_map_file : str, optional path to the frame number map file force : bool, default False if True, existing prediction with this name will be overwritten embedding : bool, default False if True, the prediction is made for the embedding task

def evaluate_prediction( self, prediction_name: str, parameters_update: Dict = None, data_path: str = None, annotation_path: str = None, file_paths: Set = None, mode: str = None, remove_saved_features: bool = False, annotation_type: str = 'none', num_classes: int = None) -> Tuple[float, dict]:
1347    def evaluate_prediction(
1348        self,
1349        prediction_name: str,
1350        parameters_update: Dict = None,
1351        data_path: str = None,
1352        annotation_path: str = None,
1353        file_paths: Set = None,
1354        mode: str = None,
1355        remove_saved_features: bool = False,
1356        annotation_type: str = "none",
1357        num_classes: int = None,  # Set when using data_path
1358    ) -> Tuple[float, dict]:
1359        """Make predictions and evaluate them
1360        inputs:
1361            prediction_name (str): the name of the prediction
1362            parameters_update (dict): a dictionary of parameter updates
1363            data_path (str): the data path to run the prediction for
1364            annotation_path (str): the annotation path to run the prediction for
1365            file_paths (set): a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction for
1366            mode (str): the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
1367            remove_saved_features (bool): if `True`, pre-computed features will be deleted
1368            annotation_type (str): the type of annotation to use for evaluation
1369            num_classes (int): the number of classes in the dataset, must be set with data_path
1370        outputs:
1371            results (dict): a dictionary of average values of metric functions
1372        """
1373
1374        prediction_path = os.path.join(
1375            self.project_path, "results", "predictions", f"{prediction_name}"
1376        )
1377        prediction_dict = {}
1378        for prediction_file_path in [
1379            os.path.join(prediction_path, i) for i in os.listdir(prediction_path)
1380        ]:
1381            with open(os.path.join(prediction_file_path), "rb") as f:
1382                prediction = pickle.load(f)
1383            video_id = os.path.basename(prediction_file_path).split(
1384                "_" + prediction_name
1385            )[0]
1386            prediction_dict[video_id] = prediction
1387        if parameters_update is None:
1388            parameters_update = {}
1389        parameters_update = self._update(
1390            self._predictions().load_parameters(prediction_name), parameters_update
1391        )
1392        parameters_update.pop("model")
1393        if not data_path is None:
1394            assert (
1395                not num_classes is None
1396            ), "num_classes must be provided if data_path is provided"
1397            parameters_update["general"]["num_classes"] = num_classes + int(
1398                parameters_update["general"]["exclusive"]
1399            )
1400        task, parameters, mode = self._make_task_prediction(
1401            "_",
1402            load_episode=None,
1403            parameters_update=parameters_update,
1404            data_path=data_path,
1405            annotation_path=annotation_path,
1406            file_paths=file_paths,
1407            mode=mode,
1408            annotation_type=annotation_type,
1409        )
1410        results = task.evaluate_prediction(prediction_dict, data=mode)
1411        if remove_saved_features:
1412            self._remove_stores(parameters)
1413        results = Project._reformat_results(
1414            results[1],
1415            task.behaviors_dict(),
1416            exclusive=task.general_parameters["exclusive"],
1417        )
1418        return results

Make predictions and evaluate them inputs: prediction_name (str): the name of the prediction parameters_update (dict): a dictionary of parameter updates data_path (str): the data path to run the prediction for annotation_path (str): the annotation path to run the prediction for file_paths (set): a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction for mode (str): the subset of the data to make the prediction for (forced to 'all' if data_path is not None) remove_saved_features (bool): if True, pre-computed features will be deleted annotation_type (str): the type of annotation to use for evaluation num_classes (int): the number of classes in the dataset, must be set with data_path outputs: results (dict): a dictionary of average values of metric functions

def evaluate( self, episode_names: List, load_epochs: List = None, augment_n: int = 0, data_path: str = None, file_paths: Set = None, mode: str = None, parameters_update: Dict = None, multiple_episode_policy: str = 'average', remove_saved_features: bool = False, skip_updating_meta: bool = True, annotation_type: str = 'none') -> Dict:
1420    def evaluate(
1421        self,
1422        episode_names: List,
1423        load_epochs: List = None,
1424        augment_n: int = 0,
1425        data_path: str = None,
1426        file_paths: Set = None,
1427        mode: str = None,
1428        parameters_update: Dict = None,
1429        multiple_episode_policy: str = "average",
1430        remove_saved_features: bool = False,
1431        skip_updating_meta: bool = True,
1432        annotation_type: str = "none",
1433    ) -> Dict:
1434        """Load one or several models from previously run episodes to make an evaluation.
1435
1436        By default it will run on the test (or validation, if there is no test) subset of the project dataset.
1437
1438        Parameters
1439        ----------
1440        episode_names : list
1441            a list of string episode names to load the models from
1442        load_epochs : list, optional
1443            a list of integer epoch indices to load the model from; if None, the last ones are used
1444        augment_n : int, default 0
1445            the number of augmentations to average over
1446        data_path : str, optional
1447            the data path to run the prediction for
1448        file_paths : set, optional
1449            a set of files to run the prediction for
1450        mode : {'test', 'val', 'train', 'all'}
1451            the subset of the data to make the prediction for (forced to 'all' if data_path is not None;
1452            by default 'test' if test subset is not empty and 'val' otherwise)
1453        parameters_update : dict, optional
1454            a dictionary with parameter updates (cannot change model parameters)
1455        multiple_episode_policy : {'average', 'statistics'}
1456            the policy to use when multiple episodes are provided
1457        remove_saved_features : bool, default False
1458            if `True`, the dataset will be deleted
1459        skip_updating_meta : bool, default True
1460            if `True`, the meta file will not be updated with the computed metrics
1461
1462        Returns
1463        -------
1464        metric : dict
1465            a dictionary of average values of metric functions
1466
1467        """
1468        names = []
1469        for episode_name in episode_names:
1470            names += self._episodes().get_runs(episode_name)
1471        if len(set(episode_names)) == 1:
1472            print(f"EVALUATION {episode_names[0]}")
1473        else:
1474            print(f"EVALUATION {episode_names}")
1475        if len(names) > 1:
1476            evaluate = True
1477        else:
1478            evaluate = False
1479        if multiple_episode_policy == "average":
1480            task, parameters, mode, prediction, inference_time, behavior_dict = (
1481                self._make_prediction(
1482                    "_",
1483                    episode_names,
1484                    load_epochs,
1485                    parameters_update,
1486                    mode=mode,
1487                    data_path=data_path,
1488                    file_paths=file_paths,
1489                    augment_n=augment_n,
1490                    evaluate=evaluate,
1491                    annotation_type=annotation_type,
1492                )
1493            )
1494            print("EVALUATE PREDICTION:")
1495            indices = [
1496                list(behavior_dict.keys()).index(i) for i in range(len(behavior_dict))
1497            ]
1498            _, results = task.evaluate_prediction(
1499                prediction, data=mode, indices=indices
1500            )
1501            if len(names) == 1 and mode == "val" and not skip_updating_meta:
1502                self._update_episode_metrics(names[0], results)
1503            results = Project._reformat_results(
1504                results,
1505                behavior_dict,
1506                exclusive=task.general_parameters["exclusive"],
1507            )
1508
1509        elif multiple_episode_policy == "statistics":
1510            values = defaultdict(lambda: [])
1511            task = None
1512            for name in names:
1513                (
1514                    task,
1515                    parameters,
1516                    mode,
1517                    prediction,
1518                    inference_time,
1519                    behavior_dict,
1520                ) = self._make_prediction(
1521                    "_",
1522                    [name],
1523                    load_epochs,
1524                    parameters_update,
1525                    mode=mode,
1526                    data_path=data_path,
1527                    file_paths=file_paths,
1528                    augment_n=augment_n,
1529                    evaluate=evaluate,
1530                    task=task,
1531                )
1532                _, metrics = task.evaluate_prediction(
1533                    prediction, data=mode, indices=list(behavior_dict.keys())
1534                )
1535                for name, value in metrics.items():
1536                    values[name].append(value)
1537                if mode == "val" and not skip_updating_meta:
1538                    self._update_episode_metrics(name, metrics)
1539            results = defaultdict(lambda: {})
1540            mean_string = ""
1541            std_string = ""
1542            for key, value_list in values.items():
1543                results[key]["mean"] = np.mean(value_list)
1544                results[key]["std"] = np.std(value_list)
1545                results[key]["all"] = value_list
1546                mean_string += f"{key} {np.mean(value_list):.3f}, "
1547                std_string += f"{key} {np.std(value_list):.3f}, "
1548            print("MEAN:")
1549            print(mean_string)
1550            print("STD:")
1551            print(std_string)
1552        else:
1553            raise ValueError(
1554                f"The {multiple_episode_policy} multiple episode policy is not recognized; please choose "
1555                f"from ['average', 'statistics']"
1556            )
1557        if len(names) > 0 and remove_saved_features:
1558            self._remove_stores(parameters)
1559        print(f"Inference time: {inference_time}")
1560        print("\n")
1561        return results

Load one or several models from previously run episodes to make an evaluation.

By default it will run on the test (or validation, if there is no test) subset of the project dataset.

Parameters

episode_names : list a list of string episode names to load the models from load_epochs : list, optional a list of integer epoch indices to load the model from; if None, the last ones are used augment_n : int, default 0 the number of augmentations to average over data_path : str, optional the data path to run the prediction for file_paths : set, optional a set of files to run the prediction for mode : {'test', 'val', 'train', 'all'} the subset of the data to make the prediction for (forced to 'all' if data_path is not None; by default 'test' if test subset is not empty and 'val' otherwise) parameters_update : dict, optional a dictionary with parameter updates (cannot change model parameters) multiple_episode_policy : {'average', 'statistics'} the policy to use when multiple episodes are provided remove_saved_features : bool, default False if True, the dataset will be deleted skip_updating_meta : bool, default True if True, the meta file will not be updated with the computed metrics

Returns

metric : dict a dictionary of average values of metric functions

def run_suggestion( self, suggestions_name: str, error_episode: str = None, error_load_epoch: int = None, error_class: str = None, suggestions_prediction: str = None, suggestion_episodes: List = [None], suggestion_load_epoch: int = None, suggestion_classes: List = None, error_threshold: float = 0.5, error_threshold_diff: float = 0.1, error_hysteresis: bool = False, suggestion_threshold: Union[float, List] = 0.5, suggestion_threshold_diff: Union[float, List] = 0.1, suggestion_hysteresis: Union[bool, List] = True, min_frames_suggestion: int = 10, min_frames_al: int = 30, visibility_min_score: float = 0, visibility_min_frac: float = 0.7, augment_n: int = 0, exclude_classes: List = None, exclude_threshold: Union[float, List] = 0.6, exclude_threshold_diff: Union[float, List] = 0.1, exclude_hysteresis: Union[bool, List] = False, include_classes: List = None, include_threshold: Union[float, List] = 0.4, include_threshold_diff: Union[float, List] = 0.1, include_hysteresis: Union[bool, List] = False, data_path: str = None, file_paths: Set = None, parameters_update: Dict = None, mode: str = 'all', force: bool = False, remove_saved_features: bool = False, cut_annotated: bool = False, background_threshold: float = None) -> None:
1563    def run_suggestion(
1564        self,
1565        suggestions_name: str,
1566        error_episode: str = None,
1567        error_load_epoch: int = None,
1568        error_class: str = None,
1569        suggestions_prediction: str = None,
1570        suggestion_episodes: List = [None],
1571        suggestion_load_epoch: int = None,
1572        suggestion_classes: List = None,
1573        error_threshold: float = 0.5,
1574        error_threshold_diff: float = 0.1,
1575        error_hysteresis: bool = False,
1576        suggestion_threshold: Union[float, List] = 0.5,
1577        suggestion_threshold_diff: Union[float, List] = 0.1,
1578        suggestion_hysteresis: Union[bool, List] = True,
1579        min_frames_suggestion: int = 10,
1580        min_frames_al: int = 30,
1581        visibility_min_score: float = 0,
1582        visibility_min_frac: float = 0.7,
1583        augment_n: int = 0,
1584        exclude_classes: List = None,
1585        exclude_threshold: Union[float, List] = 0.6,
1586        exclude_threshold_diff: Union[float, List] = 0.1,
1587        exclude_hysteresis: Union[bool, List] = False,
1588        include_classes: List = None,
1589        include_threshold: Union[float, List] = 0.4,
1590        include_threshold_diff: Union[float, List] = 0.1,
1591        include_hysteresis: Union[bool, List] = False,
1592        data_path: str = None,
1593        file_paths: Set = None,
1594        parameters_update: Dict = None,
1595        mode: str = "all",
1596        force: bool = False,
1597        remove_saved_features: bool = False,
1598        cut_annotated: bool = False,
1599        background_threshold: float = None,
1600    ) -> None:
1601        """Create active learning and suggestion files.
1602
1603        Generate predictions with the error and suggestion model and use them to create
1604        suggestion files for the labeling interface. Those files will render as suggested labels
1605        at intervals with high pose estimation quality. Quality here is defined by probability of error
1606        (predicted by the error model) and visibility parameters.
1607
1608        If `error_episode` or `exclude_classes` is not `None`,
1609        an active learning file will be created as well (with frames with high predicted probability of classes
1610        from `exclude_classes` and/or errors excluded from the active learning intervals).
1611
1612        In all three steps (predicting errors, suggesting labels and excluding them from active learning intervals)
1613        you can apply one of three methods.
1614
1615        - **Simple threshold**
1616
1617            Set the `hysteresis` parameter (e.g. `error_hysteresis`) to `False` and the `threshold`
1618            parameter to $\alpha$.
1619            In this case if the probability of a label is predicted to be higher than $\alpha$ the frame will
1620            be considered labeled.
1621
1622        - **Hysteresis threshold**
1623
1624            Set the `hysteresis` parameter (e.g. `error_hysteresis`) to `True`, the `threshold`
1625            parameter to $\alpha$ and the `threshold_diff` parameter to $\beta$.
1626            Now intervals will be marked with a label if the probability of that label for all frames is higher
1627            than $\alpha - \beta$ and at least for one frame in that interval it is higher than $\alpha$.
1628
1629        - **Max hysteresis threshold**
1630
1631            Set the `hysteresis` parameter (e.g. `error_hysteresis`) to `True`, the `threshold`
1632            parameter to $\alpha$ and the `threshold_diff` parameter to `None`.
1633            With this combination intervals are marked with a label if that label is more likely than any other
1634            for all frames in this interval and at for at least one of those frames its probability is higher than
1635            $\alpha$.
1636
1637        Parameters
1638        ----------
1639        suggestions_name : str
1640            the name of the suggestions
1641        error_episode : str, optional
1642            the name of the episode where the error model should be loaded from
1643        error_load_epoch : int, optional
1644            the epoch the error model should be loaded from
1645        error_class : str, optional
1646            the name of the error class (in `error_episode`)
1647        suggestions_prediction : str, optional
1648            the name of the predictions that should be used for the suggestion model
1649        suggestion_episodes : list, optional
1650            the names of the episodes where the suggestion models should be loaded from
1651        suggestion_load_epoch : int, optional
1652            the epoch the suggestion model should be loaded from
1653        suggestion_classes : list, optional
1654            a list of string names of the classes that should be suggested (in `suggestion_episode`)
1655        error_threshold : float, default 0.5
1656            the hard threshold for error prediction
1657        error_threshold_diff : float, default 0.1
1658            the difference between soft and hard thresholds for error prediction (in case hysteresis is used)
1659        error_hysteresis : bool, default False
1660            if True, hysteresis is used for error prediction
1661        suggestion_threshold : float | list, default 0.5
1662            the hard threshold for class prediction (use a list to set different rules for different classes)
1663        suggestion_threshold_diff : float | list, default 0.1
1664            the difference between soft and hard thresholds for class prediction (in case hysteresis is used;
1665            use a list to set different rules for different classes)
1666        suggestion_hysteresis : bool | list, default True
1667            if True, hysteresis is used for class prediction (use a list to set different rules for different classes)
1668        min_frames_suggestion : int, default 10
1669            only actions longer than this number of frames will be suggested
1670        min_frames_al : int, default 30
1671            only active learning intervals longer than this number of frames will be suggested
1672        visibility_min_score : float, default 0
1673            the minimum visibility score for visibility filtering
1674        visibility_min_frac : float, default 0.7
1675            the minimum fraction of visible frames for visibility filtering
1676        augment_n : int, default 10
1677            the number of augmentations to average the predictions over
1678        exclude_classes : list, optional
1679            a list of string names of classes that should be excluded from the active learning intervals
1680        exclude_threshold : float | list, default 0.6
1681            the hard threshold for excluded class prediction (use a list to set different rules for different classes)
1682        exclude_threshold_diff : float | list, default 0.1
1683            the difference between soft and hard thresholds for excluded class prediction (in case hysteresis is used)
1684        exclude_hysteresis : bool | list, default False
1685            if True, hysteresis is used for excluded class prediction (use a list to set different rules for different classes)
1686        include_classes : list, optional
1687            a list of string names of classes that should be included into the active learning intervals
1688        include_threshold : float | list, default 0.6
1689            the hard threshold for included class prediction (use a list to set different rules for different classes)
1690        include_threshold_diff : float | list, default 0.1
1691            the difference between soft and hard thresholds for included class prediction (in case hysteresis is used)
1692        include_hysteresis : bool | list, default False
1693            if True, hysteresis is used for included class prediction (use a list to set different rules for different classes)
1694        data_path : str, optional
1695            the data path to run the prediction for
1696        file_paths : set, optional
1697            a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction
1698            for
1699        parameters_update : dict, optional
1700            the parameters update dictionary
1701        mode : {'all', 'test', 'val', 'train'}
1702            the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
1703        force : bool, default False
1704            if `True` and an episode with name `episode_name` already exists, it will be overwritten (use with caution!)
1705        remove_saved_features : bool, default False
1706            if `True`, the dataset will be deleted.
1707        cut_annotated : bool, default False
1708            if `True`, annotated frames will be cut from the suggestions
1709        background_threshold : float, default 0.5
1710            the threshold for background prediction
1711
1712        """
1713        self._check_suggestions_validity(suggestions_name, force=force)
1714        if any([x is None for x in suggestion_episodes]):
1715            suggestion_episodes = None
1716        if error_episode is None and (
1717            suggestion_episodes is None and suggestions_prediction is None
1718        ):
1719            raise ValueError(
1720                "Both error_episode and suggestion_episode parameters cannot be None at the same time"
1721            )
1722        print(f"SUGGESTION {suggestions_name}")
1723        task = None
1724        if suggestion_classes is None:
1725            suggestion_classes = []
1726        if exclude_classes is None:
1727            exclude_classes = []
1728        if include_classes is None:
1729            include_classes = []
1730        if isinstance(suggestion_threshold, list):
1731            if len(suggestion_threshold) != len(suggestion_classes):
1732                raise ValueError(
1733                    "The suggestion_threshold parameter has to be either a float value or a list of "
1734                    f"float values of the same length as suggestion_classes (got a list of length "
1735                    f"{len(suggestion_threshold)} for {len(suggestion_classes)} classes)"
1736                )
1737        else:
1738            suggestion_threshold = [suggestion_threshold for _ in suggestion_classes]
1739        if isinstance(suggestion_threshold_diff, list):
1740            if len(suggestion_threshold_diff) != len(suggestion_classes):
1741                raise ValueError(
1742                    "The suggestion_threshold_diff parameter has to be either a float value or a list of "
1743                    f"float values of the same length as suggestion_classes (got a list of length "
1744                    f"{len(suggestion_threshold)} for {len(suggestion_classes)} classes)"
1745                )
1746        else:
1747            suggestion_threshold_diff = [
1748                suggestion_threshold_diff for _ in suggestion_classes
1749            ]
1750        if isinstance(suggestion_hysteresis, list):
1751            if len(suggestion_hysteresis) != len(suggestion_classes):
1752                raise ValueError(
1753                    "The suggestion_threshold_diff parameter has to be either a float value or a list of "
1754                    f"float values of the same length as suggestion_classes (got a list of length "
1755                    f"{len(suggestion_hysteresis)} for {len(suggestion_classes)} classes)"
1756                )
1757        else:
1758            suggestion_hysteresis = [suggestion_hysteresis for _ in suggestion_classes]
1759        if isinstance(exclude_threshold, list):
1760            if len(exclude_threshold) != len(exclude_classes):
1761                raise ValueError(
1762                    "The exclude_threshold parameter has to be either a float value or a list of "
1763                    f"float values of the same length as exclude_classes (got a list of length "
1764                    f"{len(exclude_threshold)} for {len(exclude_classes)} classes)"
1765                )
1766        else:
1767            exclude_threshold = [exclude_threshold for _ in exclude_classes]
1768        if isinstance(exclude_threshold_diff, list):
1769            if len(exclude_threshold_diff) != len(exclude_classes):
1770                raise ValueError(
1771                    "The exclude_threshold_diff parameter has to be either a float value or a list of "
1772                    f"float values of the same length as exclude_classes (got a list of length "
1773                    f"{len(exclude_threshold_diff)} for {len(exclude_classes)} classes)"
1774                )
1775        else:
1776            exclude_threshold_diff = [exclude_threshold_diff for _ in exclude_classes]
1777        if isinstance(exclude_hysteresis, list):
1778            if len(exclude_hysteresis) != len(exclude_classes):
1779                raise ValueError(
1780                    "The suggestion_threshold_diff parameter has to be either a float value or a list of "
1781                    f"float values of the same length as suggestion_classes (got a list of length "
1782                    f"{len(exclude_hysteresis)} for {len(exclude_classes)} classes)"
1783                )
1784        else:
1785            exclude_hysteresis = [exclude_hysteresis for _ in exclude_classes]
1786        if isinstance(include_threshold, list):
1787            if len(include_threshold) != len(include_classes):
1788                raise ValueError(
1789                    "The exclude_threshold parameter has to be either a float value or a list of "
1790                    f"float values of the same length as exclude_classes (got a list of length "
1791                    f"{len(include_threshold)} for {len(include_classes)} classes)"
1792                )
1793        else:
1794            include_threshold = [include_threshold for _ in include_classes]
1795        if isinstance(include_threshold_diff, list):
1796            if len(include_threshold_diff) != len(include_classes):
1797                raise ValueError(
1798                    "The exclude_threshold_diff parameter has to be either a float value or a list of "
1799                    f"float values of the same length as exclude_classes (got a list of length "
1800                    f"{len(include_threshold_diff)} for {len(include_classes)} classes)"
1801                )
1802        else:
1803            include_threshold_diff = [include_threshold_diff for _ in include_classes]
1804        if isinstance(include_hysteresis, list):
1805            if len(include_hysteresis) != len(include_classes):
1806                raise ValueError(
1807                    "The suggestion_threshold_diff parameter has to be either a float value or a list of "
1808                    f"float values of the same length as suggestion_classes (got a list of length "
1809                    f"{len(include_hysteresis)} for {len(include_classes)} classes)"
1810                )
1811        else:
1812            include_hysteresis = [include_hysteresis for _ in include_classes]
1813        if (suggestion_episodes is None and suggestions_prediction is None) and len(
1814            exclude_classes
1815        ) > 0:
1816            raise ValueError(
1817                "In order to exclude classes from the active learning intervals you need to set the "
1818                "suggestion_episode parameter"
1819            )
1820
1821        task = None
1822        if error_episode is not None:
1823            task, parameters, mode = self._make_task_prediction(
1824                prediction_name=suggestions_name,
1825                load_episode=error_episode,
1826                parameters_update=parameters_update,
1827                load_epoch=error_load_epoch,
1828                data_path=data_path,
1829                mode=mode,
1830                file_paths=file_paths,
1831                task=task,
1832            )
1833            predicted_error = task.predict(
1834                data=mode,
1835                raw_output=True,
1836                apply_primary_function=True,
1837                augment_n=augment_n,
1838            )
1839        else:
1840            predicted_error = None
1841
1842        if suggestion_episodes is not None:
1843            (
1844                task,
1845                parameters,
1846                mode,
1847                predicted_classes,
1848                inference_time,
1849                behavior_dict,
1850            ) = self._make_prediction(
1851                prediction_name=suggestions_name,
1852                episode_names=suggestion_episodes,
1853                load_epochs=suggestion_load_epoch,
1854                parameters_update=parameters_update,
1855                data_path=data_path,
1856                file_paths=file_paths,
1857                mode=mode,
1858                task=task,
1859            )
1860        elif suggestions_prediction is not None:
1861            with open(
1862                os.path.join(
1863                    self.project_path,
1864                    "results",
1865                    "predictions",
1866                    f"{suggestions_prediction}.pickle",
1867                ),
1868                "rb",
1869            ) as f:
1870                predicted_classes = pickle.load(f)
1871            if parameters_update is None:
1872                parameters_update = {}
1873            parameters_update = self._update(
1874                self._predictions().load_parameters(suggestions_prediction),
1875                parameters_update,
1876            )
1877            parameters_update.pop("model")
1878            if suggestion_episodes is None:
1879                suggestion_episodes = [
1880                    os.path.basename(
1881                        os.path.dirname(
1882                            parameters_update["training"]["checkpoint_path"]
1883                        )
1884                    )
1885                ]
1886            task, parameters, mode = self._make_task_prediction(
1887                "_",
1888                load_episode=None,
1889                parameters_update=parameters_update,
1890                data_path=data_path,
1891                file_paths=file_paths,
1892                mode=mode,
1893            )
1894        else:
1895            predicted_classes = None
1896
1897        if len(suggestion_classes) > 0 and predicted_classes is not None:
1898            suggestions = self._make_suggestions(
1899                task,
1900                predicted_error,
1901                predicted_classes,
1902                suggestion_threshold,
1903                suggestion_threshold_diff,
1904                suggestion_hysteresis,
1905                suggestion_episodes,
1906                suggestion_classes,
1907                error_threshold,
1908                min_frames_suggestion,
1909                min_frames_al,
1910                visibility_min_score,
1911                visibility_min_frac,
1912                cut_annotated=cut_annotated,
1913            )
1914            videos = list(suggestions.keys())
1915            for v_id in videos:
1916                times_dict = defaultdict(lambda: defaultdict(lambda: []))
1917                clips = set()
1918                for c in suggestions[v_id]:
1919                    for start, end, ind in suggestions[v_id][c]:
1920                        times_dict[ind][c].append([start, end, 2])
1921                        clips.add(ind)
1922                clips = list(clips)
1923                times_dict = dict(times_dict)
1924                times = [
1925                    [times_dict[ind][c] for c in suggestion_classes] for ind in clips
1926                ]
1927                save_path = self._suggestion_path(v_id, suggestions_name)
1928                Path(os.path.dirname(save_path)).mkdir(parents=True, exist_ok=True)
1929                with open(save_path, "wb") as f:
1930                    pickle.dump((None, suggestion_classes, clips, times), f)
1931
1932        if (
1933            error_episode is not None
1934            or len(exclude_classes) > 0
1935            or len(include_classes) > 0
1936        ):
1937            al_points = self._make_al_points(
1938                task,
1939                predicted_error,
1940                predicted_classes,
1941                exclude_classes,
1942                exclude_threshold,
1943                exclude_threshold_diff,
1944                exclude_hysteresis,
1945                include_classes,
1946                include_threshold,
1947                include_threshold_diff,
1948                include_hysteresis,
1949                error_episode,
1950                error_class,
1951                suggestion_episodes,
1952                error_threshold,
1953                error_threshold_diff,
1954                error_hysteresis,
1955                min_frames_al,
1956                visibility_min_score,
1957                visibility_min_frac,
1958            )
1959        else:
1960            al_points = self._make_al_points_from_suggestions(
1961                suggestions_name,
1962                task,
1963                predicted_classes,
1964                background_threshold,
1965                visibility_min_score,
1966                visibility_min_frac,
1967                num_behaviors=len(task.behaviors_dict()),
1968            )
1969        save_path = self._al_points_path(suggestions_name)
1970        Path(os.path.dirname(save_path)).mkdir(parents=True, exist_ok=True)
1971        with open(save_path, "wb") as f:
1972            pickle.dump(al_points, f)
1973
1974        meta_parameters = {
1975            "error_episode": error_episode,
1976            "error_load_epoch": error_load_epoch,
1977            "error_class": error_class,
1978            "suggestion_episode": suggestion_episodes,
1979            "suggestion_load_epoch": suggestion_load_epoch,
1980            "suggestion_classes": suggestion_classes,
1981            "error_threshold": error_threshold,
1982            "error_threshold_diff": error_threshold_diff,
1983            "error_hysteresis": error_hysteresis,
1984            "suggestion_threshold": suggestion_threshold,
1985            "suggestion_threshold_diff": suggestion_threshold_diff,
1986            "suggestion_hysteresis": suggestion_hysteresis,
1987            "min_frames_suggestion": min_frames_suggestion,
1988            "min_frames_al": min_frames_al,
1989            "visibility_min_score": visibility_min_score,
1990            "visibility_min_frac": visibility_min_frac,
1991            "augment_n": augment_n,
1992            "exclude_classes": exclude_classes,
1993            "exclude_threshold": exclude_threshold,
1994            "exclude_threshold_diff": exclude_threshold_diff,
1995            "exclude_hysteresis": exclude_hysteresis,
1996        }
1997        self._save_suggestions(suggestions_name, {}, meta_parameters)
1998        if data_path is not None or file_paths is not None or remove_saved_features:
1999            self._remove_stores(parameters)
2000        print(f"\n")

Create active learning and suggestion files.

Generate predictions with the error and suggestion model and use them to create suggestion files for the labeling interface. Those files will render as suggested labels at intervals with high pose estimation quality. Quality here is defined by probability of error (predicted by the error model) and visibility parameters.

If error_episode or exclude_classes is not None, an active learning file will be created as well (with frames with high predicted probability of classes from exclude_classes and/or errors excluded from the active learning intervals).

In all three steps (predicting errors, suggesting labels and excluding them from active learning intervals) you can apply one of three methods.

  • Simple threshold

    Set the hysteresis parameter (e.g. error_hysteresis) to False and the threshold parameter to $lpha$. In this case if the probability of a label is predicted to be higher than $lpha$ the frame will be considered labeled.

  • Hysteresis threshold

    Set the hysteresis parameter (e.g. error_hysteresis) to True, the threshold parameter to $lpha$ and the threshold_diff parameter to $eta$. Now intervals will be marked with a label if the probability of that label for all frames is higher than $lpha - eta$ and at least for one frame in that interval it is higher than $lpha$.

  • Max hysteresis threshold

    Set the hysteresis parameter (e.g. error_hysteresis) to True, the threshold parameter to $lpha$ and the threshold_diff parameter to None. With this combination intervals are marked with a label if that label is more likely than any other for all frames in this interval and at for at least one of those frames its probability is higher than $lpha$.

Parameters

suggestions_name : str the name of the suggestions error_episode : str, optional the name of the episode where the error model should be loaded from error_load_epoch : int, optional the epoch the error model should be loaded from error_class : str, optional the name of the error class (in error_episode) suggestions_prediction : str, optional the name of the predictions that should be used for the suggestion model suggestion_episodes : list, optional the names of the episodes where the suggestion models should be loaded from suggestion_load_epoch : int, optional the epoch the suggestion model should be loaded from suggestion_classes : list, optional a list of string names of the classes that should be suggested (in suggestion_episode) error_threshold : float, default 0.5 the hard threshold for error prediction error_threshold_diff : float, default 0.1 the difference between soft and hard thresholds for error prediction (in case hysteresis is used) error_hysteresis : bool, default False if True, hysteresis is used for error prediction suggestion_threshold : float | list, default 0.5 the hard threshold for class prediction (use a list to set different rules for different classes) suggestion_threshold_diff : float | list, default 0.1 the difference between soft and hard thresholds for class prediction (in case hysteresis is used; use a list to set different rules for different classes) suggestion_hysteresis : bool | list, default True if True, hysteresis is used for class prediction (use a list to set different rules for different classes) min_frames_suggestion : int, default 10 only actions longer than this number of frames will be suggested min_frames_al : int, default 30 only active learning intervals longer than this number of frames will be suggested visibility_min_score : float, default 0 the minimum visibility score for visibility filtering visibility_min_frac : float, default 0.7 the minimum fraction of visible frames for visibility filtering augment_n : int, default 10 the number of augmentations to average the predictions over exclude_classes : list, optional a list of string names of classes that should be excluded from the active learning intervals exclude_threshold : float | list, default 0.6 the hard threshold for excluded class prediction (use a list to set different rules for different classes) exclude_threshold_diff : float | list, default 0.1 the difference between soft and hard thresholds for excluded class prediction (in case hysteresis is used) exclude_hysteresis : bool | list, default False if True, hysteresis is used for excluded class prediction (use a list to set different rules for different classes) include_classes : list, optional a list of string names of classes that should be included into the active learning intervals include_threshold : float | list, default 0.6 the hard threshold for included class prediction (use a list to set different rules for different classes) include_threshold_diff : float | list, default 0.1 the difference between soft and hard thresholds for included class prediction (in case hysteresis is used) include_hysteresis : bool | list, default False if True, hysteresis is used for included class prediction (use a list to set different rules for different classes) data_path : str, optional the data path to run the prediction for file_paths : set, optional a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction for parameters_update : dict, optional the parameters update dictionary mode : {'all', 'test', 'val', 'train'} the subset of the data to make the prediction for (forced to 'all' if data_path is not None) force : bool, default False if True and an episode with name episode_name already exists, it will be overwritten (use with caution!) remove_saved_features : bool, default False if True, the dataset will be deleted. cut_annotated : bool, default False if True, annotated frames will be cut from the suggestions background_threshold : float, default 0.5 the threshold for background prediction

def suggest_intervals_with_similarity( self, suggestions_name: str, prediction_name: str, target_video_id: str, target_clip: str, target_start: int, target_end: int, min_length: int = 60, n_intervals: int = 5, force: bool = False):
2095    def suggest_intervals_with_similarity(
2096        self,
2097        suggestions_name: str,
2098        prediction_name: str,
2099        target_video_id: str,
2100        target_clip: str,
2101        target_start: int,
2102        target_end: int,
2103        min_length: int = 60,
2104        n_intervals: int = 5,
2105        force: bool = False,
2106    ):
2107        """
2108        Suggest intervals based on similarity to a target interval.
2109
2110        Parameters
2111        ----------
2112        suggestions_name : str
2113            Name of the suggestion.
2114        prediction_name : str
2115            Name of the prediction to use.
2116        target_video_id : str
2117            Video id of the target interval.
2118        target_clip : str
2119            Clip id of the target interval.
2120        target_start : int
2121            Start frame of the target interval.
2122        target_end : int
2123            End frame of the target interval.
2124        min_length : int, default 60
2125            Minimum length of the suggested intervals.
2126        n_intervals : int, default 5
2127            Number of suggested intervals.
2128        force : bool, default False
2129            If True, the suggestion is overwritten if it already exists.
2130
2131        """
2132        self._check_suggestions_validity(suggestions_name, force=force)
2133        print(f"SUGGESTION {suggestions_name}")
2134        score_dict = self._generate_similarity_score(
2135            prediction_name, target_video_id, target_clip, target_start, target_end
2136        )
2137        intervals = self._suggest_intervals_from_dict(
2138            score_dict, min_length, n_intervals
2139        )
2140        suggestions_path = os.path.join(
2141            self.project_path,
2142            "results",
2143            "suggestions",
2144            suggestions_name,
2145        )
2146        if not os.path.exists(suggestions_path):
2147            os.mkdir(suggestions_path)
2148        with open(
2149            os.path.join(suggestions_path, f"{suggestions_name}_al_points.pickle"), "wb"
2150        ) as f:
2151            pickle.dump(intervals, f)
2152        meta_parameters = {
2153            "prediction_name": prediction_name,
2154            "min_frames_suggestion": min_length,
2155            "n_intervals": n_intervals,
2156            "target_clip": target_clip,
2157            "target_start": target_start,
2158            "target_end": target_end,
2159        }
2160        self._save_suggestions(suggestions_name, {}, meta_parameters)
2161        print("\n")

Suggest intervals based on similarity to a target interval.

Parameters

suggestions_name : str Name of the suggestion. prediction_name : str Name of the prediction to use. target_video_id : str Video id of the target interval. target_clip : str Clip id of the target interval. target_start : int Start frame of the target interval. target_end : int End frame of the target interval. min_length : int, default 60 Minimum length of the suggested intervals. n_intervals : int, default 5 Number of suggested intervals. force : bool, default False If True, the suggestion is overwritten if it already exists.

def suggest_intervals_with_uncertainty( self, suggestions_name: str, episode_names: List, load_epochs: List = None, classes: List = None, n_frames: int = 10000, method: str = 'least_confidence', min_length: int = 60, augment_n: int = 0, data_path: str = None, file_paths: Set = None, parameters_update: Dict = None, mode: str = 'all', force: bool = False, remove_saved_features: bool = False) -> None:
2163    def suggest_intervals_with_uncertainty(
2164        self,
2165        suggestions_name: str,
2166        episode_names: List,
2167        load_epochs: List = None,
2168        classes: List = None,
2169        n_frames: int = 10000,
2170        method: str = "least_confidence",
2171        min_length: int = 60,
2172        augment_n: int = 0,
2173        data_path: str = None,
2174        file_paths: Set = None,
2175        parameters_update: Dict = None,
2176        mode: str = "all",
2177        force: bool = False,
2178        remove_saved_features: bool = False,
2179    ) -> None:
2180        """Generate an active learning file based on model uncertainty.
2181
2182        If you provide several episode names, the predicted probabilities will be averaged.
2183
2184        Parameters
2185        ----------
2186        suggestions_name : str
2187            the name of the suggestion
2188        episode_names : list
2189            a list of string episode names to load the models from
2190        load_epochs : list, optional
2191            a list of epoch indices to load the models from (if `None`, the last ones will be used)
2192        classes : list, optional
2193            a list of classes to look at (by default all)
2194        n_frames : int, default 10000
2195            the threshold total number of frames in the suggested intervals (in the end result it will most likely
2196            be slightly larger; it will only be smaller if the algorithm fails to find enough intervals
2197            with the set parameters)
2198        method : {"least_confidence", "entropy"}
2199            the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if
2200            `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`)
2201        min_length : int, default 60
2202            the minimum number of frames in one interval
2203        augment_n : int, default 0
2204            the number of augmentations to average the predictions over
2205        data_path : str, optional
2206            the path to a data folder (by default, the project data is used)
2207        file_paths : set, optional
2208            a list of file paths (by default, the project data is used)
2209        parameters_update : dict, optional
2210            a dictionary of parameter updates
2211        mode : {"test", "val", "train", "all"}
2212            the subset of the data to make the prediction for (forced to 'all' if `data_path` is not `None`;
2213            by default set to `'test'` if the test subset if not empty, or to `'val'` otherwise)
2214        force : bool, default False
2215            if `True`, existing suggestions with the same name will be overwritten
2216        remove_saved_features : bool, default False
2217            if `True`, the dataset will be deleted after the computation
2218
2219        """
2220        self._check_suggestions_validity(suggestions_name, force=force)
2221        print(f"SUGGESTION {suggestions_name}")
2222        task, parameters, mode, predicted, inference_time, behavior_dict = (
2223            self._make_prediction(
2224                suggestions_name,
2225                episode_names,
2226                load_epochs,
2227                parameters_update,
2228                data_path=data_path,
2229                file_paths=file_paths,
2230                mode=mode,
2231                augment_n=augment_n,
2232                evaluate=False,
2233            )
2234        )
2235        if classes is None:
2236            classes = behavior_dict.values()
2237        episode = self._episodes().get_runs(episode_names[0])[0]
2238        score_tensors = task.generate_uncertainty_score(
2239            classes,
2240            augment_n,
2241            method,
2242            predicted,
2243            self._episode(episode).get_behaviors_dict(),
2244        )
2245        intervals = self._suggest_intervals(
2246            task.dataset(mode), score_tensors, n_frames, min_length
2247        )
2248        for k, v in intervals.items():
2249            l = sum([x[1] - x[0] for x in v])
2250            print(f"{k}: {len(v)} ({l})")
2251        if remove_saved_features:
2252            self._remove_stores(parameters)
2253        suggestions_path = os.path.join(
2254            self.project_path,
2255            "results",
2256            "suggestions",
2257            suggestions_name,
2258        )
2259        if not os.path.exists(suggestions_path):
2260            os.mkdir(suggestions_path)
2261        with open(
2262            os.path.join(suggestions_path, f"{suggestions_name}_al_points.pickle"), "wb"
2263        ) as f:
2264            pickle.dump(intervals, f)
2265        meta_parameters = {
2266            "suggestion_episode": episode_names,
2267            "suggestion_load_epoch": load_epochs,
2268            "suggestion_classes": classes,
2269            "min_frames_suggestion": min_length,
2270            "augment_n": augment_n,
2271            "method": method,
2272            "num_frames": n_frames,
2273        }
2274        self._save_suggestions(suggestions_name, {}, meta_parameters)
2275        print("\n")

Generate an active learning file based on model uncertainty.

If you provide several episode names, the predicted probabilities will be averaged.

Parameters

suggestions_name : str the name of the suggestion episode_names : list a list of string episode names to load the models from load_epochs : list, optional a list of epoch indices to load the models from (if None, the last ones will be used) classes : list, optional a list of classes to look at (by default all) n_frames : int, default 10000 the threshold total number of frames in the suggested intervals (in the end result it will most likely be slightly larger; it will only be smaller if the algorithm fails to find enough intervals with the set parameters) method : {"least_confidence", "entropy"} the method used to calculate the scores from the probability predictions ("least_confidence": 1 - p_i if p_i > 0.5 or p_i if p_i < 0.5; "entropy": - p_i * log(p_i) - (1 - p_i) * log(1 - p_i)) min_length : int, default 60 the minimum number of frames in one interval augment_n : int, default 0 the number of augmentations to average the predictions over data_path : str, optional the path to a data folder (by default, the project data is used) file_paths : set, optional a list of file paths (by default, the project data is used) parameters_update : dict, optional a dictionary of parameter updates mode : {"test", "val", "train", "all"} the subset of the data to make the prediction for (forced to 'all' if data_path is not None; by default set to 'test' if the test subset if not empty, or to 'val' otherwise) force : bool, default False if True, existing suggestions with the same name will be overwritten remove_saved_features : bool, default False if True, the dataset will be deleted after the computation

def suggest_intervals_with_bald( self, suggestions_name: str, episode_name: str, load_epoch: int = None, classes: List = None, n_frames: int = 10000, num_models: int = 10, kernel_size: int = 11, min_length: int = 60, augment_n: int = 0, data_path: str = None, file_paths: Set = None, parameters_update: Dict = None, mode: str = 'all', force: bool = False, remove_saved_features: bool = False):
2277    def suggest_intervals_with_bald(
2278        self,
2279        suggestions_name: str,
2280        episode_name: str,
2281        load_epoch: int = None,
2282        classes: List = None,
2283        n_frames: int = 10000,
2284        num_models: int = 10,
2285        kernel_size: int = 11,
2286        min_length: int = 60,
2287        augment_n: int = 0,
2288        data_path: str = None,
2289        file_paths: Set = None,
2290        parameters_update: Dict = None,
2291        mode: str = "all",
2292        force: bool = False,
2293        remove_saved_features: bool = False,
2294    ):
2295        """Generate an active learning file based on Bayesian Active Learning by Disagreement.
2296
2297        Parameters
2298        ----------
2299        suggestions_name : str
2300            the name of the suggestion
2301        episode_name : str
2302            the name of the episode to load the model from
2303        load_epoch : int, optional
2304            the index of the epoch to load the model from (if `None`, the last one will be used)
2305        classes : list, optional
2306            a list of classes to look at (by default all)
2307        n_frames : int, default 10000
2308            the threshold total number of frames in the suggested intervals (in the end result it will most likely
2309            be slightly larger; it will only be smaller if the algorithm fails to find enough intervals
2310            with the set parameters)
2311        num_models : int, default 10
2312            the number of dropout masks to apply
2313        kernel_size : int, default 11
2314            the size of the smoothing kernel applied to the discrete results
2315        min_length : int, default 60
2316            the minimum number of frames in one interval
2317        augment_n : int, default 0
2318            the number of augmentations to average the predictions over
2319        data_path : str, optional
2320            the path to a data folder (by default, the project data is used)
2321        file_paths : set, optional
2322            a list of file paths (by default, the project data is used)
2323        parameters_update : dict, optional
2324            a dictionary of parameter updates
2325        mode : {"test", "val", "train", "all"}
2326            the subset of the data to make the prediction for (forced to 'all' if `data_path` is not `None`;
2327            by default set to `'test'` if the test subset if not empty, or to `'val'` otherwise)
2328        force : bool, default False
2329            if `True`, existing suggestions with the same name will be overwritten
2330        remove_saved_features : bool, default False
2331            if `True`, the dataset will be deleted after the computation
2332
2333        """
2334        self._check_suggestions_validity(suggestions_name, force=force)
2335        print(f"SUGGESTION {suggestions_name}")
2336        task, parameters, mode = self._make_task_prediction(
2337            suggestions_name,
2338            episode_name,
2339            parameters_update,
2340            load_epoch,
2341            data_path=data_path,
2342            file_paths=file_paths,
2343            mode=mode,
2344        )
2345        if classes is None:
2346            classes = list(task.behaviors_dict().values())
2347        score_tensors = task.generate_bald_score(
2348            classes, augment_n, num_models, kernel_size
2349        )
2350        intervals = self._suggest_intervals(
2351            task.dataset(mode), score_tensors, n_frames, min_length
2352        )
2353        if remove_saved_features:
2354            self._remove_stores(parameters)
2355        suggestions_path = os.path.join(
2356            self.project_path,
2357            "results",
2358            "suggestions",
2359            suggestions_name,
2360        )
2361        if not os.path.exists(suggestions_path):
2362            os.mkdir(suggestions_path)
2363        with open(
2364            os.path.join(suggestions_path, f"{suggestions_name}_al_points.pickle"), "wb"
2365        ) as f:
2366            pickle.dump(intervals, f)
2367        meta_parameters = {
2368            "suggestion_episode": episode_name,
2369            "suggestion_load_epoch": load_epoch,
2370            "suggestion_classes": classes,
2371            "min_frames_suggestion": min_length,
2372            "augment_n": augment_n,
2373            "method": f"BALD:{num_models}",
2374            "num_frames": n_frames,
2375        }
2376        self._save_suggestions(suggestions_name, {}, meta_parameters)
2377        print("\n")

Generate an active learning file based on Bayesian Active Learning by Disagreement.

Parameters

suggestions_name : str the name of the suggestion episode_name : str the name of the episode to load the model from load_epoch : int, optional the index of the epoch to load the model from (if None, the last one will be used) classes : list, optional a list of classes to look at (by default all) n_frames : int, default 10000 the threshold total number of frames in the suggested intervals (in the end result it will most likely be slightly larger; it will only be smaller if the algorithm fails to find enough intervals with the set parameters) num_models : int, default 10 the number of dropout masks to apply kernel_size : int, default 11 the size of the smoothing kernel applied to the discrete results min_length : int, default 60 the minimum number of frames in one interval augment_n : int, default 0 the number of augmentations to average the predictions over data_path : str, optional the path to a data folder (by default, the project data is used) file_paths : set, optional a list of file paths (by default, the project data is used) parameters_update : dict, optional a dictionary of parameter updates mode : {"test", "val", "train", "all"} the subset of the data to make the prediction for (forced to 'all' if data_path is not None; by default set to 'test' if the test subset if not empty, or to 'val' otherwise) force : bool, default False if True, existing suggestions with the same name will be overwritten remove_saved_features : bool, default False if True, the dataset will be deleted after the computation

def list_episodes( self, episode_names: List = None, value_filter: str = '', display_parameters: List = None, print_results: bool = True) -> pandas.core.frame.DataFrame:
2379    def list_episodes(
2380        self,
2381        episode_names: List = None,
2382        value_filter: str = "",
2383        display_parameters: List = None,
2384        print_results: bool = True,
2385    ) -> pd.DataFrame:
2386        """Get a filtered pandas dataframe with episode metadata.
2387
2388        Parameters
2389        ----------
2390        episode_names : list
2391            a list of strings of episode names
2392        value_filter : str
2393            a string of filters to apply; of this general structure:
2394            'group_name1/par_name1::(</>/<=/>=/=)value1,group_name2/par_name2::(</>/<=/>=/=)value2', e.g.
2395            'data/overlap::=50,results/recall::>0.5,data/feature_extraction::=kinematic,meta/training_time::>=00:00:10'
2396        display_parameters : list
2397            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
2398        print_results : bool, default True
2399            if True, the result will be printed to standard output
2400
2401        Returns
2402        -------
2403        pd.DataFrame
2404            the filtered dataframe
2405
2406        """
2407        episodes = self._episodes().list_episodes(
2408            episode_names, value_filter, display_parameters
2409        )
2410        if print_results:
2411            print("TRAINING EPISODES")
2412            print(episodes)
2413            print("\n")
2414        return episodes

Get a filtered pandas dataframe with episode metadata.

Parameters

episode_names : list a list of strings of episode names value_filter : str a string of filters to apply; of this general structure: 'group_name1/par_name1::(/<=/>=/=)value1,group_name2/par_name2::(/<=/>=/=)value2', e.g. 'data/overlap::=50,results/recall::>0.5,data/feature_extraction::=kinematic,meta/training_time::>=00:00:10' display_parameters : list list of parameters to display (e.g. ['data/overlap', 'results/recall']) print_results : bool, default True if True, the result will be printed to standard output

Returns

pd.DataFrame the filtered dataframe

def list_predictions( self, episode_names: List = None, value_filter: str = '', display_parameters: List = None, print_results: bool = True) -> pandas.core.frame.DataFrame:
2416    def list_predictions(
2417        self,
2418        episode_names: List = None,
2419        value_filter: str = "",
2420        display_parameters: List = None,
2421        print_results: bool = True,
2422    ) -> pd.DataFrame:
2423        """Get a filtered pandas dataframe with prediction metadata.
2424
2425        Parameters
2426        ----------
2427        episode_names : list
2428            a list of strings of episode names
2429        value_filter : str
2430            a string of filters to apply; of this general structure:
2431            'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g.
2432            'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic'
2433        display_parameters : list
2434            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
2435        print_results : bool, default True
2436            if True, the result will be printed to standard output
2437
2438        Returns
2439        -------
2440        pd.DataFrame
2441            the filtered dataframe
2442
2443        """
2444        predictions = self._predictions().list_episodes(
2445            episode_names, value_filter, display_parameters
2446        )
2447        if print_results:
2448            print("PREDICTIONS")
2449            print(predictions)
2450            print("\n")
2451        return predictions

Get a filtered pandas dataframe with prediction metadata.

Parameters

episode_names : list a list of strings of episode names value_filter : str a string of filters to apply; of this general structure: 'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g. 'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic' display_parameters : list list of parameters to display (e.g. ['data/overlap', 'results/recall']) print_results : bool, default True if True, the result will be printed to standard output

Returns

pd.DataFrame the filtered dataframe

def list_suggestions( self, suggestions_names: List = None, value_filter: str = '', display_parameters: List = None, print_results: bool = True) -> pandas.core.frame.DataFrame:
2453    def list_suggestions(
2454        self,
2455        suggestions_names: List = None,
2456        value_filter: str = "",
2457        display_parameters: List = None,
2458        print_results: bool = True,
2459    ) -> pd.DataFrame:
2460        """Get a filtered pandas dataframe with prediction metadata.
2461
2462        Parameters
2463        ----------
2464        suggestions_names : list
2465            a list of strings of suggestion names
2466        value_filter : str
2467            a string of filters to apply; of this general structure:
2468            'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g.
2469            'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic'
2470        display_parameters : list
2471            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
2472        print_results : bool, default True
2473            if True, the result will be printed to standard output
2474
2475        Returns
2476        -------
2477        pd.DataFrame
2478            the filtered dataframe
2479
2480        """
2481        suggestions = self._suggestions().list_episodes(
2482            suggestions_names, value_filter, display_parameters
2483        )
2484        if print_results:
2485            print("SUGGESTIONS")
2486            print(suggestions)
2487            print("\n")
2488        return suggestions

Get a filtered pandas dataframe with prediction metadata.

Parameters

suggestions_names : list a list of strings of suggestion names value_filter : str a string of filters to apply; of this general structure: 'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g. 'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic' display_parameters : list list of parameters to display (e.g. ['data/overlap', 'results/recall']) print_results : bool, default True if True, the result will be printed to standard output

Returns

pd.DataFrame the filtered dataframe

def list_searches( self, search_names: List = None, value_filter: str = '', display_parameters: List = None, print_results: bool = True) -> pandas.core.frame.DataFrame:
2490    def list_searches(
2491        self,
2492        search_names: List = None,
2493        value_filter: str = "",
2494        display_parameters: List = None,
2495        print_results: bool = True,
2496    ) -> pd.DataFrame:
2497        """Get a filtered pandas dataframe with hyperparameter search metadata.
2498
2499        Parameters
2500        ----------
2501        search_names : list
2502            a list of strings of search names
2503        value_filter : str
2504            a string of filters to apply; of this general structure:
2505            'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g.
2506            'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic'
2507        display_parameters : list
2508            list of parameters to display (e.g. ['data/overlap', 'results/recall'])
2509        print_results : bool, default True
2510            if True, the result will be printed to standard output
2511
2512        Returns
2513        -------
2514        pd.DataFrame
2515            the filtered dataframe
2516
2517        """
2518        searches = self._searches().list_episodes(
2519            search_names, value_filter, display_parameters
2520        )
2521        if print_results:
2522            print("SEARCHES")
2523            print(searches)
2524            print("\n")
2525        return searches

Get a filtered pandas dataframe with hyperparameter search metadata.

Parameters

search_names : list a list of strings of search names value_filter : str a string of filters to apply; of this general structure: 'group_name1/par_name1:(<>=)value1,group_name2/par_name2:(<>=)value2', e.g. 'data/overlap:=50,results/recall:>0.5,data/feature_extraction:=kinematic' display_parameters : list list of parameters to display (e.g. ['data/overlap', 'results/recall']) print_results : bool, default True if True, the result will be printed to standard output

Returns

pd.DataFrame the filtered dataframe

def get_best_parameters(self, search_name: str, round_to_binary: List = None):
2527    def get_best_parameters(
2528        self,
2529        search_name: str,
2530        round_to_binary: List = None,
2531    ):
2532        """Get the best parameters found by a search.
2533
2534        Parameters
2535        ----------
2536        search_name : str
2537            the name of the search
2538        round_to_binary : list, default None
2539            a list of parameters to round to binary values
2540
2541        Returns
2542        -------
2543        best_params : dict
2544            a dictionary of the best parameters where the keys are in '{group}/{name}' format
2545
2546        """
2547        params, model = self._searches().get_best_params(
2548            search_name, round_to_binary=round_to_binary
2549        )
2550        params = self._update(params, {"general": {"model_name": model}})
2551        return params

Get the best parameters found by a search.

Parameters

search_name : str the name of the search round_to_binary : list, default None a list of parameters to round to binary values

Returns

best_params : dict a dictionary of the best parameters where the keys are in '{group}/{name}' format

def list_best_parameters(self, search_name: str, print_results: bool = True) -> Dict:
2553    def list_best_parameters(
2554        self, search_name: str, print_results: bool = True
2555    ) -> Dict:
2556        """Get the raw dictionary of best parameters found by a search.
2557
2558        Parameters
2559        ----------
2560        search_name : str
2561            the name of the search
2562        print_results : bool, default True
2563            if True, the result will be printed to standard output
2564
2565        Returns
2566        -------
2567        best_params : dict
2568            a dictionary of the best parameters where the keys are in '{group}/{name}' format
2569
2570        """
2571        params = self._searches().get_best_params_raw(search_name)
2572        if print_results:
2573            print(f"SEARCH RESULTS {search_name}")
2574            for k, v in params.items():
2575                print(f"{k}: {v}")
2576            print("\n")
2577        return params

Get the raw dictionary of best parameters found by a search.

Parameters

search_name : str the name of the search print_results : bool, default True if True, the result will be printed to standard output

Returns

best_params : dict a dictionary of the best parameters where the keys are in '{group}/{name}' format

def plot_episodes( self, episode_names: List, metrics: Union[List, str], modes: Union[List, str] = None, title: str = None, episode_labels: List = None, save_path: str = None, add_hlines: List = None, epoch_limits: List = None, colors: List = None, add_highpoint_hlines: bool = False, remove_box: bool = False, font_size: float = None, linewidth: float = None, return_ax: bool = False) -> None:
2579    def plot_episodes(
2580        self,
2581        episode_names: List,
2582        metrics: List | str,
2583        modes: List | str = None,
2584        title: str = None,
2585        episode_labels: List = None,
2586        save_path: str = None,
2587        add_hlines: List = None,
2588        epoch_limits: List = None,
2589        colors: List = None,
2590        add_highpoint_hlines: bool = False,
2591        remove_box: bool = False,
2592        font_size: float = None,
2593        linewidth: float = None,
2594        return_ax: bool = False,
2595    ) -> None:
2596        """Plot episode training curves.
2597
2598        Parameters
2599        ----------
2600        episode_names : list
2601            a list of episode names to plot; to plot to episodes in one line combine them in a list
2602            (e.g. ['episode1', ['episode2', 'episode3']] to plot episode2 and episode3 as one experiment)
2603        metrics : list
2604            a list of metric to plot
2605        modes : list, optional
2606            a list of modes to plot ('train' and/or 'val'; `['val']` by default)
2607        title : str, optional
2608            title for the plot
2609        episode_labels : list, optional
2610            a list of strings used to label the curves (has to be the same length as episode_names)
2611        save_path : str, optional
2612            the path to save the resulting plot
2613        add_hlines : list, optional
2614            a list of float values (or (value, label) tuples) to mark with horizontal lines
2615        epoch_limits : list, optional
2616            a list of (min, max) tuples to set the x-axis limits for each episode
2617        colors: list, optional
2618            a list of matplotlib colors
2619        add_highpoint_hlines : bool, default False
2620            if `True`, horizontal lines will be added at the highest value of each episode
2621        """
2622
2623        if isinstance(metrics, str):
2624            metrics = [metrics]
2625        if isinstance(modes, str):
2626            modes = [modes]
2627
2628        if font_size is not None:
2629            font = {"size": font_size}
2630            rc("font", **font)
2631        if modes is None:
2632            modes = ["val"]
2633        if add_hlines is None:
2634            add_hlines = []
2635        logs = []
2636        epochs = []
2637        labels = []
2638        if episode_labels is not None:
2639            assert len(episode_labels) == len(episode_names)
2640        for name_i, name in enumerate(episode_names):
2641            log_params = product(metrics, modes)
2642            for metric, mode in log_params:
2643                if episode_labels is not None:
2644                    label = episode_labels[name_i]
2645                else:
2646                    label = deepcopy(name)
2647                if len(modes) != 1:
2648                    label += f"_{mode}"
2649                if len(metrics) != 1:
2650                    label += f"_{metric}"
2651                labels.append(label)
2652                if isinstance(name, Iterable) and not isinstance(name, str):
2653                    epoch_list = defaultdict(lambda: [])
2654                    multi_logs = defaultdict(lambda: [])
2655                    for i, n in enumerate(name):
2656                        runs = self._episodes().get_runs(n)
2657                        if len(runs) > 1:
2658                            for run in runs:
2659                                if "::" in run:
2660                                    index = run.split("::")[-1]
2661                                else:
2662                                    index = run.split("#")[-1]
2663                                if multi_logs[index] == []:
2664                                    if multi_logs["null"] is None:
2665                                        raise RuntimeError(
2666                                            "The run indices are not consistent across episodes!"
2667                                        )
2668                                    else:
2669                                        multi_logs[index] += multi_logs["null"]
2670                                multi_logs[index] += list(
2671                                    self._episode(run).get_metric_log(mode, metric)
2672                                )
2673                                start = (
2674                                    0
2675                                    if len(epoch_list[index]) == 0
2676                                    else epoch_list[index][-1]
2677                                )
2678                                epoch_list[index] += [
2679                                    x + start
2680                                    for x in self._episode(run).get_epoch_list(mode)
2681                                ]
2682                            multi_logs["null"] = None
2683                        else:
2684                            if len(multi_logs.keys()) > 1:
2685                                raise RuntimeError(
2686                                    "Cannot plot a single-run episode after a multi-run episode!"
2687                                )
2688                            multi_logs["null"] += list(
2689                                self._episode(n).get_metric_log(mode, metric)
2690                            )
2691                            start = (
2692                                0
2693                                if len(epoch_list["null"]) == 0
2694                                else epoch_list["null"][-1]
2695                            )
2696                            epoch_list["null"] += [
2697                                x + start for x in self._episode(n).get_epoch_list(mode)
2698                            ]
2699                    if len(multi_logs.keys()) == 1:
2700                        log = multi_logs["null"]
2701                        epochs.append(epoch_list["null"])
2702                    else:
2703                        log = tuple([v for k, v in multi_logs.items() if k != "null"])
2704                        epochs.append(
2705                            tuple([v for k, v in epoch_list.items() if k != "null"])
2706                        )
2707                else:
2708                    runs = self._episodes().get_runs(name)
2709                    if len(runs) > 1:
2710                        log = []
2711                        for run in runs:
2712                            tracked_metrics = self._episode(run).get_metrics()
2713                            if metric in tracked_metrics:
2714                                log.append(
2715                                    list(
2716                                        self._episode(run).get_metric_log(mode, metric)
2717                                    )
2718                                )
2719                            else:
2720                                relevant = []
2721                                for m in tracked_metrics:
2722                                    m_split = m.split("_")
2723                                    if (
2724                                        "_".join(m_split[:-1]) == metric
2725                                        and m_split[-1].isnumeric()
2726                                    ):
2727                                        relevant.append(m)
2728                                if len(relevant) == 0:
2729                                    raise ValueError(
2730                                        f"The {metric} metric was not tracked at {run}"
2731                                    )
2732                                arr = 0
2733                                for m in relevant:
2734                                    arr += self._episode(run).get_metric_log(mode, m)
2735                                arr /= len(relevant)
2736                                log.append(list(arr))
2737                        log = tuple(log)
2738                        epochs.append(
2739                            tuple(
2740                                [
2741                                    self._episode(run).get_epoch_list(mode)
2742                                    for run in runs
2743                                ]
2744                            )
2745                        )
2746                    else:
2747                        tracked_metrics = self._episode(name).get_metrics()
2748                        if metric in tracked_metrics:
2749                            log = list(self._episode(name).get_metric_log(mode, metric))
2750                        else:
2751                            relevant = []
2752                            for m in tracked_metrics:
2753                                m_split = m.split("_")
2754                                if (
2755                                    "_".join(m_split[:-1]) == metric
2756                                    and m_split[-1].isnumeric()
2757                                ):
2758                                    relevant.append(m)
2759                            if len(relevant) == 0:
2760                                raise ValueError(
2761                                    f"The {metric} metric was not tracked at {name}"
2762                                )
2763                            arr = 0
2764                            for m in relevant:
2765                                arr += self._episode(name).get_metric_log(mode, m)
2766                            arr /= len(relevant)
2767                            log = list(arr)
2768                        epochs.append(self._episode(name).get_epoch_list(mode))
2769                logs.append(log)
2770        # if episode_labels is not None:
2771        #     print(f'{len(episode_labels)=}, {len(logs)=}')
2772        #     if len(episode_labels) != len(logs):
2773
2774        #         raise ValueError(
2775        #             f"The length of episode_labels ({len(episode_labels)}) has to be equal to the length of "
2776        #             f"curves ({len(logs)})!"
2777        #         )
2778        #     else:
2779        #         labels = episode_labels
2780        if colors is None:
2781            colors = cm.rainbow(np.linspace(0, 1, len(logs)))
2782        if len(colors) != len(logs):
2783            raise ValueError(
2784                "The length of colors has to be equal to the length of curves (metrics * modes * episode_names)!"
2785            )
2786        f, ax = plt.subplots()
2787        length = 0
2788        for log, label, color, epoch_list in zip(logs, labels, colors, epochs):
2789            if type(log) is list:
2790                if len(log) > length:
2791                    length = len(log)
2792                ax.plot(
2793                    epoch_list,
2794                    log,
2795                    label=label,
2796                    color=color,
2797                )
2798                if add_highpoint_hlines:
2799                    ax.axhline(np.max(log), linestyle="dashed", color=color)
2800            else:
2801                for l, xx in zip(log, epoch_list):
2802                    if len(l) > length:
2803                        length = len(l)
2804                    ax.plot(
2805                        xx,
2806                        l,
2807                        color=color,
2808                        alpha=0.2,
2809                    )
2810                if not all([len(x) == len(log[0]) for x in log]):
2811                    warnings.warn(
2812                        f"Got logs with unequal lengths in parallel runs for {label}"
2813                    )
2814                    log = list(log)
2815                    epoch_list = list(epoch_list)
2816                    for i, x in enumerate(epoch_list):
2817                        to_remove = []
2818                        for j, y in enumerate(x[1:]):
2819                            if y <= x[j - 1]:
2820                                y_ind = x.index(y)
2821                                to_remove += list(range(y_ind, j))
2822                        epoch_list[i] = [
2823                            y for j, y in enumerate(x) if j not in to_remove
2824                        ]
2825                        log[i] = [y for j, y in enumerate(log[i]) if j not in to_remove]
2826                    length = min([len(x) for x in log])
2827                    for i in range(len(log)):
2828                        log[i] = log[i][:length]
2829                        epoch_list[i] = epoch_list[i][:length]
2830                    if not all([x == epoch_list[0] for x in epoch_list]):
2831                        raise RuntimeError(
2832                            f"Got different epoch indices in parallel runs for {label}"
2833                        )
2834                mean = np.array(log).mean(0)
2835                ax.plot(
2836                    epoch_list[0],
2837                    mean,
2838                    label=label,
2839                    color=color,
2840                    linewidth=linewidth,
2841                )
2842                if add_highpoint_hlines:
2843                    ax.axhline(np.max(mean), linestyle="dashed", color=color)
2844        for x in add_hlines:
2845            label = None
2846            if isinstance(x, Iterable):
2847                x, label = x
2848            ax.axhline(x, label=label)
2849            ax.set_xlim((0, length))
2850
2851        ax.legend()
2852        ax.set_xlabel("epochs")
2853        if len(metrics) == 1:
2854            ax.set_ylabel(metrics[0])
2855        else:
2856            ax.set_ylabel("value")
2857        if title is None:
2858            if len(episode_names) == 1:
2859                title = episode_names[0]
2860            elif len(metrics) == 1:
2861                title = metrics[0]
2862        if epoch_limits is not None:
2863            ax.set_xlim(epoch_limits)
2864        if title is not None:
2865            ax.set_title(title)
2866        if remove_box:
2867            ax.box(False)
2868        if return_ax:
2869            return ax
2870        if save_path is not None:
2871            plt.savefig(save_path)
2872        plt.show()

Plot episode training curves.

Parameters

episode_names : list a list of episode names to plot; to plot to episodes in one line combine them in a list (e.g. ['episode1', ['episode2', 'episode3']] to plot episode2 and episode3 as one experiment) metrics : list a list of metric to plot modes : list, optional a list of modes to plot ('train' and/or 'val'; ['val'] by default) title : str, optional title for the plot episode_labels : list, optional a list of strings used to label the curves (has to be the same length as episode_names) save_path : str, optional the path to save the resulting plot add_hlines : list, optional a list of float values (or (value, label) tuples) to mark with horizontal lines epoch_limits : list, optional a list of (min, max) tuples to set the x-axis limits for each episode colors: list, optional a list of matplotlib colors add_highpoint_hlines : bool, default False if True, horizontal lines will be added at the highest value of each episode

def update_parameters( self, parameters_update: Dict = None, load_search: str = None, load_parameters: List = None, round_to_binary: List = None) -> None:
2874    def update_parameters(
2875        self,
2876        parameters_update: Dict = None,
2877        load_search: str = None,
2878        load_parameters: List = None,
2879        round_to_binary: List = None,
2880    ) -> None:
2881        """Update the parameters in the project config files.
2882
2883        Parameters
2884        ----------
2885        parameters_update : dict, optional
2886            a dictionary of parameter updates
2887        load_search : str, optional
2888            the name of hyperparameter search results to load to config
2889        load_parameters : list, optional
2890            a list of lists of string names of the parameters to load from the searches
2891        round_to_binary : list, optional
2892            a list of string names of the loaded parameters that should be rounded to the nearest power of two
2893
2894        """
2895        keys = [
2896            "general",
2897            "losses",
2898            "metrics",
2899            "ssl",
2900            "training",
2901            "data",
2902        ]
2903        parameters = self._read_parameters(catch_blanks=False)
2904        if parameters_update is not None:
2905            model_params = (
2906                parameters_update.pop("model") if "model" in parameters_update else None
2907            )
2908            feat_params = (
2909                parameters_update.pop("features")
2910                if "features" in parameters_update
2911                else None
2912            )
2913            aug_params = (
2914                parameters_update.pop("augmentations")
2915                if "augmentations" in parameters_update
2916                else None
2917            )
2918
2919            parameters = self._update(parameters, parameters_update)
2920            model_name = parameters["general"]["model_name"]
2921            parameters["model"] = self._open_yaml(
2922                os.path.join(self.project_path, "config", "model", f"{model_name}.yaml")
2923            )
2924            if model_params is not None:
2925                parameters["model"] = self._update(parameters["model"], model_params)
2926            feat_name = parameters["general"]["feature_extraction"]
2927            parameters["features"] = self._open_yaml(
2928                os.path.join(
2929                    self.project_path, "config", "features", f"{feat_name}.yaml"
2930                )
2931            )
2932            if feat_params is not None:
2933                parameters["features"] = self._update(
2934                    parameters["features"], feat_params
2935                )
2936            aug_name = options.extractor_to_transformer[
2937                parameters["general"]["feature_extraction"]
2938            ]
2939            parameters["augmentations"] = self._open_yaml(
2940                os.path.join(
2941                    self.project_path, "config", "augmentations", f"{aug_name}.yaml"
2942                )
2943            )
2944            if aug_params is not None:
2945                parameters["augmentations"] = self._update(
2946                    parameters["augmentations"], aug_params
2947                )
2948        if load_search is not None:
2949            parameters_update, model_name = self._searches().get_best_params(
2950                load_search, load_parameters, round_to_binary
2951            )
2952            parameters["general"]["model_name"] = model_name
2953            parameters["model"] = self._open_yaml(
2954                os.path.join(self.project_path, "config", "model", f"{model_name}.yaml")
2955            )
2956            parameters = self._update(parameters, parameters_update)
2957        for key in keys:
2958            with open(
2959                os.path.join(self.project_path, "config", f"{key}.yaml"),
2960                "w",
2961                encoding="utf-8",
2962            ) as f:
2963                YAML().dump(parameters[key], f)
2964        model_name = parameters["general"]["model_name"]
2965        model_path = os.path.join(
2966            self.project_path, "config", "model", f"{model_name}.yaml"
2967        )
2968        with open(model_path, "w", encoding="utf-8") as f:
2969            YAML().dump(parameters["model"], f)
2970        features_name = parameters["general"]["feature_extraction"]
2971        features_path = os.path.join(
2972            self.project_path, "config", "features", f"{features_name}.yaml"
2973        )
2974        with open(features_path, "w", encoding="utf-8") as f:
2975            YAML().dump(parameters["features"], f)
2976        aug_name = options.extractor_to_transformer[features_name]
2977        aug_path = os.path.join(
2978            self.project_path, "config", "augmentations", f"{aug_name}.yaml"
2979        )
2980        with open(aug_path, "w", encoding="utf-8") as f:
2981            YAML().dump(parameters["augmentations"], f)

Update the parameters in the project config files.

Parameters

parameters_update : dict, optional a dictionary of parameter updates load_search : str, optional the name of hyperparameter search results to load to config load_parameters : list, optional a list of lists of string names of the parameters to load from the searches round_to_binary : list, optional a list of string names of the loaded parameters that should be rounded to the nearest power of two

def get_summary( self, episode_names: list, method: str = 'last', average: int = 1, metrics: List = None, return_values: bool = False) -> Dict:
2983    def get_summary(
2984        self,
2985        episode_names: list,
2986        method: str = "last",
2987        average: int = 1,
2988        metrics: List = None,
2989        return_values: bool = False,
2990    ) -> Dict:
2991        """Get a summary of episode statistics.
2992
2993        If an episode has multiple runs, the statistics will be aggregated over all of them.
2994
2995        Parameters
2996        ----------
2997        episode_names : str
2998            the names of the episodes
2999        method : ["best", "last"]
3000            the method for choosing the epochs
3001        average : int, default 1
3002            the number of epochs to average over (for each run)
3003        metrics : list, optional
3004            a list of metrics
3005
3006        Returns
3007        -------
3008        statistics : dict
3009            a nested dictionary where first-level keys are metric names and second-level keys are 'mean' for the mean
3010            and 'std' for the standard deviation
3011
3012        """
3013        runs = []
3014        for episode_name in episode_names:
3015            runs_ep = self._episodes().get_runs(episode_name)
3016            if len(runs_ep) == 0:
3017                raise RuntimeError(
3018                    f"There is no {episode_name} episode in the project memory"
3019                )
3020            runs += runs_ep
3021        if metrics is None:
3022            metrics = self._episode(runs[0]).get_metrics()
3023
3024        values = {m: [] for m in metrics}
3025        for run in runs:
3026            for m in metrics:
3027                log = self._episode(run).get_metric_log(mode="val", metric_name=m)
3028                if method == "best":
3029                    log = sorted(log)
3030                    values[m] += list(log[-average:])
3031                elif method == "last":
3032                    if len(log) == 0:
3033                        episodes = self._episodes().data
3034                        if average == 1 and ("results", m) in episodes.columns:
3035                            values[m] += [episodes.loc[run, ("results", m)]]
3036                        else:
3037                            raise RuntimeError(f"Did not find {m} metric for {run} run")
3038                    values[m] += list(log[-average:])
3039                elif method.startswith("epoch"):
3040                    epoch = int(method[5:]) - 1
3041                    pars = self._episodes().load_parameters(run)
3042                    step = int(pars["training"]["validation_interval"])
3043                    values[m] += [log[epoch // step]]
3044                else:
3045                    raise ValueError(
3046                        f"The {method} method is not recognized! Please choose from ['last', 'best', 'epoch...']"
3047                    )
3048        statistics = defaultdict(lambda: {})
3049        for m, v in values.items():
3050            statistics[m]["mean"] = np.mean(v)
3051            statistics[m]["std"] = np.std(v)
3052        print(f"SUMMARY {episode_names}")
3053        for m, v in statistics.items():
3054            print(f'{m}: mean {v["mean"]:.3f}, std {v["std"]:.3f}')
3055        print("\n")
3056
3057        return (dict(statistics), values) if return_values else dict(statistics)

Get a summary of episode statistics.

If an episode has multiple runs, the statistics will be aggregated over all of them.

Parameters

episode_names : str the names of the episodes method : ["best", "last"] the method for choosing the epochs average : int, default 1 the number of epochs to average over (for each run) metrics : list, optional a list of metrics

Returns

statistics : dict a nested dictionary where first-level keys are metric names and second-level keys are 'mean' for the mean and 'std' for the standard deviation

@staticmethod
def remove_project(name: str, projects_path: str = None) -> None:
3059    @staticmethod
3060    def remove_project(name: str, projects_path: str = None) -> None:
3061        """Remove all project files and experiment records and results.
3062
3063        Parameters
3064        ----------
3065        name : str
3066            the name of the project to remove
3067        projects_path : str, optional
3068            the path to the projects directory (by default the home DLC2Action directory)
3069
3070        """
3071        if projects_path is None:
3072            projects_path = os.path.join(str(Path.home()), "DLC2Action")
3073        project_path = os.path.join(projects_path, name)
3074        if os.path.exists(project_path):
3075            shutil.rmtree(project_path)

Remove all project files and experiment records and results.

Parameters

name : str the name of the project to remove projects_path : str, optional the path to the projects directory (by default the home DLC2Action directory)

def remove_saved_features( self, dataset_names: List = None, exceptions: List = None, remove_active: bool = False) -> None:
3077    def remove_saved_features(
3078        self,
3079        dataset_names: List = None,
3080        exceptions: List = None,
3081        remove_active: bool = False,
3082    ) -> None:
3083        """Remove saved pre-computed dataset feature files.
3084
3085        By default, all features will be deleted.
3086        No essential information can get lost, storing them only saves time. Be careful with deleting datasets
3087        while training or inference is happening though.
3088
3089        Parameters
3090        ----------
3091        dataset_names : list, optional
3092            a list of dataset names to delete (by default all names are added)
3093        exceptions : list, optional
3094            a list of dataset names to not be deleted
3095        remove_active : bool, default False
3096            if `False`, datasets used by unfinished episodes will not be deleted
3097
3098        """
3099        print("Removing datasets...")
3100        if dataset_names is None:
3101            dataset_names = []
3102        if exceptions is None:
3103            exceptions = []
3104        if not remove_active:
3105            exceptions += self._episodes().get_active_datasets()
3106        dataset_path = os.path.join(self.project_path, "saved_datasets")
3107        if os.path.exists(dataset_path):
3108            if dataset_names == []:
3109                dataset_names = set([f.split(".")[0] for f in os.listdir(dataset_path)])
3110
3111            to_remove = [
3112                x
3113                for x in dataset_names
3114                if os.path.exists(os.path.join(dataset_path, x)) and x not in exceptions
3115            ]
3116            if len(to_remove) > 2:
3117                to_remove = tqdm(to_remove)
3118            for dataset in to_remove:
3119                shutil.rmtree(os.path.join(dataset_path, dataset))
3120            to_remove = [
3121                f"{x}.pickle"
3122                for x in dataset_names
3123                if os.path.exists(os.path.join(dataset_path, f"{x}.pickle"))
3124                and x not in exceptions
3125            ]
3126            for dataset in to_remove:
3127                os.remove(os.path.join(dataset_path, dataset))
3128            names = self._saved_datasets().dataset_names()
3129            self._saved_datasets().remove(names)
3130        print("\n")

Remove saved pre-computed dataset feature files.

By default, all features will be deleted. No essential information can get lost, storing them only saves time. Be careful with deleting datasets while training or inference is happening though.

Parameters

dataset_names : list, optional a list of dataset names to delete (by default all names are added) exceptions : list, optional a list of dataset names to not be deleted remove_active : bool, default False if False, datasets used by unfinished episodes will not be deleted

def remove_extra_checkpoints(self, episode_names: List = None, exceptions: List = None) -> None:
3132    def remove_extra_checkpoints(
3133        self, episode_names: List = None, exceptions: List = None
3134    ) -> None:
3135        """Remove intermediate model checkpoint files (only leave the files for the last epoch).
3136
3137        By default, all intermediate checkpoints will be deleted.
3138        Files in the model folder that are not associated with any record in the meta files are also deleted.
3139
3140        Parameters
3141        ----------
3142        episode_names : list, optional
3143            a list of episode names to clean (by default all names are added)
3144        exceptions : list, optional
3145            a list of episode names to not clean
3146
3147        """
3148        model_path = os.path.join(self.project_path, "results", "model")
3149        try:
3150            all_names = self._episodes().data.index
3151        except:
3152            all_names = os.listdir(model_path)
3153        if episode_names is None:
3154            episode_names = all_names
3155        if exceptions is None:
3156            exceptions = []
3157        to_remove = [x for x in episode_names if x not in exceptions]
3158        folders = os.listdir(model_path)
3159        for folder in folders:
3160            if folder not in all_names:
3161                shutil.rmtree(os.path.join(model_path, folder))
3162            elif folder in to_remove:
3163                files = os.listdir(os.path.join(model_path, folder))
3164                for file in sorted(files)[:-1]:
3165                    os.remove(os.path.join(model_path, folder, file))

Remove intermediate model checkpoint files (only leave the files for the last epoch).

By default, all intermediate checkpoints will be deleted. Files in the model folder that are not associated with any record in the meta files are also deleted.

Parameters

episode_names : list, optional a list of episode names to clean (by default all names are added) exceptions : list, optional a list of episode names to not clean

def remove_suggestion(self, suggestion_name: str) -> None:
3181    def remove_suggestion(self, suggestion_name: str) -> None:
3182        """Remove a suggestion record.
3183
3184        Parameters
3185        ----------
3186        suggestion_name : str
3187            the name of the suggestion to remove
3188
3189        """
3190        self._suggestions().remove_episode(suggestion_name)
3191        suggestion_path = os.path.join(
3192            self.project_path, "results", "suggestions", suggestion_name
3193        )
3194        if os.path.exists(suggestion_path):
3195            shutil.rmtree(suggestion_path)

Remove a suggestion record.

Parameters

suggestion_name : str the name of the suggestion to remove

def remove_prediction(self, prediction_name: str) -> None:
3197    def remove_prediction(self, prediction_name: str) -> None:
3198        """Remove a prediction record.
3199
3200        Parameters
3201        ----------
3202        prediction_name : str
3203            the name of the prediction to remove
3204
3205        """
3206        self._predictions().remove_episode(prediction_name)
3207        prediction_path = self.prediction_path(prediction_name)
3208        if os.path.exists(prediction_path):
3209            shutil.rmtree(prediction_path)

Remove a prediction record.

Parameters

prediction_name : str the name of the prediction to remove

def check_prediction_exists(self, prediction_name: str) -> str | None:
3211    def check_prediction_exists(self, prediction_name: str) -> str | None:
3212        """Check if a prediction exists.
3213
3214        Parameters
3215        ----------
3216        prediction_name : str
3217            the name of the prediction to check
3218
3219        Returns
3220        -------
3221        str | None
3222            the path to the prediction if it exists, `None` otherwise
3223
3224        """
3225        prediction_path = self.prediction_path(prediction_name)
3226        if os.path.exists(prediction_path):
3227            return prediction_path
3228        return None

Check if a prediction exists.

Parameters

prediction_name : str the name of the prediction to check

Returns

str | None the path to the prediction if it exists, None otherwise

def remove_episode(self, episode_name: str) -> None:
3230    def remove_episode(self, episode_name: str) -> None:
3231        """Remove all model, logs and metafile records related to an episode.
3232
3233        Parameters
3234        ----------
3235        episode_name : str
3236            the name of the episode to remove
3237
3238        """
3239        runs = self._episodes().get_runs(episode_name)
3240        runs.append(episode_name)
3241        for run in runs:
3242            self._episodes().remove_episode(run)
3243            model_path = os.path.join(self.project_path, "results", "model", run)
3244            if os.path.exists(model_path):
3245                shutil.rmtree(model_path)
3246            log_path = os.path.join(self.project_path, "results", "logs", f"{run}.txt")
3247            if os.path.exists(log_path):
3248                os.remove(log_path)

Remove all model, logs and metafile records related to an episode.

Parameters

episode_name : str the name of the episode to remove

def prune_unfinished(self, exceptions: List = None) -> List:
3273    def prune_unfinished(self, exceptions: List = None) -> List:
3274        """Remove all interrupted episodes.
3275
3276        Remove all episodes that either don't have a log file or have less epochs in the log file than in
3277        the training parameters or have a model folder but not a record. Note that it can remove episodes that are
3278        currently running!
3279
3280        Parameters
3281        ----------
3282        exceptions : list
3283            the episodes to keep even if they are interrupted
3284
3285        Returns
3286        -------
3287        pruned : list
3288            a list of the episode names that were pruned
3289
3290        """
3291        if exceptions is None:
3292            exceptions = []
3293        unfinished = self._episodes().unfinished_episodes()
3294        unfinished = [x for x in unfinished if x not in exceptions]
3295        model_folders = os.listdir(os.path.join(self.project_path, "results", "model"))
3296        unfinished += [
3297            x for x in model_folders if x not in self._episodes().list_episodes().index
3298        ]
3299        print(f"PRUNING {unfinished}")
3300        for episode_name in unfinished:
3301            self.remove_episode(episode_name)
3302        print(f"\n")
3303        return unfinished

Remove all interrupted episodes.

Remove all episodes that either don't have a log file or have less epochs in the log file than in the training parameters or have a model folder but not a record. Note that it can remove episodes that are currently running!

Parameters

exceptions : list the episodes to keep even if they are interrupted

Returns

pruned : list a list of the episode names that were pruned

def prediction_path(self, prediction_name: str) -> str:
3305    def prediction_path(self, prediction_name: str) -> str:
3306        """Get the path where prediction files are saved.
3307
3308        Parameters
3309        ----------
3310        prediction_name : str
3311            name of the prediction
3312
3313        Returns
3314        -------
3315        prediction_path : str
3316            the file path
3317
3318        """
3319        return os.path.join(
3320            self.project_path, "results", "predictions", f"{prediction_name}"
3321        )

Get the path where prediction files are saved.

Parameters

prediction_name : str name of the prediction

Returns

prediction_path : str the file path

def suggestion_path(self, suggestion_name: str) -> str:
3323    def suggestion_path(self, suggestion_name: str) -> str:
3324        """Get the path where suggestion files are saved.
3325
3326        Parameters
3327        ----------
3328        suggestion_name : str
3329            name of the prediction
3330
3331        Returns
3332        -------
3333        suggestion_path : str
3334            the file path
3335
3336        """
3337        return os.path.join(
3338            self.project_path, "results", "suggestions", f"{suggestion_name}"
3339        )

Get the path where suggestion files are saved.

Parameters

suggestion_name : str name of the prediction

Returns

suggestion_path : str the file path

@classmethod
def print_data_types(cls):
3341    @classmethod
3342    def print_data_types(cls):
3343        """Print available data types."""
3344        print("DATA TYPES:")
3345        for key, value in cls.data_types().items():
3346            print(f"{key}:")
3347            print(value.__doc__)

Print available data types.

@classmethod
def print_annotation_types(cls):
3349    @classmethod
3350    def print_annotation_types(cls):
3351        """Print available annotation types."""
3352        print("ANNOTATION TYPES:")
3353        for key, value in cls.annotation_types():
3354            print(f"{key}:")
3355            print(value.__doc__)

Print available annotation types.

@staticmethod
def data_types() -> List:
3357    @staticmethod
3358    def data_types() -> List:
3359        """Get available data types.
3360
3361        Returns
3362        -------
3363        data_types : list
3364            available data types
3365
3366        """
3367        return options.input_stores

Get available data types.

Returns

data_types : list available data types

@staticmethod
def annotation_types() -> List:
3369    @staticmethod
3370    def annotation_types() -> List:
3371        """Get available annotation types.
3372
3373        Returns
3374        -------
3375        list
3376            available annotation types
3377
3378        """
3379        return options.annotation_stores

Get available annotation types.

Returns

list available annotation types

def set_main_parameters(self, model_name: str = None, metric_names: List = None):
3957    def set_main_parameters(self, model_name: str = None, metric_names: List = None):
3958        """Select the model and the metrics.
3959
3960        Parameters
3961        ----------
3962        model_name : str, optional
3963            model name; run `project.help("model") to find out more
3964        metric_names : list, optional
3965            a list of metric function names; run `project.help("metrics") to find out more
3966
3967        """
3968        pars = {"general": {}}
3969        if model_name is not None:
3970            assert model_name in options.models
3971            pars["general"]["model_name"] = model_name
3972        if metric_names is not None:
3973            for metric in metric_names:
3974                assert metric in options.metrics
3975            pars["general"]["metric_functions"] = metric_names
3976        self.update_parameters(pars)

Select the model and the metrics.

Parameters

model_name : str, optional model name; run project.help("model") to find out more metric_names : list, optional a list of metric function names; runproject.help("metrics") to find out more

def help(self, keyword: str = None):
3978    def help(self, keyword: str = None):
3979        """Get information on available options.
3980
3981        Parameters
3982        ----------
3983        keyword : str, optional
3984            the keyword for options (run without arguments to see which keywords are available)
3985
3986        """
3987        if keyword is None:
3988            print("AVAILABLE HELP FUNCTIONS:")
3989            print("- Try running `project.help(keyword)` with the following keywords:")
3990            print("    - model: to get more information on available models,")
3991            print(
3992                "    - features: to get more information on available feature extraction modes,"
3993            )
3994            print(
3995                "    - partition_method: to get more information on available train/test/val partitioning methods,"
3996            )
3997            print("    - metrics: to see a list of available metric functions.")
3998            print("    - data: to see help for expected data structure")
3999            print(
4000                "- To start working with this project, first run `project.list_blanks()` to check which parameters need to be filled in."
4001            )
4002            print(
4003                "- After a model and metrics are set, run `project.list_basic_parameters()` to see a list of the most important parameters that you might want to modify"
4004            )
4005            print(
4006                f"- If you want to dig deeper, get the full dictionary with project._read_parameters() (it is a `ruamel.yaml.comments.CommentedMap` instance)."
4007            )
4008        elif keyword == "model":
4009            print("MODELS:")
4010            for key, model in options.models.items():
4011                print(f"{key}:")
4012                print(model.__doc__)
4013        elif keyword == "features":
4014            print("FEATURE EXTRACTORS:")
4015            for key, extractor in options.feature_extractors.items():
4016                print(f"{key}:")
4017                print(extractor.__doc__)
4018        elif keyword == "partition_method":
4019            print("PARTITION METHODS:")
4020            print(
4021                BehaviorDataset.partition_train_test_val.__doc__.split(
4022                    "The partitioning method:"
4023                )[1].split("val_frac :")[0]
4024            )
4025        elif keyword == "metrics":
4026            print("METRICS:")
4027            for key, metric in options.metrics.items():
4028                print(f"{key}:")
4029                print(metric.__doc__)
4030        elif keyword == "data":
4031            print("DATA:")
4032            print(f"Video data: {self.data_type}")
4033            print(options.input_stores[self.data_type].__doc__)
4034            print(f"Annotation data: {self.annotation_type}")
4035            print(options.annotation_stores[self.annotation_type].__doc__)
4036            print(
4037                "Annotation path and data path don't have to be separate, you can keep everything in one folder."
4038            )
4039        else:
4040            raise ValueError(f"The {keyword} keyword is not recognized")
4041        print("\n")

Get information on available options.

Parameters

keyword : str, optional the keyword for options (run without arguments to see which keywords are available)

def list_blanks(self, blanks=None):
4082    def list_blanks(self, blanks=None):
4083        """List parameters that need to be filled in.
4084
4085        Parameters
4086        ----------
4087        blanks : list, optional
4088            a list of the parameters to list, if already known
4089
4090        """
4091        if blanks is None:
4092            blanks = self._get_blanks()
4093        if len(blanks) > 0:
4094            to_update = defaultdict(lambda: [])
4095            for b, k, c in blanks:
4096                to_update[b].append((k, c))
4097            print("Before running experiments, please update all the blanks.")
4098            print("To do that, you can run this.")
4099            print("--------------------------------------------------------")
4100            print(f"project.update_parameters(")
4101            print(f"    {{")
4102            for big_key, keys in to_update.items():
4103                print(f'        "{big_key}": {{')
4104                for key, comment in keys:
4105                    print(f'            "{key}": ..., {comment}')
4106                print(f"        }}")
4107            print(f"    }}")
4108            print(")")
4109            print("--------------------------------------------------------")
4110            print("Replace ... with relevant values.")
4111        else:
4112            print("There is no blanks left!")

List parameters that need to be filled in.

Parameters

blanks : list, optional a list of the parameters to list, if already known

def list_basic_parameters(self):
4114    def list_basic_parameters(
4115        self,
4116    ):
4117        """Get a list of most relevant parameters and code to modify them."""
4118        parameters = self._read_parameters()
4119        print("BASIC PARAMETERS:")
4120        model_name = parameters["general"]["model_name"]
4121        metric_names = parameters["general"]["metric_functions"]
4122        loss_name = parameters["general"]["loss_function"]
4123        feature_extraction = parameters["general"]["feature_extraction"]
4124        print("Here is a list of current parameters.")
4125        print(
4126            "You can copy this code, change the parameters you want to set and run it to update the project config."
4127        )
4128        print("--------------------------------------------------------")
4129        print("project.update_parameters(")
4130        print("    {")
4131        for group in ["general", "data", "training"]:
4132            print(f'        "{group}": {{')
4133            for key in options.basic_parameters[group]:
4134                if key in parameters[group]:
4135                    print(
4136                        f'            "{key}": {self._process_value(parameters[group][key])}, {self._get_comment(parameters[group].ca.items.get(key))}'
4137                    )
4138            print("        },")
4139        print('        "losses": {')
4140        print(f'            "{loss_name}": {{')
4141        for key in options.basic_parameters["losses"][loss_name]:
4142            if key in parameters["losses"][loss_name]:
4143                print(
4144                    f'                "{key}": {self._process_value(parameters["losses"][loss_name][key])}, {self._get_comment(parameters["losses"][loss_name].ca.items.get(key))}'
4145                )
4146        print("            },")
4147        print("        },")
4148        print('        "metrics": {')
4149        for metric in metric_names:
4150            print(f'            "{metric}": {{')
4151            for key in parameters["metrics"][metric]:
4152                print(
4153                    f'                "{key}": {self._process_value(parameters["metrics"][metric][key])}, {self._get_comment(parameters["metrics"][metric].ca.items.get(key))}'
4154                )
4155            print("            },")
4156        print("        },")
4157        print('        "model": {')
4158        for key in options.basic_parameters["model"][model_name]:
4159            if key in parameters["model"]:
4160                print(
4161                    f'            "{key}": {self._process_value(parameters["model"][key])}, {self._get_comment(parameters["model"].ca.items.get(key))}'
4162                )
4163
4164        print("        },")
4165        print('        "features": {')
4166        for key in options.basic_parameters["features"][feature_extraction]:
4167            if key in parameters["features"]:
4168                print(
4169                    f'            "{key}": {self._process_value(parameters["features"][key])}, {self._get_comment(parameters["features"].ca.items.get(key))}'
4170                )
4171
4172        print("        },")
4173        print('        "augmentations": {')
4174        for key in options.basic_parameters["augmentations"][feature_extraction]:
4175            if key in parameters["augmentations"]:
4176                print(
4177                    f'            "{key}": {self._process_value(parameters["augmentations"][key])}, {self._get_comment(parameters["augmentations"].ca.items.get(key))}'
4178                )
4179        print("        },")
4180        print("    },")
4181        print(")")
4182        print("--------------------------------------------------------")
4183        print("\n")

Get a list of most relevant parameters and code to modify them.

def count_classes( self, load_episode: str = None, parameters_update: Dict = None, remove_saved_features: bool = False, bouts: bool = True) -> Dict:
5097    def count_classes(
5098        self,
5099        load_episode: str = None,
5100        parameters_update: Dict = None,
5101        remove_saved_features: bool = False,
5102        bouts: bool = True,
5103    ) -> Dict:
5104        """Get a dictionary of class counts in different modes.
5105
5106        Parameters
5107        ----------
5108        load_episode : str, optional
5109            the episode settings to load
5110        parameters_update : dict, optional
5111            a dictionary of parameter updates (only for "data" and "general" categories)
5112        remove_saved_features : bool, default False
5113            if `True`, the dataset that is used for computation is then deleted
5114        bouts : bool, default False
5115            if `True`, instead of frame counts segment counts are returned
5116
5117        Returns
5118        -------
5119        class_counts : dict
5120            a dictionary where first-level keys are "train", "val" and "test", second-level keys are
5121            class names and values are class counts (in frames)
5122
5123        """
5124        if load_episode is None:
5125            task, parameters = self._make_task_training(
5126                episode_name="_", parameters_update=parameters_update, throwaway=True
5127            )
5128        else:
5129            task, parameters, _ = self._make_task_prediction(
5130                "_",
5131                load_episode=load_episode,
5132                parameters_update=parameters_update,
5133            )
5134        class_counts = task.count_classes(bouts=bouts)
5135        behaviors = task.behaviors_dict()
5136        class_counts = {
5137            kk: {behaviors.get(k, "unknown"): v for k, v in vv.items()}
5138            for kk, vv in class_counts.items()
5139        }
5140        if remove_saved_features:
5141            self._remove_stores(parameters)
5142        return class_counts

Get a dictionary of class counts in different modes.

Parameters

load_episode : str, optional the episode settings to load parameters_update : dict, optional a dictionary of parameter updates (only for "data" and "general" categories) remove_saved_features : bool, default False if True, the dataset that is used for computation is then deleted bouts : bool, default False if True, instead of frame counts segment counts are returned

Returns

class_counts : dict a dictionary where first-level keys are "train", "val" and "test", second-level keys are class names and values are class counts (in frames)

def plot_class_distribution( self, parameters_update: Dict = None, frame_cutoff: int = 1, bout_cutoff: int = 1, print_full: bool = False, remove_saved_features: bool = False, save: str = None) -> None:
5144    def plot_class_distribution(
5145        self,
5146        parameters_update: Dict = None,
5147        frame_cutoff: int = 1,
5148        bout_cutoff: int = 1,
5149        print_full: bool = False,
5150        remove_saved_features: bool = False,
5151        save: str = None,
5152    ) -> None:
5153        """Make a class distribution plot.
5154
5155        You can either specify the parameters, choose an existing dataset or do neither (in that case a dataset
5156        is created or loaded for the computation with the default parameters).
5157
5158        Parameters
5159        ----------
5160        parameters_update : dict, optional
5161            a dictionary of parameter updates (only for "data" and "general" categories)
5162        frame_cutoff : int, default 1
5163            the minimum number of frames for a segment to be considered
5164        bout_cutoff : int, default 1
5165            the minimum number of bouts for a class to be considered
5166        print_full : bool, default False
5167            if `True`, the full class distribution is printed
5168        remove_saved_features : bool, default False
5169            if `True`, the dataset that is used for computation is then deleted
5170
5171        """
5172        task, parameters = self._make_task_training(
5173            episode_name="_", parameters_update=parameters_update, throwaway=True
5174        )
5175        cutoff = {True: bout_cutoff, False: frame_cutoff}
5176        for bouts in [True, False]:
5177            class_counts = task.count_classes(bouts=bouts)
5178            if print_full:
5179                print("Bouts:" if bouts else "Frames:")
5180                for k, v in class_counts.items():
5181                    if sum(v.values()) != 0:
5182                        print(f"  {k}:")
5183                        values, keys = zip(
5184                            *[
5185                                x
5186                                for x in sorted(zip(v.values(), v.keys()), reverse=True)
5187                                if x[-1] != -100
5188                            ]
5189                        )
5190                        for kk, vv in zip(keys, values):
5191                            print(f"    {task.behaviors_dict()[kk]}: {vv}")
5192            class_counts = {
5193                kk: {k: v for k, v in vv.items() if v >= cutoff[bouts]}
5194                for kk, vv in class_counts.items()
5195            }
5196            for key, d in class_counts.items():
5197                if sum(d.values()) != 0:
5198                    values, keys = zip(
5199                        *[x for x in sorted(zip(d.values(), d.keys())) if x[-1] != -100]
5200                    )
5201                    keys = [task.behaviors_dict()[x] for x in keys]
5202                    plt.bar(keys, values)
5203                    plt.title(key)
5204                    plt.xticks(rotation=45, ha="right")
5205                    if bouts:
5206                        plt.ylabel("bouts")
5207                    else:
5208                        plt.ylabel("frames")
5209                    plt.tight_layout()
5210
5211                    if save is None:
5212                        plt.savefig(save)
5213                        plt.close()
5214                    else:
5215                        plt.show()
5216        if remove_saved_features:
5217            self._remove_stores(parameters)

Make a class distribution plot.

You can either specify the parameters, choose an existing dataset or do neither (in that case a dataset is created or loaded for the computation with the default parameters).

Parameters

parameters_update : dict, optional a dictionary of parameter updates (only for "data" and "general" categories) frame_cutoff : int, default 1 the minimum number of frames for a segment to be considered bout_cutoff : int, default 1 the minimum number of bouts for a class to be considered print_full : bool, default False if True, the full class distribution is printed remove_saved_features : bool, default False if True, the dataset that is used for computation is then deleted

def plot_confusion_matrix( self, episode_name: str, load_epoch: int = None, parameters_update: Dict = None, metric: str = 'recall', mode: str = 'val', remove_saved_features: bool = False, save_path: str = None, cmap: str = 'viridis') -> Tuple[numpy.ndarray, Iterable]:
5710    def plot_confusion_matrix(
5711        self,
5712        episode_name: str,
5713        load_epoch: int = None,
5714        parameters_update: Dict = None,
5715        metric: str = "recall",
5716        mode: str = "val",
5717        remove_saved_features: bool = False,
5718        save_path: str = None,
5719        cmap: str = "viridis",
5720    ) -> Tuple[ndarray, Iterable]:
5721        """Make a confusion matrix plot and return the data.
5722
5723        If the annotation is non-exclusive, only false positive labels are considered.
5724
5725        Parameters
5726        ----------
5727        episode_name : str
5728            the name of the episode to load
5729        load_epoch : int, optional
5730            the index of the epoch to load (by default the last one is loaded)
5731        parameters_update : dict, optional
5732            a dictionary of parameter updates (only for "data" and "general" categories)
5733        metric : {"recall", "precision"}
5734            for datasets with non-exclusive annotation, if `type` is `"recall"`, only false positives are taken
5735            into account, and if `type` is `"precision"`, only false negatives
5736        mode : {'val', 'all', 'test', 'train'}
5737            the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
5738        remove_saved_features : bool, default False
5739            if `True`, the dataset that is used for computation is then deleted
5740
5741        Returns
5742        -------
5743        confusion_matrix : np.ndarray
5744            a confusion matrix of shape `(#classes, #classes)` where `A[i, j] = F_ij/N_i`, `F_ij` is the number of
5745            frames that have the i-th label in the ground truth and a false positive j-th label in the prediction,
5746            `N_i` is the number of frames that have the i-th label in the ground truth
5747        classes : list
5748            a list of labels
5749
5750        """
5751        task, parameters, mode = self._make_task_prediction(
5752            "_",
5753            load_episode=episode_name,
5754            load_epoch=load_epoch,
5755            parameters_update=parameters_update,
5756            mode=mode,
5757        )
5758        dataset = task.dataset(mode)
5759        prediction = task.predict(dataset, raw_output=True)
5760        confusion_matrix, classes, type = dataset.get_confusion_matrix(prediction, type)
5761        if remove_saved_features:
5762            self._remove_stores(parameters)
5763        fig, ax = plt.subplots(figsize=(len(classes), len(classes)))
5764        ax.imshow(confusion_matrix, cmap=cmap)
5765        # Show all ticks and label them with the respective list entries
5766        ax.set_xticks(np.arange(len(classes)))
5767        ax.set_xticklabels(classes)
5768        ax.set_yticks(np.arange(len(classes)))
5769        ax.set_yticklabels(classes)
5770        # Rotate the tick labels and set their alignment.
5771        plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
5772        # Loop over data dimensions and create text annotations.
5773        for i in range(len(classes)):
5774            for j in range(len(classes)):
5775                ax.text(
5776                    j,
5777                    i,
5778                    np.round(confusion_matrix[i, j], 2),
5779                    ha="center",
5780                    va="center",
5781                    color="w",
5782                )
5783        if metric is not None:
5784            ax.set_title(f"{metric} {episode_name}")
5785        else:
5786            ax.set_title(episode_name)
5787        fig.tight_layout()
5788        if save_path is None:
5789            plt.show()
5790        else:
5791            plt.savefig(save_path)
5792            plt.close()
5793        return confusion_matrix, classes

Make a confusion matrix plot and return the data.

If the annotation is non-exclusive, only false positive labels are considered.

Parameters

episode_name : str the name of the episode to load load_epoch : int, optional the index of the epoch to load (by default the last one is loaded) parameters_update : dict, optional a dictionary of parameter updates (only for "data" and "general" categories) metric : {"recall", "precision"} for datasets with non-exclusive annotation, if type is "recall", only false positives are taken into account, and if type is "precision", only false negatives mode : {'val', 'all', 'test', 'train'} the subset of the data to make the prediction for (forced to 'all' if data_path is not None) remove_saved_features : bool, default False if True, the dataset that is used for computation is then deleted

Returns

confusion_matrix : np.ndarray a confusion matrix of shape (#classes, #classes) where A[i, j] = F_ij/N_i, F_ij is the number of frames that have the i-th label in the ground truth and a false positive j-th label in the prediction, N_i is the number of frames that have the i-th label in the ground truth classes : list a list of labels

def plot_ethograms( self, episode_name: str, prediction_name: str, start: int = 0, end: int = -1, save_path: str = None, cmap_pred: str = 'binary', cmap_gt: str = 'binary', fontsize: int = 22, time_mode: str = 'frames', fps: int = None):
5883    def plot_ethograms(
5884        self,
5885        episode_name: str,
5886        prediction_name: str,
5887        start: int = 0,
5888        end: int = -1,
5889        save_path: str = None,
5890        cmap_pred: str = "binary",
5891        cmap_gt: str = "binary",
5892        fontsize: int = 22,
5893        time_mode: str = "frames",
5894        fps: int = None,
5895    ):
5896        """Plot ethograms from start to end time (in frames) for ground truth and prediction"""
5897        params = self._read_parameters(catch_blanks=False)
5898        parameters = self._get_data_pars(
5899            params,
5900        )
5901        if not save_path is None:
5902            os.makedirs(save_path, exist_ok=True)
5903        gt_files = [
5904            f for f in self.data_path if f.endswith(parameters["annotation_suffix"])
5905        ]
5906        pred_path = os.path.join(
5907            self.project_path, "results", "predictions", prediction_name
5908        )
5909        pred_paths = [os.path.join(pred_path, f) for f in os.listdir(pred_path)]
5910        for pred_path in pred_paths:
5911            predictions = load_pickle(pred_path)
5912            behaviors = self.get_behavior_dictionary(episode_name)
5913            gt_filename = os.path.basename(pred_path).replace(
5914                "_".join(["_" + prediction_name, "prediction.pickle"]),
5915                parameters["annotation_suffix"],
5916            )
5917            if os.path.exists(os.path.join(self.data_path, gt_filename)):
5918                gt_data = load_pickle(os.path.join(self.data_path, gt_filename))
5919
5920                self._plot_ethograms_gt_pred(
5921                    gt_data,
5922                    predictions,
5923                    gt_data[1],
5924                    behaviors,
5925                    start=start,
5926                    end=end,
5927                    save=os.path.join(
5928                        save_path,
5929                        os.path.splitext(os.path.basename(pred_path))[0] + "_gt_pred",
5930                    ),
5931                    cmap_pred=cmap_pred,
5932                    cmap_gt=cmap_gt,
5933                    fontsize=fontsize,
5934                    time_mode=time_mode,
5935                    fps=fps,
5936                )
5937            else:
5938                print("GT file not found")

Plot ethograms from start to end time (in frames) for ground truth and prediction

def create_annotated_video( self, prediction_file_paths: list, video_file_paths: list, episode_name: str, ground_truth_file_paths: list = None, pred_thresh: float = 0.5, start: int = 0, end: int = -1):
5999    def create_annotated_video(
6000        self,
6001        prediction_file_paths: list,
6002        video_file_paths: list,
6003        episode_name: str,  # To get the list of behaviors
6004        ground_truth_file_paths: list = None,
6005        pred_thresh: float = 0.5,
6006        start: int = 0,
6007        end: int = -1,
6008    ):
6009        """Create a video with the predictions overlaid on the video"""
6010        for k, (pred_path, vid_path) in enumerate(
6011            zip(prediction_file_paths, video_file_paths)
6012        ):
6013            print("Generating video for :", os.path.basename(vid_path))
6014            predictions = load_pickle(pred_path)
6015            best_pred = predictions[list(predictions.keys())[0]].numpy() > pred_thresh
6016            behaviors = self.get_behavior_dictionary(episode_name)
6017            # Load video
6018            labels_pred = [behaviors[i] for i in range(len(behaviors))]
6019            labels_pred = np.roll(
6020                labels_pred, 1
6021            ).tolist() 
6022
6023            gt_data = None
6024            if ground_truth_file_paths is not None:
6025                gt_data = load_pickle(ground_truth_file_paths[k])
6026                labels_gt = gt_data[1]
6027                gt_data = binarize_data(gt_data, max_frame=best_pred.shape[1])
6028
6029            cap = cv2.VideoCapture(vid_path)
6030            cap.set(cv2.CAP_PROP_POS_FRAMES, start)
6031            end = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if end < 0 else end
6032            fps = cap.get(cv2.CAP_PROP_FPS)
6033            width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
6034            height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
6035            fourcc = cv2.VideoWriter_fourcc(*"mp4v")
6036            out = cv2.VideoWriter(
6037                os.path.join(
6038                    os.path.dirname(vid_path),
6039                    os.path.splitext(os.path.basename(vid_path))[0] + "_annotated.mp4",
6040                ),
6041                fourcc,
6042                fps,
6043                # (width + int(width/4) , height),
6044                (600, 300),
6045            )
6046            count = 0
6047            bar = tqdm(total=end - start)
6048            while cap.isOpened():
6049                ret, frame = cap.read()
6050                if not ret:
6051                    break
6052
6053                side_panel = self._create_side_panel(
6054                    height,
6055                    width,
6056                    labels_pred,
6057                    best_pred[:, count],
6058                    labels_gt,
6059                    gt_data[:, count],
6060                )
6061                frame = np.concatenate((frame, side_panel), axis=1)
6062                frame = cv2.resize(frame, (0, 0), fx=0.25, fy=0.25)
6063                out.write(frame)
6064                count += 1
6065                bar.update(1)
6066
6067                if count > end:
6068                    break
6069
6070            cap.release()
6071            out.release()
6072            cv2.destroyAllWindows()

Create a video with the predictions overlaid on the video

def plot_predictions( self, episode_name: str, load_epoch: int = None, parameters_update: Dict = None, add_legend: bool = True, ground_truth: bool = True, colormap: str = 'dlc2action', hide_axes: bool = False, min_classes: int = 1, width: float = 10, whole_video: bool = False, transparent: bool = False, drop_classes: Set = None, search_classes: Set = None, num_plots: int = 1, remove_saved_features: bool = False, smooth_interval_prediction: int = 0, data_path: str = None, file_paths: Set = None, mode: str = 'val', font_size: float = None, window_size: int = 400) -> None:
6074    def plot_predictions(
6075        self,
6076        episode_name: str,
6077        load_epoch: int = None,
6078        parameters_update: Dict = None,
6079        add_legend: bool = True,
6080        ground_truth: bool = True,
6081        colormap: str = "dlc2action",
6082        hide_axes: bool = False,
6083        min_classes: int = 1,
6084        width: float = 10,
6085        whole_video: bool = False,
6086        transparent: bool = False,
6087        drop_classes: Set = None,
6088        search_classes: Set = None,
6089        num_plots: int = 1,
6090        remove_saved_features: bool = False,
6091        smooth_interval_prediction: int = 0,
6092        data_path: str = None,
6093        file_paths: Set = None,
6094        mode: str = "val",
6095        font_size: float = None,
6096        window_size: int = 400,
6097    ) -> None:
6098        """Visualize random predictions.
6099
6100        Parameters
6101        ----------
6102        episode_name : str
6103            the name of the episode to load
6104        load_epoch : int, optional
6105            the epoch to load (by default last)
6106        parameters_update : dict, optional
6107            parameter update dictionary
6108        add_legend : bool, default True
6109            if True, legend will be added to the plot
6110        ground_truth : bool, default True
6111            if True, ground truth will be added to the plot
6112        colormap : str, default 'Accent'
6113            the `matplotlib` colormap to use
6114        hide_axes : bool, default True
6115            if `True`, the axes will be hidden on the plot
6116        min_classes : int, default 1
6117            the minimum number of classes in a displayed interval
6118        width : float, default 10
6119            the width of the plot
6120        whole_video : bool, default False
6121            if `True`, whole videos are plotted instead of segments
6122        transparent : bool, default False
6123            if `True`, the background on the plot is transparent
6124        drop_classes : set, optional
6125            a set of class names to not be displayed
6126        search_classes : set, optional
6127            if given, only intervals where at least one of the classes is in ground truth will be shown
6128        num_plots : int, default 1
6129            the number of plots to make
6130        remove_saved_features : bool, default False
6131            if `True`, the dataset will be deleted after computation
6132        smooth_interval_prediction : int, default 0
6133            if >0, predictions shorter than this number of frames are removed (filled with prediction for the previous frame)
6134        data_path : str, optional
6135            the data path to run the prediction for
6136        file_paths : set, optional
6137            a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction
6138            for
6139        mode : {'all', 'test', 'val', 'train'}
6140            the subset of the data to make the prediction for (forced to 'all' if data_path is not None)
6141
6142        """
6143        plot_path = os.path.join(self.project_path, "results", "plots")
6144        task, parameters, mode = self._make_task_prediction(
6145            "_",
6146            load_episode=episode_name,
6147            parameters_update=parameters_update,
6148            load_epoch=load_epoch,
6149            data_path=data_path,
6150            file_paths=file_paths,
6151            mode=mode,
6152        )
6153        os.makedirs(plot_path, exist_ok=True)
6154        task.visualize_results(
6155            save_path=os.path.join(plot_path, f"{episode_name}_prediction.svg"),
6156            add_legend=add_legend,
6157            ground_truth=ground_truth,
6158            colormap=colormap,
6159            hide_axes=hide_axes,
6160            min_classes=min_classes,
6161            whole_video=whole_video,
6162            transparent=transparent,
6163            dataset=mode,
6164            drop_classes=drop_classes,
6165            search_classes=search_classes,
6166            width=width,
6167            smooth_interval_prediction=smooth_interval_prediction,
6168            font_size=font_size,
6169            num_plots=num_plots,
6170            window_size=window_size,
6171        )
6172        if remove_saved_features:
6173            self._remove_stores(parameters)

Visualize random predictions.

Parameters

episode_name : str the name of the episode to load load_epoch : int, optional the epoch to load (by default last) parameters_update : dict, optional parameter update dictionary add_legend : bool, default True if True, legend will be added to the plot ground_truth : bool, default True if True, ground truth will be added to the plot colormap : str, default 'Accent' the matplotlib colormap to use hide_axes : bool, default True if True, the axes will be hidden on the plot min_classes : int, default 1 the minimum number of classes in a displayed interval width : float, default 10 the width of the plot whole_video : bool, default False if True, whole videos are plotted instead of segments transparent : bool, default False if True, the background on the plot is transparent drop_classes : set, optional a set of class names to not be displayed search_classes : set, optional if given, only intervals where at least one of the classes is in ground truth will be shown num_plots : int, default 1 the number of plots to make remove_saved_features : bool, default False if True, the dataset will be deleted after computation smooth_interval_prediction : int, default 0 if >0, predictions shorter than this number of frames are removed (filled with prediction for the previous frame) data_path : str, optional the data path to run the prediction for file_paths : set, optional a set of string file paths (data with all prefixes + feature files, in any order) to run the prediction for mode : {'all', 'test', 'val', 'train'} the subset of the data to make the prediction for (forced to 'all' if data_path is not None)

def create_video_from_labels( self, video_dir_path: str, mode='ground_truth', prediction_name: str = None, save_path: str = None):
6175    def create_video_from_labels(
6176        self,
6177        video_dir_path: str,
6178        mode="ground_truth",
6179        prediction_name: str = None,
6180        save_path: str = None,
6181    ):
6182        if save_path is None:
6183            save_path = os.path.join(
6184                self.project_path, "results", f"annotated_videos_from_{mode}"
6185            )
6186        os.makedirs(save_path, exist_ok=True)
6187
6188        params = self._read_parameters(catch_blanks=False)
6189
6190        if mode == "ground_truth":
6191            source_dir = self.annotation_path
6192            annotation_suffix = params["data"]["annotation_suffix"]
6193        elif mode == "prediction":
6194            assert (
6195                not prediction_name is None
6196            ), "Please provide a prediction name with mode 'prediction'"
6197            source_dir = os.path.join(
6198                self.project_path, "results", "predictions", prediction_name
6199            )
6200            annotation_suffix = f"_{prediction_name}_prediction.pickle"
6201
6202        video_annotation_pairs = [
6203            (
6204                os.path.join(video_dir_path, f),
6205                os.path.join(
6206                    source_dir, f.replace(f.split(".")[-1], annotation_suffix)
6207                ),
6208            )
6209            for f in os.listdir(video_dir_path)
6210            if os.path.exists(
6211                os.path.join(source_dir, f.replace(f.split(".")[-1], annotation_suffix))
6212            )
6213        ]
6214
6215        for video_file, annotation_file in tqdm(video_annotation_pairs):
6216            if not os.path.exists(video_file):
6217                print(f"Video file {video_file} does not exist, skipping.")
6218                continue
6219            if not os.path.exists(annotation_file):
6220                print(f"Annotation file {annotation_file} does not exist, skipping.")
6221                continue
6222
6223            if annotation_file.endswith(".pickle"):
6224                annotations = load_pickle(annotation_file)
6225            elif annotation_file.endswith(".csv"):
6226                annotations = pd.read_csv(annotation_file)
6227
6228            if mode == "ground_truth":
6229                behaviors = annotations[1]
6230                annot_data = annotations[3]
6231            elif mode == "predictions":
6232                behaviors = list(annotations["classes"].values())
6233                annot_data = [
6234                    annotations[key]
6235                    for key in annotations.keys()
6236                    if key not in ["classes", "min_frame", "max_frame"]
6237                ]
6238                if params["general"]["exclusive"]:
6239                    annot_data = [np.argmax(annot, axis=1) for annot in annot_data]
6240                    seqs = [
6241                        [
6242                            self._bin_array_to_sequences(annot, target_value=k)
6243                            for k in range(len(behaviors))
6244                        ]
6245                        for annot in annot_data
6246                    ]
6247                else:
6248                    annot_data = [np.where(annot > 0.5)[0] for annot in annot_data]
6249                    seqs = [
6250                        self._bin_array_to_sequences(annot, target_value=1)
6251                        for annot in annot_data
6252                    ]
6253                annotations = ["", "", seqs]
6254
6255            for individual in annotations[3]:
6256                for behavior in annotations[3][individual]:
6257                    intervals = annotations[3][individual][behavior]
6258                    self._extract_videos(
6259                        video_file,
6260                        intervals,
6261                        behavior,
6262                        individual,
6263                        save_path,
6264                        resolution=(640, 480),
6265                        fps=30,
6266                    )
def create_metadata_backup(self) -> None:
6327    def create_metadata_backup(self) -> None:
6328        """Create a copy of the meta files."""
6329        meta_copy_path = os.path.join(self.project_path, "meta", "backup")
6330        meta_path = os.path.join(self.project_path, "meta")
6331        if os.path.exists(meta_copy_path):
6332            shutil.rmtree(meta_copy_path)
6333        os.mkdir(meta_copy_path)
6334        for file in os.listdir(meta_path):
6335            if file == "backup":
6336                continue
6337            if os.path.isdir(os.path.join(meta_path, file)):
6338                continue
6339            shutil.copy(
6340                os.path.join(meta_path, file), os.path.join(meta_copy_path, file)
6341            )

Create a copy of the meta files.

def load_metadata_backup(self) -> None:
6343    def load_metadata_backup(self) -> None:
6344        """Load from previously created meta data backup (in case of corruption)."""
6345        meta_copy_path = os.path.join(self.project_path, "meta", "backup")
6346        meta_path = os.path.join(self.project_path, "meta")
6347        for file in os.listdir(meta_copy_path):
6348            shutil.copy(
6349                os.path.join(meta_copy_path, file), os.path.join(meta_path, file)
6350            )

Load from previously created meta data backup (in case of corruption).

def get_behavior_dictionary(self, episode_name: str) -> Dict:
6352    def get_behavior_dictionary(self, episode_name: str) -> Dict:
6353        """Get the behavior dictionary for an episode.
6354
6355        Parameters
6356        ----------
6357        episode_name : str
6358            the name of the episode
6359
6360        Returns
6361        -------
6362        behaviors_dictionary : dict
6363            a dictionary where keys are label indices and values are label names
6364
6365        """
6366        return self._episode(episode_name).get_behaviors_dict()

Get the behavior dictionary for an episode.

Parameters

episode_name : str the name of the episode

Returns

behaviors_dictionary : dict a dictionary where keys are label indices and values are label names

def import_episodes( self, episodes_directory: str, name_map: Dict = None, repeat_policy: str = 'error') -> None:
6368    def import_episodes(
6369        self,
6370        episodes_directory: str,
6371        name_map: Dict = None,
6372        repeat_policy: str = "error",
6373    ) -> None:
6374        """Import episodes exported with `Project.export_episodes`.
6375
6376        Parameters
6377        ----------
6378        episodes_directory : str
6379            the path to the exported episodes directory
6380        name_map : dict, optional
6381            a name change dictionary for the episodes: keys are old names, values are new names
6382        repeat_policy : {'error', 'skip', 'force'}, default 'error'
6383            the policy for repeated episode names: 'error' raises an error, 'skip' skips duplicates,
6384            'force' overwrites existing episodes
6385
6386        """
6387        if name_map is None:
6388            name_map = {}
6389        episodes = pd.read_pickle(os.path.join(episodes_directory, "episodes.pickle"))
6390        to_remove = []
6391        import_string = "Imported episodes: "
6392        for episode_name in episodes.index:
6393            if episode_name in name_map:
6394                import_string += f"{episode_name} "
6395                episode_name = name_map[episode_name]
6396                import_string += f"({episode_name}), "
6397            else:
6398                import_string += f"{episode_name}, "
6399            try:
6400                self._check_episode_validity(episode_name, allow_doublecolon=True)
6401            except ValueError as e:
6402                if str(e).endswith("is already taken!"):
6403                    if repeat_policy == "skip":
6404                        to_remove.append(episode_name)
6405                    elif repeat_policy == "force":
6406                        self.remove_episode(episode_name)
6407                    elif repeat_policy == "error":
6408                        raise ValueError(
6409                            f"The {episode_name} episode name is already taken; please use the name_map parameter to rename it"
6410                        )
6411                    else:
6412                        raise ValueError(
6413                            f"The {repeat_policy} repeat policy is not recognized; please choose from ['skip', 'force' and 'error']"
6414                        )
6415        episodes = episodes.drop(index=to_remove)
6416        self._episodes().update(
6417            episodes,
6418            name_map=name_map,
6419            force=(repeat_policy == "force"),
6420            data_path=self.data_path,
6421            annotation_path=self.annotation_path,
6422        )
6423        for episode_name in episodes.index:
6424            if episode_name in name_map:
6425                new_episode_name = name_map[episode_name]
6426            else:
6427                new_episode_name = episode_name
6428            model_dir = os.path.join(
6429                self.project_path, "results", "model", new_episode_name
6430            )
6431            old_model_dir = os.path.join(episodes_directory, "model", episode_name)
6432            if os.path.exists(model_dir):
6433                shutil.rmtree(model_dir)
6434            os.mkdir(model_dir)
6435            for file in os.listdir(old_model_dir):
6436                shutil.copyfile(
6437                    os.path.join(old_model_dir, file), os.path.join(model_dir, file)
6438                )
6439            log_file = os.path.join(
6440                self.project_path, "results", "logs", f"{new_episode_name}.txt"
6441            )
6442            old_log_file = os.path.join(
6443                episodes_directory, "logs", f"{episode_name}.txt"
6444            )
6445            shutil.copyfile(old_log_file, log_file)
6446        print(import_string)
6447        print("\n")

Import episodes exported with Project.export_episodes.

Parameters

episodes_directory : str the path to the exported episodes directory name_map : dict, optional a name change dictionary for the episodes: keys are old names, values are new names repeat_policy : {'error', 'skip', 'force'}, default 'error' the policy for repeated episode names: 'error' raises an error, 'skip' skips duplicates, 'force' overwrites existing episodes

def export_episodes( self, episode_names: List, output_directory: str, name: str = None) -> None:
6449    def export_episodes(
6450        self, episode_names: List, output_directory: str, name: str = None
6451    ) -> None:
6452        """Save selected episodes as a file that can be imported into another project with `Project.import_episodes`.
6453
6454        Parameters
6455        ----------
6456        episode_names : list
6457            a list of string episode names
6458        output_directory : str
6459            the path to the directory where the episodes will be saved
6460        name : str, optional
6461            the name of the episodes directory (by default `exported_episodes`)
6462
6463        """
6464        if name is None:
6465            name = "exported_episodes"
6466        if os.path.exists(
6467            os.path.join(output_directory, name + ".zip")
6468        ) or os.path.exists(os.path.join(output_directory, name)):
6469            i = 1
6470            while os.path.exists(
6471                os.path.join(output_directory, name + f"_{i}.zip")
6472            ) or os.path.exists(os.path.join(output_directory, name + f"_{i}")):
6473                i += 1
6474            name = name + f"_{i}"
6475        dest_dir = os.path.join(output_directory, name)
6476        os.mkdir(dest_dir)
6477        os.mkdir(os.path.join(dest_dir, "model"))
6478        os.mkdir(os.path.join(dest_dir, "logs"))
6479        runs = []
6480        for episode in episode_names:
6481            runs += self._episodes().get_runs(episode)
6482        for run in runs:
6483            shutil.copytree(
6484                os.path.join(self.project_path, "results", "model", run),
6485                os.path.join(dest_dir, "model", run),
6486            )
6487            shutil.copyfile(
6488                os.path.join(self.project_path, "results", "logs", f"{run}.txt"),
6489                os.path.join(dest_dir, "logs", f"{run}.txt"),
6490            )
6491        data = self._episodes().get_subset(runs)
6492        data.to_pickle(os.path.join(dest_dir, "episodes.pickle"))

Save selected episodes as a file that can be imported into another project with Project.import_episodes.

Parameters

episode_names : list a list of string episode names output_directory : str the path to the directory where the episodes will be saved name : str, optional the name of the episodes directory (by default exported_episodes)

def get_results_table( self, episode_names: List, metrics: List = None, mode: str = 'mean', print_results: bool = True, classes: List = None):
6494    def get_results_table(
6495        self,
6496        episode_names: List,
6497        metrics: List = None,
6498        mode: str = "mean",  # Choose between ["mean", "statistics", "detail"]
6499        print_results: bool = True,
6500        classes: List = None,
6501    ):
6502        """Generate a `pandas` dataframe with a summary of episode results.
6503
6504        Parameters
6505        ----------
6506        episode_names : list
6507            a list of names of episodes to include
6508        metrics : list, optional
6509            a list of metric names to include
6510        mode : bool, optional
6511            the mode of the results table, choose between ["mean", "statistics", "detail"], by default "mean"
6512        print_results : bool, optional
6513            if True, the results will be printed to the console, by default True
6514        classes : list, optional
6515            a list of names of classes to include (by default all are included)
6516
6517        Returns
6518        -------
6519        results : pd.DataFrame
6520            a table with the results
6521
6522        """
6523        run_names = []
6524        for episode in episode_names:
6525            run_names += self._episodes().get_runs(episode)
6526        episodes = self.list_episodes(run_names, print_results=False)
6527        metric_columns = [x for x in episodes.columns if x[0] == "results"]
6528        results_df = pd.DataFrame()
6529        if metrics is not None:
6530            metric_columns = [
6531                x for x in metric_columns if x[1].split("_")[0] in metrics
6532            ]
6533        for episode in episode_names:
6534            results = []
6535            metric_set = set()
6536            for run in self._episodes().get_runs(episode):
6537                beh_dict = self.get_behavior_dictionary(run)
6538                res_dict = defaultdict(lambda: {})
6539                for column in metric_columns:
6540                    if np.isnan(episodes.loc[run, column]):
6541                        continue
6542                    split = column[1].split("_")
6543                    if split[-1].isnumeric():
6544                        beh_ind = int(split[-1])
6545                        metric_name = "_".join(split[:-1])
6546                        beh = beh_dict[beh_ind]
6547                    else:
6548                        beh = "average"
6549                        metric_name = column[1]
6550                    res_dict[beh][metric_name] = episodes.loc[run, column]
6551                    metric_set.add(metric_name)
6552                if "average" not in res_dict:
6553                    res_dict["average"] = {}
6554                for metric in metric_set:
6555                    if metric not in res_dict["average"]:
6556                        arr = [
6557                            res_dict[beh][metric]
6558                            for beh in res_dict
6559                            if metric in res_dict[beh]
6560                        ]
6561                        res_dict["average"][metric] = np.mean(arr)
6562                results.append(res_dict)
6563            episode_results = {}
6564            for metric in metric_set:
6565                for beh in results[0].keys():
6566                    if classes is not None and beh not in classes:
6567                        continue
6568                    arr = []
6569                    for res_dict in results:
6570                        if metric in res_dict[beh]:
6571                            arr.append(res_dict[beh][metric])
6572                    if len(arr) > 0:
6573                        if mode == "statistics":
6574                            episode_results[(beh, f"{episode} {metric} mean")] = (
6575                                np.mean(arr)
6576                            )
6577                            episode_results[(beh, f"{episode} {metric} std")] = np.std(
6578                                arr
6579                            )
6580                        elif mode == "mean":
6581                            episode_results[(beh, f"{episode} {metric}")] = np.mean(arr)
6582                        elif mode == "detail":
6583                            for i, val in enumerate(arr):
6584                                episode_results[(beh, f"{episode}::{i} {metric}")] = val
6585            for key, value in episode_results.items():
6586                results_df.loc[key[0], key[1]] = value
6587        if print_results:
6588            print(f"RESULTS:")
6589            print(results_df)
6590            print("\n")
6591        return results_df

Generate a pandas dataframe with a summary of episode results.

Parameters

episode_names : list a list of names of episodes to include metrics : list, optional a list of metric names to include mode : bool, optional the mode of the results table, choose between ["mean", "statistics", "detail"], by default "mean" print_results : bool, optional if True, the results will be printed to the console, by default True classes : list, optional a list of names of classes to include (by default all are included)

Returns

results : pd.DataFrame a table with the results

def episode_exists(self, episode_name: str) -> bool:
6593    def episode_exists(self, episode_name: str) -> bool:
6594        """Check if an episode already exists.
6595
6596        Parameters
6597        ----------
6598        episode_name : str
6599            the episode name
6600
6601        Returns
6602        -------
6603        exists : bool
6604            `True` if the episode exists
6605
6606        """
6607        return self._episodes().check_name_validity(episode_name)

Check if an episode already exists.

Parameters

episode_name : str the episode name

Returns

exists : bool True if the episode exists

def search_exists(self, search_name: str) -> bool:
6609    def search_exists(self, search_name: str) -> bool:
6610        """Check if a search already exists.
6611
6612        Parameters
6613        ----------
6614        search_name : str
6615            the search name
6616
6617        Returns
6618        -------
6619        exists : bool
6620            `True` if the search exists
6621
6622        """
6623        return self._searches().check_name_validity(search_name)

Check if a search already exists.

Parameters

search_name : str the search name

Returns

exists : bool True if the search exists

def prediction_exists(self, prediction_name: str) -> bool:
6625    def prediction_exists(self, prediction_name: str) -> bool:
6626        """Check if a prediction already exists.
6627
6628        Parameters
6629        ----------
6630        prediction_name : str
6631            the prediction name
6632
6633        Returns
6634        -------
6635        exists : bool
6636            `True` if the prediction exists
6637
6638        """
6639        return self._predictions().check_name_validity(prediction_name)

Check if a prediction already exists.

Parameters

prediction_name : str the prediction name

Returns

exists : bool True if the prediction exists

@staticmethod
def project_name_available(projects_path: str, project_name: str):
6641    @staticmethod
6642    def project_name_available(projects_path: str, project_name: str):
6643        """Check if a project name is available.
6644
6645        Parameters
6646        ----------
6647        projects_path : str
6648            the path to the projects directory
6649        project_name : str
6650            the name of the project to check
6651
6652        Returns
6653        -------
6654        available : bool
6655            `True` if the project name is available
6656
6657        """
6658        if projects_path is None:
6659            projects_path = os.path.join(str(Path.home()), "DLC2Action")
6660        return not os.path.exists(os.path.join(projects_path, project_name))

Check if a project name is available.

Parameters

projects_path : str the path to the projects directory project_name : str the name of the project to check

Returns

available : bool True if the project name is available

def rename_episode(self, episode_name: str, new_episode_name: str):
6675    def rename_episode(self, episode_name: str, new_episode_name: str):
6676        """Rename an episode.
6677
6678        Parameters
6679        ----------
6680        episode_name : str
6681            the current episode name
6682        new_episode_name : str
6683            the new episode name
6684
6685        """
6686        shutil.move(
6687            os.path.join(self.project_path, "results", "model", episode_name),
6688            os.path.join(self.project_path, "results", "model", new_episode_name),
6689        )
6690        shutil.move(
6691            os.path.join(self.project_path, "results", "logs", f"{episode_name}.txt"),
6692            os.path.join(
6693                self.project_path, "results", "logs", f"{new_episode_name}.txt"
6694            ),
6695        )
6696        self._episodes().rename_episode(episode_name, new_episode_name)

Rename an episode.

Parameters

episode_name : str the current episode name new_episode_name : str the new episode name