dlc2action.task.task_dispatcher

Class that provides an interface for dlc2action.task.universal_task.Task.

   1#
   2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
   3#
   4# This project and all its files are licensed under GNU AGPLv3 or later version. 
   5# A copy is included in dlc2action/LICENSE.AGPL.
   6#
   7"""Class that provides an interface for `dlc2action.task.universal_task.Task`."""
   8
   9import inspect
  10import warnings
  11from collections.abc import Iterable, Mapping
  12from copy import deepcopy
  13from typing import Callable, Dict, List, Set, Tuple, Union
  14
  15import numpy as np
  16import torch
  17from dlc2action import options
  18from dlc2action.data.dataset import BehaviorDataset
  19from dlc2action.metric.base_metric import Metric
  20from dlc2action.model.base_model import LoadedModel, Model
  21from dlc2action.ssl.base_ssl import EmptySSL, SSLConstructor
  22from dlc2action.task.universal_task import Task
  23from dlc2action.transformer.base_transformer import Transformer
  24from dlc2action.utils import PostProcessor
  25from optuna.trial import Trial
  26from torch.optim import Optimizer
  27from torch.utils.data import DataLoader
  28
  29
  30class TaskDispatcher:
  31    """A class that manages the interactions between config dictionaries and a Task."""
  32
  33    def __init__(self, parameters: Dict) -> None:
  34        """Initialize the `TaskDispatcher`.
  35
  36        Parameters
  37        ----------
  38        parameters : dict
  39            a dictionary of task parameters
  40
  41        """
  42        pars = deepcopy(parameters)
  43        self.class_weights = None
  44        self.general_parameters = pars.get("general", {})
  45        self.data_parameters = pars.get("data", {})
  46        self.model_parameters = pars.get("model", {})
  47        self.training_parameters = pars.get("training", {})
  48        self.loss_parameters = pars.get("losses", {})
  49        self.metric_parameters = pars.get("metrics", {})
  50        self.ssl_parameters = pars.get("ssl", {})
  51        self.aug_parameters = pars.get("augmentations", {})
  52        self.feature_parameters = pars.get("features", {})
  53        self.blanks = {blank: [] for blank in options.blanks}
  54
  55        self.task = None
  56        self._initialize_task()
  57        self._print_behaviors()
  58
  59    @staticmethod
  60    def complete_function_parameters(parameters, function, general_dicts: List) -> Dict:
  61        """Complete a parameter dictionary with values from other dictionaries if required by a function.
  62
  63        Parameters
  64        ----------
  65        parameters : dict
  66            the function parameters dictionary
  67        function : callable
  68            the function to be inspected
  69        general_dicts : list
  70            a list of dictionaries where the missing values will be pulled from
  71
  72        Returns
  73        -------
  74        parameters : dict
  75            the updated parameter dictionary
  76
  77        """
  78        parameter_names = inspect.getfullargspec(function).args
  79        for param in parameter_names:
  80            for dic in general_dicts:
  81                if param not in parameters and param in dic:
  82                    parameters[param] = dic[param]
  83        return parameters
  84
  85    @staticmethod
  86    def complete_dataset_parameters(
  87        parameters: dict,
  88        general_dict: dict,
  89        data_type: str,
  90        annotation_type: str,
  91    ) -> Dict:
  92        """Complete a parameter dictionary with values from other dictionaries if required by a dataset.
  93
  94        Parameters
  95        ----------
  96        parameters : dict
  97            the function parameters dictionary
  98        general_dict : dict
  99            the dictionary where the missing values will be pulled from
 100        data_type : str
 101            the input type of the dataset
 102        annotation_type : str
 103            the annotation type of the dataset
 104
 105        Returns
 106        -------
 107        parameters : dict
 108            the updated parameter dictionary
 109
 110        """
 111        params = deepcopy(parameters)
 112        parameter_names = BehaviorDataset.get_parameters(data_type, annotation_type)
 113        for param in parameter_names:
 114            if param not in params and param in general_dict:
 115                params[param] = general_dict[param]
 116        return params
 117
 118    @staticmethod
 119    def check(parameters: Dict, name: str) -> bool:
 120        """Check whether there is a non-`None` value under the name key in the parameters dictionary.
 121
 122        Parameters
 123        ----------
 124        parameters : dict
 125            the dictionary to check
 126        name : str
 127            the key to check
 128
 129        Returns
 130        -------
 131        result : bool
 132            True if a non-`None` value exists
 133
 134        """
 135        if name in parameters and parameters[name] is not None:
 136            return True
 137        else:
 138            return False
 139
 140    @staticmethod
 141    def get(parameters: Dict, name: str, default):
 142        """Get the value under the name key or the default if it is `None` or does not exist.
 143
 144        Parameters
 145        ----------
 146        parameters : dict
 147            the dictionary to check
 148        name : str
 149            the key to check
 150        default
 151            the default value to return
 152
 153        Returns
 154        -------
 155        value
 156            the resulting value
 157
 158        """
 159        if TaskDispatcher.check(parameters, name):
 160            return parameters[name]
 161        else:
 162            return default
 163
 164    @staticmethod
 165    def make_dataloader(
 166        dataset: BehaviorDataset, batch_size: int = 32, shuffle: bool = False
 167    ) -> DataLoader:
 168        """Make a torch dataloader from a dataset.
 169
 170        Parameters
 171        ----------
 172        dataset : dlc2action.data.dataset.BehaviorDataset
 173            the dataset
 174        batch_size : int
 175            the batch size
 176        shuffle : bool
 177            whether to shuffle the dataset
 178
 179        Returns
 180        -------
 181        dataloader : DataLoader
 182            the dataloader (or `None` if the length of the dataset is 0)
 183
 184        """
 185        if dataset is None or len(dataset) == 0:
 186            return None
 187        else:
 188            return DataLoader(dataset, batch_size=int(batch_size), shuffle=shuffle)
 189
 190    def _construct_ssl(self) -> List:
 191        """Generate SSL constructors."""
 192        ssl_list = deepcopy(self.general_parameters.get("ssl", None))
 193        model_name = self.general_parameters.get("model_name", "")
 194        # ssl_constructors = options.ssl_constructors if not "tcn" in model_name else options.ssl_constructors_tcn
 195        ssl_constructors = options.ssl_constructors
 196        if not isinstance(ssl_list, Iterable):
 197            ssl_list = [ssl_list]
 198        for i, ssl in enumerate(ssl_list):
 199            if type(ssl) is str:
 200                if ssl in ssl_constructors:
 201                    pars = self.get(self.ssl_parameters, ssl, default={})
 202                    pars = self.complete_function_parameters(
 203                        parameters=pars,
 204                        function=ssl_constructors[ssl],
 205                        general_dicts=[
 206                            self.model_parameters,
 207                            self.data_parameters,
 208                            self.general_parameters,
 209                        ],
 210                    )
 211                    ssl_list[i] = ssl_constructors[ssl](**pars)
 212                else:
 213                    raise ValueError(
 214                        f"The {ssl} SSL is not available, please choose from {list(ssl_constructors.keys())}"
 215                    )
 216            elif ssl is None:
 217                ssl_list[i] = EmptySSL()
 218            elif not isinstance(ssl, SSLConstructor):
 219                raise TypeError(
 220                    f"The ssl parameter has to be a list of either strings, SSLConstructor instances or None, got {type(ssl)}"
 221                )
 222        return ssl_list
 223
 224    def _construct_model(self) -> Model:
 225        """Generate a model."""
 226        if self.check(self.general_parameters, "model"):
 227            pars = self.complete_function_parameters(
 228                function=LoadedModel,
 229                parameters=self.model_parameters,
 230                general_dicts=[self.general_parameters],
 231            )
 232            model = LoadedModel(**pars)
 233        elif self.check(self.general_parameters, "model_name"):
 234            name = self.general_parameters["model_name"]
 235            if name in options.models:
 236                pars = self.complete_function_parameters(
 237                    function=options.models[name],
 238                    parameters=self.model_parameters,
 239                    general_dicts=[self.general_parameters],
 240                )
 241                model = options.models[name](**pars)
 242            else:
 243                raise ValueError(
 244                    f"The {name} model is not available, please choose from {list(options.models.keys())}"
 245                )
 246        else:
 247            raise ValueError(
 248                "You need to provide either a model or its name in the model_parameters!"
 249            )
 250
 251        if self.get(self.training_parameters, "freeze_features", False):
 252            model.freeze_feature_extractor()
 253        return model
 254
 255    def _construct_dataset(self) -> BehaviorDataset:
 256        """
 257        Generate a dataset
 258        """
 259        data_type = self.general_parameters.get("data_type", None)
 260        if data_type is None:
 261            raise ValueError(
 262                "You need to provide the data_type parameter in the data parameters!"
 263            )
 264        annotation_type = self.get(self.general_parameters, "annotation_type", "none")
 265        feature_extraction = self.general_parameters.get("feature_extraction", "none")
 266        if feature_extraction is None:
 267            raise ValueError(
 268                "You need to provide the feature_extraction parameter in the data parameters!"
 269            )
 270        feature_extraction_pars = self.complete_function_parameters(
 271            self.feature_parameters,
 272            options.feature_extractors[feature_extraction],
 273            [self.general_parameters, self.data_parameters],
 274        )
 275
 276        pars = self.complete_dataset_parameters(
 277            self.data_parameters,
 278            self.general_parameters,
 279            data_type=data_type,
 280            annotation_type=annotation_type,
 281        )
 282        pars["feature_extraction_pars"] = feature_extraction_pars
 283        dataset = BehaviorDataset(**pars)
 284
 285        if self.get(self.general_parameters, "save_dataset", default=False):
 286            save_data_path = self.data_parameters.get("saved_data_path", None)
 287            dataset.save(save_path=save_data_path)
 288
 289        return dataset
 290
 291    def _construct_transformer(self) -> Transformer:
 292        """Generate a transformer."""
 293        features = self.general_parameters["feature_extraction"]
 294        name = options.extractor_to_transformer[features]
 295        if name in options.transformers:
 296            transformer_class = options.transformers[name]
 297            pars = self.complete_function_parameters(
 298                function=transformer_class,
 299                parameters=self.aug_parameters,
 300                general_dicts=[self.general_parameters],
 301            )
 302            transformer = transformer_class(**pars)
 303        else:
 304            raise ValueError(f"The {name} transformer is not available")
 305        return transformer
 306
 307    def _construct_loss(self) -> torch.nn.Module:
 308        """Generate a loss function."""
 309        if "loss_function" not in self.general_parameters:
 310            raise ValueError(
 311                'Please add a "loss_function" key to the parameters["general"] dictionary (either a name '
 312                f"from {list(options.losses.keys())} or a function)"
 313            )
 314        else:
 315            loss_function = self.general_parameters["loss_function"]
 316        if type(loss_function) is str:
 317            if loss_function in options.losses:
 318                pars = self.get(self.loss_parameters, loss_function, default={})
 319                pars = self._set_loss_weights(pars)
 320                pars = self.complete_function_parameters(
 321                    function=options.losses[loss_function],
 322                    parameters=pars,
 323                    general_dicts=[self.general_parameters],
 324                )
 325                loss = options.losses[loss_function](**pars)
 326            else:
 327                raise ValueError(
 328                    f"The {loss_function} loss is not available, please choose from {list(options.losses.keys())}"
 329                )
 330        else:
 331            loss = loss_function
 332        return loss
 333
 334    def _construct_metrics(self) -> List:
 335        """Generate the metric."""
 336        metric_functions = self.get(
 337            self.general_parameters, "metric_functions", default={}
 338        )
 339        if isinstance(metric_functions, Iterable):
 340            metrics = {}
 341            for func in metric_functions:
 342                if isinstance(func, str):
 343                    if func in options.metrics:
 344                        pars = self.get(self.metric_parameters, func, default={})
 345                        pars = self.complete_function_parameters(
 346                            function=options.metrics[func],
 347                            parameters=pars,
 348                            general_dicts=[self.general_parameters],
 349                        )
 350                        metrics[func] = options.metrics[func](**pars)
 351                    else:
 352                        raise ValueError(
 353                            f"The {func} metric is not available, please choose from {list(options.metrics.keys())}"
 354                        )
 355                elif isinstance(func, Metric):
 356                    name = "function_1"
 357                    i = 1
 358                    while name in metrics:
 359                        i += 1
 360                        name = f"function_{i}"
 361                    metrics[name] = func
 362                else:
 363                    raise TypeError(
 364                        'The elements of parameters["general"]["metric_functions"] have to be either strings '
 365                        f"from {list(options.metrics.keys())} or Metric instances; got {type(func)} instead"
 366                    )
 367        elif isinstance(metric_functions, dict):
 368            metrics = metric_functions
 369        else:
 370            raise TypeError(
 371                'The value at parameters["general"]["metric_functions"] can be either list, dictionary or None;'
 372                f"got {type(metric_functions)} instead"
 373            )
 374        return metrics
 375
 376    def _construct_optimizer(self) -> Optimizer:
 377        """Generate an optimizer."""
 378        if "optimizer" in self.training_parameters:
 379            name = self.training_parameters["optimizer"]
 380            if name in options.optimizers:
 381                optimizer = options.optimizers[name]
 382            else:
 383                raise ValueError(
 384                    f"The {name} optimizer is not available, please choose from {list(options.optimizers.keys())}"
 385                )
 386        else:
 387            optimizer = None
 388        return optimizer
 389
 390    def _construct_predict_functions(self) -> Tuple[Callable, Callable]:
 391        """Construct predict functions."""
 392        predict_function = self.training_parameters.get("predict_function", None)
 393        primary_predict_function = self.training_parameters.get(
 394            "primary_predict_function", None
 395        )
 396        model_name = self.general_parameters.get("model_name", "")
 397        threshold = self.training_parameters.get("hard_threshold", 0.5)
 398        if not isinstance(predict_function, Callable):
 399            if model_name in ["c2f_tcn", "c2f_transformer", "c2f_tcn_p"]:
 400                if self.general_parameters["exclusive"]:
 401                    func = lambda x: torch.softmax(x, dim=1)
 402                else:
 403                    func = lambda x: torch.sigmoid(x)
 404
 405                def primary_predict_function(x):
 406                    if len(x) == 1:
 407                        return func(x)
 408                    else:
 409                        if len(x.shape) != 4:
 410                            x = x.reshape((4, -1, x.shape[-2], x.shape[-1]))
 411                        weights = [1, 1, 1, 1]
 412                        ensemble_prob = func(x[0]) * weights[0] / sum(weights)
 413                        for i, outp_ele in enumerate(x[1:]):
 414                            ensemble_prob = ensemble_prob + func(outp_ele) * weights[
 415                                i + 1
 416                            ] / sum(weights)
 417                        return ensemble_prob
 418
 419            else:
 420                if model_name.startswith("ms_tcn") or model_name in [
 421                    "asformer",
 422                    "transformer",
 423                    "c3d_ms",
 424                    "transformer_ms",
 425                ]:
 426                    f = lambda x: x[-1] if len(x.shape) == 4 else x
 427                elif model_name == "asrf":
 428
 429                    def f(x):
 430                        x = x[-1]
 431                        # bounds = x[:, 0, :].unsqueeze(1)
 432                        cls = x[:, 1:, :]
 433                        # device = x.device
 434                        # x = PostProcessor("refinement_with_boundary")._refinement_with_boundary(cls.detach().cpu().numpy(), bounds.detach().cpu().numpy())
 435                        # x = torch.tensor(x).to(device)
 436                        return cls
 437
 438                else:
 439                    f = lambda x: x
 440                if self.general_parameters["exclusive"]:
 441                    primary_predict_function = lambda x: torch.softmax(f(x), dim=1)
 442                else:
 443                    primary_predict_function = lambda x: torch.sigmoid(f(x))
 444            if self.general_parameters["exclusive"]:
 445                predict_function = lambda x: torch.max(x.data, dim=1)[1]
 446            else:
 447                predict_function = lambda x: (x > threshold).int()
 448        return primary_predict_function, predict_function
 449
 450    def _get_parameters_from_training(self) -> Dict:
 451        """Get the training parameters that need to be passed to the Task."""
 452        task_training_par_names = [
 453            "lr",
 454            "parallel",
 455            "device",
 456            "verbose",
 457            "log_file",
 458            "augment_train",
 459            "augment_val",
 460            "hard_threshold",
 461            "ssl_losses",
 462            "model_save_path",
 463            "model_save_epochs",
 464            "pseudolabel",
 465            "pseudolabel_start",
 466            "correction_interval",
 467            "pseudolabel_alpha_f",
 468            "alpha_growth_stop",
 469            "num_epochs",
 470            "validation_interval",
 471            "ignore_tags",
 472            "skip_metrics",
 473        ]
 474        task_training_pars = {
 475            name: self.training_parameters[name]
 476            for name in task_training_par_names
 477            if self.check(self.training_parameters, name)
 478        }
 479        if self.check(self.general_parameters, "ssl"):
 480            ssl_weights = [
 481                self.training_parameters["ssl_weights"][x]
 482                for x in self.general_parameters["ssl"]
 483            ]
 484            task_training_pars["ssl_weights"] = ssl_weights
 485        return task_training_pars
 486
 487    def _update_parameters_from_ssl(self, ssl_list: list) -> None:
 488        """Update the necessary parameters given the list of SSL constructors."""
 489        if self.task is not None:
 490            self.task.set_ssl_transformations([ssl.transformation for ssl in ssl_list])
 491            self.task.set_ssl_losses([ssl.loss for ssl in ssl_list])
 492            self.task.set_keep_target_none(
 493                [ssl.type in ["contrastive"] for ssl in ssl_list]
 494            )
 495            self.task.set_generate_ssl_input(
 496                [ssl.type == "contrastive" for ssl in ssl_list]
 497            )
 498        self.data_parameters["ssl_transformations"] = [
 499            ssl.transformation for ssl in ssl_list
 500        ]
 501        self.training_parameters["ssl_losses"] = [ssl.loss for ssl in ssl_list]
 502        self.model_parameters["ssl_types"] = [ssl.type for ssl in ssl_list]
 503        self.model_parameters["ssl_modules"] = [
 504            ssl.construct_module() for ssl in ssl_list
 505        ]
 506        self.aug_parameters["generate_ssl_input"] = [
 507            x.type == "contrastive" for x in ssl_list
 508        ]
 509        self.aug_parameters["keep_target_none"] = [
 510            x.type == "contrastive" for x in ssl_list
 511        ]
 512
 513    def _set_loss_weights(self, parameters):
 514        """Replace the `"dataset_inverse_weights"` blank in loss parameters with class weight values."""
 515        for k in list(parameters.keys()):
 516            if parameters[k] in [
 517                "dataset_inverse_weights",
 518                "dataset_proportional_weights",
 519            ]:
 520                if parameters[k] == "dataset_inverse_weights":
 521                    parameters[k] = self.class_weights
 522                else:
 523                    parameters[k] = self.proportional_class_weights
 524                print("Initializing class weights:")
 525                string = "    "
 526                if isinstance(parameters[k], Mapping):
 527                    for key, val in parameters[k].items():
 528                        string += ": ".join(
 529                            (
 530                                " " + str(key),
 531                                ", ".join((map(lambda x: str(np.round(x, 3)), val))),
 532                            )
 533                        )
 534                else:
 535                    string += ", ".join(
 536                        (map(lambda x: str(np.round(x, 3)), parameters[k]))
 537                    )
 538                print(string)
 539        return parameters
 540
 541    def _partition_dataset(
 542        self, dataset: BehaviorDataset
 543    ) -> Tuple[BehaviorDataset, BehaviorDataset, BehaviorDataset]:
 544        """Partition the dataset into train, validation and test subsamples."""
 545        use_test = self.get(self.training_parameters, "use_test", 0)
 546        split_path = self.training_parameters.get("split_path", None)
 547        partition_method = self.training_parameters.get("partition_method", "random")
 548        val_frac = self.get(self.training_parameters, "val_frac", 0)
 549        test_frac = self.get(self.training_parameters, "test_frac", 0)
 550        save_split = self.get(self.training_parameters, "save_split", True)
 551        normalize = self.get(self.training_parameters, "normalize", False)
 552        skip_normalization_keys = self.training_parameters.get(
 553            "skip_normalization_keys"
 554        )
 555        stats = self.training_parameters.get("stats")
 556        train_dataset, test_dataset, val_dataset = dataset.partition_train_test_val(
 557            use_test,
 558            split_path,
 559            partition_method,
 560            val_frac,
 561            test_frac,
 562            save_split,
 563            normalize,
 564            skip_normalization_keys,
 565            stats,
 566        )
 567        bs = int(self.training_parameters.get("batch_size", 32))
 568        train_dataloader, test_dataloader, val_dataloader = (
 569            self.make_dataloader(train_dataset, batch_size=bs, shuffle=True),
 570            self.make_dataloader(test_dataset, batch_size=bs, shuffle=False),
 571            self.make_dataloader(val_dataset, batch_size=bs, shuffle=False),
 572        )
 573        return train_dataloader, test_dataloader, val_dataloader
 574
 575    def _initialize_task(self):
 576        """Create a `dlc2action.task.universal_task.Task` instance."""
 577        dataset = self._construct_dataset()
 578        self._update_data_blanks(dataset)
 579        model = self._construct_model()
 580        self._update_model_blanks(model)
 581        ssl_list = self._construct_ssl()
 582        self._update_parameters_from_ssl(ssl_list)
 583        model.set_ssl(ssl_constructors=ssl_list)
 584        dataset.set_ssl_transformations([ssl.transformation for ssl in ssl_list])
 585        transformer = self._construct_transformer()
 586        metrics = self._construct_metrics()
 587        optimizer = self._construct_optimizer()
 588        primary_predict_function, predict_function = self._construct_predict_functions()
 589
 590        task_training_pars = self._get_parameters_from_training()
 591        train_dataloader, test_dataloader, val_dataloader = self._partition_dataset(
 592            dataset
 593        )
 594        self.class_weights = train_dataloader.dataset.class_weights()
 595        self._update_num_classes_parameter(dataset)
 596        self.proportional_class_weights = train_dataloader.dataset.class_weights(True)
 597        loss = self._construct_loss()
 598        exclusive = self.general_parameters["exclusive"]
 599
 600        task_pars = {
 601            "train_dataloader": train_dataloader,
 602            "model": model,
 603            "loss": loss,
 604            "transformer": transformer,
 605            "metrics": metrics,
 606            "val_dataloader": val_dataloader,
 607            "test_dataloader": test_dataloader,
 608            "exclusive": exclusive,
 609            "optimizer": optimizer,
 610            "predict_function": predict_function,
 611            "primary_predict_function": primary_predict_function,
 612        }
 613        task_pars.update(task_training_pars)
 614        self.task = Task(**task_pars)
 615        checkpoint_path = self.training_parameters.get("checkpoint_path", None)
 616        if checkpoint_path is not None:
 617            only_model = self.get(self.training_parameters, "only_load_model", False)
 618            load_strict = self.get(self.training_parameters, "load_strict", True)
 619            self.task.load_from_checkpoint(checkpoint_path, only_model, load_strict)
 620        if (
 621            self.general_parameters["only_load_annotated"]
 622            and self.general_parameters.get("ssl") is not None
 623        ):
 624            warnings.warn(
 625                "Note that you are using SSL modules but only loading annotated files! Set "
 626                "general/only_load_annotated to False to change that"
 627            )
 628
 629    def _update_data_blanks(
 630        self, dataset: BehaviorDataset = None, remember: bool = False
 631    ) -> None:
 632        """Update all blanks from a dataset."""
 633        if dataset is None:
 634            dataset = self.dataset()
 635        self._update_dim_parameter(dataset, remember)
 636        self._update_bodyparts_parameter(dataset, remember)
 637        self._update_num_classes_parameter(dataset, remember)
 638        self._update_len_segment_parameter(dataset, remember)
 639        self._update_boundary_parameter(dataset, remember)
 640
 641    def _update_model_blanks(self, model: Model, remember: bool = False) -> None:
 642        """Update blanks related to model parameters."""
 643        self._update_features_parameter(model, remember)
 644
 645    def _update_parameter(self, blank_name: str, value, remember: bool = False):
 646        """Update a single blank parameter."""
 647        parameters = [
 648            self.model_parameters,
 649            self.ssl_parameters,
 650            self.general_parameters,
 651            self.feature_parameters,
 652            self.data_parameters,
 653            self.training_parameters,
 654            self.metric_parameters,
 655            self.loss_parameters,
 656            self.aug_parameters,
 657        ]
 658        par_names = [
 659            "model",
 660            "ssl",
 661            "general",
 662            "feature",
 663            "data",
 664            "training",
 665            "metrics",
 666            "losses",
 667            "augmentations",
 668        ]
 669        for names in self.blanks[blank_name]:
 670            group = names[0]
 671            key = names[1]
 672            ind = par_names.index(group)
 673            if len(names) == 3:
 674                if names[2] in parameters[ind][key]:
 675                    parameters[ind][key][names[2]] = value
 676            else:
 677                if key in parameters[ind]:
 678                    parameters[ind][key] = value
 679        for name, dic in zip(par_names, parameters):
 680            for k, v in dic.items():
 681                if v == blank_name:
 682                    dic[k] = value
 683                    if [name, k] not in self.blanks[blank_name]:
 684                        self.blanks[blank_name].append([name, k])
 685                elif isinstance(v, Mapping):
 686                    for kk, vv in v.items():
 687                        if vv == blank_name:
 688                            dic[k][kk] = value
 689                            if [name, k, kk] not in self.blanks[blank_name]:
 690                                self.blanks[blank_name].append([name, k, kk])
 691
 692    def _update_features_parameter(self, model: Model, remember: bool = False) -> None:
 693        """Fill the `"model_features"` blank."""
 694        value = model.features_shape()
 695        self._update_parameter("model_features", value, remember)
 696
 697    def _update_bodyparts_parameter(
 698        self, dataset: BehaviorDataset, remember: bool = False
 699    ) -> None:
 700        """Fill the `"dataset_bodyparts"` blank."""
 701        value = dataset.bodyparts_order()
 702        self._update_parameter("dataset_bodyparts", value, remember)
 703
 704    def _update_dim_parameter(
 705        self, dataset: BehaviorDataset, remember: bool = False
 706    ) -> None:
 707        """Fill the `"dataset_features"` blank."""
 708        value = dataset.features_shape()
 709        self._update_parameter("dataset_features", value, remember)
 710
 711    def _update_boundary_parameter(
 712        self, dataset: BehaviorDataset, remember: bool = False
 713    ) -> None:
 714        """Fill the `"dataset_features"` blank."""
 715        value = dataset._boundary_class_weight()
 716        self._update_parameter("dataset_boundary_weight", value, remember)
 717
 718    def _update_num_classes_parameter(
 719        self, dataset: BehaviorDataset, remember: bool = False
 720    ) -> None:
 721        """Fill in the `"dataset_classes"` blank."""
 722        value = dataset.num_classes()
 723        self._update_parameter("dataset_classes", value, remember)
 724
 725    def _update_len_segment_parameter(
 726        self, dataset: BehaviorDataset, remember: bool = False
 727    ) -> None:
 728        """Fill in the `"dataset_len_segment"` blank."""
 729        value = dataset.len_segment()
 730        self._update_parameter("dataset_len_segment", value, remember)
 731
 732    def _print_behaviors(self):
 733        behavior_set = self.behaviors_dict()
 734        print(f"Behavior indices:")
 735        for key, value in sorted(behavior_set.items()):
 736            print(f"    {key}: {value}")
 737
 738    def update_task(self, parameters: Dict) -> None:
 739        """Update the `dlc2action.task.universal_task.Task` instance given the parameter updates.
 740
 741        Parameters
 742        ----------
 743        parameters : dict
 744            the dictionary of parameter updates
 745
 746        """
 747        pars = deepcopy(parameters)
 748        # for blank_name in self.blanks:
 749        #     for names in self.blanks[blank_name]:
 750        #         group = names[0]
 751        #         key = names[1]
 752        #         if len(names) == 3:
 753        #             if (
 754        #                 group in pars
 755        #                 and key in pars[group]
 756        #                 and names[2] in pars[group][key]
 757        #             ):
 758        #                 pars[group][key].pop(names[2])
 759        #         else:
 760        #             if group in pars and key in pars[group]:
 761        #                 pars[group].pop(key)
 762        stay = False
 763        if "ssl" in pars:
 764            for key in pars["ssl"]:
 765                if key in self.ssl_parameters:
 766                    self.ssl_parameters[key].update(pars["ssl"][key])
 767                else:
 768                    self.ssl_parameters[key] = pars["ssl"][key]
 769
 770        if "general" in pars:
 771            if stay:
 772                stay = False
 773            if (
 774                "model_name" in pars["general"]
 775                and pars["general"]["model_name"]
 776                != self.general_parameters["model_name"]
 777            ):
 778                if "model" not in pars:
 779                    raise ValueError(
 780                        "When updating a task with a new model name you need to pass the parameters for the "
 781                        "new model"
 782                    )
 783                self.model_parameters = {}
 784            self.general_parameters.update(pars["general"])
 785            data_related = [
 786                "num_classes",
 787                "exclusive",
 788                "data_type",
 789                "annotation_type",
 790            ]
 791            ssl_related = ["ssl", "exclusive", "num_classes"]
 792            loss_related = ["num_classes", "loss_function", "exclusive"]
 793            augmentation_related = ["augmentation_type"]
 794            metric_related = ["metric_functions"]
 795            related_lists = [
 796                data_related,
 797                ssl_related,
 798                loss_related,
 799                augmentation_related,
 800                metric_related,
 801            ]
 802            names = ["data", "ssl", "losses", "augmentations", "metrics"]
 803            for related_list, name in zip(related_lists, names):
 804                if (
 805                    any([x in pars["general"] for x in related_list])
 806                    and name not in pars
 807                ):
 808                    pars[name] = {}
 809
 810        if "training" in pars:
 811            if "data" not in pars or not stay:
 812                for x in [
 813                    "to_ram",
 814                    "use_test",
 815                    "partition_method",
 816                    "val_frac",
 817                    "test_frac",
 818                    "save_split",
 819                    "batch_size",
 820                    "save_split",
 821                ]:
 822                    if (
 823                        x in pars["training"]
 824                        and pars["training"][x] != self.training_parameters[x]
 825                    ):
 826                        if "data" not in pars:
 827                            pars["data"] = {}
 828                        stay = True
 829            self.training_parameters.update(pars["training"])
 830            self.task.update_parameters(self._get_parameters_from_training())
 831
 832        if "data" in pars or "features" in pars:
 833            for k, v in pars["data"].items():
 834                if k not in self.data_parameters or v != self.data_parameters[k]:
 835                    stay = True
 836            for k, v in pars["features"].items():
 837                if k not in self.feature_parameters or v != self.feature_parameters[k]:
 838                    stay = True
 839            if stay:
 840                self.data_parameters.update(pars["data"])
 841                self.feature_parameters.update(pars["features"])
 842                dataset = self._construct_dataset()
 843                (
 844                    train_dataloader,
 845                    test_dataloader,
 846                    val_dataloader,
 847                ) = self._partition_dataset(dataset)
 848                self.task.set_dataloaders(
 849                    train_dataloader, val_dataloader, test_dataloader
 850                )
 851                self.class_weights = train_dataloader.dataset.class_weights()
 852                self.proportional_class_weights = (
 853                    train_dataloader.dataset.class_weights(True)
 854                )
 855                if "losses" not in pars:
 856                    pars["losses"] = {}
 857
 858        if "model" in pars:
 859            self.model_parameters.update(pars["model"])
 860
 861        self._update_data_blanks()
 862
 863        if "augmentations" in pars:
 864            self.aug_parameters.update(pars["augmentations"])
 865            transformer = self._construct_transformer()
 866            self.task.set_transformer(transformer)
 867
 868        if "losses" in pars:
 869            for key in pars["losses"]:
 870                if key in self.loss_parameters:
 871                    self.loss_parameters[key].update(pars["losses"][key])
 872                else:
 873                    self.loss_parameters[key] = pars["losses"][key]
 874            self.loss_parameters.update(pars["losses"])
 875            loss = self._construct_loss()
 876            self.task.set_loss(loss)
 877
 878        if "metrics" in pars:
 879            for key in pars["metrics"]:
 880                if key in self.metric_parameters:
 881                    self.metric_parameters[key].update(pars["metrics"][key])
 882                else:
 883                    self.metric_parameters[key] = pars["metrics"][key]
 884            metrics = self._construct_metrics()
 885            self.task.set_metrics(metrics)
 886
 887        self.task.set_ssl_transformations(self.data_parameters["ssl_transformations"])
 888        self._set_loss_weights(
 889            pars.get("losses", {}).get(self.general_parameters["loss_function"], {})
 890        )
 891        model = self._construct_model()
 892        predict_functions = self._construct_predict_functions()
 893        self.task.set_predict_functions(*predict_functions)
 894        self._update_model_blanks(model)
 895        ssl_list = self._construct_ssl()
 896        self._update_parameters_from_ssl(ssl_list)
 897        model.set_ssl(ssl_constructors=ssl_list)
 898        self.task.set_ssl_transformations([ssl.transformation for ssl in ssl_list])
 899        self.task.set_model(model)
 900        if "training" in pars and "checkpoint_path" in pars["training"]:
 901            checkpoint_path = pars["training"]["checkpoint_path"]
 902            only_model = pars["training"].get("only_load_model", False)
 903            load_strict = pars["training"].get("load_strict", True)
 904            self.task.load_from_checkpoint(checkpoint_path, only_model, load_strict)
 905        if (
 906            self.general_parameters["only_load_annotated"]
 907            and self.general_parameters.get("ssl") is not None
 908        ):
 909            warnings.warn(
 910                "Note that you are using SSL modules but only loading annotated files! Set "
 911                "general/only_load_annotated to False to change that"
 912            )
 913        if self.task.dataset("train").annotation_class() != "none":
 914            self._print_behaviors()
 915
 916    def train(
 917        self,
 918        trial: Trial = None,
 919        optimized_metric: str = None,
 920        autostop_metric: str = None,
 921        autostop_interval: int = 10,
 922        autostop_threshold: float = 0.001,
 923        loading_bar: bool = False,
 924    ) -> Tuple:
 925        """Train the task and return a log of epoch-average loss and metric.
 926
 927        You can use the autostop parameters to finish training when the parameters are not improving. It will be
 928        stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than
 929        the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the
 930        current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared.
 931
 932        Parameters
 933        ----------
 934        trial : Trial
 935            an `optuna` trial (for hyperparameter searches)
 936        optimized_metric : str
 937            the name of the metric being optimized (for hyperparameter searches)
 938        autostop_metric : str, optional
 939            the autostop metric (can be any one of the tracked metrics of `'loss'`)
 940        autostop_interval : int, default 50
 941            the number of epochs to average the autostop metric over
 942        autostop_threshold : float, default 0.001
 943            the autostop difference threshold
 944        loading_bar : bool, default False
 945            whether to show a loading bar
 946
 947        Returns
 948        -------
 949        loss_log: list
 950            a list of float loss function values for each epoch
 951        metrics_log: dict
 952            a dictionary of metric value logs (first-level keys are 'train' and 'val', second-level keys are metric
 953            names, values are lists of function values)
 954
 955        """
 956        to_ram = self.training_parameters.get("to_ram", False)
 957        logs = self.task.train(
 958            trial,
 959            optimized_metric,
 960            to_ram,
 961            autostop_metric=autostop_metric,
 962            autostop_interval=autostop_interval,
 963            autostop_threshold=autostop_threshold,
 964            main_task_on=self.training_parameters.get("main_task_on", True),
 965            ssl_on=self.training_parameters.get("ssl_on", True),
 966            temporal_subsampling_size=self.training_parameters.get(
 967                "temporal_subsampling_size"
 968            ),
 969            loading_bar=loading_bar,
 970        )
 971        return logs
 972
 973    def save_model(self, save_path: str) -> None:
 974        """Save the model of the `dlc2action.task.universal_task.Task` instance.
 975
 976        Parameters
 977        ----------
 978        save_path : str
 979            the path to the saved file
 980
 981        """
 982        self.task.save_model(save_path)
 983
 984    def evaluate(
 985        self,
 986        data: Union[DataLoader, BehaviorDataset, str] = None,
 987        augment_n: int = 0,
 988        verbose: bool = True,
 989    ) -> Tuple:
 990        """Evaluate the Task model.
 991
 992        Parameters
 993        ----------
 994        data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional
 995            the data to evaluate on (if not provided, evaluate on the Task validation dataset)
 996        augment_n : int, default 0
 997            the number of augmentations to average results over
 998        verbose : bool, default True
 999            if True, the process is reported to standard output
1000
1001        Returns
1002        -------
1003        loss : float
1004            the average value of the loss function
1005        ssl_loss : float
1006            the average value of the SSL loss function
1007        metric : dict
1008            a dictionary of average values of metric functions
1009
1010        """
1011        res = self.task.evaluate(
1012            data,
1013            augment_n,
1014            int(self.training_parameters.get("batch_size", 32)),
1015            verbose,
1016        )
1017        return res
1018
1019    def evaluate_prediction(
1020        self,
1021        prediction: torch.Tensor,
1022        data: Union[DataLoader, BehaviorDataset, str] = None,
1023        indices: Union[List[int], None] = None,
1024    ) -> Tuple:
1025        """Compute metrics for a prediction.
1026
1027        Parameters
1028        ----------
1029        prediction : torch.Tensor
1030            the prediction
1031        data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional
1032            the data the prediction was made for (if not provided, take the validation dataset)
1033
1034        Returns
1035        -------
1036        loss : float
1037            the average value of the loss function
1038        metric : dict
1039            a dictionary of average values of metric functions
1040        """
1041        return self.task.evaluate_prediction(
1042            prediction, data, int(self.training_parameters.get("batch_size", 32)), indices
1043        )
1044
1045    def predict(
1046        self,
1047        data: Union[DataLoader, BehaviorDataset, str],
1048        raw_output: bool = False,
1049        apply_primary_function: bool = True,
1050        augment_n: int = 0,
1051        embedding: bool = False,
1052    ) -> torch.Tensor:
1053        """Make a prediction with the Task model.
1054
1055        Parameters
1056        ----------
1057        data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional
1058            the data to evaluate on (if not provided, evaluate on the Task validation dataset)
1059        raw_output : bool, default False
1060            if `True`, the raw predicted probabilities are returned
1061        apply_primary_function : bool, default True
1062            if `True`, the primary predict function is applied (to map the model output into a shape corresponding to
1063            the input)
1064        augment_n : int, default 0
1065            the number of augmentations to average results over
1066        embedding : bool, default False
1067            if `True`, the embedding is returned instead of the prediction
1068
1069        Returns
1070        -------
1071        prediction : torch.Tensor
1072            a prediction for the input data
1073
1074        """
1075        to_ram = self.training_parameters.get("to_ram", False)
1076        return self.task.predict(
1077            data,
1078            raw_output,
1079            apply_primary_function,
1080            augment_n,
1081            int(self.training_parameters.get("batch_size", 32)),
1082            to_ram,
1083            embedding=embedding,
1084        )
1085
1086    def dataset(self, mode: str = "train") -> BehaviorDataset:
1087        """Get a dataset.
1088
1089        Parameters
1090        ----------
1091        mode : {'train', 'val', 'test'}
1092            the dataset to get
1093
1094        Returns
1095        -------
1096        dataset : dlc2action.data.dataset.BehaviorDataset
1097            the dataset
1098
1099        """
1100        return self.task.dataset(mode)
1101
1102    def generate_full_length_prediction(
1103        self,
1104        dataset: Union[BehaviorDataset, str] = None,
1105        augment_n: int = 10,
1106    ) -> Dict:
1107        """Compile a prediction for the original input sequences.
1108
1109        Parameters
1110        ----------
1111        dataset : dlc2action.data.dataset.BehaviorDataset | str, optional
1112            the dataset to generate a prediction for (if `None`, generate for the `dlc2action.task.universal_task.Task`
1113            instance validation dataset)
1114        augment_n : int, default 10
1115            the number of augmentations to average results over
1116
1117        Returns
1118        -------
1119        prediction : dict
1120            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
1121            are prediction tensors
1122
1123        """
1124        return self.task.generate_full_length_prediction(
1125            dataset, int(self.training_parameters.get("batch_size", 32)), augment_n
1126        )
1127
1128    def generate_submission(
1129        self,
1130        frame_number_map_file: str,
1131        dataset: Union[BehaviorDataset, str] = None,
1132        augment_n: int = 10,
1133    ) -> Dict:
1134        """Generate a MABe-22 style submission dictionary.
1135
1136        Parameters
1137        ----------
1138        frame_number_map_file : str
1139            path to the frame number map file
1140        dataset : BehaviorDataset, optional
1141            the dataset to generate a prediction for (if `None`, generate for the validation dataset)
1142        augment_n : int, default 10
1143            the number of augmentations to average results over
1144
1145        Returns
1146        -------
1147        submission : dict
1148            a dictionary with frame number mapping and embeddings
1149
1150        """
1151        return self.task.generate_submission(
1152            frame_number_map_file,
1153            dataset,
1154            int(self.training_parameters.get("batch_size", 32)),
1155            augment_n,
1156        )
1157
1158    def behaviors_dict(self):
1159        """Get a behavior dictionary.
1160
1161        Keys are label indices and values are label names.
1162
1163        Returns
1164        -------
1165        behaviors_dict : dict
1166            behavior dictionary
1167
1168        """
1169        return self.task.behaviors_dict()
1170
1171    def count_classes(self, bouts: bool = False) -> Dict:
1172        """Get a dictionary of class counts in different modes.
1173
1174        Parameters
1175        ----------
1176        bouts : bool, default False
1177            if `True`, instead of frame counts segment counts are returned
1178
1179        Returns
1180        -------
1181        class_counts : dict
1182            a dictionary where first-level keys are "train", "val" and "test", second-level keys are
1183            class names and values are class counts (in frames)
1184
1185        """
1186        return self.task.count_classes(bouts)
1187
1188    def _visualize_results_label(
1189        self,
1190        label: str,
1191        save_path: str = None,
1192        add_legend: bool = True,
1193        ground_truth: bool = True,
1194        hide_axes: bool = False,
1195        width: int = 10,
1196        whole_video: bool = False,
1197        transparent: bool = False,
1198        dataset: BehaviorDataset = None,
1199        smooth_interval: int = 0,
1200        title: str = None,
1201    ):
1202        return self.task._visualize_results_single(
1203            label,
1204            save_path,
1205            add_legend,
1206            ground_truth,
1207            hide_axes,
1208            width,
1209            whole_video,
1210            transparent,
1211            dataset,
1212            smooth_interval=smooth_interval,
1213            title=title,
1214        )
1215
1216    def visualize_results(
1217        self,
1218        save_path: str = None,
1219        add_legend: bool = True,
1220        ground_truth: bool = True,
1221        colormap: str = "viridis",
1222        hide_axes: bool = False,
1223        min_classes: int = 1,
1224        width: float = 10,
1225        whole_video: bool = False,
1226        transparent: bool = False,
1227        dataset: Union[BehaviorDataset, DataLoader, str, None] = None,
1228        drop_classes: Set = None,
1229        search_classes: Set = None,
1230        smooth_interval_prediction: int = None,
1231        font_size: float = None,
1232        num_plots:int = 1,
1233        window_size:int=400,
1234    ) -> None:
1235        """Visualize random predictions.
1236
1237        Parameters
1238        ----------
1239        save_path : str, optional
1240            the path where the plot will be saved
1241        add_legend : bool, default True
1242            if True, legend will be added to the plot
1243        ground_truth : bool, default True
1244            if True, ground truth will be added to the plot
1245        colormap : str, default 'Accent'
1246            the `matplotlib` colormap to use
1247        hide_axes : bool, default True
1248            if `True`, the axes will be hidden on the plot
1249        min_classes : int, default 1
1250            the minimum number of classes in a displayed interval
1251        width : float, default 10
1252            the width of the plot
1253        whole_video : bool, default False
1254            if `True`, whole videos are plotted instead of segments
1255        transparent : bool, default False
1256            if `True`, the background on the plot is transparent
1257        dataset : BehaviorDataset | DataLoader | str | None, optional
1258            the dataset to make the prediction for (if not provided, the validation dataset is used)
1259        drop_classes : set, optional
1260            a set of class names to not be displayed
1261        search_classes : set, optional
1262            if given, only intervals where at least one of the classes is in ground truth will be shown
1263        smooth_interval_prediction : int, optional
1264            if given, the prediction will be smoothed over the given number of frames
1265
1266        """
1267        return self.task.visualize_results(
1268            save_path,
1269            add_legend,
1270            ground_truth,
1271            colormap,
1272            hide_axes,
1273            min_classes,
1274            width,
1275            whole_video,
1276            transparent,
1277            dataset,
1278            drop_classes,
1279            search_classes,
1280            font_size=font_size,
1281            smooth_interval_prediction=smooth_interval_prediction,
1282            num_plots=num_plots,
1283            window_size=window_size
1284        )
1285
1286    def generate_uncertainty_score(
1287        self,
1288        classes: List,
1289        augment_n: int = 0,
1290        method: str = "least_confidence",
1291        predicted: torch.Tensor = None,
1292        behaviors_dict: Dict = None,
1293    ) -> Dict:
1294        """Generate frame-wise scores for active learning.
1295
1296        Parameters
1297        ----------
1298        classes : list
1299            a list of class names or indices; their confidence scores will be computed separately and stacked
1300        augment_n : int, default 0
1301            the number of augmentations to average over
1302        method : {"least_confidence", "entropy"}
1303            the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if
1304            `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`)
1305        predicted : torch.Tensor, optional
1306            if given, the predictions will be used instead of the model's predictions
1307        behaviors_dict : dict, optional
1308            if given, the behaviors dictionary will be used instead of the model's behaviors dictionary
1309
1310        Returns
1311        -------
1312        score_dicts : dict
1313            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
1314            are score tensors
1315
1316        """
1317        return self.task.generate_uncertainty_score(
1318            classes,
1319            augment_n,
1320            int(self.training_parameters.get("batch_size", 32)),
1321            method,
1322            predicted,
1323            behaviors_dict,
1324        )
1325
1326    def generate_bald_score(
1327        self,
1328        classes: List,
1329        augment_n: int = 0,
1330        num_models: int = 10,
1331        kernel_size: int = 11,
1332    ) -> Dict:
1333        """Generate frame-wise Bayesian Active Learning by Disagreement scores for active learning.
1334
1335        Parameters
1336        ----------
1337        classes : list
1338            a list of class names or indices; their confidence scores will be computed separately and stacked
1339        augment_n : int, default 0
1340            the number of augmentations to average over
1341        num_models : int, default 10
1342            the number of dropout masks to apply
1343        kernel_size : int, default 11
1344            the size of the smoothing gaussian kernel
1345
1346        Returns
1347        -------
1348        score_dicts : dict
1349            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
1350            are score tensors
1351
1352        """
1353        return self.task.generate_bald_score(
1354            classes,
1355            augment_n,
1356            int(self.training_parameters.get("batch_size", 32)),
1357            num_models,
1358            kernel_size,
1359        )
1360
1361    def exists(self, mode) -> bool:
1362        """Check whether the task has a train/test/validation subset.
1363
1364        Parameters
1365        ----------
1366        mode : {"train", "val", "test"}
1367            the name of the subset to check for
1368        Returns
1369        -------
1370        exists : bool
1371            `True` if the subset exists
1372
1373        """
1374        dl = self.task.dataloader(mode)
1375        if dl is None:
1376            return False
1377        else:
1378            return True
1379
1380    def get_normalization_stats(self) -> Dict:
1381        """Get the pre-computed normalization stats.
1382
1383        Returns
1384        -------
1385        normalization_stats : dict
1386            a dictionary of means and stds
1387
1388        """
1389        return self.task.get_normalization_stats()
class TaskDispatcher:
  31class TaskDispatcher:
  32    """A class that manages the interactions between config dictionaries and a Task."""
  33
  34    def __init__(self, parameters: Dict) -> None:
  35        """Initialize the `TaskDispatcher`.
  36
  37        Parameters
  38        ----------
  39        parameters : dict
  40            a dictionary of task parameters
  41
  42        """
  43        pars = deepcopy(parameters)
  44        self.class_weights = None
  45        self.general_parameters = pars.get("general", {})
  46        self.data_parameters = pars.get("data", {})
  47        self.model_parameters = pars.get("model", {})
  48        self.training_parameters = pars.get("training", {})
  49        self.loss_parameters = pars.get("losses", {})
  50        self.metric_parameters = pars.get("metrics", {})
  51        self.ssl_parameters = pars.get("ssl", {})
  52        self.aug_parameters = pars.get("augmentations", {})
  53        self.feature_parameters = pars.get("features", {})
  54        self.blanks = {blank: [] for blank in options.blanks}
  55
  56        self.task = None
  57        self._initialize_task()
  58        self._print_behaviors()
  59
  60    @staticmethod
  61    def complete_function_parameters(parameters, function, general_dicts: List) -> Dict:
  62        """Complete a parameter dictionary with values from other dictionaries if required by a function.
  63
  64        Parameters
  65        ----------
  66        parameters : dict
  67            the function parameters dictionary
  68        function : callable
  69            the function to be inspected
  70        general_dicts : list
  71            a list of dictionaries where the missing values will be pulled from
  72
  73        Returns
  74        -------
  75        parameters : dict
  76            the updated parameter dictionary
  77
  78        """
  79        parameter_names = inspect.getfullargspec(function).args
  80        for param in parameter_names:
  81            for dic in general_dicts:
  82                if param not in parameters and param in dic:
  83                    parameters[param] = dic[param]
  84        return parameters
  85
  86    @staticmethod
  87    def complete_dataset_parameters(
  88        parameters: dict,
  89        general_dict: dict,
  90        data_type: str,
  91        annotation_type: str,
  92    ) -> Dict:
  93        """Complete a parameter dictionary with values from other dictionaries if required by a dataset.
  94
  95        Parameters
  96        ----------
  97        parameters : dict
  98            the function parameters dictionary
  99        general_dict : dict
 100            the dictionary where the missing values will be pulled from
 101        data_type : str
 102            the input type of the dataset
 103        annotation_type : str
 104            the annotation type of the dataset
 105
 106        Returns
 107        -------
 108        parameters : dict
 109            the updated parameter dictionary
 110
 111        """
 112        params = deepcopy(parameters)
 113        parameter_names = BehaviorDataset.get_parameters(data_type, annotation_type)
 114        for param in parameter_names:
 115            if param not in params and param in general_dict:
 116                params[param] = general_dict[param]
 117        return params
 118
 119    @staticmethod
 120    def check(parameters: Dict, name: str) -> bool:
 121        """Check whether there is a non-`None` value under the name key in the parameters dictionary.
 122
 123        Parameters
 124        ----------
 125        parameters : dict
 126            the dictionary to check
 127        name : str
 128            the key to check
 129
 130        Returns
 131        -------
 132        result : bool
 133            True if a non-`None` value exists
 134
 135        """
 136        if name in parameters and parameters[name] is not None:
 137            return True
 138        else:
 139            return False
 140
 141    @staticmethod
 142    def get(parameters: Dict, name: str, default):
 143        """Get the value under the name key or the default if it is `None` or does not exist.
 144
 145        Parameters
 146        ----------
 147        parameters : dict
 148            the dictionary to check
 149        name : str
 150            the key to check
 151        default
 152            the default value to return
 153
 154        Returns
 155        -------
 156        value
 157            the resulting value
 158
 159        """
 160        if TaskDispatcher.check(parameters, name):
 161            return parameters[name]
 162        else:
 163            return default
 164
 165    @staticmethod
 166    def make_dataloader(
 167        dataset: BehaviorDataset, batch_size: int = 32, shuffle: bool = False
 168    ) -> DataLoader:
 169        """Make a torch dataloader from a dataset.
 170
 171        Parameters
 172        ----------
 173        dataset : dlc2action.data.dataset.BehaviorDataset
 174            the dataset
 175        batch_size : int
 176            the batch size
 177        shuffle : bool
 178            whether to shuffle the dataset
 179
 180        Returns
 181        -------
 182        dataloader : DataLoader
 183            the dataloader (or `None` if the length of the dataset is 0)
 184
 185        """
 186        if dataset is None or len(dataset) == 0:
 187            return None
 188        else:
 189            return DataLoader(dataset, batch_size=int(batch_size), shuffle=shuffle)
 190
 191    def _construct_ssl(self) -> List:
 192        """Generate SSL constructors."""
 193        ssl_list = deepcopy(self.general_parameters.get("ssl", None))
 194        model_name = self.general_parameters.get("model_name", "")
 195        # ssl_constructors = options.ssl_constructors if not "tcn" in model_name else options.ssl_constructors_tcn
 196        ssl_constructors = options.ssl_constructors
 197        if not isinstance(ssl_list, Iterable):
 198            ssl_list = [ssl_list]
 199        for i, ssl in enumerate(ssl_list):
 200            if type(ssl) is str:
 201                if ssl in ssl_constructors:
 202                    pars = self.get(self.ssl_parameters, ssl, default={})
 203                    pars = self.complete_function_parameters(
 204                        parameters=pars,
 205                        function=ssl_constructors[ssl],
 206                        general_dicts=[
 207                            self.model_parameters,
 208                            self.data_parameters,
 209                            self.general_parameters,
 210                        ],
 211                    )
 212                    ssl_list[i] = ssl_constructors[ssl](**pars)
 213                else:
 214                    raise ValueError(
 215                        f"The {ssl} SSL is not available, please choose from {list(ssl_constructors.keys())}"
 216                    )
 217            elif ssl is None:
 218                ssl_list[i] = EmptySSL()
 219            elif not isinstance(ssl, SSLConstructor):
 220                raise TypeError(
 221                    f"The ssl parameter has to be a list of either strings, SSLConstructor instances or None, got {type(ssl)}"
 222                )
 223        return ssl_list
 224
 225    def _construct_model(self) -> Model:
 226        """Generate a model."""
 227        if self.check(self.general_parameters, "model"):
 228            pars = self.complete_function_parameters(
 229                function=LoadedModel,
 230                parameters=self.model_parameters,
 231                general_dicts=[self.general_parameters],
 232            )
 233            model = LoadedModel(**pars)
 234        elif self.check(self.general_parameters, "model_name"):
 235            name = self.general_parameters["model_name"]
 236            if name in options.models:
 237                pars = self.complete_function_parameters(
 238                    function=options.models[name],
 239                    parameters=self.model_parameters,
 240                    general_dicts=[self.general_parameters],
 241                )
 242                model = options.models[name](**pars)
 243            else:
 244                raise ValueError(
 245                    f"The {name} model is not available, please choose from {list(options.models.keys())}"
 246                )
 247        else:
 248            raise ValueError(
 249                "You need to provide either a model or its name in the model_parameters!"
 250            )
 251
 252        if self.get(self.training_parameters, "freeze_features", False):
 253            model.freeze_feature_extractor()
 254        return model
 255
 256    def _construct_dataset(self) -> BehaviorDataset:
 257        """
 258        Generate a dataset
 259        """
 260        data_type = self.general_parameters.get("data_type", None)
 261        if data_type is None:
 262            raise ValueError(
 263                "You need to provide the data_type parameter in the data parameters!"
 264            )
 265        annotation_type = self.get(self.general_parameters, "annotation_type", "none")
 266        feature_extraction = self.general_parameters.get("feature_extraction", "none")
 267        if feature_extraction is None:
 268            raise ValueError(
 269                "You need to provide the feature_extraction parameter in the data parameters!"
 270            )
 271        feature_extraction_pars = self.complete_function_parameters(
 272            self.feature_parameters,
 273            options.feature_extractors[feature_extraction],
 274            [self.general_parameters, self.data_parameters],
 275        )
 276
 277        pars = self.complete_dataset_parameters(
 278            self.data_parameters,
 279            self.general_parameters,
 280            data_type=data_type,
 281            annotation_type=annotation_type,
 282        )
 283        pars["feature_extraction_pars"] = feature_extraction_pars
 284        dataset = BehaviorDataset(**pars)
 285
 286        if self.get(self.general_parameters, "save_dataset", default=False):
 287            save_data_path = self.data_parameters.get("saved_data_path", None)
 288            dataset.save(save_path=save_data_path)
 289
 290        return dataset
 291
 292    def _construct_transformer(self) -> Transformer:
 293        """Generate a transformer."""
 294        features = self.general_parameters["feature_extraction"]
 295        name = options.extractor_to_transformer[features]
 296        if name in options.transformers:
 297            transformer_class = options.transformers[name]
 298            pars = self.complete_function_parameters(
 299                function=transformer_class,
 300                parameters=self.aug_parameters,
 301                general_dicts=[self.general_parameters],
 302            )
 303            transformer = transformer_class(**pars)
 304        else:
 305            raise ValueError(f"The {name} transformer is not available")
 306        return transformer
 307
 308    def _construct_loss(self) -> torch.nn.Module:
 309        """Generate a loss function."""
 310        if "loss_function" not in self.general_parameters:
 311            raise ValueError(
 312                'Please add a "loss_function" key to the parameters["general"] dictionary (either a name '
 313                f"from {list(options.losses.keys())} or a function)"
 314            )
 315        else:
 316            loss_function = self.general_parameters["loss_function"]
 317        if type(loss_function) is str:
 318            if loss_function in options.losses:
 319                pars = self.get(self.loss_parameters, loss_function, default={})
 320                pars = self._set_loss_weights(pars)
 321                pars = self.complete_function_parameters(
 322                    function=options.losses[loss_function],
 323                    parameters=pars,
 324                    general_dicts=[self.general_parameters],
 325                )
 326                loss = options.losses[loss_function](**pars)
 327            else:
 328                raise ValueError(
 329                    f"The {loss_function} loss is not available, please choose from {list(options.losses.keys())}"
 330                )
 331        else:
 332            loss = loss_function
 333        return loss
 334
 335    def _construct_metrics(self) -> List:
 336        """Generate the metric."""
 337        metric_functions = self.get(
 338            self.general_parameters, "metric_functions", default={}
 339        )
 340        if isinstance(metric_functions, Iterable):
 341            metrics = {}
 342            for func in metric_functions:
 343                if isinstance(func, str):
 344                    if func in options.metrics:
 345                        pars = self.get(self.metric_parameters, func, default={})
 346                        pars = self.complete_function_parameters(
 347                            function=options.metrics[func],
 348                            parameters=pars,
 349                            general_dicts=[self.general_parameters],
 350                        )
 351                        metrics[func] = options.metrics[func](**pars)
 352                    else:
 353                        raise ValueError(
 354                            f"The {func} metric is not available, please choose from {list(options.metrics.keys())}"
 355                        )
 356                elif isinstance(func, Metric):
 357                    name = "function_1"
 358                    i = 1
 359                    while name in metrics:
 360                        i += 1
 361                        name = f"function_{i}"
 362                    metrics[name] = func
 363                else:
 364                    raise TypeError(
 365                        'The elements of parameters["general"]["metric_functions"] have to be either strings '
 366                        f"from {list(options.metrics.keys())} or Metric instances; got {type(func)} instead"
 367                    )
 368        elif isinstance(metric_functions, dict):
 369            metrics = metric_functions
 370        else:
 371            raise TypeError(
 372                'The value at parameters["general"]["metric_functions"] can be either list, dictionary or None;'
 373                f"got {type(metric_functions)} instead"
 374            )
 375        return metrics
 376
 377    def _construct_optimizer(self) -> Optimizer:
 378        """Generate an optimizer."""
 379        if "optimizer" in self.training_parameters:
 380            name = self.training_parameters["optimizer"]
 381            if name in options.optimizers:
 382                optimizer = options.optimizers[name]
 383            else:
 384                raise ValueError(
 385                    f"The {name} optimizer is not available, please choose from {list(options.optimizers.keys())}"
 386                )
 387        else:
 388            optimizer = None
 389        return optimizer
 390
 391    def _construct_predict_functions(self) -> Tuple[Callable, Callable]:
 392        """Construct predict functions."""
 393        predict_function = self.training_parameters.get("predict_function", None)
 394        primary_predict_function = self.training_parameters.get(
 395            "primary_predict_function", None
 396        )
 397        model_name = self.general_parameters.get("model_name", "")
 398        threshold = self.training_parameters.get("hard_threshold", 0.5)
 399        if not isinstance(predict_function, Callable):
 400            if model_name in ["c2f_tcn", "c2f_transformer", "c2f_tcn_p"]:
 401                if self.general_parameters["exclusive"]:
 402                    func = lambda x: torch.softmax(x, dim=1)
 403                else:
 404                    func = lambda x: torch.sigmoid(x)
 405
 406                def primary_predict_function(x):
 407                    if len(x) == 1:
 408                        return func(x)
 409                    else:
 410                        if len(x.shape) != 4:
 411                            x = x.reshape((4, -1, x.shape[-2], x.shape[-1]))
 412                        weights = [1, 1, 1, 1]
 413                        ensemble_prob = func(x[0]) * weights[0] / sum(weights)
 414                        for i, outp_ele in enumerate(x[1:]):
 415                            ensemble_prob = ensemble_prob + func(outp_ele) * weights[
 416                                i + 1
 417                            ] / sum(weights)
 418                        return ensemble_prob
 419
 420            else:
 421                if model_name.startswith("ms_tcn") or model_name in [
 422                    "asformer",
 423                    "transformer",
 424                    "c3d_ms",
 425                    "transformer_ms",
 426                ]:
 427                    f = lambda x: x[-1] if len(x.shape) == 4 else x
 428                elif model_name == "asrf":
 429
 430                    def f(x):
 431                        x = x[-1]
 432                        # bounds = x[:, 0, :].unsqueeze(1)
 433                        cls = x[:, 1:, :]
 434                        # device = x.device
 435                        # x = PostProcessor("refinement_with_boundary")._refinement_with_boundary(cls.detach().cpu().numpy(), bounds.detach().cpu().numpy())
 436                        # x = torch.tensor(x).to(device)
 437                        return cls
 438
 439                else:
 440                    f = lambda x: x
 441                if self.general_parameters["exclusive"]:
 442                    primary_predict_function = lambda x: torch.softmax(f(x), dim=1)
 443                else:
 444                    primary_predict_function = lambda x: torch.sigmoid(f(x))
 445            if self.general_parameters["exclusive"]:
 446                predict_function = lambda x: torch.max(x.data, dim=1)[1]
 447            else:
 448                predict_function = lambda x: (x > threshold).int()
 449        return primary_predict_function, predict_function
 450
 451    def _get_parameters_from_training(self) -> Dict:
 452        """Get the training parameters that need to be passed to the Task."""
 453        task_training_par_names = [
 454            "lr",
 455            "parallel",
 456            "device",
 457            "verbose",
 458            "log_file",
 459            "augment_train",
 460            "augment_val",
 461            "hard_threshold",
 462            "ssl_losses",
 463            "model_save_path",
 464            "model_save_epochs",
 465            "pseudolabel",
 466            "pseudolabel_start",
 467            "correction_interval",
 468            "pseudolabel_alpha_f",
 469            "alpha_growth_stop",
 470            "num_epochs",
 471            "validation_interval",
 472            "ignore_tags",
 473            "skip_metrics",
 474        ]
 475        task_training_pars = {
 476            name: self.training_parameters[name]
 477            for name in task_training_par_names
 478            if self.check(self.training_parameters, name)
 479        }
 480        if self.check(self.general_parameters, "ssl"):
 481            ssl_weights = [
 482                self.training_parameters["ssl_weights"][x]
 483                for x in self.general_parameters["ssl"]
 484            ]
 485            task_training_pars["ssl_weights"] = ssl_weights
 486        return task_training_pars
 487
 488    def _update_parameters_from_ssl(self, ssl_list: list) -> None:
 489        """Update the necessary parameters given the list of SSL constructors."""
 490        if self.task is not None:
 491            self.task.set_ssl_transformations([ssl.transformation for ssl in ssl_list])
 492            self.task.set_ssl_losses([ssl.loss for ssl in ssl_list])
 493            self.task.set_keep_target_none(
 494                [ssl.type in ["contrastive"] for ssl in ssl_list]
 495            )
 496            self.task.set_generate_ssl_input(
 497                [ssl.type == "contrastive" for ssl in ssl_list]
 498            )
 499        self.data_parameters["ssl_transformations"] = [
 500            ssl.transformation for ssl in ssl_list
 501        ]
 502        self.training_parameters["ssl_losses"] = [ssl.loss for ssl in ssl_list]
 503        self.model_parameters["ssl_types"] = [ssl.type for ssl in ssl_list]
 504        self.model_parameters["ssl_modules"] = [
 505            ssl.construct_module() for ssl in ssl_list
 506        ]
 507        self.aug_parameters["generate_ssl_input"] = [
 508            x.type == "contrastive" for x in ssl_list
 509        ]
 510        self.aug_parameters["keep_target_none"] = [
 511            x.type == "contrastive" for x in ssl_list
 512        ]
 513
 514    def _set_loss_weights(self, parameters):
 515        """Replace the `"dataset_inverse_weights"` blank in loss parameters with class weight values."""
 516        for k in list(parameters.keys()):
 517            if parameters[k] in [
 518                "dataset_inverse_weights",
 519                "dataset_proportional_weights",
 520            ]:
 521                if parameters[k] == "dataset_inverse_weights":
 522                    parameters[k] = self.class_weights
 523                else:
 524                    parameters[k] = self.proportional_class_weights
 525                print("Initializing class weights:")
 526                string = "    "
 527                if isinstance(parameters[k], Mapping):
 528                    for key, val in parameters[k].items():
 529                        string += ": ".join(
 530                            (
 531                                " " + str(key),
 532                                ", ".join((map(lambda x: str(np.round(x, 3)), val))),
 533                            )
 534                        )
 535                else:
 536                    string += ", ".join(
 537                        (map(lambda x: str(np.round(x, 3)), parameters[k]))
 538                    )
 539                print(string)
 540        return parameters
 541
 542    def _partition_dataset(
 543        self, dataset: BehaviorDataset
 544    ) -> Tuple[BehaviorDataset, BehaviorDataset, BehaviorDataset]:
 545        """Partition the dataset into train, validation and test subsamples."""
 546        use_test = self.get(self.training_parameters, "use_test", 0)
 547        split_path = self.training_parameters.get("split_path", None)
 548        partition_method = self.training_parameters.get("partition_method", "random")
 549        val_frac = self.get(self.training_parameters, "val_frac", 0)
 550        test_frac = self.get(self.training_parameters, "test_frac", 0)
 551        save_split = self.get(self.training_parameters, "save_split", True)
 552        normalize = self.get(self.training_parameters, "normalize", False)
 553        skip_normalization_keys = self.training_parameters.get(
 554            "skip_normalization_keys"
 555        )
 556        stats = self.training_parameters.get("stats")
 557        train_dataset, test_dataset, val_dataset = dataset.partition_train_test_val(
 558            use_test,
 559            split_path,
 560            partition_method,
 561            val_frac,
 562            test_frac,
 563            save_split,
 564            normalize,
 565            skip_normalization_keys,
 566            stats,
 567        )
 568        bs = int(self.training_parameters.get("batch_size", 32))
 569        train_dataloader, test_dataloader, val_dataloader = (
 570            self.make_dataloader(train_dataset, batch_size=bs, shuffle=True),
 571            self.make_dataloader(test_dataset, batch_size=bs, shuffle=False),
 572            self.make_dataloader(val_dataset, batch_size=bs, shuffle=False),
 573        )
 574        return train_dataloader, test_dataloader, val_dataloader
 575
 576    def _initialize_task(self):
 577        """Create a `dlc2action.task.universal_task.Task` instance."""
 578        dataset = self._construct_dataset()
 579        self._update_data_blanks(dataset)
 580        model = self._construct_model()
 581        self._update_model_blanks(model)
 582        ssl_list = self._construct_ssl()
 583        self._update_parameters_from_ssl(ssl_list)
 584        model.set_ssl(ssl_constructors=ssl_list)
 585        dataset.set_ssl_transformations([ssl.transformation for ssl in ssl_list])
 586        transformer = self._construct_transformer()
 587        metrics = self._construct_metrics()
 588        optimizer = self._construct_optimizer()
 589        primary_predict_function, predict_function = self._construct_predict_functions()
 590
 591        task_training_pars = self._get_parameters_from_training()
 592        train_dataloader, test_dataloader, val_dataloader = self._partition_dataset(
 593            dataset
 594        )
 595        self.class_weights = train_dataloader.dataset.class_weights()
 596        self._update_num_classes_parameter(dataset)
 597        self.proportional_class_weights = train_dataloader.dataset.class_weights(True)
 598        loss = self._construct_loss()
 599        exclusive = self.general_parameters["exclusive"]
 600
 601        task_pars = {
 602            "train_dataloader": train_dataloader,
 603            "model": model,
 604            "loss": loss,
 605            "transformer": transformer,
 606            "metrics": metrics,
 607            "val_dataloader": val_dataloader,
 608            "test_dataloader": test_dataloader,
 609            "exclusive": exclusive,
 610            "optimizer": optimizer,
 611            "predict_function": predict_function,
 612            "primary_predict_function": primary_predict_function,
 613        }
 614        task_pars.update(task_training_pars)
 615        self.task = Task(**task_pars)
 616        checkpoint_path = self.training_parameters.get("checkpoint_path", None)
 617        if checkpoint_path is not None:
 618            only_model = self.get(self.training_parameters, "only_load_model", False)
 619            load_strict = self.get(self.training_parameters, "load_strict", True)
 620            self.task.load_from_checkpoint(checkpoint_path, only_model, load_strict)
 621        if (
 622            self.general_parameters["only_load_annotated"]
 623            and self.general_parameters.get("ssl") is not None
 624        ):
 625            warnings.warn(
 626                "Note that you are using SSL modules but only loading annotated files! Set "
 627                "general/only_load_annotated to False to change that"
 628            )
 629
 630    def _update_data_blanks(
 631        self, dataset: BehaviorDataset = None, remember: bool = False
 632    ) -> None:
 633        """Update all blanks from a dataset."""
 634        if dataset is None:
 635            dataset = self.dataset()
 636        self._update_dim_parameter(dataset, remember)
 637        self._update_bodyparts_parameter(dataset, remember)
 638        self._update_num_classes_parameter(dataset, remember)
 639        self._update_len_segment_parameter(dataset, remember)
 640        self._update_boundary_parameter(dataset, remember)
 641
 642    def _update_model_blanks(self, model: Model, remember: bool = False) -> None:
 643        """Update blanks related to model parameters."""
 644        self._update_features_parameter(model, remember)
 645
 646    def _update_parameter(self, blank_name: str, value, remember: bool = False):
 647        """Update a single blank parameter."""
 648        parameters = [
 649            self.model_parameters,
 650            self.ssl_parameters,
 651            self.general_parameters,
 652            self.feature_parameters,
 653            self.data_parameters,
 654            self.training_parameters,
 655            self.metric_parameters,
 656            self.loss_parameters,
 657            self.aug_parameters,
 658        ]
 659        par_names = [
 660            "model",
 661            "ssl",
 662            "general",
 663            "feature",
 664            "data",
 665            "training",
 666            "metrics",
 667            "losses",
 668            "augmentations",
 669        ]
 670        for names in self.blanks[blank_name]:
 671            group = names[0]
 672            key = names[1]
 673            ind = par_names.index(group)
 674            if len(names) == 3:
 675                if names[2] in parameters[ind][key]:
 676                    parameters[ind][key][names[2]] = value
 677            else:
 678                if key in parameters[ind]:
 679                    parameters[ind][key] = value
 680        for name, dic in zip(par_names, parameters):
 681            for k, v in dic.items():
 682                if v == blank_name:
 683                    dic[k] = value
 684                    if [name, k] not in self.blanks[blank_name]:
 685                        self.blanks[blank_name].append([name, k])
 686                elif isinstance(v, Mapping):
 687                    for kk, vv in v.items():
 688                        if vv == blank_name:
 689                            dic[k][kk] = value
 690                            if [name, k, kk] not in self.blanks[blank_name]:
 691                                self.blanks[blank_name].append([name, k, kk])
 692
 693    def _update_features_parameter(self, model: Model, remember: bool = False) -> None:
 694        """Fill the `"model_features"` blank."""
 695        value = model.features_shape()
 696        self._update_parameter("model_features", value, remember)
 697
 698    def _update_bodyparts_parameter(
 699        self, dataset: BehaviorDataset, remember: bool = False
 700    ) -> None:
 701        """Fill the `"dataset_bodyparts"` blank."""
 702        value = dataset.bodyparts_order()
 703        self._update_parameter("dataset_bodyparts", value, remember)
 704
 705    def _update_dim_parameter(
 706        self, dataset: BehaviorDataset, remember: bool = False
 707    ) -> None:
 708        """Fill the `"dataset_features"` blank."""
 709        value = dataset.features_shape()
 710        self._update_parameter("dataset_features", value, remember)
 711
 712    def _update_boundary_parameter(
 713        self, dataset: BehaviorDataset, remember: bool = False
 714    ) -> None:
 715        """Fill the `"dataset_features"` blank."""
 716        value = dataset._boundary_class_weight()
 717        self._update_parameter("dataset_boundary_weight", value, remember)
 718
 719    def _update_num_classes_parameter(
 720        self, dataset: BehaviorDataset, remember: bool = False
 721    ) -> None:
 722        """Fill in the `"dataset_classes"` blank."""
 723        value = dataset.num_classes()
 724        self._update_parameter("dataset_classes", value, remember)
 725
 726    def _update_len_segment_parameter(
 727        self, dataset: BehaviorDataset, remember: bool = False
 728    ) -> None:
 729        """Fill in the `"dataset_len_segment"` blank."""
 730        value = dataset.len_segment()
 731        self._update_parameter("dataset_len_segment", value, remember)
 732
 733    def _print_behaviors(self):
 734        behavior_set = self.behaviors_dict()
 735        print(f"Behavior indices:")
 736        for key, value in sorted(behavior_set.items()):
 737            print(f"    {key}: {value}")
 738
 739    def update_task(self, parameters: Dict) -> None:
 740        """Update the `dlc2action.task.universal_task.Task` instance given the parameter updates.
 741
 742        Parameters
 743        ----------
 744        parameters : dict
 745            the dictionary of parameter updates
 746
 747        """
 748        pars = deepcopy(parameters)
 749        # for blank_name in self.blanks:
 750        #     for names in self.blanks[blank_name]:
 751        #         group = names[0]
 752        #         key = names[1]
 753        #         if len(names) == 3:
 754        #             if (
 755        #                 group in pars
 756        #                 and key in pars[group]
 757        #                 and names[2] in pars[group][key]
 758        #             ):
 759        #                 pars[group][key].pop(names[2])
 760        #         else:
 761        #             if group in pars and key in pars[group]:
 762        #                 pars[group].pop(key)
 763        stay = False
 764        if "ssl" in pars:
 765            for key in pars["ssl"]:
 766                if key in self.ssl_parameters:
 767                    self.ssl_parameters[key].update(pars["ssl"][key])
 768                else:
 769                    self.ssl_parameters[key] = pars["ssl"][key]
 770
 771        if "general" in pars:
 772            if stay:
 773                stay = False
 774            if (
 775                "model_name" in pars["general"]
 776                and pars["general"]["model_name"]
 777                != self.general_parameters["model_name"]
 778            ):
 779                if "model" not in pars:
 780                    raise ValueError(
 781                        "When updating a task with a new model name you need to pass the parameters for the "
 782                        "new model"
 783                    )
 784                self.model_parameters = {}
 785            self.general_parameters.update(pars["general"])
 786            data_related = [
 787                "num_classes",
 788                "exclusive",
 789                "data_type",
 790                "annotation_type",
 791            ]
 792            ssl_related = ["ssl", "exclusive", "num_classes"]
 793            loss_related = ["num_classes", "loss_function", "exclusive"]
 794            augmentation_related = ["augmentation_type"]
 795            metric_related = ["metric_functions"]
 796            related_lists = [
 797                data_related,
 798                ssl_related,
 799                loss_related,
 800                augmentation_related,
 801                metric_related,
 802            ]
 803            names = ["data", "ssl", "losses", "augmentations", "metrics"]
 804            for related_list, name in zip(related_lists, names):
 805                if (
 806                    any([x in pars["general"] for x in related_list])
 807                    and name not in pars
 808                ):
 809                    pars[name] = {}
 810
 811        if "training" in pars:
 812            if "data" not in pars or not stay:
 813                for x in [
 814                    "to_ram",
 815                    "use_test",
 816                    "partition_method",
 817                    "val_frac",
 818                    "test_frac",
 819                    "save_split",
 820                    "batch_size",
 821                    "save_split",
 822                ]:
 823                    if (
 824                        x in pars["training"]
 825                        and pars["training"][x] != self.training_parameters[x]
 826                    ):
 827                        if "data" not in pars:
 828                            pars["data"] = {}
 829                        stay = True
 830            self.training_parameters.update(pars["training"])
 831            self.task.update_parameters(self._get_parameters_from_training())
 832
 833        if "data" in pars or "features" in pars:
 834            for k, v in pars["data"].items():
 835                if k not in self.data_parameters or v != self.data_parameters[k]:
 836                    stay = True
 837            for k, v in pars["features"].items():
 838                if k not in self.feature_parameters or v != self.feature_parameters[k]:
 839                    stay = True
 840            if stay:
 841                self.data_parameters.update(pars["data"])
 842                self.feature_parameters.update(pars["features"])
 843                dataset = self._construct_dataset()
 844                (
 845                    train_dataloader,
 846                    test_dataloader,
 847                    val_dataloader,
 848                ) = self._partition_dataset(dataset)
 849                self.task.set_dataloaders(
 850                    train_dataloader, val_dataloader, test_dataloader
 851                )
 852                self.class_weights = train_dataloader.dataset.class_weights()
 853                self.proportional_class_weights = (
 854                    train_dataloader.dataset.class_weights(True)
 855                )
 856                if "losses" not in pars:
 857                    pars["losses"] = {}
 858
 859        if "model" in pars:
 860            self.model_parameters.update(pars["model"])
 861
 862        self._update_data_blanks()
 863
 864        if "augmentations" in pars:
 865            self.aug_parameters.update(pars["augmentations"])
 866            transformer = self._construct_transformer()
 867            self.task.set_transformer(transformer)
 868
 869        if "losses" in pars:
 870            for key in pars["losses"]:
 871                if key in self.loss_parameters:
 872                    self.loss_parameters[key].update(pars["losses"][key])
 873                else:
 874                    self.loss_parameters[key] = pars["losses"][key]
 875            self.loss_parameters.update(pars["losses"])
 876            loss = self._construct_loss()
 877            self.task.set_loss(loss)
 878
 879        if "metrics" in pars:
 880            for key in pars["metrics"]:
 881                if key in self.metric_parameters:
 882                    self.metric_parameters[key].update(pars["metrics"][key])
 883                else:
 884                    self.metric_parameters[key] = pars["metrics"][key]
 885            metrics = self._construct_metrics()
 886            self.task.set_metrics(metrics)
 887
 888        self.task.set_ssl_transformations(self.data_parameters["ssl_transformations"])
 889        self._set_loss_weights(
 890            pars.get("losses", {}).get(self.general_parameters["loss_function"], {})
 891        )
 892        model = self._construct_model()
 893        predict_functions = self._construct_predict_functions()
 894        self.task.set_predict_functions(*predict_functions)
 895        self._update_model_blanks(model)
 896        ssl_list = self._construct_ssl()
 897        self._update_parameters_from_ssl(ssl_list)
 898        model.set_ssl(ssl_constructors=ssl_list)
 899        self.task.set_ssl_transformations([ssl.transformation for ssl in ssl_list])
 900        self.task.set_model(model)
 901        if "training" in pars and "checkpoint_path" in pars["training"]:
 902            checkpoint_path = pars["training"]["checkpoint_path"]
 903            only_model = pars["training"].get("only_load_model", False)
 904            load_strict = pars["training"].get("load_strict", True)
 905            self.task.load_from_checkpoint(checkpoint_path, only_model, load_strict)
 906        if (
 907            self.general_parameters["only_load_annotated"]
 908            and self.general_parameters.get("ssl") is not None
 909        ):
 910            warnings.warn(
 911                "Note that you are using SSL modules but only loading annotated files! Set "
 912                "general/only_load_annotated to False to change that"
 913            )
 914        if self.task.dataset("train").annotation_class() != "none":
 915            self._print_behaviors()
 916
 917    def train(
 918        self,
 919        trial: Trial = None,
 920        optimized_metric: str = None,
 921        autostop_metric: str = None,
 922        autostop_interval: int = 10,
 923        autostop_threshold: float = 0.001,
 924        loading_bar: bool = False,
 925    ) -> Tuple:
 926        """Train the task and return a log of epoch-average loss and metric.
 927
 928        You can use the autostop parameters to finish training when the parameters are not improving. It will be
 929        stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than
 930        the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the
 931        current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared.
 932
 933        Parameters
 934        ----------
 935        trial : Trial
 936            an `optuna` trial (for hyperparameter searches)
 937        optimized_metric : str
 938            the name of the metric being optimized (for hyperparameter searches)
 939        autostop_metric : str, optional
 940            the autostop metric (can be any one of the tracked metrics of `'loss'`)
 941        autostop_interval : int, default 50
 942            the number of epochs to average the autostop metric over
 943        autostop_threshold : float, default 0.001
 944            the autostop difference threshold
 945        loading_bar : bool, default False
 946            whether to show a loading bar
 947
 948        Returns
 949        -------
 950        loss_log: list
 951            a list of float loss function values for each epoch
 952        metrics_log: dict
 953            a dictionary of metric value logs (first-level keys are 'train' and 'val', second-level keys are metric
 954            names, values are lists of function values)
 955
 956        """
 957        to_ram = self.training_parameters.get("to_ram", False)
 958        logs = self.task.train(
 959            trial,
 960            optimized_metric,
 961            to_ram,
 962            autostop_metric=autostop_metric,
 963            autostop_interval=autostop_interval,
 964            autostop_threshold=autostop_threshold,
 965            main_task_on=self.training_parameters.get("main_task_on", True),
 966            ssl_on=self.training_parameters.get("ssl_on", True),
 967            temporal_subsampling_size=self.training_parameters.get(
 968                "temporal_subsampling_size"
 969            ),
 970            loading_bar=loading_bar,
 971        )
 972        return logs
 973
 974    def save_model(self, save_path: str) -> None:
 975        """Save the model of the `dlc2action.task.universal_task.Task` instance.
 976
 977        Parameters
 978        ----------
 979        save_path : str
 980            the path to the saved file
 981
 982        """
 983        self.task.save_model(save_path)
 984
 985    def evaluate(
 986        self,
 987        data: Union[DataLoader, BehaviorDataset, str] = None,
 988        augment_n: int = 0,
 989        verbose: bool = True,
 990    ) -> Tuple:
 991        """Evaluate the Task model.
 992
 993        Parameters
 994        ----------
 995        data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional
 996            the data to evaluate on (if not provided, evaluate on the Task validation dataset)
 997        augment_n : int, default 0
 998            the number of augmentations to average results over
 999        verbose : bool, default True
1000            if True, the process is reported to standard output
1001
1002        Returns
1003        -------
1004        loss : float
1005            the average value of the loss function
1006        ssl_loss : float
1007            the average value of the SSL loss function
1008        metric : dict
1009            a dictionary of average values of metric functions
1010
1011        """
1012        res = self.task.evaluate(
1013            data,
1014            augment_n,
1015            int(self.training_parameters.get("batch_size", 32)),
1016            verbose,
1017        )
1018        return res
1019
1020    def evaluate_prediction(
1021        self,
1022        prediction: torch.Tensor,
1023        data: Union[DataLoader, BehaviorDataset, str] = None,
1024        indices: Union[List[int], None] = None,
1025    ) -> Tuple:
1026        """Compute metrics for a prediction.
1027
1028        Parameters
1029        ----------
1030        prediction : torch.Tensor
1031            the prediction
1032        data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional
1033            the data the prediction was made for (if not provided, take the validation dataset)
1034
1035        Returns
1036        -------
1037        loss : float
1038            the average value of the loss function
1039        metric : dict
1040            a dictionary of average values of metric functions
1041        """
1042        return self.task.evaluate_prediction(
1043            prediction, data, int(self.training_parameters.get("batch_size", 32)), indices
1044        )
1045
1046    def predict(
1047        self,
1048        data: Union[DataLoader, BehaviorDataset, str],
1049        raw_output: bool = False,
1050        apply_primary_function: bool = True,
1051        augment_n: int = 0,
1052        embedding: bool = False,
1053    ) -> torch.Tensor:
1054        """Make a prediction with the Task model.
1055
1056        Parameters
1057        ----------
1058        data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional
1059            the data to evaluate on (if not provided, evaluate on the Task validation dataset)
1060        raw_output : bool, default False
1061            if `True`, the raw predicted probabilities are returned
1062        apply_primary_function : bool, default True
1063            if `True`, the primary predict function is applied (to map the model output into a shape corresponding to
1064            the input)
1065        augment_n : int, default 0
1066            the number of augmentations to average results over
1067        embedding : bool, default False
1068            if `True`, the embedding is returned instead of the prediction
1069
1070        Returns
1071        -------
1072        prediction : torch.Tensor
1073            a prediction for the input data
1074
1075        """
1076        to_ram = self.training_parameters.get("to_ram", False)
1077        return self.task.predict(
1078            data,
1079            raw_output,
1080            apply_primary_function,
1081            augment_n,
1082            int(self.training_parameters.get("batch_size", 32)),
1083            to_ram,
1084            embedding=embedding,
1085        )
1086
1087    def dataset(self, mode: str = "train") -> BehaviorDataset:
1088        """Get a dataset.
1089
1090        Parameters
1091        ----------
1092        mode : {'train', 'val', 'test'}
1093            the dataset to get
1094
1095        Returns
1096        -------
1097        dataset : dlc2action.data.dataset.BehaviorDataset
1098            the dataset
1099
1100        """
1101        return self.task.dataset(mode)
1102
1103    def generate_full_length_prediction(
1104        self,
1105        dataset: Union[BehaviorDataset, str] = None,
1106        augment_n: int = 10,
1107    ) -> Dict:
1108        """Compile a prediction for the original input sequences.
1109
1110        Parameters
1111        ----------
1112        dataset : dlc2action.data.dataset.BehaviorDataset | str, optional
1113            the dataset to generate a prediction for (if `None`, generate for the `dlc2action.task.universal_task.Task`
1114            instance validation dataset)
1115        augment_n : int, default 10
1116            the number of augmentations to average results over
1117
1118        Returns
1119        -------
1120        prediction : dict
1121            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
1122            are prediction tensors
1123
1124        """
1125        return self.task.generate_full_length_prediction(
1126            dataset, int(self.training_parameters.get("batch_size", 32)), augment_n
1127        )
1128
1129    def generate_submission(
1130        self,
1131        frame_number_map_file: str,
1132        dataset: Union[BehaviorDataset, str] = None,
1133        augment_n: int = 10,
1134    ) -> Dict:
1135        """Generate a MABe-22 style submission dictionary.
1136
1137        Parameters
1138        ----------
1139        frame_number_map_file : str
1140            path to the frame number map file
1141        dataset : BehaviorDataset, optional
1142            the dataset to generate a prediction for (if `None`, generate for the validation dataset)
1143        augment_n : int, default 10
1144            the number of augmentations to average results over
1145
1146        Returns
1147        -------
1148        submission : dict
1149            a dictionary with frame number mapping and embeddings
1150
1151        """
1152        return self.task.generate_submission(
1153            frame_number_map_file,
1154            dataset,
1155            int(self.training_parameters.get("batch_size", 32)),
1156            augment_n,
1157        )
1158
1159    def behaviors_dict(self):
1160        """Get a behavior dictionary.
1161
1162        Keys are label indices and values are label names.
1163
1164        Returns
1165        -------
1166        behaviors_dict : dict
1167            behavior dictionary
1168
1169        """
1170        return self.task.behaviors_dict()
1171
1172    def count_classes(self, bouts: bool = False) -> Dict:
1173        """Get a dictionary of class counts in different modes.
1174
1175        Parameters
1176        ----------
1177        bouts : bool, default False
1178            if `True`, instead of frame counts segment counts are returned
1179
1180        Returns
1181        -------
1182        class_counts : dict
1183            a dictionary where first-level keys are "train", "val" and "test", second-level keys are
1184            class names and values are class counts (in frames)
1185
1186        """
1187        return self.task.count_classes(bouts)
1188
1189    def _visualize_results_label(
1190        self,
1191        label: str,
1192        save_path: str = None,
1193        add_legend: bool = True,
1194        ground_truth: bool = True,
1195        hide_axes: bool = False,
1196        width: int = 10,
1197        whole_video: bool = False,
1198        transparent: bool = False,
1199        dataset: BehaviorDataset = None,
1200        smooth_interval: int = 0,
1201        title: str = None,
1202    ):
1203        return self.task._visualize_results_single(
1204            label,
1205            save_path,
1206            add_legend,
1207            ground_truth,
1208            hide_axes,
1209            width,
1210            whole_video,
1211            transparent,
1212            dataset,
1213            smooth_interval=smooth_interval,
1214            title=title,
1215        )
1216
1217    def visualize_results(
1218        self,
1219        save_path: str = None,
1220        add_legend: bool = True,
1221        ground_truth: bool = True,
1222        colormap: str = "viridis",
1223        hide_axes: bool = False,
1224        min_classes: int = 1,
1225        width: float = 10,
1226        whole_video: bool = False,
1227        transparent: bool = False,
1228        dataset: Union[BehaviorDataset, DataLoader, str, None] = None,
1229        drop_classes: Set = None,
1230        search_classes: Set = None,
1231        smooth_interval_prediction: int = None,
1232        font_size: float = None,
1233        num_plots:int = 1,
1234        window_size:int=400,
1235    ) -> None:
1236        """Visualize random predictions.
1237
1238        Parameters
1239        ----------
1240        save_path : str, optional
1241            the path where the plot will be saved
1242        add_legend : bool, default True
1243            if True, legend will be added to the plot
1244        ground_truth : bool, default True
1245            if True, ground truth will be added to the plot
1246        colormap : str, default 'Accent'
1247            the `matplotlib` colormap to use
1248        hide_axes : bool, default True
1249            if `True`, the axes will be hidden on the plot
1250        min_classes : int, default 1
1251            the minimum number of classes in a displayed interval
1252        width : float, default 10
1253            the width of the plot
1254        whole_video : bool, default False
1255            if `True`, whole videos are plotted instead of segments
1256        transparent : bool, default False
1257            if `True`, the background on the plot is transparent
1258        dataset : BehaviorDataset | DataLoader | str | None, optional
1259            the dataset to make the prediction for (if not provided, the validation dataset is used)
1260        drop_classes : set, optional
1261            a set of class names to not be displayed
1262        search_classes : set, optional
1263            if given, only intervals where at least one of the classes is in ground truth will be shown
1264        smooth_interval_prediction : int, optional
1265            if given, the prediction will be smoothed over the given number of frames
1266
1267        """
1268        return self.task.visualize_results(
1269            save_path,
1270            add_legend,
1271            ground_truth,
1272            colormap,
1273            hide_axes,
1274            min_classes,
1275            width,
1276            whole_video,
1277            transparent,
1278            dataset,
1279            drop_classes,
1280            search_classes,
1281            font_size=font_size,
1282            smooth_interval_prediction=smooth_interval_prediction,
1283            num_plots=num_plots,
1284            window_size=window_size
1285        )
1286
1287    def generate_uncertainty_score(
1288        self,
1289        classes: List,
1290        augment_n: int = 0,
1291        method: str = "least_confidence",
1292        predicted: torch.Tensor = None,
1293        behaviors_dict: Dict = None,
1294    ) -> Dict:
1295        """Generate frame-wise scores for active learning.
1296
1297        Parameters
1298        ----------
1299        classes : list
1300            a list of class names or indices; their confidence scores will be computed separately and stacked
1301        augment_n : int, default 0
1302            the number of augmentations to average over
1303        method : {"least_confidence", "entropy"}
1304            the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if
1305            `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`)
1306        predicted : torch.Tensor, optional
1307            if given, the predictions will be used instead of the model's predictions
1308        behaviors_dict : dict, optional
1309            if given, the behaviors dictionary will be used instead of the model's behaviors dictionary
1310
1311        Returns
1312        -------
1313        score_dicts : dict
1314            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
1315            are score tensors
1316
1317        """
1318        return self.task.generate_uncertainty_score(
1319            classes,
1320            augment_n,
1321            int(self.training_parameters.get("batch_size", 32)),
1322            method,
1323            predicted,
1324            behaviors_dict,
1325        )
1326
1327    def generate_bald_score(
1328        self,
1329        classes: List,
1330        augment_n: int = 0,
1331        num_models: int = 10,
1332        kernel_size: int = 11,
1333    ) -> Dict:
1334        """Generate frame-wise Bayesian Active Learning by Disagreement scores for active learning.
1335
1336        Parameters
1337        ----------
1338        classes : list
1339            a list of class names or indices; their confidence scores will be computed separately and stacked
1340        augment_n : int, default 0
1341            the number of augmentations to average over
1342        num_models : int, default 10
1343            the number of dropout masks to apply
1344        kernel_size : int, default 11
1345            the size of the smoothing gaussian kernel
1346
1347        Returns
1348        -------
1349        score_dicts : dict
1350            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
1351            are score tensors
1352
1353        """
1354        return self.task.generate_bald_score(
1355            classes,
1356            augment_n,
1357            int(self.training_parameters.get("batch_size", 32)),
1358            num_models,
1359            kernel_size,
1360        )
1361
1362    def exists(self, mode) -> bool:
1363        """Check whether the task has a train/test/validation subset.
1364
1365        Parameters
1366        ----------
1367        mode : {"train", "val", "test"}
1368            the name of the subset to check for
1369        Returns
1370        -------
1371        exists : bool
1372            `True` if the subset exists
1373
1374        """
1375        dl = self.task.dataloader(mode)
1376        if dl is None:
1377            return False
1378        else:
1379            return True
1380
1381    def get_normalization_stats(self) -> Dict:
1382        """Get the pre-computed normalization stats.
1383
1384        Returns
1385        -------
1386        normalization_stats : dict
1387            a dictionary of means and stds
1388
1389        """
1390        return self.task.get_normalization_stats()

A class that manages the interactions between config dictionaries and a Task.

TaskDispatcher(parameters: Dict)
34    def __init__(self, parameters: Dict) -> None:
35        """Initialize the `TaskDispatcher`.
36
37        Parameters
38        ----------
39        parameters : dict
40            a dictionary of task parameters
41
42        """
43        pars = deepcopy(parameters)
44        self.class_weights = None
45        self.general_parameters = pars.get("general", {})
46        self.data_parameters = pars.get("data", {})
47        self.model_parameters = pars.get("model", {})
48        self.training_parameters = pars.get("training", {})
49        self.loss_parameters = pars.get("losses", {})
50        self.metric_parameters = pars.get("metrics", {})
51        self.ssl_parameters = pars.get("ssl", {})
52        self.aug_parameters = pars.get("augmentations", {})
53        self.feature_parameters = pars.get("features", {})
54        self.blanks = {blank: [] for blank in options.blanks}
55
56        self.task = None
57        self._initialize_task()
58        self._print_behaviors()

Initialize the TaskDispatcher.

Parameters

parameters : dict a dictionary of task parameters

class_weights
general_parameters
data_parameters
model_parameters
training_parameters
loss_parameters
metric_parameters
ssl_parameters
aug_parameters
feature_parameters
blanks
task
@staticmethod
def complete_function_parameters(parameters, function, general_dicts: List) -> Dict:
60    @staticmethod
61    def complete_function_parameters(parameters, function, general_dicts: List) -> Dict:
62        """Complete a parameter dictionary with values from other dictionaries if required by a function.
63
64        Parameters
65        ----------
66        parameters : dict
67            the function parameters dictionary
68        function : callable
69            the function to be inspected
70        general_dicts : list
71            a list of dictionaries where the missing values will be pulled from
72
73        Returns
74        -------
75        parameters : dict
76            the updated parameter dictionary
77
78        """
79        parameter_names = inspect.getfullargspec(function).args
80        for param in parameter_names:
81            for dic in general_dicts:
82                if param not in parameters and param in dic:
83                    parameters[param] = dic[param]
84        return parameters

Complete a parameter dictionary with values from other dictionaries if required by a function.

Parameters

parameters : dict the function parameters dictionary function : callable the function to be inspected general_dicts : list a list of dictionaries where the missing values will be pulled from

Returns

parameters : dict the updated parameter dictionary

@staticmethod
def complete_dataset_parameters( parameters: dict, general_dict: dict, data_type: str, annotation_type: str) -> Dict:
 86    @staticmethod
 87    def complete_dataset_parameters(
 88        parameters: dict,
 89        general_dict: dict,
 90        data_type: str,
 91        annotation_type: str,
 92    ) -> Dict:
 93        """Complete a parameter dictionary with values from other dictionaries if required by a dataset.
 94
 95        Parameters
 96        ----------
 97        parameters : dict
 98            the function parameters dictionary
 99        general_dict : dict
100            the dictionary where the missing values will be pulled from
101        data_type : str
102            the input type of the dataset
103        annotation_type : str
104            the annotation type of the dataset
105
106        Returns
107        -------
108        parameters : dict
109            the updated parameter dictionary
110
111        """
112        params = deepcopy(parameters)
113        parameter_names = BehaviorDataset.get_parameters(data_type, annotation_type)
114        for param in parameter_names:
115            if param not in params and param in general_dict:
116                params[param] = general_dict[param]
117        return params

Complete a parameter dictionary with values from other dictionaries if required by a dataset.

Parameters

parameters : dict the function parameters dictionary general_dict : dict the dictionary where the missing values will be pulled from data_type : str the input type of the dataset annotation_type : str the annotation type of the dataset

Returns

parameters : dict the updated parameter dictionary

@staticmethod
def check(parameters: Dict, name: str) -> bool:
119    @staticmethod
120    def check(parameters: Dict, name: str) -> bool:
121        """Check whether there is a non-`None` value under the name key in the parameters dictionary.
122
123        Parameters
124        ----------
125        parameters : dict
126            the dictionary to check
127        name : str
128            the key to check
129
130        Returns
131        -------
132        result : bool
133            True if a non-`None` value exists
134
135        """
136        if name in parameters and parameters[name] is not None:
137            return True
138        else:
139            return False

Check whether there is a non-None value under the name key in the parameters dictionary.

Parameters

parameters : dict the dictionary to check name : str the key to check

Returns

result : bool True if a non-None value exists

@staticmethod
def get(parameters: Dict, name: str, default):
141    @staticmethod
142    def get(parameters: Dict, name: str, default):
143        """Get the value under the name key or the default if it is `None` or does not exist.
144
145        Parameters
146        ----------
147        parameters : dict
148            the dictionary to check
149        name : str
150            the key to check
151        default
152            the default value to return
153
154        Returns
155        -------
156        value
157            the resulting value
158
159        """
160        if TaskDispatcher.check(parameters, name):
161            return parameters[name]
162        else:
163            return default

Get the value under the name key or the default if it is None or does not exist.

Parameters

parameters : dict the dictionary to check name : str the key to check default the default value to return

Returns

value the resulting value

@staticmethod
def make_dataloader( dataset: dlc2action.data.dataset.BehaviorDataset, batch_size: int = 32, shuffle: bool = False) -> torch.utils.data.dataloader.DataLoader:
165    @staticmethod
166    def make_dataloader(
167        dataset: BehaviorDataset, batch_size: int = 32, shuffle: bool = False
168    ) -> DataLoader:
169        """Make a torch dataloader from a dataset.
170
171        Parameters
172        ----------
173        dataset : dlc2action.data.dataset.BehaviorDataset
174            the dataset
175        batch_size : int
176            the batch size
177        shuffle : bool
178            whether to shuffle the dataset
179
180        Returns
181        -------
182        dataloader : DataLoader
183            the dataloader (or `None` if the length of the dataset is 0)
184
185        """
186        if dataset is None or len(dataset) == 0:
187            return None
188        else:
189            return DataLoader(dataset, batch_size=int(batch_size), shuffle=shuffle)

Make a torch dataloader from a dataset.

Parameters

dataset : dlc2action.data.dataset.BehaviorDataset the dataset batch_size : int the batch size shuffle : bool whether to shuffle the dataset

Returns

dataloader : DataLoader the dataloader (or None if the length of the dataset is 0)

def update_task(self, parameters: Dict) -> None:
739    def update_task(self, parameters: Dict) -> None:
740        """Update the `dlc2action.task.universal_task.Task` instance given the parameter updates.
741
742        Parameters
743        ----------
744        parameters : dict
745            the dictionary of parameter updates
746
747        """
748        pars = deepcopy(parameters)
749        # for blank_name in self.blanks:
750        #     for names in self.blanks[blank_name]:
751        #         group = names[0]
752        #         key = names[1]
753        #         if len(names) == 3:
754        #             if (
755        #                 group in pars
756        #                 and key in pars[group]
757        #                 and names[2] in pars[group][key]
758        #             ):
759        #                 pars[group][key].pop(names[2])
760        #         else:
761        #             if group in pars and key in pars[group]:
762        #                 pars[group].pop(key)
763        stay = False
764        if "ssl" in pars:
765            for key in pars["ssl"]:
766                if key in self.ssl_parameters:
767                    self.ssl_parameters[key].update(pars["ssl"][key])
768                else:
769                    self.ssl_parameters[key] = pars["ssl"][key]
770
771        if "general" in pars:
772            if stay:
773                stay = False
774            if (
775                "model_name" in pars["general"]
776                and pars["general"]["model_name"]
777                != self.general_parameters["model_name"]
778            ):
779                if "model" not in pars:
780                    raise ValueError(
781                        "When updating a task with a new model name you need to pass the parameters for the "
782                        "new model"
783                    )
784                self.model_parameters = {}
785            self.general_parameters.update(pars["general"])
786            data_related = [
787                "num_classes",
788                "exclusive",
789                "data_type",
790                "annotation_type",
791            ]
792            ssl_related = ["ssl", "exclusive", "num_classes"]
793            loss_related = ["num_classes", "loss_function", "exclusive"]
794            augmentation_related = ["augmentation_type"]
795            metric_related = ["metric_functions"]
796            related_lists = [
797                data_related,
798                ssl_related,
799                loss_related,
800                augmentation_related,
801                metric_related,
802            ]
803            names = ["data", "ssl", "losses", "augmentations", "metrics"]
804            for related_list, name in zip(related_lists, names):
805                if (
806                    any([x in pars["general"] for x in related_list])
807                    and name not in pars
808                ):
809                    pars[name] = {}
810
811        if "training" in pars:
812            if "data" not in pars or not stay:
813                for x in [
814                    "to_ram",
815                    "use_test",
816                    "partition_method",
817                    "val_frac",
818                    "test_frac",
819                    "save_split",
820                    "batch_size",
821                    "save_split",
822                ]:
823                    if (
824                        x in pars["training"]
825                        and pars["training"][x] != self.training_parameters[x]
826                    ):
827                        if "data" not in pars:
828                            pars["data"] = {}
829                        stay = True
830            self.training_parameters.update(pars["training"])
831            self.task.update_parameters(self._get_parameters_from_training())
832
833        if "data" in pars or "features" in pars:
834            for k, v in pars["data"].items():
835                if k not in self.data_parameters or v != self.data_parameters[k]:
836                    stay = True
837            for k, v in pars["features"].items():
838                if k not in self.feature_parameters or v != self.feature_parameters[k]:
839                    stay = True
840            if stay:
841                self.data_parameters.update(pars["data"])
842                self.feature_parameters.update(pars["features"])
843                dataset = self._construct_dataset()
844                (
845                    train_dataloader,
846                    test_dataloader,
847                    val_dataloader,
848                ) = self._partition_dataset(dataset)
849                self.task.set_dataloaders(
850                    train_dataloader, val_dataloader, test_dataloader
851                )
852                self.class_weights = train_dataloader.dataset.class_weights()
853                self.proportional_class_weights = (
854                    train_dataloader.dataset.class_weights(True)
855                )
856                if "losses" not in pars:
857                    pars["losses"] = {}
858
859        if "model" in pars:
860            self.model_parameters.update(pars["model"])
861
862        self._update_data_blanks()
863
864        if "augmentations" in pars:
865            self.aug_parameters.update(pars["augmentations"])
866            transformer = self._construct_transformer()
867            self.task.set_transformer(transformer)
868
869        if "losses" in pars:
870            for key in pars["losses"]:
871                if key in self.loss_parameters:
872                    self.loss_parameters[key].update(pars["losses"][key])
873                else:
874                    self.loss_parameters[key] = pars["losses"][key]
875            self.loss_parameters.update(pars["losses"])
876            loss = self._construct_loss()
877            self.task.set_loss(loss)
878
879        if "metrics" in pars:
880            for key in pars["metrics"]:
881                if key in self.metric_parameters:
882                    self.metric_parameters[key].update(pars["metrics"][key])
883                else:
884                    self.metric_parameters[key] = pars["metrics"][key]
885            metrics = self._construct_metrics()
886            self.task.set_metrics(metrics)
887
888        self.task.set_ssl_transformations(self.data_parameters["ssl_transformations"])
889        self._set_loss_weights(
890            pars.get("losses", {}).get(self.general_parameters["loss_function"], {})
891        )
892        model = self._construct_model()
893        predict_functions = self._construct_predict_functions()
894        self.task.set_predict_functions(*predict_functions)
895        self._update_model_blanks(model)
896        ssl_list = self._construct_ssl()
897        self._update_parameters_from_ssl(ssl_list)
898        model.set_ssl(ssl_constructors=ssl_list)
899        self.task.set_ssl_transformations([ssl.transformation for ssl in ssl_list])
900        self.task.set_model(model)
901        if "training" in pars and "checkpoint_path" in pars["training"]:
902            checkpoint_path = pars["training"]["checkpoint_path"]
903            only_model = pars["training"].get("only_load_model", False)
904            load_strict = pars["training"].get("load_strict", True)
905            self.task.load_from_checkpoint(checkpoint_path, only_model, load_strict)
906        if (
907            self.general_parameters["only_load_annotated"]
908            and self.general_parameters.get("ssl") is not None
909        ):
910            warnings.warn(
911                "Note that you are using SSL modules but only loading annotated files! Set "
912                "general/only_load_annotated to False to change that"
913            )
914        if self.task.dataset("train").annotation_class() != "none":
915            self._print_behaviors()

Update the dlc2action.task.universal_task.Task instance given the parameter updates.

Parameters

parameters : dict the dictionary of parameter updates

def train( self, trial: optuna.trial._trial.Trial = None, optimized_metric: str = None, autostop_metric: str = None, autostop_interval: int = 10, autostop_threshold: float = 0.001, loading_bar: bool = False) -> Tuple:
917    def train(
918        self,
919        trial: Trial = None,
920        optimized_metric: str = None,
921        autostop_metric: str = None,
922        autostop_interval: int = 10,
923        autostop_threshold: float = 0.001,
924        loading_bar: bool = False,
925    ) -> Tuple:
926        """Train the task and return a log of epoch-average loss and metric.
927
928        You can use the autostop parameters to finish training when the parameters are not improving. It will be
929        stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than
930        the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the
931        current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared.
932
933        Parameters
934        ----------
935        trial : Trial
936            an `optuna` trial (for hyperparameter searches)
937        optimized_metric : str
938            the name of the metric being optimized (for hyperparameter searches)
939        autostop_metric : str, optional
940            the autostop metric (can be any one of the tracked metrics of `'loss'`)
941        autostop_interval : int, default 50
942            the number of epochs to average the autostop metric over
943        autostop_threshold : float, default 0.001
944            the autostop difference threshold
945        loading_bar : bool, default False
946            whether to show a loading bar
947
948        Returns
949        -------
950        loss_log: list
951            a list of float loss function values for each epoch
952        metrics_log: dict
953            a dictionary of metric value logs (first-level keys are 'train' and 'val', second-level keys are metric
954            names, values are lists of function values)
955
956        """
957        to_ram = self.training_parameters.get("to_ram", False)
958        logs = self.task.train(
959            trial,
960            optimized_metric,
961            to_ram,
962            autostop_metric=autostop_metric,
963            autostop_interval=autostop_interval,
964            autostop_threshold=autostop_threshold,
965            main_task_on=self.training_parameters.get("main_task_on", True),
966            ssl_on=self.training_parameters.get("ssl_on", True),
967            temporal_subsampling_size=self.training_parameters.get(
968                "temporal_subsampling_size"
969            ),
970            loading_bar=loading_bar,
971        )
972        return logs

Train the task and return a log of epoch-average loss and metric.

You can use the autostop parameters to finish training when the parameters are not improving. It will be stopped if the average value of autostop_metric over the last autostop_interval epochs is smaller than the average over the previous autostop_interval epochs + autostop_threshold. For example, if the current epoch is 120 and autostop_interval is 50, the averages over epochs 70-120 and 20-70 will be compared.

Parameters

trial : Trial an optuna trial (for hyperparameter searches) optimized_metric : str the name of the metric being optimized (for hyperparameter searches) autostop_metric : str, optional the autostop metric (can be any one of the tracked metrics of 'loss') autostop_interval : int, default 50 the number of epochs to average the autostop metric over autostop_threshold : float, default 0.001 the autostop difference threshold loading_bar : bool, default False whether to show a loading bar

Returns

loss_log: list a list of float loss function values for each epoch metrics_log: dict a dictionary of metric value logs (first-level keys are 'train' and 'val', second-level keys are metric names, values are lists of function values)

def save_model(self, save_path: str) -> None:
974    def save_model(self, save_path: str) -> None:
975        """Save the model of the `dlc2action.task.universal_task.Task` instance.
976
977        Parameters
978        ----------
979        save_path : str
980            the path to the saved file
981
982        """
983        self.task.save_model(save_path)

Save the model of the dlc2action.task.universal_task.Task instance.

Parameters

save_path : str the path to the saved file

def evaluate( self, data: Union[torch.utils.data.dataloader.DataLoader, dlc2action.data.dataset.BehaviorDataset, str] = None, augment_n: int = 0, verbose: bool = True) -> Tuple:
 985    def evaluate(
 986        self,
 987        data: Union[DataLoader, BehaviorDataset, str] = None,
 988        augment_n: int = 0,
 989        verbose: bool = True,
 990    ) -> Tuple:
 991        """Evaluate the Task model.
 992
 993        Parameters
 994        ----------
 995        data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional
 996            the data to evaluate on (if not provided, evaluate on the Task validation dataset)
 997        augment_n : int, default 0
 998            the number of augmentations to average results over
 999        verbose : bool, default True
1000            if True, the process is reported to standard output
1001
1002        Returns
1003        -------
1004        loss : float
1005            the average value of the loss function
1006        ssl_loss : float
1007            the average value of the SSL loss function
1008        metric : dict
1009            a dictionary of average values of metric functions
1010
1011        """
1012        res = self.task.evaluate(
1013            data,
1014            augment_n,
1015            int(self.training_parameters.get("batch_size", 32)),
1016            verbose,
1017        )
1018        return res

Evaluate the Task model.

Parameters

data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional the data to evaluate on (if not provided, evaluate on the Task validation dataset) augment_n : int, default 0 the number of augmentations to average results over verbose : bool, default True if True, the process is reported to standard output

Returns

loss : float the average value of the loss function ssl_loss : float the average value of the SSL loss function metric : dict a dictionary of average values of metric functions

def evaluate_prediction( self, prediction: torch.Tensor, data: Union[torch.utils.data.dataloader.DataLoader, dlc2action.data.dataset.BehaviorDataset, str] = None, indices: Optional[List[int]] = None) -> Tuple:
1020    def evaluate_prediction(
1021        self,
1022        prediction: torch.Tensor,
1023        data: Union[DataLoader, BehaviorDataset, str] = None,
1024        indices: Union[List[int], None] = None,
1025    ) -> Tuple:
1026        """Compute metrics for a prediction.
1027
1028        Parameters
1029        ----------
1030        prediction : torch.Tensor
1031            the prediction
1032        data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional
1033            the data the prediction was made for (if not provided, take the validation dataset)
1034
1035        Returns
1036        -------
1037        loss : float
1038            the average value of the loss function
1039        metric : dict
1040            a dictionary of average values of metric functions
1041        """
1042        return self.task.evaluate_prediction(
1043            prediction, data, int(self.training_parameters.get("batch_size", 32)), indices
1044        )

Compute metrics for a prediction.

Parameters

prediction : torch.Tensor the prediction data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional the data the prediction was made for (if not provided, take the validation dataset)

Returns

loss : float the average value of the loss function metric : dict a dictionary of average values of metric functions

def predict( self, data: Union[torch.utils.data.dataloader.DataLoader, dlc2action.data.dataset.BehaviorDataset, str], raw_output: bool = False, apply_primary_function: bool = True, augment_n: int = 0, embedding: bool = False) -> torch.Tensor:
1046    def predict(
1047        self,
1048        data: Union[DataLoader, BehaviorDataset, str],
1049        raw_output: bool = False,
1050        apply_primary_function: bool = True,
1051        augment_n: int = 0,
1052        embedding: bool = False,
1053    ) -> torch.Tensor:
1054        """Make a prediction with the Task model.
1055
1056        Parameters
1057        ----------
1058        data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional
1059            the data to evaluate on (if not provided, evaluate on the Task validation dataset)
1060        raw_output : bool, default False
1061            if `True`, the raw predicted probabilities are returned
1062        apply_primary_function : bool, default True
1063            if `True`, the primary predict function is applied (to map the model output into a shape corresponding to
1064            the input)
1065        augment_n : int, default 0
1066            the number of augmentations to average results over
1067        embedding : bool, default False
1068            if `True`, the embedding is returned instead of the prediction
1069
1070        Returns
1071        -------
1072        prediction : torch.Tensor
1073            a prediction for the input data
1074
1075        """
1076        to_ram = self.training_parameters.get("to_ram", False)
1077        return self.task.predict(
1078            data,
1079            raw_output,
1080            apply_primary_function,
1081            augment_n,
1082            int(self.training_parameters.get("batch_size", 32)),
1083            to_ram,
1084            embedding=embedding,
1085        )

Make a prediction with the Task model.

Parameters

data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional the data to evaluate on (if not provided, evaluate on the Task validation dataset) raw_output : bool, default False if True, the raw predicted probabilities are returned apply_primary_function : bool, default True if True, the primary predict function is applied (to map the model output into a shape corresponding to the input) augment_n : int, default 0 the number of augmentations to average results over embedding : bool, default False if True, the embedding is returned instead of the prediction

Returns

prediction : torch.Tensor a prediction for the input data

def dataset(self, mode: str = 'train') -> dlc2action.data.dataset.BehaviorDataset:
1087    def dataset(self, mode: str = "train") -> BehaviorDataset:
1088        """Get a dataset.
1089
1090        Parameters
1091        ----------
1092        mode : {'train', 'val', 'test'}
1093            the dataset to get
1094
1095        Returns
1096        -------
1097        dataset : dlc2action.data.dataset.BehaviorDataset
1098            the dataset
1099
1100        """
1101        return self.task.dataset(mode)

Get a dataset.

Parameters

mode : {'train', 'val', 'test'} the dataset to get

Returns

dataset : dlc2action.data.dataset.BehaviorDataset the dataset

def generate_full_length_prediction( self, dataset: Union[dlc2action.data.dataset.BehaviorDataset, str] = None, augment_n: int = 10) -> Dict:
1103    def generate_full_length_prediction(
1104        self,
1105        dataset: Union[BehaviorDataset, str] = None,
1106        augment_n: int = 10,
1107    ) -> Dict:
1108        """Compile a prediction for the original input sequences.
1109
1110        Parameters
1111        ----------
1112        dataset : dlc2action.data.dataset.BehaviorDataset | str, optional
1113            the dataset to generate a prediction for (if `None`, generate for the `dlc2action.task.universal_task.Task`
1114            instance validation dataset)
1115        augment_n : int, default 10
1116            the number of augmentations to average results over
1117
1118        Returns
1119        -------
1120        prediction : dict
1121            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
1122            are prediction tensors
1123
1124        """
1125        return self.task.generate_full_length_prediction(
1126            dataset, int(self.training_parameters.get("batch_size", 32)), augment_n
1127        )

Compile a prediction for the original input sequences.

Parameters

dataset : dlc2action.data.dataset.BehaviorDataset | str, optional the dataset to generate a prediction for (if None, generate for the dlc2action.task.universal_task.Task instance validation dataset) augment_n : int, default 10 the number of augmentations to average results over

Returns

prediction : dict a nested dictionary where first level keys are video ids, second level keys are clip ids and values are prediction tensors

def generate_submission( self, frame_number_map_file: str, dataset: Union[dlc2action.data.dataset.BehaviorDataset, str] = None, augment_n: int = 10) -> Dict:
1129    def generate_submission(
1130        self,
1131        frame_number_map_file: str,
1132        dataset: Union[BehaviorDataset, str] = None,
1133        augment_n: int = 10,
1134    ) -> Dict:
1135        """Generate a MABe-22 style submission dictionary.
1136
1137        Parameters
1138        ----------
1139        frame_number_map_file : str
1140            path to the frame number map file
1141        dataset : BehaviorDataset, optional
1142            the dataset to generate a prediction for (if `None`, generate for the validation dataset)
1143        augment_n : int, default 10
1144            the number of augmentations to average results over
1145
1146        Returns
1147        -------
1148        submission : dict
1149            a dictionary with frame number mapping and embeddings
1150
1151        """
1152        return self.task.generate_submission(
1153            frame_number_map_file,
1154            dataset,
1155            int(self.training_parameters.get("batch_size", 32)),
1156            augment_n,
1157        )

Generate a MABe-22 style submission dictionary.

Parameters

frame_number_map_file : str path to the frame number map file dataset : BehaviorDataset, optional the dataset to generate a prediction for (if None, generate for the validation dataset) augment_n : int, default 10 the number of augmentations to average results over

Returns

submission : dict a dictionary with frame number mapping and embeddings

def behaviors_dict(self):
1159    def behaviors_dict(self):
1160        """Get a behavior dictionary.
1161
1162        Keys are label indices and values are label names.
1163
1164        Returns
1165        -------
1166        behaviors_dict : dict
1167            behavior dictionary
1168
1169        """
1170        return self.task.behaviors_dict()

Get a behavior dictionary.

Keys are label indices and values are label names.

Returns

behaviors_dict : dict behavior dictionary

def count_classes(self, bouts: bool = False) -> Dict:
1172    def count_classes(self, bouts: bool = False) -> Dict:
1173        """Get a dictionary of class counts in different modes.
1174
1175        Parameters
1176        ----------
1177        bouts : bool, default False
1178            if `True`, instead of frame counts segment counts are returned
1179
1180        Returns
1181        -------
1182        class_counts : dict
1183            a dictionary where first-level keys are "train", "val" and "test", second-level keys are
1184            class names and values are class counts (in frames)
1185
1186        """
1187        return self.task.count_classes(bouts)

Get a dictionary of class counts in different modes.

Parameters

bouts : bool, default False if True, instead of frame counts segment counts are returned

Returns

class_counts : dict a dictionary where first-level keys are "train", "val" and "test", second-level keys are class names and values are class counts (in frames)

def visualize_results( self, save_path: str = None, add_legend: bool = True, ground_truth: bool = True, colormap: str = 'viridis', hide_axes: bool = False, min_classes: int = 1, width: float = 10, whole_video: bool = False, transparent: bool = False, dataset: Union[dlc2action.data.dataset.BehaviorDataset, torch.utils.data.dataloader.DataLoader, str, NoneType] = None, drop_classes: Set = None, search_classes: Set = None, smooth_interval_prediction: int = None, font_size: float = None, num_plots: int = 1, window_size: int = 400) -> None:
1217    def visualize_results(
1218        self,
1219        save_path: str = None,
1220        add_legend: bool = True,
1221        ground_truth: bool = True,
1222        colormap: str = "viridis",
1223        hide_axes: bool = False,
1224        min_classes: int = 1,
1225        width: float = 10,
1226        whole_video: bool = False,
1227        transparent: bool = False,
1228        dataset: Union[BehaviorDataset, DataLoader, str, None] = None,
1229        drop_classes: Set = None,
1230        search_classes: Set = None,
1231        smooth_interval_prediction: int = None,
1232        font_size: float = None,
1233        num_plots:int = 1,
1234        window_size:int=400,
1235    ) -> None:
1236        """Visualize random predictions.
1237
1238        Parameters
1239        ----------
1240        save_path : str, optional
1241            the path where the plot will be saved
1242        add_legend : bool, default True
1243            if True, legend will be added to the plot
1244        ground_truth : bool, default True
1245            if True, ground truth will be added to the plot
1246        colormap : str, default 'Accent'
1247            the `matplotlib` colormap to use
1248        hide_axes : bool, default True
1249            if `True`, the axes will be hidden on the plot
1250        min_classes : int, default 1
1251            the minimum number of classes in a displayed interval
1252        width : float, default 10
1253            the width of the plot
1254        whole_video : bool, default False
1255            if `True`, whole videos are plotted instead of segments
1256        transparent : bool, default False
1257            if `True`, the background on the plot is transparent
1258        dataset : BehaviorDataset | DataLoader | str | None, optional
1259            the dataset to make the prediction for (if not provided, the validation dataset is used)
1260        drop_classes : set, optional
1261            a set of class names to not be displayed
1262        search_classes : set, optional
1263            if given, only intervals where at least one of the classes is in ground truth will be shown
1264        smooth_interval_prediction : int, optional
1265            if given, the prediction will be smoothed over the given number of frames
1266
1267        """
1268        return self.task.visualize_results(
1269            save_path,
1270            add_legend,
1271            ground_truth,
1272            colormap,
1273            hide_axes,
1274            min_classes,
1275            width,
1276            whole_video,
1277            transparent,
1278            dataset,
1279            drop_classes,
1280            search_classes,
1281            font_size=font_size,
1282            smooth_interval_prediction=smooth_interval_prediction,
1283            num_plots=num_plots,
1284            window_size=window_size
1285        )

Visualize random predictions.

Parameters

save_path : str, optional the path where the plot will be saved add_legend : bool, default True if True, legend will be added to the plot ground_truth : bool, default True if True, ground truth will be added to the plot colormap : str, default 'Accent' the matplotlib colormap to use hide_axes : bool, default True if True, the axes will be hidden on the plot min_classes : int, default 1 the minimum number of classes in a displayed interval width : float, default 10 the width of the plot whole_video : bool, default False if True, whole videos are plotted instead of segments transparent : bool, default False if True, the background on the plot is transparent dataset : BehaviorDataset | DataLoader | str | None, optional the dataset to make the prediction for (if not provided, the validation dataset is used) drop_classes : set, optional a set of class names to not be displayed search_classes : set, optional if given, only intervals where at least one of the classes is in ground truth will be shown smooth_interval_prediction : int, optional if given, the prediction will be smoothed over the given number of frames

def generate_uncertainty_score( self, classes: List, augment_n: int = 0, method: str = 'least_confidence', predicted: torch.Tensor = None, behaviors_dict: Dict = None) -> Dict:
1287    def generate_uncertainty_score(
1288        self,
1289        classes: List,
1290        augment_n: int = 0,
1291        method: str = "least_confidence",
1292        predicted: torch.Tensor = None,
1293        behaviors_dict: Dict = None,
1294    ) -> Dict:
1295        """Generate frame-wise scores for active learning.
1296
1297        Parameters
1298        ----------
1299        classes : list
1300            a list of class names or indices; their confidence scores will be computed separately and stacked
1301        augment_n : int, default 0
1302            the number of augmentations to average over
1303        method : {"least_confidence", "entropy"}
1304            the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if
1305            `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`)
1306        predicted : torch.Tensor, optional
1307            if given, the predictions will be used instead of the model's predictions
1308        behaviors_dict : dict, optional
1309            if given, the behaviors dictionary will be used instead of the model's behaviors dictionary
1310
1311        Returns
1312        -------
1313        score_dicts : dict
1314            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
1315            are score tensors
1316
1317        """
1318        return self.task.generate_uncertainty_score(
1319            classes,
1320            augment_n,
1321            int(self.training_parameters.get("batch_size", 32)),
1322            method,
1323            predicted,
1324            behaviors_dict,
1325        )

Generate frame-wise scores for active learning.

Parameters

classes : list a list of class names or indices; their confidence scores will be computed separately and stacked augment_n : int, default 0 the number of augmentations to average over method : {"least_confidence", "entropy"} the method used to calculate the scores from the probability predictions ("least_confidence": 1 - p_i if p_i > 0.5 or p_i if p_i < 0.5; "entropy": - p_i * log(p_i) - (1 - p_i) * log(1 - p_i)) predicted : torch.Tensor, optional if given, the predictions will be used instead of the model's predictions behaviors_dict : dict, optional if given, the behaviors dictionary will be used instead of the model's behaviors dictionary

Returns

score_dicts : dict a nested dictionary where first level keys are video ids, second level keys are clip ids and values are score tensors

def generate_bald_score( self, classes: List, augment_n: int = 0, num_models: int = 10, kernel_size: int = 11) -> Dict:
1327    def generate_bald_score(
1328        self,
1329        classes: List,
1330        augment_n: int = 0,
1331        num_models: int = 10,
1332        kernel_size: int = 11,
1333    ) -> Dict:
1334        """Generate frame-wise Bayesian Active Learning by Disagreement scores for active learning.
1335
1336        Parameters
1337        ----------
1338        classes : list
1339            a list of class names or indices; their confidence scores will be computed separately and stacked
1340        augment_n : int, default 0
1341            the number of augmentations to average over
1342        num_models : int, default 10
1343            the number of dropout masks to apply
1344        kernel_size : int, default 11
1345            the size of the smoothing gaussian kernel
1346
1347        Returns
1348        -------
1349        score_dicts : dict
1350            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
1351            are score tensors
1352
1353        """
1354        return self.task.generate_bald_score(
1355            classes,
1356            augment_n,
1357            int(self.training_parameters.get("batch_size", 32)),
1358            num_models,
1359            kernel_size,
1360        )

Generate frame-wise Bayesian Active Learning by Disagreement scores for active learning.

Parameters

classes : list a list of class names or indices; their confidence scores will be computed separately and stacked augment_n : int, default 0 the number of augmentations to average over num_models : int, default 10 the number of dropout masks to apply kernel_size : int, default 11 the size of the smoothing gaussian kernel

Returns

score_dicts : dict a nested dictionary where first level keys are video ids, second level keys are clip ids and values are score tensors

def exists(self, mode) -> bool:
1362    def exists(self, mode) -> bool:
1363        """Check whether the task has a train/test/validation subset.
1364
1365        Parameters
1366        ----------
1367        mode : {"train", "val", "test"}
1368            the name of the subset to check for
1369        Returns
1370        -------
1371        exists : bool
1372            `True` if the subset exists
1373
1374        """
1375        dl = self.task.dataloader(mode)
1376        if dl is None:
1377            return False
1378        else:
1379            return True

Check whether the task has a train/test/validation subset.

Parameters

mode : {"train", "val", "test"} the name of the subset to check for

Returns

exists : bool True if the subset exists

def get_normalization_stats(self) -> Dict:
1381    def get_normalization_stats(self) -> Dict:
1382        """Get the pre-computed normalization stats.
1383
1384        Returns
1385        -------
1386        normalization_stats : dict
1387            a dictionary of means and stds
1388
1389        """
1390        return self.task.get_normalization_stats()

Get the pre-computed normalization stats.

Returns

normalization_stats : dict a dictionary of means and stds