dlc2action.task.task_dispatcher

Class that provides an interface for dlc2action.task.universal_task.Task

   1#
   2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved.
   3#
   4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
   5#
   6"""
   7Class that provides an interface for `dlc2action.task.universal_task.Task`
   8"""
   9
  10import inspect
  11from typing import Dict, Union, Tuple, List, Callable, Set
  12from torch.utils.data import DataLoader
  13from torch.optim import Optimizer
  14from collections.abc import Iterable, Mapping
  15import torch
  16from copy import deepcopy
  17import numpy as np
  18from optuna.trial import Trial
  19import warnings
  20
  21from dlc2action.data.dataset import BehaviorDataset
  22from dlc2action.task.universal_task import Task
  23from dlc2action.transformer.base_transformer import Transformer
  24from dlc2action.ssl.base_ssl import SSLConstructor
  25from dlc2action.ssl.base_ssl import EmptySSL
  26from dlc2action.model.base_model import LoadedModel, Model
  27from dlc2action.metric.base_metric import Metric
  28from dlc2action.utils import PostProcessor
  29
  30from dlc2action import options
  31
  32
  33class TaskDispatcher:
  34    """
  35    A class that manages the interactions between config dictionaries and a Task
  36    """
  37
  38    def __init__(self, parameters: Dict) -> None:
  39        """
  40        Parameters
  41        ----------
  42        parameters : dict
  43            a dictionary of task parameters
  44        """
  45
  46        pars = deepcopy(parameters)
  47        self.class_weights = None
  48        self.general_parameters = pars.get("general", {})
  49        self.data_parameters = pars.get("data", {})
  50        self.model_parameters = pars.get("model", {})
  51        self.training_parameters = pars.get("training", {})
  52        self.loss_parameters = pars.get("losses", {})
  53        self.metric_parameters = pars.get("metrics", {})
  54        self.ssl_parameters = pars.get("ssl", {})
  55        self.aug_parameters = pars.get("augmentations", {})
  56        self.feature_parameters = pars.get("features", {})
  57        self.blanks = {blank: [] for blank in options.blanks}
  58
  59        self.task = None
  60        self._initialize_task()
  61        self._print_behaviors()
  62
  63    @staticmethod
  64    def complete_function_parameters(parameters, function, general_dicts: List) -> Dict:
  65        """
  66        Complete a parameter dictionary with values from other dictionaries if required by a function
  67
  68        Parameters
  69        ----------
  70        parameters : dict
  71            the function parameters dictionary
  72        function : callable
  73            the function to be inspected
  74        general_dicts : list
  75            a list of dictionaries where the missing values will be pulled from
  76        """
  77
  78        parameter_names = inspect.getfullargspec(function).args
  79        for param in parameter_names:
  80            for dic in general_dicts:
  81                if param not in parameters and param in dic:
  82                    parameters[param] = dic[param]
  83        return parameters
  84
  85    @staticmethod
  86    def complete_dataset_parameters(
  87        parameters: dict,
  88        general_dict: dict,
  89        data_type: str,
  90        annotation_type: str,
  91    ) -> Dict:
  92        """
  93        Complete a parameter dictionary with values from other dictionaries if required by a dataset
  94
  95        Parameters
  96        ----------
  97        parameters : dict
  98            the function parameters dictionary
  99        general_dict : dict
 100            the dictionary where the missing values will be pulled from
 101        data_type : str
 102            the input type of the dataset
 103        annotation_type : str
 104            the annotation type of the dataset
 105
 106        Returns
 107        -------
 108        parameters : dict
 109            the updated parameter dictionary
 110        """
 111
 112        params = deepcopy(parameters)
 113        parameter_names = BehaviorDataset.get_parameters(data_type, annotation_type)
 114        for param in parameter_names:
 115            if param not in params and param in general_dict:
 116                params[param] = general_dict[param]
 117        return params
 118
 119    @staticmethod
 120    def check(parameters: Dict, name: str) -> bool:
 121        """
 122        Check whether there is a non-`None` value under the name key in the parameters dictionary
 123
 124        Parameters
 125        ----------
 126        parameters : dict
 127            the dictionary to check
 128        name : str
 129            the key to check
 130
 131        Returns
 132        -------
 133        result : bool
 134            True if a non-`None` value exists
 135        """
 136
 137        if name in parameters and parameters[name] is not None:
 138            return True
 139        else:
 140            return False
 141
 142    @staticmethod
 143    def get(parameters: Dict, name: str, default):
 144        """
 145        Get the value under the name key or the default if it is `None` or does not exist
 146
 147        Parameters
 148        ----------
 149        parameters : dict
 150            the dictionary to check
 151        name : str
 152            the key to check
 153        default
 154            the default value to return
 155
 156        Returns
 157        -------
 158        value
 159            the resulting value
 160        """
 161
 162        if TaskDispatcher.check(parameters, name):
 163            return parameters[name]
 164        else:
 165            return default
 166
 167    @staticmethod
 168    def make_dataloader(
 169        dataset: BehaviorDataset, batch_size: int = 32, shuffle: bool = False
 170    ) -> DataLoader:
 171        """
 172        Make a torch dataloader from a dataset
 173
 174        Parameters
 175        ----------
 176        dataset : dlc2action.data.dataset.BehaviorDataset
 177            the dataset
 178        batch_size : int
 179            the batch size
 180
 181        Returns
 182        -------
 183        dataloader : DataLoader
 184            the dataloader (or `None` if the length of the dataset is 0)
 185        """
 186
 187        if dataset is None or len(dataset) == 0:
 188            return None
 189        else:
 190            return DataLoader(dataset, batch_size=int(batch_size), shuffle=shuffle)
 191
 192    def _construct_ssl(self) -> List:
 193        """
 194        Generate SSL constructors
 195        """
 196
 197        ssl_list = deepcopy(self.general_parameters.get("ssl", None))
 198        if not isinstance(ssl_list, Iterable):
 199            ssl_list = [ssl_list]
 200        for i, ssl in enumerate(ssl_list):
 201            if type(ssl) is str:
 202                if ssl in options.ssl_constructors:
 203                    pars = self.get(self.ssl_parameters, ssl, default={})
 204                    pars = self.complete_function_parameters(
 205                        parameters=pars,
 206                        function=options.ssl_constructors[ssl],
 207                        general_dicts=[
 208                            self.model_parameters,
 209                            self.data_parameters,
 210                            self.general_parameters,
 211                        ],
 212                    )
 213                    ssl_list[i] = options.ssl_constructors[ssl](**pars)
 214                else:
 215                    raise ValueError(
 216                        f"The {ssl} SSL is not available, please choose from {list(options.ssl_constructors.keys())}"
 217                    )
 218            elif ssl is None:
 219                ssl_list[i] = EmptySSL()
 220            elif not isinstance(ssl, SSLConstructor):
 221                raise TypeError(
 222                    f"The ssl parameter has to be a list of either strings, SSLConstructor instances or None, got {type(ssl)}"
 223                )
 224        return ssl_list
 225
 226    def _construct_model(self) -> Model:
 227        """
 228        Generate a model
 229        """
 230
 231        if self.check(self.general_parameters, "model"):
 232            pars = self.complete_function_parameters(
 233                function=LoadedModel,
 234                parameters=self.model_parameters,
 235                general_dicts=[self.general_parameters],
 236            )
 237            model = LoadedModel(**pars)
 238        elif self.check(self.general_parameters, "model_name"):
 239            name = self.general_parameters["model_name"]
 240            if name in options.models:
 241                pars = self.complete_function_parameters(
 242                    function=options.models[name],
 243                    parameters=self.model_parameters,
 244                    general_dicts=[self.general_parameters],
 245                )
 246                model = options.models[name](**pars)
 247            else:
 248                raise ValueError(
 249                    f"The {name} model is not available, please choose from {list(options.models.keys())}"
 250                )
 251        else:
 252            raise ValueError(
 253                "You need to provide either a model or its name in the model_parameters!"
 254            )
 255
 256        if self.get(self.training_parameters, "freeze_features", False):
 257            model.freeze_feature_extractor()
 258        return model
 259
 260    def _construct_dataset(self) -> BehaviorDataset:
 261        """
 262        Generate a dataset
 263        """
 264
 265        data_type = self.general_parameters.get("data_type", None)
 266        if data_type is None:
 267            raise ValueError(
 268                "You need to provide the data_type parameter in the data parameters!"
 269            )
 270        annotation_type = self.get(self.general_parameters, "annotation_type", "none")
 271        feature_extraction = self.general_parameters.get("feature_extraction", "none")
 272        if feature_extraction is None:
 273            raise ValueError(
 274                "You need to provide the feature_extraction parameter in the data parameters!"
 275            )
 276        feature_extraction_pars = self.complete_function_parameters(
 277            self.feature_parameters,
 278            options.feature_extractors[feature_extraction],
 279            [self.general_parameters, self.data_parameters],
 280        )
 281
 282        pars = self.complete_dataset_parameters(
 283            self.data_parameters,
 284            self.general_parameters,
 285            data_type=data_type,
 286            annotation_type=annotation_type,
 287        )
 288        pars["feature_extraction_pars"] = feature_extraction_pars
 289        dataset = BehaviorDataset(**pars)
 290
 291        if self.get(self.general_parameters, "save_dataset", default=False):
 292            save_data_path = self.data_parameters.get("saved_data_path", None)
 293            dataset.save(save_path=save_data_path)
 294
 295        return dataset
 296
 297    def _construct_transformer(self) -> Transformer:
 298        """
 299        Generate a transformer
 300        """
 301
 302        features = self.general_parameters["feature_extraction"]
 303        name = options.extractor_to_transformer[features]
 304        if name in options.transformers:
 305            transformer_class = options.transformers[name]
 306            pars = self.complete_function_parameters(
 307                function=transformer_class,
 308                parameters=self.aug_parameters,
 309                general_dicts=[self.general_parameters],
 310            )
 311            transformer = transformer_class(**pars)
 312        else:
 313            raise ValueError(f"The {name} transformer is not available")
 314        return transformer
 315
 316    def _construct_loss(self) -> torch.nn.Module:
 317        """
 318        Generate a loss function
 319        """
 320
 321        if "loss_function" not in self.general_parameters:
 322            raise ValueError(
 323                'Please add a "loss_function" key to the parameters["general"] dictionary (either a name '
 324                f"from {list(options.losses.keys())} or a function)"
 325            )
 326        else:
 327            loss_function = self.general_parameters["loss_function"]
 328        if type(loss_function) is str:
 329            if loss_function in options.losses:
 330                pars = self.get(self.loss_parameters, loss_function, default={})
 331                pars = self._set_loss_weights(pars)
 332                pars = self.complete_function_parameters(
 333                    function=options.losses[loss_function],
 334                    parameters=pars,
 335                    general_dicts=[self.general_parameters],
 336                )
 337                loss = options.losses[loss_function](**pars)
 338            else:
 339                raise ValueError(
 340                    f"The {loss_function} loss is not available, please choose from {list(options.losses.keys())}"
 341                )
 342        else:
 343            loss = loss_function
 344        return loss
 345
 346    def _construct_metrics(self) -> List:
 347        """
 348        Generate the metric
 349        """
 350
 351        metric_functions = self.get(
 352            self.general_parameters, "metric_functions", default={}
 353        )
 354        if isinstance(metric_functions, Iterable):
 355            metrics = {}
 356            for func in metric_functions:
 357                if isinstance(func, str):
 358                    if func in options.metrics:
 359                        pars = self.get(self.metric_parameters, func, default={})
 360                        pars = self.complete_function_parameters(
 361                            function=options.metrics[func],
 362                            parameters=pars,
 363                            general_dicts=[self.general_parameters],
 364                        )
 365                        metrics[func] = options.metrics[func](**pars)
 366                    else:
 367                        raise ValueError(
 368                            f"The {func} metric is not available, please choose from {list(options.metrics.keys())}"
 369                        )
 370                elif isinstance(func, Metric):
 371                    name = "function_1"
 372                    i = 1
 373                    while name in metrics:
 374                        i += 1
 375                        name = f"function_{i}"
 376                    metrics[name] = func
 377                else:
 378                    raise TypeError(
 379                        'The elements of parameters["general"]["metric_functions"] have to be either strings '
 380                        f"from {list(options.metrics.keys())} or Metric instances; got {type(func)} instead"
 381                    )
 382        elif isinstance(metric_functions, dict):
 383            metrics = metric_functions
 384        else:
 385            raise TypeError(
 386                'The value at parameters["general"]["metric_functions"] can be either list, dictionary or None;'
 387                f"got {type(metric_functions)} instead"
 388            )
 389        return metrics
 390
 391    def _construct_optimizer(self) -> Optimizer:
 392        """
 393        Generate an optimizer
 394        """
 395
 396        if "optimizer" in self.training_parameters:
 397            name = self.training_parameters["optimizer"]
 398            if name in options.optimizers:
 399                optimizer = options.optimizers[name]
 400            else:
 401                raise ValueError(
 402                    f"The {name} optimizer is not available, please choose from {list(options.optimizers.keys())}"
 403                )
 404        else:
 405            optimizer = None
 406        return optimizer
 407
 408    def _construct_predict_functions(self) -> Tuple[Callable, Callable]:
 409        """
 410        Construct predict functions
 411        """
 412
 413        predict_function = self.training_parameters.get("predict_function", None)
 414        primary_predict_function = self.training_parameters.get(
 415            "primary_predict_function", None
 416        )
 417        model_name = self.general_parameters.get("model_name", "")
 418        threshold = self.training_parameters.get("hard_threshold", 0.5)
 419        if not isinstance(predict_function, Callable):
 420            if model_name in ["c2f_tcn", "c2f_transformer", "c2f_tcn_p"]:
 421                if self.general_parameters["exclusive"]:
 422                    func = lambda x: torch.softmax(x, dim=1)
 423                else:
 424                    func = lambda x: torch.sigmoid(x)
 425
 426                def primary_predict_function(x):
 427                    if len(x.shape) != 4:
 428                        x = x.reshape((4, -1, x.shape[-2], x.shape[-1]))
 429                    weights = [1, 1, 1, 1]
 430                    ensemble_prob = func(x[0]) * weights[0] / sum(weights)
 431                    for i, outp_ele in enumerate(x[1:]):
 432                        ensemble_prob = ensemble_prob + func(outp_ele) * weights[
 433                            i + 1
 434                        ] / sum(weights)
 435                    return ensemble_prob
 436
 437            else:
 438                if model_name.startswith("ms_tcn") or model_name in [
 439                    "asformer",
 440                    "transformer",
 441                    "c3d_ms",
 442                    "transformer_ms",
 443                ]:
 444                    f = lambda x: x[-1] if len(x.shape) == 4 else x
 445                elif model_name == "asrf":
 446
 447                    def f(x):
 448                        x = x[-1]
 449                        # bounds = x[:, 0, :].unsqueeze(1)
 450                        cls = x[:, 1:, :]
 451                        # device = x.device
 452                        # x = PostProcessor("refinement_with_boundary")._refinement_with_boundary(cls.detach().cpu().numpy(), bounds.detach().cpu().numpy())
 453                        # x = torch.tensor(x).to(device)
 454                        return cls
 455
 456                elif model_name == "actionclip":
 457
 458                    def f(x):
 459                        video_embedding, text_embedding, logit_scale = (
 460                            x["video"],
 461                            x["text"],
 462                            x["logit_scale"],
 463                        )
 464                        B, Ff, T = video_embedding.shape
 465                        video_embedding = video_embedding.permute(0, 2, 1).reshape(
 466                            (B * T, -1)
 467                        )
 468                        video_embedding /= video_embedding.norm(dim=-1, keepdim=True)
 469                        text_embedding /= text_embedding.norm(dim=-1, keepdim=True)
 470                        similarity = logit_scale * video_embedding @ text_embedding.T
 471                        similarity = similarity.reshape((B, T, -1)).permute(0, 2, 1)
 472                        return similarity
 473
 474                else:
 475                    f = lambda x: x
 476                if self.general_parameters["exclusive"]:
 477                    primary_predict_function = lambda x: torch.softmax(f(x), dim=1)
 478                else:
 479                    primary_predict_function = lambda x: torch.sigmoid(f(x))
 480            if self.general_parameters["exclusive"]:
 481                predict_function = lambda x: torch.max(x.data, dim=1)[1]
 482            else:
 483                predict_function = lambda x: (x > threshold).int()
 484        return primary_predict_function, predict_function
 485
 486    def _get_parameters_from_training(self) -> Dict:
 487        """
 488        Get the training parameters that need to be passed to the Task
 489        """
 490
 491        task_training_par_names = [
 492            "lr",
 493            "parallel",
 494            "device",
 495            "verbose",
 496            "log_file",
 497            "augment_train",
 498            "augment_val",
 499            "hard_threshold",
 500            "ssl_losses",
 501            "model_save_path",
 502            "model_save_epochs",
 503            "pseudolabel",
 504            "pseudolabel_start",
 505            "correction_interval",
 506            "pseudolabel_alpha_f",
 507            "alpha_growth_stop",
 508            "num_epochs",
 509            "validation_interval",
 510            "ignore_tags",
 511            "skip_metrics",
 512        ]
 513        task_training_pars = {
 514            name: self.training_parameters[name]
 515            for name in task_training_par_names
 516            if self.check(self.training_parameters, name)
 517        }
 518        if self.check(self.general_parameters, "ssl"):
 519            ssl_weights = [
 520                self.training_parameters["ssl_weights"][x]
 521                for x in self.general_parameters["ssl"]
 522            ]
 523            task_training_pars["ssl_weights"] = ssl_weights
 524        return task_training_pars
 525
 526    def _update_parameters_from_ssl(self, ssl_list: list) -> None:
 527        """
 528        Update the necessary parameters given the list of SSL constructors
 529        """
 530
 531        if self.task is not None:
 532            self.task.set_ssl_transformations([ssl.transformation for ssl in ssl_list])
 533            self.task.set_ssl_losses([ssl.loss for ssl in ssl_list])
 534            self.task.set_keep_target_none(
 535                [ssl.type in ["contrastive"] for ssl in ssl_list]
 536            )
 537            self.task.set_generate_ssl_input(
 538                [ssl.type == "contrastive" for ssl in ssl_list]
 539            )
 540        self.data_parameters["ssl_transformations"] = [
 541            ssl.transformation for ssl in ssl_list
 542        ]
 543        self.training_parameters["ssl_losses"] = [ssl.loss for ssl in ssl_list]
 544        self.model_parameters["ssl_types"] = [ssl.type for ssl in ssl_list]
 545        self.model_parameters["ssl_modules"] = [
 546            ssl.construct_module() for ssl in ssl_list
 547        ]
 548        self.aug_parameters["generate_ssl_input"] = [
 549            x.type == "contrastive" for x in ssl_list
 550        ]
 551        self.aug_parameters["keep_target_none"] = [
 552            x.type == "contrastive" for x in ssl_list
 553        ]
 554
 555    def _set_loss_weights(self, parameters):
 556        """
 557        Replace the `"dataset_inverse_weights"` blank in loss parameters with class weight values
 558        """
 559
 560        for k in list(parameters.keys()):
 561            if parameters[k] in [
 562                "dataset_inverse_weights",
 563                "dataset_proportional_weights",
 564            ]:
 565                if parameters[k] == "dataset_inverse_weights":
 566                    parameters[k] = self.class_weights
 567                else:
 568                    parameters[k] = self.proportional_class_weights
 569                print("Initializing class weights:")
 570                string = "    "
 571                if isinstance(parameters[k], Mapping):
 572                    for key, val in parameters[k].items():
 573                        string += ": ".join(
 574                            (
 575                                " " + str(key),
 576                                ", ".join((map(lambda x: str(np.round(x, 3)), val))),
 577                            )
 578                        )
 579                else:
 580                    string += ", ".join(
 581                        (map(lambda x: str(np.round(x, 3)), parameters[k]))
 582                    )
 583                print(string)
 584        return parameters
 585
 586    def _partition_dataset(
 587        self, dataset: BehaviorDataset
 588    ) -> Tuple[BehaviorDataset, BehaviorDataset, BehaviorDataset]:
 589        """
 590        Partition the dataset into train, validation and test subsamples
 591        """
 592
 593        use_test = self.get(self.training_parameters, "use_test", 0)
 594        split_path = self.training_parameters.get("split_path", None)
 595        partition_method = self.training_parameters.get("partition_method", "random")
 596        val_frac = self.get(self.training_parameters, "val_frac", 0)
 597        test_frac = self.get(self.training_parameters, "test_frac", 0)
 598        save_split = self.get(self.training_parameters, "save_split", True)
 599        normalize = self.get(self.training_parameters, "normalize", False)
 600        skip_normalization_keys = self.training_parameters.get(
 601            "skip_normalization_keys"
 602        )
 603        stats = self.training_parameters.get("stats")
 604        train_dataset, test_dataset, val_dataset = dataset.partition_train_test_val(
 605            use_test,
 606            split_path,
 607            partition_method,
 608            val_frac,
 609            test_frac,
 610            save_split,
 611            normalize,
 612            skip_normalization_keys,
 613            stats,
 614        )
 615        bs = int(self.training_parameters.get("batch_size", 32))
 616        train_dataloader, test_dataloader, val_dataloader = (
 617            self.make_dataloader(train_dataset, batch_size=bs, shuffle=True),
 618            self.make_dataloader(test_dataset, batch_size=bs, shuffle=False),
 619            self.make_dataloader(val_dataset, batch_size=bs, shuffle=False),
 620        )
 621        return train_dataloader, test_dataloader, val_dataloader
 622
 623    def _initialize_task(self):
 624        """
 625        Create a `dlc2action.task.universal_task.Task` instance
 626        """
 627
 628        dataset = self._construct_dataset()
 629        self._update_data_blanks(dataset)
 630        model = self._construct_model()
 631        self._update_model_blanks(model)
 632        ssl_list = self._construct_ssl()
 633        self._update_parameters_from_ssl(ssl_list)
 634        model.set_ssl(ssl_constructors=ssl_list)
 635        dataset.set_ssl_transformations([ssl.transformation for ssl in ssl_list])
 636        transformer = self._construct_transformer()
 637        metrics = self._construct_metrics()
 638        optimizer = self._construct_optimizer()
 639        primary_predict_function, predict_function = self._construct_predict_functions()
 640
 641        task_training_pars = self._get_parameters_from_training()
 642        train_dataloader, test_dataloader, val_dataloader = self._partition_dataset(
 643            dataset
 644        )
 645        self.class_weights = train_dataloader.dataset.class_weights()
 646        self.proportional_class_weights = train_dataloader.dataset.class_weights(True)
 647        loss = self._construct_loss()
 648        exclusive = self.general_parameters["exclusive"]
 649
 650        task_pars = {
 651            "train_dataloader": train_dataloader,
 652            "model": model,
 653            "loss": loss,
 654            "transformer": transformer,
 655            "metrics": metrics,
 656            "val_dataloader": val_dataloader,
 657            "test_dataloader": test_dataloader,
 658            "exclusive": exclusive,
 659            "optimizer": optimizer,
 660            "predict_function": predict_function,
 661            "primary_predict_function": primary_predict_function,
 662        }
 663        task_pars.update(task_training_pars)
 664
 665        self.task = Task(**task_pars)
 666        checkpoint_path = self.training_parameters.get("checkpoint_path", None)
 667        if checkpoint_path is not None:
 668            only_model = self.get(self.training_parameters, "only_load_model", False)
 669            load_strict = self.get(self.training_parameters, "load_strict", True)
 670            self.task.load_from_checkpoint(checkpoint_path, only_model, load_strict)
 671        if (
 672            self.general_parameters["only_load_annotated"]
 673            and self.general_parameters.get("ssl") is not None
 674        ):
 675            warnings.warn(
 676                "Note that you are using SSL modules but only loading annotated files! Set "
 677                "general/only_load_annotated to False to change that"
 678            )
 679
 680    def _update_data_blanks(
 681        self, dataset: BehaviorDataset = None, remember: bool = False
 682    ) -> None:
 683        """
 684        Update all blanks from a dataset
 685        """
 686
 687        if dataset is None:
 688            dataset = self.dataset()
 689        self._update_dim_parameter(dataset, remember)
 690        self._update_bodyparts_parameter(dataset, remember)
 691        self._update_num_classes_parameter(dataset, remember)
 692        self._update_len_segment_parameter(dataset, remember)
 693        self._update_boundary_parameter(dataset, remember)
 694
 695    def _update_model_blanks(self, model: Model, remember: bool = False) -> None:
 696        self._update_features_parameter(model, remember)
 697
 698    def _update_parameter(self, blank_name: str, value, remember: bool = False):
 699        parameters = [
 700            self.model_parameters,
 701            self.ssl_parameters,
 702            self.general_parameters,
 703            self.feature_parameters,
 704            self.data_parameters,
 705            self.training_parameters,
 706            self.metric_parameters,
 707            self.loss_parameters,
 708            self.aug_parameters,
 709        ]
 710        par_names = [
 711            "model",
 712            "ssl",
 713            "general",
 714            "feature",
 715            "data",
 716            "training",
 717            "metrics",
 718            "losses",
 719            "augmentations",
 720        ]
 721        for names in self.blanks[blank_name]:
 722            group = names[0]
 723            key = names[1]
 724            ind = par_names.index(group)
 725            if len(names) == 3:
 726                if names[2] in parameters[ind][key]:
 727                    parameters[ind][key][names[2]] = value
 728            else:
 729                if key in parameters[ind]:
 730                    parameters[ind][key] = value
 731        for name, dic in zip(par_names, parameters):
 732            for k, v in dic.items():
 733                if v == blank_name:
 734                    dic[k] = value
 735                    if [name, k] not in self.blanks[blank_name]:
 736                        self.blanks[blank_name].append([name, k])
 737                elif isinstance(v, Mapping):
 738                    for kk, vv in v.items():
 739                        if vv == blank_name:
 740                            dic[k][kk] = value
 741                            if [name, k, kk] not in self.blanks[blank_name]:
 742                                self.blanks[blank_name].append([name, k, kk])
 743
 744    def _update_features_parameter(self, model: Model, remember: bool = False) -> None:
 745        """
 746        Fill the `"model_features"` blank
 747        """
 748
 749        value = model.features_shape()
 750        self._update_parameter("model_features", value, remember)
 751
 752    def _update_bodyparts_parameter(
 753        self, dataset: BehaviorDataset, remember: bool = False
 754    ) -> None:
 755        """
 756        Fill the `"dataset_bodyparts"` blank
 757        """
 758
 759        value = dataset.bodyparts_order()
 760        self._update_parameter("dataset_bodyparts", value, remember)
 761
 762    def _update_dim_parameter(
 763        self, dataset: BehaviorDataset, remember: bool = False
 764    ) -> None:
 765        """
 766        Fill the `"dataset_features"` blank
 767        """
 768
 769        value = dataset.features_shape()
 770        self._update_parameter("dataset_features", value, remember)
 771
 772    def _update_boundary_parameter(
 773        self, dataset: BehaviorDataset, remember: bool = False
 774    ) -> None:
 775        """
 776        Fill the `"dataset_features"` blank
 777        """
 778
 779        value = dataset.boundary_class_weight()
 780        self._update_parameter("dataset_boundary_weight", value, remember)
 781
 782    def _update_num_classes_parameter(
 783        self, dataset: BehaviorDataset, remember: bool = False
 784    ) -> None:
 785        """
 786        Fill in the `"dataset_classes"` blank
 787        """
 788
 789        value = dataset.num_classes()
 790        self._update_parameter("dataset_classes", value, remember)
 791
 792    def _update_len_segment_parameter(
 793        self, dataset: BehaviorDataset, remember: bool = False
 794    ) -> None:
 795        """
 796        Fill in the `"dataset_len_segment"` blank
 797        """
 798
 799        value = dataset.len_segment()
 800        self._update_parameter("dataset_len_segment", value, remember)
 801
 802    def _print_behaviors(self):
 803        behavior_set = self.behaviors_dict()
 804        print(f"Behavior indices:")
 805        for key, value in sorted(behavior_set.items()):
 806            print(f"    {key}: {value}")
 807
 808    def update_task(self, parameters: Dict) -> None:
 809        """
 810        Update the `dlc2action.task.universal_task.Task` instance given the parameter updates
 811
 812        Parameters
 813        ----------
 814        parameters : dict
 815            the dictionary of parameter updates
 816        """
 817
 818        pars = deepcopy(parameters)
 819        # for blank_name in self.blanks:
 820        #     for names in self.blanks[blank_name]:
 821        #         group = names[0]
 822        #         key = names[1]
 823        #         if len(names) == 3:
 824        #             if (
 825        #                 group in pars
 826        #                 and key in pars[group]
 827        #                 and names[2] in pars[group][key]
 828        #             ):
 829        #                 pars[group][key].pop(names[2])
 830        #         else:
 831        #             if group in pars and key in pars[group]:
 832        #                 pars[group].pop(key)
 833        stay = False
 834        if "ssl" in pars:
 835            for key in pars["ssl"]:
 836                if key in self.ssl_parameters:
 837                    self.ssl_parameters[key].update(pars["ssl"][key])
 838                else:
 839                    self.ssl_parameters[key] = pars["ssl"][key]
 840
 841        if "general" in pars:
 842            if stay:
 843                stay = False
 844            if (
 845                "model_name" in pars["general"]
 846                and pars["general"]["model_name"]
 847                != self.general_parameters["model_name"]
 848            ):
 849                if "model" not in pars:
 850                    raise ValueError(
 851                        "When updating a task with a new model name you need to pass the parameters for the "
 852                        "new model"
 853                    )
 854                self.model_parameters = {}
 855            self.general_parameters.update(pars["general"])
 856            data_related = [
 857                "num_classes",
 858                "exclusive",
 859                "data_type",
 860                "annotation_type",
 861            ]
 862            ssl_related = ["ssl", "exclusive", "num_classes"]
 863            loss_related = ["num_classes", "loss_function", "exclusive"]
 864            augmentation_related = ["augmentation_type"]
 865            metric_related = ["metric_functions"]
 866            related_lists = [
 867                data_related,
 868                ssl_related,
 869                loss_related,
 870                augmentation_related,
 871                metric_related,
 872            ]
 873            names = ["data", "ssl", "losses", "augmentations", "metrics"]
 874            for related_list, name in zip(related_lists, names):
 875                if (
 876                    any([x in pars["general"] for x in related_list])
 877                    and name not in pars
 878                ):
 879                    pars[name] = {}
 880
 881        if "training" in pars:
 882            if "data" not in pars or not stay:
 883                for x in [
 884                    "to_ram",
 885                    "use_test",
 886                    "partition_method",
 887                    "val_frac",
 888                    "test_frac",
 889                    "save_split",
 890                    "batch_size",
 891                    "save_split",
 892                ]:
 893                    if (
 894                        x in pars["training"]
 895                        and pars["training"][x] != self.training_parameters[x]
 896                    ):
 897                        if "data" not in pars:
 898                            pars["data"] = {}
 899                        stay = True
 900            self.training_parameters.update(pars["training"])
 901            self.task.update_parameters(self._get_parameters_from_training())
 902
 903        if "data" in pars or "features" in pars:
 904            for k, v in pars["data"].items():
 905                if k not in self.data_parameters or v != self.data_parameters[k]:
 906                    stay = True
 907            for k, v in pars["features"].items():
 908                if k not in self.feature_parameters or v != self.feature_parameters[k]:
 909                    stay = True
 910            if stay:
 911                self.data_parameters.update(pars["data"])
 912                self.feature_parameters.update(pars["features"])
 913                dataset = self._construct_dataset()
 914                (
 915                    train_dataloader,
 916                    test_dataloader,
 917                    val_dataloader,
 918                ) = self._partition_dataset(dataset)
 919                self.task.set_dataloaders(
 920                    train_dataloader, val_dataloader, test_dataloader
 921                )
 922                self.class_weights = train_dataloader.dataset.class_weights()
 923                self.proportional_class_weights = (
 924                    train_dataloader.dataset.class_weights(True)
 925                )
 926                if "losses" not in pars:
 927                    pars["losses"] = {}
 928
 929        if "model" in pars:
 930            self.model_parameters.update(pars["model"])
 931
 932        self._update_data_blanks()
 933
 934        if "augmentations" in pars:
 935            self.aug_parameters.update(pars["augmentations"])
 936            transformer = self._construct_transformer()
 937            self.task.set_transformer(transformer)
 938
 939        if "losses" in pars:
 940            for key in pars["losses"]:
 941                if key in self.loss_parameters:
 942                    self.loss_parameters[key].update(pars["losses"][key])
 943                else:
 944                    self.loss_parameters[key] = pars["losses"][key]
 945            self.loss_parameters.update(pars["losses"])
 946            loss = self._construct_loss()
 947            self.task.set_loss(loss)
 948
 949        if "metrics" in pars:
 950            for key in pars["metrics"]:
 951                if key in self.metric_parameters:
 952                    self.metric_parameters[key].update(pars["metrics"][key])
 953                else:
 954                    self.metric_parameters[key] = pars["metrics"][key]
 955            metrics = self._construct_metrics()
 956            self.task.set_metrics(metrics)
 957
 958        self.task.set_ssl_transformations(self.data_parameters["ssl_transformations"])
 959        self._set_loss_weights(
 960            pars.get("losses", {}).get(self.general_parameters["loss_function"], {})
 961        )
 962        model = self._construct_model()
 963        predict_functions = self._construct_predict_functions()
 964        self.task.set_predict_functions(*predict_functions)
 965        self._update_model_blanks(model)
 966        ssl_list = self._construct_ssl()
 967        self._update_parameters_from_ssl(ssl_list)
 968        model.set_ssl(ssl_constructors=ssl_list)
 969        self.task.set_ssl_transformations([ssl.transformation for ssl in ssl_list])
 970        self.task.set_model(model)
 971        if "training" in pars and "checkpoint_path" in pars["training"]:
 972            checkpoint_path = pars["training"]["checkpoint_path"]
 973            only_model = pars["training"].get("only_load_model", False)
 974            load_strict = pars["training"].get("load_strict", True)
 975            self.task.load_from_checkpoint(checkpoint_path, only_model, load_strict)
 976        if (
 977            self.general_parameters["only_load_annotated"]
 978            and self.general_parameters.get("ssl") is not None
 979        ):
 980            warnings.warn(
 981                "Note that you are using SSL modules but only loading annotated files! Set "
 982                "general/only_load_annotated to False to change that"
 983            )
 984        if self.task.dataset("train").annotation_class() != "none":
 985            self._print_behaviors()
 986
 987    def train(
 988        self,
 989        trial: Trial = None,
 990        optimized_metric: str = None,
 991        autostop_metric: str = None,
 992        autostop_interval: int = 10,
 993        autostop_threshold: float = 0.001,
 994        loading_bar: bool = False,
 995    ) -> Tuple:
 996        """
 997        Train the task and return a log of epoch-average loss and metric
 998
 999        You can use the autostop parameters to finish training when the parameters are not improving. It will be
1000        stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than
1001        the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the
1002        current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared.
1003
1004        Parameters
1005        ----------
1006        trial : Trial
1007            an `optuna` trial (for hyperparameter searches)
1008        optimized_metric : str
1009            the name of the metric being optimized (for hyperparameter searches)
1010        to_ram : bool, default False
1011            if `True`, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes
1012            if the dataset is too large)
1013        autostop_interval : int, default 50
1014            the number of epochs to average the autostop metric over
1015        autostop_threshold : float, default 0.001
1016            the autostop difference threshold
1017        autostop_metric : str, optional
1018            the autostop metric (can be any one of the tracked metrics of `'loss'`)
1019        main_task_on : bool, default True
1020            if `False`, the main task (action segmentation) will not be used in training
1021        ssl_on : bool, default True
1022            if `False`, the SSL task will not be used in training
1023
1024        Returns
1025        -------
1026        loss_log: list
1027            a list of float loss function values for each epoch
1028        metrics_log: dict
1029            a dictionary of metric value logs (first-level keys are 'train' and 'val', second-level keys are metric
1030            names, values are lists of function values)
1031        """
1032
1033        to_ram = self.training_parameters.get("to_ram", False)
1034        logs = self.task.train(
1035            trial,
1036            optimized_metric,
1037            to_ram,
1038            autostop_metric=autostop_metric,
1039            autostop_interval=autostop_interval,
1040            autostop_threshold=autostop_threshold,
1041            main_task_on=self.training_parameters.get("main_task_on", True),
1042            ssl_on=self.training_parameters.get("ssl_on", True),
1043            temporal_subsampling_size=self.training_parameters.get(
1044                "temporal_subsampling_size"
1045            ),
1046            loading_bar=loading_bar,
1047        )
1048        return logs
1049
1050    def save_model(self, save_path: str) -> None:
1051        """
1052        Save the model of the `dlc2action.task.universal_task.Task` instance
1053
1054        Parameters
1055        ----------
1056        save_path : str
1057            the path to the saved file
1058        """
1059
1060        self.task.save_model(save_path)
1061
1062    def evaluate(
1063        self,
1064        data: Union[DataLoader, BehaviorDataset, str] = None,
1065        augment_n: int = 0,
1066        verbose: bool = True,
1067    ) -> Tuple:
1068        """
1069        Evaluate the Task model
1070
1071        Parameters
1072        ----------
1073        data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional
1074            the data to evaluate on (if not provided, evaluate on the Task validation dataset)
1075        augment_n : int, default 0
1076            the number of augmentations to average results over
1077        verbose : bool, default True
1078            if True, the process is reported to standard output
1079
1080        Returns
1081        -------
1082        loss : float
1083            the average value of the loss function
1084        ssl_loss : float
1085            the average value of the SSL loss function
1086        metric : dict
1087            a dictionary of average values of metric functions
1088        """
1089
1090        res = self.task.evaluate(
1091            data,
1092            augment_n,
1093            int(self.training_parameters.get("batch_size", 32)),
1094            verbose,
1095        )
1096        return res
1097
1098    def evaluate_prediction(
1099        self,
1100        prediction: torch.Tensor,
1101        data: Union[DataLoader, BehaviorDataset, str] = None,
1102    ) -> Tuple:
1103        """
1104        Compute metrics for a prediction
1105
1106        Parameters
1107        ----------
1108        prediction : torch.Tensor
1109            the prediction
1110        data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional
1111            the data the prediction was made for (if not provided, take the validation dataset)
1112
1113        Returns
1114        -------
1115        loss : float
1116            the average value of the loss function
1117        metric : dict
1118            a dictionary of average values of metric functions
1119        """
1120
1121        return self.task.evaluate_prediction(
1122            prediction, data, int(self.training_parameters.get("batch_size", 32))
1123        )
1124
1125    def predict(
1126        self,
1127        data: Union[DataLoader, BehaviorDataset, str],
1128        raw_output: bool = False,
1129        apply_primary_function: bool = True,
1130        augment_n: int = 0,
1131        embedding: bool = False,
1132    ) -> torch.Tensor:
1133        """
1134        Make a prediction with the Task model
1135
1136        Parameters
1137        ----------
1138        data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional
1139            the data to evaluate on (if not provided, evaluate on the Task validation dataset)
1140        raw_output : bool, default False
1141            if `True`, the raw predicted probabilities are returned
1142        apply_primary_function : bool, default True
1143            if `True`, the primary predict function is applied (to map the model output into a shape corresponding to
1144            the input)
1145        augment_n : int, default 0
1146            the number of augmentations to average results over
1147
1148        Returns
1149        -------
1150        prediction : torch.Tensor
1151            a prediction for the input data
1152        """
1153
1154        to_ram = self.training_parameters.get("to_ram", False)
1155        return self.task.predict(
1156            data,
1157            raw_output,
1158            apply_primary_function,
1159            augment_n,
1160            int(self.training_parameters.get("batch_size", 32)),
1161            to_ram,
1162            embedding=embedding,
1163        )
1164
1165    def dataset(self, mode: str = "train") -> BehaviorDataset:
1166        """
1167        Get a dataset
1168
1169        Parameters
1170        ----------
1171        mode : {'train', 'val', 'test'}
1172            the dataset to get
1173
1174        Returns
1175        -------
1176        dataset : dlc2action.data.dataset.BehaviorDataset
1177            the dataset
1178        """
1179
1180        return self.task.dataset(mode)
1181
1182    def generate_full_length_prediction(
1183        self,
1184        dataset: Union[BehaviorDataset, str] = None,
1185        augment_n: int = 10,
1186    ) -> Dict:
1187        """
1188        Compile a prediction for the original input sequences
1189
1190        Parameters
1191        ----------
1192        dataset : dlc2action.data.dataset.BehaviorDataset | str, optional
1193            the dataset to generate a prediction for (if `None`, generate for the `dlc2action.task.universal_task.Task`
1194            instance validation dataset)
1195        augment_n : int, default 10
1196            the number of augmentations to average results over
1197
1198        Returns
1199        -------
1200        prediction : dict
1201            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
1202            are prediction tensors
1203        """
1204
1205        return self.task.generate_full_length_prediction(
1206            dataset, int(self.training_parameters.get("batch_size", 32)), augment_n
1207        )
1208
1209    def generate_submission(
1210        self,
1211        frame_number_map_file: str,
1212        dataset: Union[BehaviorDataset, str] = None,
1213        augment_n: int = 10,
1214    ) -> Dict:
1215        """
1216        Generate a MABe-22 style submission dictionary
1217
1218        Parameters
1219        ----------
1220        frame_number_map_file : str
1221            path to the frame number map file
1222        dataset : BehaviorDataset, optional
1223            the dataset to generate a prediction for (if `None`, generate for the validation dataset)
1224        augment_n : int, default 10
1225            the number of augmentations to average results over
1226
1227        Returns
1228        -------
1229        submission : dict
1230            a dictionary with frame number mapping and embeddings
1231        """
1232
1233        return self.task.generate_submission(
1234            frame_number_map_file,
1235            dataset,
1236            int(self.training_parameters.get("batch_size", 32)),
1237            augment_n,
1238        )
1239
1240    def behaviors_dict(self):
1241        """
1242        Get a behavior dictionary
1243
1244        Keys are label indices and values are label names.
1245
1246        Returns
1247        -------
1248        behaviors_dict : dict
1249            behavior dictionary
1250        """
1251
1252        return self.task.behaviors_dict()
1253
1254    def count_classes(self, bouts: bool = False) -> Dict:
1255        """
1256        Get a dictionary of class counts in different modes
1257
1258        Parameters
1259        ----------
1260        bouts : bool, default False
1261            if `True`, instead of frame counts segment counts are returned
1262
1263        Returns
1264        -------
1265        class_counts : dict
1266            a dictionary where first-level keys are "train", "val" and "test", second-level keys are
1267            class names and values are class counts (in frames)
1268        """
1269
1270        return self.task.count_classes(bouts)
1271
1272    def _visualize_results_label(
1273        self,
1274        label: str,
1275        save_path: str = None,
1276        add_legend: bool = True,
1277        ground_truth: bool = True,
1278        hide_axes: bool = False,
1279        width: int = 10,
1280        whole_video: bool = False,
1281        transparent: bool = False,
1282        dataset: BehaviorDataset = None,
1283        smooth_interval: int = 0,
1284        title: str = None,
1285    ):
1286        return self.task._visualize_results_label(
1287            label,
1288            save_path,
1289            add_legend,
1290            ground_truth,
1291            hide_axes,
1292            width,
1293            whole_video,
1294            transparent,
1295            dataset,
1296            smooth_interval=smooth_interval,
1297            title=title,
1298        )
1299
1300    def visualize_results(
1301        self,
1302        save_path: str = None,
1303        add_legend: bool = True,
1304        ground_truth: bool = True,
1305        colormap: str = "viridis",
1306        hide_axes: bool = False,
1307        min_classes: int = 1,
1308        width: float = 10,
1309        whole_video: bool = False,
1310        transparent: bool = False,
1311        dataset: Union[BehaviorDataset, DataLoader, str, None] = None,
1312        drop_classes: Set = None,
1313        search_classes: Set = None,
1314        smooth_interval_prediction: int = None,
1315    ) -> None:
1316        """
1317        Visualize random predictions
1318
1319        Parameters
1320        ----------
1321        save_path : str, optional
1322            the path where the plot will be saved
1323        add_legend : bool, default True
1324            if True, legend will be added to the plot
1325        ground_truth : bool, default True
1326            if True, ground truth will be added to the plot
1327        colormap : str, default 'Accent'
1328            the `matplotlib` colormap to use
1329        hide_axes : bool, default True
1330            if `True`, the axes will be hidden on the plot
1331        min_classes : int, default 1
1332            the minimum number of classes in a displayed interval
1333        width : float, default 10
1334            the width of the plot
1335        whole_video : bool, default False
1336            if `True`, whole videos are plotted instead of segments
1337        transparent : bool, default False
1338            if `True`, the background on the plot is transparent
1339        dataset : BehaviorDataset | DataLoader | str | None, optional
1340            the dataset to make the prediction for (if not provided, the validation dataset is used)
1341        drop_classes : set, optional
1342            a set of class names to not be displayed
1343        search_classes : set, optional
1344            if given, only intervals where at least one of the classes is in ground truth will be shown
1345        """
1346
1347        return self.task.visualize_results(
1348            save_path,
1349            add_legend,
1350            ground_truth,
1351            colormap,
1352            hide_axes,
1353            min_classes,
1354            width,
1355            whole_video,
1356            transparent,
1357            dataset,
1358            drop_classes,
1359            search_classes,
1360            smooth_interval_prediction=smooth_interval_prediction,
1361        )
1362
1363    def generate_uncertainty_score(
1364        self,
1365        classes: List,
1366        augment_n: int = 0,
1367        method: str = "least_confidence",
1368        predicted: torch.Tensor = None,
1369        behaviors_dict: Dict = None,
1370    ) -> Dict:
1371        """
1372        Generate frame-wise scores for active learning
1373
1374        Parameters
1375        ----------
1376        classes : list
1377            a list of class names or indices; their confidence scores will be computed separately and stacked
1378        augment_n : int, default 0
1379            the number of augmentations to average over
1380        method : {"least_confidence", "entropy"}
1381            the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if
1382            `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`)
1383
1384        Returns
1385        -------
1386        score_dicts : dict
1387            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
1388            are score tensors
1389        """
1390
1391        return self.task.generate_uncertainty_score(
1392            classes,
1393            augment_n,
1394            int(self.training_parameters.get("batch_size", 32)),
1395            method,
1396            predicted,
1397            behaviors_dict,
1398        )
1399
1400    def generate_bald_score(
1401        self,
1402        classes: List,
1403        augment_n: int = 0,
1404        num_models: int = 10,
1405        kernel_size: int = 11,
1406    ) -> Dict:
1407        """
1408        Generate frame-wise Bayesian Active Learning by Disagreement scores for active learning
1409
1410        Parameters
1411        ----------
1412        classes : list
1413            a list of class names or indices; their confidence scores will be computed separately and stacked
1414        augment_n : int, default 0
1415            the number of augmentations to average over
1416        num_models : int, default 10
1417            the number of dropout masks to apply
1418        kernel_size : int, default 11
1419            the size of the smoothing gaussian kernel
1420
1421        Returns
1422        -------
1423        score_dicts : dict
1424            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
1425            are score tensors
1426        """
1427
1428        return self.task.generate_bald_score(
1429            classes,
1430            augment_n,
1431            int(self.training_parameters.get("batch_size", 32)),
1432            num_models,
1433            kernel_size,
1434        )
1435
1436    def get_normalization_stats(self) -> Dict:
1437        """
1438        Get the pre-computed normalization stats
1439
1440        Returns
1441        -------
1442        normalization_stats : dict
1443            a dictionary of means and stds
1444        """
1445
1446        return self.task.get_normalization_stats()
1447
1448    def exists(self, mode) -> bool:
1449        """
1450        Check whether the task has a train/test/validation subset
1451
1452        Parameters
1453        ----------
1454        mode : {"train", "val", "test"}
1455            the name of the subset to check for
1456
1457        Returns
1458        -------
1459        exists : bool
1460            `True` if the subset exists
1461        """
1462
1463        dl = self.task.dataloader(mode)
1464        if dl is None:
1465            return False
1466        else:
1467            return True
class TaskDispatcher:
  34class TaskDispatcher:
  35    """
  36    A class that manages the interactions between config dictionaries and a Task
  37    """
  38
  39    def __init__(self, parameters: Dict) -> None:
  40        """
  41        Parameters
  42        ----------
  43        parameters : dict
  44            a dictionary of task parameters
  45        """
  46
  47        pars = deepcopy(parameters)
  48        self.class_weights = None
  49        self.general_parameters = pars.get("general", {})
  50        self.data_parameters = pars.get("data", {})
  51        self.model_parameters = pars.get("model", {})
  52        self.training_parameters = pars.get("training", {})
  53        self.loss_parameters = pars.get("losses", {})
  54        self.metric_parameters = pars.get("metrics", {})
  55        self.ssl_parameters = pars.get("ssl", {})
  56        self.aug_parameters = pars.get("augmentations", {})
  57        self.feature_parameters = pars.get("features", {})
  58        self.blanks = {blank: [] for blank in options.blanks}
  59
  60        self.task = None
  61        self._initialize_task()
  62        self._print_behaviors()
  63
  64    @staticmethod
  65    def complete_function_parameters(parameters, function, general_dicts: List) -> Dict:
  66        """
  67        Complete a parameter dictionary with values from other dictionaries if required by a function
  68
  69        Parameters
  70        ----------
  71        parameters : dict
  72            the function parameters dictionary
  73        function : callable
  74            the function to be inspected
  75        general_dicts : list
  76            a list of dictionaries where the missing values will be pulled from
  77        """
  78
  79        parameter_names = inspect.getfullargspec(function).args
  80        for param in parameter_names:
  81            for dic in general_dicts:
  82                if param not in parameters and param in dic:
  83                    parameters[param] = dic[param]
  84        return parameters
  85
  86    @staticmethod
  87    def complete_dataset_parameters(
  88        parameters: dict,
  89        general_dict: dict,
  90        data_type: str,
  91        annotation_type: str,
  92    ) -> Dict:
  93        """
  94        Complete a parameter dictionary with values from other dictionaries if required by a dataset
  95
  96        Parameters
  97        ----------
  98        parameters : dict
  99            the function parameters dictionary
 100        general_dict : dict
 101            the dictionary where the missing values will be pulled from
 102        data_type : str
 103            the input type of the dataset
 104        annotation_type : str
 105            the annotation type of the dataset
 106
 107        Returns
 108        -------
 109        parameters : dict
 110            the updated parameter dictionary
 111        """
 112
 113        params = deepcopy(parameters)
 114        parameter_names = BehaviorDataset.get_parameters(data_type, annotation_type)
 115        for param in parameter_names:
 116            if param not in params and param in general_dict:
 117                params[param] = general_dict[param]
 118        return params
 119
 120    @staticmethod
 121    def check(parameters: Dict, name: str) -> bool:
 122        """
 123        Check whether there is a non-`None` value under the name key in the parameters dictionary
 124
 125        Parameters
 126        ----------
 127        parameters : dict
 128            the dictionary to check
 129        name : str
 130            the key to check
 131
 132        Returns
 133        -------
 134        result : bool
 135            True if a non-`None` value exists
 136        """
 137
 138        if name in parameters and parameters[name] is not None:
 139            return True
 140        else:
 141            return False
 142
 143    @staticmethod
 144    def get(parameters: Dict, name: str, default):
 145        """
 146        Get the value under the name key or the default if it is `None` or does not exist
 147
 148        Parameters
 149        ----------
 150        parameters : dict
 151            the dictionary to check
 152        name : str
 153            the key to check
 154        default
 155            the default value to return
 156
 157        Returns
 158        -------
 159        value
 160            the resulting value
 161        """
 162
 163        if TaskDispatcher.check(parameters, name):
 164            return parameters[name]
 165        else:
 166            return default
 167
 168    @staticmethod
 169    def make_dataloader(
 170        dataset: BehaviorDataset, batch_size: int = 32, shuffle: bool = False
 171    ) -> DataLoader:
 172        """
 173        Make a torch dataloader from a dataset
 174
 175        Parameters
 176        ----------
 177        dataset : dlc2action.data.dataset.BehaviorDataset
 178            the dataset
 179        batch_size : int
 180            the batch size
 181
 182        Returns
 183        -------
 184        dataloader : DataLoader
 185            the dataloader (or `None` if the length of the dataset is 0)
 186        """
 187
 188        if dataset is None or len(dataset) == 0:
 189            return None
 190        else:
 191            return DataLoader(dataset, batch_size=int(batch_size), shuffle=shuffle)
 192
 193    def _construct_ssl(self) -> List:
 194        """
 195        Generate SSL constructors
 196        """
 197
 198        ssl_list = deepcopy(self.general_parameters.get("ssl", None))
 199        if not isinstance(ssl_list, Iterable):
 200            ssl_list = [ssl_list]
 201        for i, ssl in enumerate(ssl_list):
 202            if type(ssl) is str:
 203                if ssl in options.ssl_constructors:
 204                    pars = self.get(self.ssl_parameters, ssl, default={})
 205                    pars = self.complete_function_parameters(
 206                        parameters=pars,
 207                        function=options.ssl_constructors[ssl],
 208                        general_dicts=[
 209                            self.model_parameters,
 210                            self.data_parameters,
 211                            self.general_parameters,
 212                        ],
 213                    )
 214                    ssl_list[i] = options.ssl_constructors[ssl](**pars)
 215                else:
 216                    raise ValueError(
 217                        f"The {ssl} SSL is not available, please choose from {list(options.ssl_constructors.keys())}"
 218                    )
 219            elif ssl is None:
 220                ssl_list[i] = EmptySSL()
 221            elif not isinstance(ssl, SSLConstructor):
 222                raise TypeError(
 223                    f"The ssl parameter has to be a list of either strings, SSLConstructor instances or None, got {type(ssl)}"
 224                )
 225        return ssl_list
 226
 227    def _construct_model(self) -> Model:
 228        """
 229        Generate a model
 230        """
 231
 232        if self.check(self.general_parameters, "model"):
 233            pars = self.complete_function_parameters(
 234                function=LoadedModel,
 235                parameters=self.model_parameters,
 236                general_dicts=[self.general_parameters],
 237            )
 238            model = LoadedModel(**pars)
 239        elif self.check(self.general_parameters, "model_name"):
 240            name = self.general_parameters["model_name"]
 241            if name in options.models:
 242                pars = self.complete_function_parameters(
 243                    function=options.models[name],
 244                    parameters=self.model_parameters,
 245                    general_dicts=[self.general_parameters],
 246                )
 247                model = options.models[name](**pars)
 248            else:
 249                raise ValueError(
 250                    f"The {name} model is not available, please choose from {list(options.models.keys())}"
 251                )
 252        else:
 253            raise ValueError(
 254                "You need to provide either a model or its name in the model_parameters!"
 255            )
 256
 257        if self.get(self.training_parameters, "freeze_features", False):
 258            model.freeze_feature_extractor()
 259        return model
 260
 261    def _construct_dataset(self) -> BehaviorDataset:
 262        """
 263        Generate a dataset
 264        """
 265
 266        data_type = self.general_parameters.get("data_type", None)
 267        if data_type is None:
 268            raise ValueError(
 269                "You need to provide the data_type parameter in the data parameters!"
 270            )
 271        annotation_type = self.get(self.general_parameters, "annotation_type", "none")
 272        feature_extraction = self.general_parameters.get("feature_extraction", "none")
 273        if feature_extraction is None:
 274            raise ValueError(
 275                "You need to provide the feature_extraction parameter in the data parameters!"
 276            )
 277        feature_extraction_pars = self.complete_function_parameters(
 278            self.feature_parameters,
 279            options.feature_extractors[feature_extraction],
 280            [self.general_parameters, self.data_parameters],
 281        )
 282
 283        pars = self.complete_dataset_parameters(
 284            self.data_parameters,
 285            self.general_parameters,
 286            data_type=data_type,
 287            annotation_type=annotation_type,
 288        )
 289        pars["feature_extraction_pars"] = feature_extraction_pars
 290        dataset = BehaviorDataset(**pars)
 291
 292        if self.get(self.general_parameters, "save_dataset", default=False):
 293            save_data_path = self.data_parameters.get("saved_data_path", None)
 294            dataset.save(save_path=save_data_path)
 295
 296        return dataset
 297
 298    def _construct_transformer(self) -> Transformer:
 299        """
 300        Generate a transformer
 301        """
 302
 303        features = self.general_parameters["feature_extraction"]
 304        name = options.extractor_to_transformer[features]
 305        if name in options.transformers:
 306            transformer_class = options.transformers[name]
 307            pars = self.complete_function_parameters(
 308                function=transformer_class,
 309                parameters=self.aug_parameters,
 310                general_dicts=[self.general_parameters],
 311            )
 312            transformer = transformer_class(**pars)
 313        else:
 314            raise ValueError(f"The {name} transformer is not available")
 315        return transformer
 316
 317    def _construct_loss(self) -> torch.nn.Module:
 318        """
 319        Generate a loss function
 320        """
 321
 322        if "loss_function" not in self.general_parameters:
 323            raise ValueError(
 324                'Please add a "loss_function" key to the parameters["general"] dictionary (either a name '
 325                f"from {list(options.losses.keys())} or a function)"
 326            )
 327        else:
 328            loss_function = self.general_parameters["loss_function"]
 329        if type(loss_function) is str:
 330            if loss_function in options.losses:
 331                pars = self.get(self.loss_parameters, loss_function, default={})
 332                pars = self._set_loss_weights(pars)
 333                pars = self.complete_function_parameters(
 334                    function=options.losses[loss_function],
 335                    parameters=pars,
 336                    general_dicts=[self.general_parameters],
 337                )
 338                loss = options.losses[loss_function](**pars)
 339            else:
 340                raise ValueError(
 341                    f"The {loss_function} loss is not available, please choose from {list(options.losses.keys())}"
 342                )
 343        else:
 344            loss = loss_function
 345        return loss
 346
 347    def _construct_metrics(self) -> List:
 348        """
 349        Generate the metric
 350        """
 351
 352        metric_functions = self.get(
 353            self.general_parameters, "metric_functions", default={}
 354        )
 355        if isinstance(metric_functions, Iterable):
 356            metrics = {}
 357            for func in metric_functions:
 358                if isinstance(func, str):
 359                    if func in options.metrics:
 360                        pars = self.get(self.metric_parameters, func, default={})
 361                        pars = self.complete_function_parameters(
 362                            function=options.metrics[func],
 363                            parameters=pars,
 364                            general_dicts=[self.general_parameters],
 365                        )
 366                        metrics[func] = options.metrics[func](**pars)
 367                    else:
 368                        raise ValueError(
 369                            f"The {func} metric is not available, please choose from {list(options.metrics.keys())}"
 370                        )
 371                elif isinstance(func, Metric):
 372                    name = "function_1"
 373                    i = 1
 374                    while name in metrics:
 375                        i += 1
 376                        name = f"function_{i}"
 377                    metrics[name] = func
 378                else:
 379                    raise TypeError(
 380                        'The elements of parameters["general"]["metric_functions"] have to be either strings '
 381                        f"from {list(options.metrics.keys())} or Metric instances; got {type(func)} instead"
 382                    )
 383        elif isinstance(metric_functions, dict):
 384            metrics = metric_functions
 385        else:
 386            raise TypeError(
 387                'The value at parameters["general"]["metric_functions"] can be either list, dictionary or None;'
 388                f"got {type(metric_functions)} instead"
 389            )
 390        return metrics
 391
 392    def _construct_optimizer(self) -> Optimizer:
 393        """
 394        Generate an optimizer
 395        """
 396
 397        if "optimizer" in self.training_parameters:
 398            name = self.training_parameters["optimizer"]
 399            if name in options.optimizers:
 400                optimizer = options.optimizers[name]
 401            else:
 402                raise ValueError(
 403                    f"The {name} optimizer is not available, please choose from {list(options.optimizers.keys())}"
 404                )
 405        else:
 406            optimizer = None
 407        return optimizer
 408
 409    def _construct_predict_functions(self) -> Tuple[Callable, Callable]:
 410        """
 411        Construct predict functions
 412        """
 413
 414        predict_function = self.training_parameters.get("predict_function", None)
 415        primary_predict_function = self.training_parameters.get(
 416            "primary_predict_function", None
 417        )
 418        model_name = self.general_parameters.get("model_name", "")
 419        threshold = self.training_parameters.get("hard_threshold", 0.5)
 420        if not isinstance(predict_function, Callable):
 421            if model_name in ["c2f_tcn", "c2f_transformer", "c2f_tcn_p"]:
 422                if self.general_parameters["exclusive"]:
 423                    func = lambda x: torch.softmax(x, dim=1)
 424                else:
 425                    func = lambda x: torch.sigmoid(x)
 426
 427                def primary_predict_function(x):
 428                    if len(x.shape) != 4:
 429                        x = x.reshape((4, -1, x.shape[-2], x.shape[-1]))
 430                    weights = [1, 1, 1, 1]
 431                    ensemble_prob = func(x[0]) * weights[0] / sum(weights)
 432                    for i, outp_ele in enumerate(x[1:]):
 433                        ensemble_prob = ensemble_prob + func(outp_ele) * weights[
 434                            i + 1
 435                        ] / sum(weights)
 436                    return ensemble_prob
 437
 438            else:
 439                if model_name.startswith("ms_tcn") or model_name in [
 440                    "asformer",
 441                    "transformer",
 442                    "c3d_ms",
 443                    "transformer_ms",
 444                ]:
 445                    f = lambda x: x[-1] if len(x.shape) == 4 else x
 446                elif model_name == "asrf":
 447
 448                    def f(x):
 449                        x = x[-1]
 450                        # bounds = x[:, 0, :].unsqueeze(1)
 451                        cls = x[:, 1:, :]
 452                        # device = x.device
 453                        # x = PostProcessor("refinement_with_boundary")._refinement_with_boundary(cls.detach().cpu().numpy(), bounds.detach().cpu().numpy())
 454                        # x = torch.tensor(x).to(device)
 455                        return cls
 456
 457                elif model_name == "actionclip":
 458
 459                    def f(x):
 460                        video_embedding, text_embedding, logit_scale = (
 461                            x["video"],
 462                            x["text"],
 463                            x["logit_scale"],
 464                        )
 465                        B, Ff, T = video_embedding.shape
 466                        video_embedding = video_embedding.permute(0, 2, 1).reshape(
 467                            (B * T, -1)
 468                        )
 469                        video_embedding /= video_embedding.norm(dim=-1, keepdim=True)
 470                        text_embedding /= text_embedding.norm(dim=-1, keepdim=True)
 471                        similarity = logit_scale * video_embedding @ text_embedding.T
 472                        similarity = similarity.reshape((B, T, -1)).permute(0, 2, 1)
 473                        return similarity
 474
 475                else:
 476                    f = lambda x: x
 477                if self.general_parameters["exclusive"]:
 478                    primary_predict_function = lambda x: torch.softmax(f(x), dim=1)
 479                else:
 480                    primary_predict_function = lambda x: torch.sigmoid(f(x))
 481            if self.general_parameters["exclusive"]:
 482                predict_function = lambda x: torch.max(x.data, dim=1)[1]
 483            else:
 484                predict_function = lambda x: (x > threshold).int()
 485        return primary_predict_function, predict_function
 486
 487    def _get_parameters_from_training(self) -> Dict:
 488        """
 489        Get the training parameters that need to be passed to the Task
 490        """
 491
 492        task_training_par_names = [
 493            "lr",
 494            "parallel",
 495            "device",
 496            "verbose",
 497            "log_file",
 498            "augment_train",
 499            "augment_val",
 500            "hard_threshold",
 501            "ssl_losses",
 502            "model_save_path",
 503            "model_save_epochs",
 504            "pseudolabel",
 505            "pseudolabel_start",
 506            "correction_interval",
 507            "pseudolabel_alpha_f",
 508            "alpha_growth_stop",
 509            "num_epochs",
 510            "validation_interval",
 511            "ignore_tags",
 512            "skip_metrics",
 513        ]
 514        task_training_pars = {
 515            name: self.training_parameters[name]
 516            for name in task_training_par_names
 517            if self.check(self.training_parameters, name)
 518        }
 519        if self.check(self.general_parameters, "ssl"):
 520            ssl_weights = [
 521                self.training_parameters["ssl_weights"][x]
 522                for x in self.general_parameters["ssl"]
 523            ]
 524            task_training_pars["ssl_weights"] = ssl_weights
 525        return task_training_pars
 526
 527    def _update_parameters_from_ssl(self, ssl_list: list) -> None:
 528        """
 529        Update the necessary parameters given the list of SSL constructors
 530        """
 531
 532        if self.task is not None:
 533            self.task.set_ssl_transformations([ssl.transformation for ssl in ssl_list])
 534            self.task.set_ssl_losses([ssl.loss for ssl in ssl_list])
 535            self.task.set_keep_target_none(
 536                [ssl.type in ["contrastive"] for ssl in ssl_list]
 537            )
 538            self.task.set_generate_ssl_input(
 539                [ssl.type == "contrastive" for ssl in ssl_list]
 540            )
 541        self.data_parameters["ssl_transformations"] = [
 542            ssl.transformation for ssl in ssl_list
 543        ]
 544        self.training_parameters["ssl_losses"] = [ssl.loss for ssl in ssl_list]
 545        self.model_parameters["ssl_types"] = [ssl.type for ssl in ssl_list]
 546        self.model_parameters["ssl_modules"] = [
 547            ssl.construct_module() for ssl in ssl_list
 548        ]
 549        self.aug_parameters["generate_ssl_input"] = [
 550            x.type == "contrastive" for x in ssl_list
 551        ]
 552        self.aug_parameters["keep_target_none"] = [
 553            x.type == "contrastive" for x in ssl_list
 554        ]
 555
 556    def _set_loss_weights(self, parameters):
 557        """
 558        Replace the `"dataset_inverse_weights"` blank in loss parameters with class weight values
 559        """
 560
 561        for k in list(parameters.keys()):
 562            if parameters[k] in [
 563                "dataset_inverse_weights",
 564                "dataset_proportional_weights",
 565            ]:
 566                if parameters[k] == "dataset_inverse_weights":
 567                    parameters[k] = self.class_weights
 568                else:
 569                    parameters[k] = self.proportional_class_weights
 570                print("Initializing class weights:")
 571                string = "    "
 572                if isinstance(parameters[k], Mapping):
 573                    for key, val in parameters[k].items():
 574                        string += ": ".join(
 575                            (
 576                                " " + str(key),
 577                                ", ".join((map(lambda x: str(np.round(x, 3)), val))),
 578                            )
 579                        )
 580                else:
 581                    string += ", ".join(
 582                        (map(lambda x: str(np.round(x, 3)), parameters[k]))
 583                    )
 584                print(string)
 585        return parameters
 586
 587    def _partition_dataset(
 588        self, dataset: BehaviorDataset
 589    ) -> Tuple[BehaviorDataset, BehaviorDataset, BehaviorDataset]:
 590        """
 591        Partition the dataset into train, validation and test subsamples
 592        """
 593
 594        use_test = self.get(self.training_parameters, "use_test", 0)
 595        split_path = self.training_parameters.get("split_path", None)
 596        partition_method = self.training_parameters.get("partition_method", "random")
 597        val_frac = self.get(self.training_parameters, "val_frac", 0)
 598        test_frac = self.get(self.training_parameters, "test_frac", 0)
 599        save_split = self.get(self.training_parameters, "save_split", True)
 600        normalize = self.get(self.training_parameters, "normalize", False)
 601        skip_normalization_keys = self.training_parameters.get(
 602            "skip_normalization_keys"
 603        )
 604        stats = self.training_parameters.get("stats")
 605        train_dataset, test_dataset, val_dataset = dataset.partition_train_test_val(
 606            use_test,
 607            split_path,
 608            partition_method,
 609            val_frac,
 610            test_frac,
 611            save_split,
 612            normalize,
 613            skip_normalization_keys,
 614            stats,
 615        )
 616        bs = int(self.training_parameters.get("batch_size", 32))
 617        train_dataloader, test_dataloader, val_dataloader = (
 618            self.make_dataloader(train_dataset, batch_size=bs, shuffle=True),
 619            self.make_dataloader(test_dataset, batch_size=bs, shuffle=False),
 620            self.make_dataloader(val_dataset, batch_size=bs, shuffle=False),
 621        )
 622        return train_dataloader, test_dataloader, val_dataloader
 623
 624    def _initialize_task(self):
 625        """
 626        Create a `dlc2action.task.universal_task.Task` instance
 627        """
 628
 629        dataset = self._construct_dataset()
 630        self._update_data_blanks(dataset)
 631        model = self._construct_model()
 632        self._update_model_blanks(model)
 633        ssl_list = self._construct_ssl()
 634        self._update_parameters_from_ssl(ssl_list)
 635        model.set_ssl(ssl_constructors=ssl_list)
 636        dataset.set_ssl_transformations([ssl.transformation for ssl in ssl_list])
 637        transformer = self._construct_transformer()
 638        metrics = self._construct_metrics()
 639        optimizer = self._construct_optimizer()
 640        primary_predict_function, predict_function = self._construct_predict_functions()
 641
 642        task_training_pars = self._get_parameters_from_training()
 643        train_dataloader, test_dataloader, val_dataloader = self._partition_dataset(
 644            dataset
 645        )
 646        self.class_weights = train_dataloader.dataset.class_weights()
 647        self.proportional_class_weights = train_dataloader.dataset.class_weights(True)
 648        loss = self._construct_loss()
 649        exclusive = self.general_parameters["exclusive"]
 650
 651        task_pars = {
 652            "train_dataloader": train_dataloader,
 653            "model": model,
 654            "loss": loss,
 655            "transformer": transformer,
 656            "metrics": metrics,
 657            "val_dataloader": val_dataloader,
 658            "test_dataloader": test_dataloader,
 659            "exclusive": exclusive,
 660            "optimizer": optimizer,
 661            "predict_function": predict_function,
 662            "primary_predict_function": primary_predict_function,
 663        }
 664        task_pars.update(task_training_pars)
 665
 666        self.task = Task(**task_pars)
 667        checkpoint_path = self.training_parameters.get("checkpoint_path", None)
 668        if checkpoint_path is not None:
 669            only_model = self.get(self.training_parameters, "only_load_model", False)
 670            load_strict = self.get(self.training_parameters, "load_strict", True)
 671            self.task.load_from_checkpoint(checkpoint_path, only_model, load_strict)
 672        if (
 673            self.general_parameters["only_load_annotated"]
 674            and self.general_parameters.get("ssl") is not None
 675        ):
 676            warnings.warn(
 677                "Note that you are using SSL modules but only loading annotated files! Set "
 678                "general/only_load_annotated to False to change that"
 679            )
 680
 681    def _update_data_blanks(
 682        self, dataset: BehaviorDataset = None, remember: bool = False
 683    ) -> None:
 684        """
 685        Update all blanks from a dataset
 686        """
 687
 688        if dataset is None:
 689            dataset = self.dataset()
 690        self._update_dim_parameter(dataset, remember)
 691        self._update_bodyparts_parameter(dataset, remember)
 692        self._update_num_classes_parameter(dataset, remember)
 693        self._update_len_segment_parameter(dataset, remember)
 694        self._update_boundary_parameter(dataset, remember)
 695
 696    def _update_model_blanks(self, model: Model, remember: bool = False) -> None:
 697        self._update_features_parameter(model, remember)
 698
 699    def _update_parameter(self, blank_name: str, value, remember: bool = False):
 700        parameters = [
 701            self.model_parameters,
 702            self.ssl_parameters,
 703            self.general_parameters,
 704            self.feature_parameters,
 705            self.data_parameters,
 706            self.training_parameters,
 707            self.metric_parameters,
 708            self.loss_parameters,
 709            self.aug_parameters,
 710        ]
 711        par_names = [
 712            "model",
 713            "ssl",
 714            "general",
 715            "feature",
 716            "data",
 717            "training",
 718            "metrics",
 719            "losses",
 720            "augmentations",
 721        ]
 722        for names in self.blanks[blank_name]:
 723            group = names[0]
 724            key = names[1]
 725            ind = par_names.index(group)
 726            if len(names) == 3:
 727                if names[2] in parameters[ind][key]:
 728                    parameters[ind][key][names[2]] = value
 729            else:
 730                if key in parameters[ind]:
 731                    parameters[ind][key] = value
 732        for name, dic in zip(par_names, parameters):
 733            for k, v in dic.items():
 734                if v == blank_name:
 735                    dic[k] = value
 736                    if [name, k] not in self.blanks[blank_name]:
 737                        self.blanks[blank_name].append([name, k])
 738                elif isinstance(v, Mapping):
 739                    for kk, vv in v.items():
 740                        if vv == blank_name:
 741                            dic[k][kk] = value
 742                            if [name, k, kk] not in self.blanks[blank_name]:
 743                                self.blanks[blank_name].append([name, k, kk])
 744
 745    def _update_features_parameter(self, model: Model, remember: bool = False) -> None:
 746        """
 747        Fill the `"model_features"` blank
 748        """
 749
 750        value = model.features_shape()
 751        self._update_parameter("model_features", value, remember)
 752
 753    def _update_bodyparts_parameter(
 754        self, dataset: BehaviorDataset, remember: bool = False
 755    ) -> None:
 756        """
 757        Fill the `"dataset_bodyparts"` blank
 758        """
 759
 760        value = dataset.bodyparts_order()
 761        self._update_parameter("dataset_bodyparts", value, remember)
 762
 763    def _update_dim_parameter(
 764        self, dataset: BehaviorDataset, remember: bool = False
 765    ) -> None:
 766        """
 767        Fill the `"dataset_features"` blank
 768        """
 769
 770        value = dataset.features_shape()
 771        self._update_parameter("dataset_features", value, remember)
 772
 773    def _update_boundary_parameter(
 774        self, dataset: BehaviorDataset, remember: bool = False
 775    ) -> None:
 776        """
 777        Fill the `"dataset_features"` blank
 778        """
 779
 780        value = dataset.boundary_class_weight()
 781        self._update_parameter("dataset_boundary_weight", value, remember)
 782
 783    def _update_num_classes_parameter(
 784        self, dataset: BehaviorDataset, remember: bool = False
 785    ) -> None:
 786        """
 787        Fill in the `"dataset_classes"` blank
 788        """
 789
 790        value = dataset.num_classes()
 791        self._update_parameter("dataset_classes", value, remember)
 792
 793    def _update_len_segment_parameter(
 794        self, dataset: BehaviorDataset, remember: bool = False
 795    ) -> None:
 796        """
 797        Fill in the `"dataset_len_segment"` blank
 798        """
 799
 800        value = dataset.len_segment()
 801        self._update_parameter("dataset_len_segment", value, remember)
 802
 803    def _print_behaviors(self):
 804        behavior_set = self.behaviors_dict()
 805        print(f"Behavior indices:")
 806        for key, value in sorted(behavior_set.items()):
 807            print(f"    {key}: {value}")
 808
 809    def update_task(self, parameters: Dict) -> None:
 810        """
 811        Update the `dlc2action.task.universal_task.Task` instance given the parameter updates
 812
 813        Parameters
 814        ----------
 815        parameters : dict
 816            the dictionary of parameter updates
 817        """
 818
 819        pars = deepcopy(parameters)
 820        # for blank_name in self.blanks:
 821        #     for names in self.blanks[blank_name]:
 822        #         group = names[0]
 823        #         key = names[1]
 824        #         if len(names) == 3:
 825        #             if (
 826        #                 group in pars
 827        #                 and key in pars[group]
 828        #                 and names[2] in pars[group][key]
 829        #             ):
 830        #                 pars[group][key].pop(names[2])
 831        #         else:
 832        #             if group in pars and key in pars[group]:
 833        #                 pars[group].pop(key)
 834        stay = False
 835        if "ssl" in pars:
 836            for key in pars["ssl"]:
 837                if key in self.ssl_parameters:
 838                    self.ssl_parameters[key].update(pars["ssl"][key])
 839                else:
 840                    self.ssl_parameters[key] = pars["ssl"][key]
 841
 842        if "general" in pars:
 843            if stay:
 844                stay = False
 845            if (
 846                "model_name" in pars["general"]
 847                and pars["general"]["model_name"]
 848                != self.general_parameters["model_name"]
 849            ):
 850                if "model" not in pars:
 851                    raise ValueError(
 852                        "When updating a task with a new model name you need to pass the parameters for the "
 853                        "new model"
 854                    )
 855                self.model_parameters = {}
 856            self.general_parameters.update(pars["general"])
 857            data_related = [
 858                "num_classes",
 859                "exclusive",
 860                "data_type",
 861                "annotation_type",
 862            ]
 863            ssl_related = ["ssl", "exclusive", "num_classes"]
 864            loss_related = ["num_classes", "loss_function", "exclusive"]
 865            augmentation_related = ["augmentation_type"]
 866            metric_related = ["metric_functions"]
 867            related_lists = [
 868                data_related,
 869                ssl_related,
 870                loss_related,
 871                augmentation_related,
 872                metric_related,
 873            ]
 874            names = ["data", "ssl", "losses", "augmentations", "metrics"]
 875            for related_list, name in zip(related_lists, names):
 876                if (
 877                    any([x in pars["general"] for x in related_list])
 878                    and name not in pars
 879                ):
 880                    pars[name] = {}
 881
 882        if "training" in pars:
 883            if "data" not in pars or not stay:
 884                for x in [
 885                    "to_ram",
 886                    "use_test",
 887                    "partition_method",
 888                    "val_frac",
 889                    "test_frac",
 890                    "save_split",
 891                    "batch_size",
 892                    "save_split",
 893                ]:
 894                    if (
 895                        x in pars["training"]
 896                        and pars["training"][x] != self.training_parameters[x]
 897                    ):
 898                        if "data" not in pars:
 899                            pars["data"] = {}
 900                        stay = True
 901            self.training_parameters.update(pars["training"])
 902            self.task.update_parameters(self._get_parameters_from_training())
 903
 904        if "data" in pars or "features" in pars:
 905            for k, v in pars["data"].items():
 906                if k not in self.data_parameters or v != self.data_parameters[k]:
 907                    stay = True
 908            for k, v in pars["features"].items():
 909                if k not in self.feature_parameters or v != self.feature_parameters[k]:
 910                    stay = True
 911            if stay:
 912                self.data_parameters.update(pars["data"])
 913                self.feature_parameters.update(pars["features"])
 914                dataset = self._construct_dataset()
 915                (
 916                    train_dataloader,
 917                    test_dataloader,
 918                    val_dataloader,
 919                ) = self._partition_dataset(dataset)
 920                self.task.set_dataloaders(
 921                    train_dataloader, val_dataloader, test_dataloader
 922                )
 923                self.class_weights = train_dataloader.dataset.class_weights()
 924                self.proportional_class_weights = (
 925                    train_dataloader.dataset.class_weights(True)
 926                )
 927                if "losses" not in pars:
 928                    pars["losses"] = {}
 929
 930        if "model" in pars:
 931            self.model_parameters.update(pars["model"])
 932
 933        self._update_data_blanks()
 934
 935        if "augmentations" in pars:
 936            self.aug_parameters.update(pars["augmentations"])
 937            transformer = self._construct_transformer()
 938            self.task.set_transformer(transformer)
 939
 940        if "losses" in pars:
 941            for key in pars["losses"]:
 942                if key in self.loss_parameters:
 943                    self.loss_parameters[key].update(pars["losses"][key])
 944                else:
 945                    self.loss_parameters[key] = pars["losses"][key]
 946            self.loss_parameters.update(pars["losses"])
 947            loss = self._construct_loss()
 948            self.task.set_loss(loss)
 949
 950        if "metrics" in pars:
 951            for key in pars["metrics"]:
 952                if key in self.metric_parameters:
 953                    self.metric_parameters[key].update(pars["metrics"][key])
 954                else:
 955                    self.metric_parameters[key] = pars["metrics"][key]
 956            metrics = self._construct_metrics()
 957            self.task.set_metrics(metrics)
 958
 959        self.task.set_ssl_transformations(self.data_parameters["ssl_transformations"])
 960        self._set_loss_weights(
 961            pars.get("losses", {}).get(self.general_parameters["loss_function"], {})
 962        )
 963        model = self._construct_model()
 964        predict_functions = self._construct_predict_functions()
 965        self.task.set_predict_functions(*predict_functions)
 966        self._update_model_blanks(model)
 967        ssl_list = self._construct_ssl()
 968        self._update_parameters_from_ssl(ssl_list)
 969        model.set_ssl(ssl_constructors=ssl_list)
 970        self.task.set_ssl_transformations([ssl.transformation for ssl in ssl_list])
 971        self.task.set_model(model)
 972        if "training" in pars and "checkpoint_path" in pars["training"]:
 973            checkpoint_path = pars["training"]["checkpoint_path"]
 974            only_model = pars["training"].get("only_load_model", False)
 975            load_strict = pars["training"].get("load_strict", True)
 976            self.task.load_from_checkpoint(checkpoint_path, only_model, load_strict)
 977        if (
 978            self.general_parameters["only_load_annotated"]
 979            and self.general_parameters.get("ssl") is not None
 980        ):
 981            warnings.warn(
 982                "Note that you are using SSL modules but only loading annotated files! Set "
 983                "general/only_load_annotated to False to change that"
 984            )
 985        if self.task.dataset("train").annotation_class() != "none":
 986            self._print_behaviors()
 987
 988    def train(
 989        self,
 990        trial: Trial = None,
 991        optimized_metric: str = None,
 992        autostop_metric: str = None,
 993        autostop_interval: int = 10,
 994        autostop_threshold: float = 0.001,
 995        loading_bar: bool = False,
 996    ) -> Tuple:
 997        """
 998        Train the task and return a log of epoch-average loss and metric
 999
1000        You can use the autostop parameters to finish training when the parameters are not improving. It will be
1001        stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than
1002        the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the
1003        current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared.
1004
1005        Parameters
1006        ----------
1007        trial : Trial
1008            an `optuna` trial (for hyperparameter searches)
1009        optimized_metric : str
1010            the name of the metric being optimized (for hyperparameter searches)
1011        to_ram : bool, default False
1012            if `True`, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes
1013            if the dataset is too large)
1014        autostop_interval : int, default 50
1015            the number of epochs to average the autostop metric over
1016        autostop_threshold : float, default 0.001
1017            the autostop difference threshold
1018        autostop_metric : str, optional
1019            the autostop metric (can be any one of the tracked metrics of `'loss'`)
1020        main_task_on : bool, default True
1021            if `False`, the main task (action segmentation) will not be used in training
1022        ssl_on : bool, default True
1023            if `False`, the SSL task will not be used in training
1024
1025        Returns
1026        -------
1027        loss_log: list
1028            a list of float loss function values for each epoch
1029        metrics_log: dict
1030            a dictionary of metric value logs (first-level keys are 'train' and 'val', second-level keys are metric
1031            names, values are lists of function values)
1032        """
1033
1034        to_ram = self.training_parameters.get("to_ram", False)
1035        logs = self.task.train(
1036            trial,
1037            optimized_metric,
1038            to_ram,
1039            autostop_metric=autostop_metric,
1040            autostop_interval=autostop_interval,
1041            autostop_threshold=autostop_threshold,
1042            main_task_on=self.training_parameters.get("main_task_on", True),
1043            ssl_on=self.training_parameters.get("ssl_on", True),
1044            temporal_subsampling_size=self.training_parameters.get(
1045                "temporal_subsampling_size"
1046            ),
1047            loading_bar=loading_bar,
1048        )
1049        return logs
1050
1051    def save_model(self, save_path: str) -> None:
1052        """
1053        Save the model of the `dlc2action.task.universal_task.Task` instance
1054
1055        Parameters
1056        ----------
1057        save_path : str
1058            the path to the saved file
1059        """
1060
1061        self.task.save_model(save_path)
1062
1063    def evaluate(
1064        self,
1065        data: Union[DataLoader, BehaviorDataset, str] = None,
1066        augment_n: int = 0,
1067        verbose: bool = True,
1068    ) -> Tuple:
1069        """
1070        Evaluate the Task model
1071
1072        Parameters
1073        ----------
1074        data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional
1075            the data to evaluate on (if not provided, evaluate on the Task validation dataset)
1076        augment_n : int, default 0
1077            the number of augmentations to average results over
1078        verbose : bool, default True
1079            if True, the process is reported to standard output
1080
1081        Returns
1082        -------
1083        loss : float
1084            the average value of the loss function
1085        ssl_loss : float
1086            the average value of the SSL loss function
1087        metric : dict
1088            a dictionary of average values of metric functions
1089        """
1090
1091        res = self.task.evaluate(
1092            data,
1093            augment_n,
1094            int(self.training_parameters.get("batch_size", 32)),
1095            verbose,
1096        )
1097        return res
1098
1099    def evaluate_prediction(
1100        self,
1101        prediction: torch.Tensor,
1102        data: Union[DataLoader, BehaviorDataset, str] = None,
1103    ) -> Tuple:
1104        """
1105        Compute metrics for a prediction
1106
1107        Parameters
1108        ----------
1109        prediction : torch.Tensor
1110            the prediction
1111        data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional
1112            the data the prediction was made for (if not provided, take the validation dataset)
1113
1114        Returns
1115        -------
1116        loss : float
1117            the average value of the loss function
1118        metric : dict
1119            a dictionary of average values of metric functions
1120        """
1121
1122        return self.task.evaluate_prediction(
1123            prediction, data, int(self.training_parameters.get("batch_size", 32))
1124        )
1125
1126    def predict(
1127        self,
1128        data: Union[DataLoader, BehaviorDataset, str],
1129        raw_output: bool = False,
1130        apply_primary_function: bool = True,
1131        augment_n: int = 0,
1132        embedding: bool = False,
1133    ) -> torch.Tensor:
1134        """
1135        Make a prediction with the Task model
1136
1137        Parameters
1138        ----------
1139        data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional
1140            the data to evaluate on (if not provided, evaluate on the Task validation dataset)
1141        raw_output : bool, default False
1142            if `True`, the raw predicted probabilities are returned
1143        apply_primary_function : bool, default True
1144            if `True`, the primary predict function is applied (to map the model output into a shape corresponding to
1145            the input)
1146        augment_n : int, default 0
1147            the number of augmentations to average results over
1148
1149        Returns
1150        -------
1151        prediction : torch.Tensor
1152            a prediction for the input data
1153        """
1154
1155        to_ram = self.training_parameters.get("to_ram", False)
1156        return self.task.predict(
1157            data,
1158            raw_output,
1159            apply_primary_function,
1160            augment_n,
1161            int(self.training_parameters.get("batch_size", 32)),
1162            to_ram,
1163            embedding=embedding,
1164        )
1165
1166    def dataset(self, mode: str = "train") -> BehaviorDataset:
1167        """
1168        Get a dataset
1169
1170        Parameters
1171        ----------
1172        mode : {'train', 'val', 'test'}
1173            the dataset to get
1174
1175        Returns
1176        -------
1177        dataset : dlc2action.data.dataset.BehaviorDataset
1178            the dataset
1179        """
1180
1181        return self.task.dataset(mode)
1182
1183    def generate_full_length_prediction(
1184        self,
1185        dataset: Union[BehaviorDataset, str] = None,
1186        augment_n: int = 10,
1187    ) -> Dict:
1188        """
1189        Compile a prediction for the original input sequences
1190
1191        Parameters
1192        ----------
1193        dataset : dlc2action.data.dataset.BehaviorDataset | str, optional
1194            the dataset to generate a prediction for (if `None`, generate for the `dlc2action.task.universal_task.Task`
1195            instance validation dataset)
1196        augment_n : int, default 10
1197            the number of augmentations to average results over
1198
1199        Returns
1200        -------
1201        prediction : dict
1202            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
1203            are prediction tensors
1204        """
1205
1206        return self.task.generate_full_length_prediction(
1207            dataset, int(self.training_parameters.get("batch_size", 32)), augment_n
1208        )
1209
1210    def generate_submission(
1211        self,
1212        frame_number_map_file: str,
1213        dataset: Union[BehaviorDataset, str] = None,
1214        augment_n: int = 10,
1215    ) -> Dict:
1216        """
1217        Generate a MABe-22 style submission dictionary
1218
1219        Parameters
1220        ----------
1221        frame_number_map_file : str
1222            path to the frame number map file
1223        dataset : BehaviorDataset, optional
1224            the dataset to generate a prediction for (if `None`, generate for the validation dataset)
1225        augment_n : int, default 10
1226            the number of augmentations to average results over
1227
1228        Returns
1229        -------
1230        submission : dict
1231            a dictionary with frame number mapping and embeddings
1232        """
1233
1234        return self.task.generate_submission(
1235            frame_number_map_file,
1236            dataset,
1237            int(self.training_parameters.get("batch_size", 32)),
1238            augment_n,
1239        )
1240
1241    def behaviors_dict(self):
1242        """
1243        Get a behavior dictionary
1244
1245        Keys are label indices and values are label names.
1246
1247        Returns
1248        -------
1249        behaviors_dict : dict
1250            behavior dictionary
1251        """
1252
1253        return self.task.behaviors_dict()
1254
1255    def count_classes(self, bouts: bool = False) -> Dict:
1256        """
1257        Get a dictionary of class counts in different modes
1258
1259        Parameters
1260        ----------
1261        bouts : bool, default False
1262            if `True`, instead of frame counts segment counts are returned
1263
1264        Returns
1265        -------
1266        class_counts : dict
1267            a dictionary where first-level keys are "train", "val" and "test", second-level keys are
1268            class names and values are class counts (in frames)
1269        """
1270
1271        return self.task.count_classes(bouts)
1272
1273    def _visualize_results_label(
1274        self,
1275        label: str,
1276        save_path: str = None,
1277        add_legend: bool = True,
1278        ground_truth: bool = True,
1279        hide_axes: bool = False,
1280        width: int = 10,
1281        whole_video: bool = False,
1282        transparent: bool = False,
1283        dataset: BehaviorDataset = None,
1284        smooth_interval: int = 0,
1285        title: str = None,
1286    ):
1287        return self.task._visualize_results_label(
1288            label,
1289            save_path,
1290            add_legend,
1291            ground_truth,
1292            hide_axes,
1293            width,
1294            whole_video,
1295            transparent,
1296            dataset,
1297            smooth_interval=smooth_interval,
1298            title=title,
1299        )
1300
1301    def visualize_results(
1302        self,
1303        save_path: str = None,
1304        add_legend: bool = True,
1305        ground_truth: bool = True,
1306        colormap: str = "viridis",
1307        hide_axes: bool = False,
1308        min_classes: int = 1,
1309        width: float = 10,
1310        whole_video: bool = False,
1311        transparent: bool = False,
1312        dataset: Union[BehaviorDataset, DataLoader, str, None] = None,
1313        drop_classes: Set = None,
1314        search_classes: Set = None,
1315        smooth_interval_prediction: int = None,
1316    ) -> None:
1317        """
1318        Visualize random predictions
1319
1320        Parameters
1321        ----------
1322        save_path : str, optional
1323            the path where the plot will be saved
1324        add_legend : bool, default True
1325            if True, legend will be added to the plot
1326        ground_truth : bool, default True
1327            if True, ground truth will be added to the plot
1328        colormap : str, default 'Accent'
1329            the `matplotlib` colormap to use
1330        hide_axes : bool, default True
1331            if `True`, the axes will be hidden on the plot
1332        min_classes : int, default 1
1333            the minimum number of classes in a displayed interval
1334        width : float, default 10
1335            the width of the plot
1336        whole_video : bool, default False
1337            if `True`, whole videos are plotted instead of segments
1338        transparent : bool, default False
1339            if `True`, the background on the plot is transparent
1340        dataset : BehaviorDataset | DataLoader | str | None, optional
1341            the dataset to make the prediction for (if not provided, the validation dataset is used)
1342        drop_classes : set, optional
1343            a set of class names to not be displayed
1344        search_classes : set, optional
1345            if given, only intervals where at least one of the classes is in ground truth will be shown
1346        """
1347
1348        return self.task.visualize_results(
1349            save_path,
1350            add_legend,
1351            ground_truth,
1352            colormap,
1353            hide_axes,
1354            min_classes,
1355            width,
1356            whole_video,
1357            transparent,
1358            dataset,
1359            drop_classes,
1360            search_classes,
1361            smooth_interval_prediction=smooth_interval_prediction,
1362        )
1363
1364    def generate_uncertainty_score(
1365        self,
1366        classes: List,
1367        augment_n: int = 0,
1368        method: str = "least_confidence",
1369        predicted: torch.Tensor = None,
1370        behaviors_dict: Dict = None,
1371    ) -> Dict:
1372        """
1373        Generate frame-wise scores for active learning
1374
1375        Parameters
1376        ----------
1377        classes : list
1378            a list of class names or indices; their confidence scores will be computed separately and stacked
1379        augment_n : int, default 0
1380            the number of augmentations to average over
1381        method : {"least_confidence", "entropy"}
1382            the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if
1383            `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`)
1384
1385        Returns
1386        -------
1387        score_dicts : dict
1388            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
1389            are score tensors
1390        """
1391
1392        return self.task.generate_uncertainty_score(
1393            classes,
1394            augment_n,
1395            int(self.training_parameters.get("batch_size", 32)),
1396            method,
1397            predicted,
1398            behaviors_dict,
1399        )
1400
1401    def generate_bald_score(
1402        self,
1403        classes: List,
1404        augment_n: int = 0,
1405        num_models: int = 10,
1406        kernel_size: int = 11,
1407    ) -> Dict:
1408        """
1409        Generate frame-wise Bayesian Active Learning by Disagreement scores for active learning
1410
1411        Parameters
1412        ----------
1413        classes : list
1414            a list of class names or indices; their confidence scores will be computed separately and stacked
1415        augment_n : int, default 0
1416            the number of augmentations to average over
1417        num_models : int, default 10
1418            the number of dropout masks to apply
1419        kernel_size : int, default 11
1420            the size of the smoothing gaussian kernel
1421
1422        Returns
1423        -------
1424        score_dicts : dict
1425            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
1426            are score tensors
1427        """
1428
1429        return self.task.generate_bald_score(
1430            classes,
1431            augment_n,
1432            int(self.training_parameters.get("batch_size", 32)),
1433            num_models,
1434            kernel_size,
1435        )
1436
1437    def get_normalization_stats(self) -> Dict:
1438        """
1439        Get the pre-computed normalization stats
1440
1441        Returns
1442        -------
1443        normalization_stats : dict
1444            a dictionary of means and stds
1445        """
1446
1447        return self.task.get_normalization_stats()
1448
1449    def exists(self, mode) -> bool:
1450        """
1451        Check whether the task has a train/test/validation subset
1452
1453        Parameters
1454        ----------
1455        mode : {"train", "val", "test"}
1456            the name of the subset to check for
1457
1458        Returns
1459        -------
1460        exists : bool
1461            `True` if the subset exists
1462        """
1463
1464        dl = self.task.dataloader(mode)
1465        if dl is None:
1466            return False
1467        else:
1468            return True

A class that manages the interactions between config dictionaries and a Task

TaskDispatcher(parameters: Dict)
39    def __init__(self, parameters: Dict) -> None:
40        """
41        Parameters
42        ----------
43        parameters : dict
44            a dictionary of task parameters
45        """
46
47        pars = deepcopy(parameters)
48        self.class_weights = None
49        self.general_parameters = pars.get("general", {})
50        self.data_parameters = pars.get("data", {})
51        self.model_parameters = pars.get("model", {})
52        self.training_parameters = pars.get("training", {})
53        self.loss_parameters = pars.get("losses", {})
54        self.metric_parameters = pars.get("metrics", {})
55        self.ssl_parameters = pars.get("ssl", {})
56        self.aug_parameters = pars.get("augmentations", {})
57        self.feature_parameters = pars.get("features", {})
58        self.blanks = {blank: [] for blank in options.blanks}
59
60        self.task = None
61        self._initialize_task()
62        self._print_behaviors()

Parameters

parameters : dict a dictionary of task parameters

@staticmethod
def complete_function_parameters(parameters, function, general_dicts: List) -> Dict:
64    @staticmethod
65    def complete_function_parameters(parameters, function, general_dicts: List) -> Dict:
66        """
67        Complete a parameter dictionary with values from other dictionaries if required by a function
68
69        Parameters
70        ----------
71        parameters : dict
72            the function parameters dictionary
73        function : callable
74            the function to be inspected
75        general_dicts : list
76            a list of dictionaries where the missing values will be pulled from
77        """
78
79        parameter_names = inspect.getfullargspec(function).args
80        for param in parameter_names:
81            for dic in general_dicts:
82                if param not in parameters and param in dic:
83                    parameters[param] = dic[param]
84        return parameters

Complete a parameter dictionary with values from other dictionaries if required by a function

Parameters

parameters : dict the function parameters dictionary function : callable the function to be inspected general_dicts : list a list of dictionaries where the missing values will be pulled from

@staticmethod
def complete_dataset_parameters( parameters: dict, general_dict: dict, data_type: str, annotation_type: str) -> Dict:
 86    @staticmethod
 87    def complete_dataset_parameters(
 88        parameters: dict,
 89        general_dict: dict,
 90        data_type: str,
 91        annotation_type: str,
 92    ) -> Dict:
 93        """
 94        Complete a parameter dictionary with values from other dictionaries if required by a dataset
 95
 96        Parameters
 97        ----------
 98        parameters : dict
 99            the function parameters dictionary
100        general_dict : dict
101            the dictionary where the missing values will be pulled from
102        data_type : str
103            the input type of the dataset
104        annotation_type : str
105            the annotation type of the dataset
106
107        Returns
108        -------
109        parameters : dict
110            the updated parameter dictionary
111        """
112
113        params = deepcopy(parameters)
114        parameter_names = BehaviorDataset.get_parameters(data_type, annotation_type)
115        for param in parameter_names:
116            if param not in params and param in general_dict:
117                params[param] = general_dict[param]
118        return params

Complete a parameter dictionary with values from other dictionaries if required by a dataset

Parameters

parameters : dict the function parameters dictionary general_dict : dict the dictionary where the missing values will be pulled from data_type : str the input type of the dataset annotation_type : str the annotation type of the dataset

Returns

parameters : dict the updated parameter dictionary

@staticmethod
def check(parameters: Dict, name: str) -> bool:
120    @staticmethod
121    def check(parameters: Dict, name: str) -> bool:
122        """
123        Check whether there is a non-`None` value under the name key in the parameters dictionary
124
125        Parameters
126        ----------
127        parameters : dict
128            the dictionary to check
129        name : str
130            the key to check
131
132        Returns
133        -------
134        result : bool
135            True if a non-`None` value exists
136        """
137
138        if name in parameters and parameters[name] is not None:
139            return True
140        else:
141            return False

Check whether there is a non-None value under the name key in the parameters dictionary

Parameters

parameters : dict the dictionary to check name : str the key to check

Returns

result : bool True if a non-None value exists

@staticmethod
def get(parameters: Dict, name: str, default)
143    @staticmethod
144    def get(parameters: Dict, name: str, default):
145        """
146        Get the value under the name key or the default if it is `None` or does not exist
147
148        Parameters
149        ----------
150        parameters : dict
151            the dictionary to check
152        name : str
153            the key to check
154        default
155            the default value to return
156
157        Returns
158        -------
159        value
160            the resulting value
161        """
162
163        if TaskDispatcher.check(parameters, name):
164            return parameters[name]
165        else:
166            return default

Get the value under the name key or the default if it is None or does not exist

Parameters

parameters : dict the dictionary to check name : str the key to check default the default value to return

Returns

value the resulting value

@staticmethod
def make_dataloader( dataset: dlc2action.data.dataset.BehaviorDataset, batch_size: int = 32, shuffle: bool = False) -> torch.utils.data.dataloader.DataLoader:
168    @staticmethod
169    def make_dataloader(
170        dataset: BehaviorDataset, batch_size: int = 32, shuffle: bool = False
171    ) -> DataLoader:
172        """
173        Make a torch dataloader from a dataset
174
175        Parameters
176        ----------
177        dataset : dlc2action.data.dataset.BehaviorDataset
178            the dataset
179        batch_size : int
180            the batch size
181
182        Returns
183        -------
184        dataloader : DataLoader
185            the dataloader (or `None` if the length of the dataset is 0)
186        """
187
188        if dataset is None or len(dataset) == 0:
189            return None
190        else:
191            return DataLoader(dataset, batch_size=int(batch_size), shuffle=shuffle)

Make a torch dataloader from a dataset

Parameters

dataset : dlc2action.data.dataset.BehaviorDataset the dataset batch_size : int the batch size

Returns

dataloader : DataLoader the dataloader (or None if the length of the dataset is 0)

def update_task(self, parameters: Dict) -> None:
809    def update_task(self, parameters: Dict) -> None:
810        """
811        Update the `dlc2action.task.universal_task.Task` instance given the parameter updates
812
813        Parameters
814        ----------
815        parameters : dict
816            the dictionary of parameter updates
817        """
818
819        pars = deepcopy(parameters)
820        # for blank_name in self.blanks:
821        #     for names in self.blanks[blank_name]:
822        #         group = names[0]
823        #         key = names[1]
824        #         if len(names) == 3:
825        #             if (
826        #                 group in pars
827        #                 and key in pars[group]
828        #                 and names[2] in pars[group][key]
829        #             ):
830        #                 pars[group][key].pop(names[2])
831        #         else:
832        #             if group in pars and key in pars[group]:
833        #                 pars[group].pop(key)
834        stay = False
835        if "ssl" in pars:
836            for key in pars["ssl"]:
837                if key in self.ssl_parameters:
838                    self.ssl_parameters[key].update(pars["ssl"][key])
839                else:
840                    self.ssl_parameters[key] = pars["ssl"][key]
841
842        if "general" in pars:
843            if stay:
844                stay = False
845            if (
846                "model_name" in pars["general"]
847                and pars["general"]["model_name"]
848                != self.general_parameters["model_name"]
849            ):
850                if "model" not in pars:
851                    raise ValueError(
852                        "When updating a task with a new model name you need to pass the parameters for the "
853                        "new model"
854                    )
855                self.model_parameters = {}
856            self.general_parameters.update(pars["general"])
857            data_related = [
858                "num_classes",
859                "exclusive",
860                "data_type",
861                "annotation_type",
862            ]
863            ssl_related = ["ssl", "exclusive", "num_classes"]
864            loss_related = ["num_classes", "loss_function", "exclusive"]
865            augmentation_related = ["augmentation_type"]
866            metric_related = ["metric_functions"]
867            related_lists = [
868                data_related,
869                ssl_related,
870                loss_related,
871                augmentation_related,
872                metric_related,
873            ]
874            names = ["data", "ssl", "losses", "augmentations", "metrics"]
875            for related_list, name in zip(related_lists, names):
876                if (
877                    any([x in pars["general"] for x in related_list])
878                    and name not in pars
879                ):
880                    pars[name] = {}
881
882        if "training" in pars:
883            if "data" not in pars or not stay:
884                for x in [
885                    "to_ram",
886                    "use_test",
887                    "partition_method",
888                    "val_frac",
889                    "test_frac",
890                    "save_split",
891                    "batch_size",
892                    "save_split",
893                ]:
894                    if (
895                        x in pars["training"]
896                        and pars["training"][x] != self.training_parameters[x]
897                    ):
898                        if "data" not in pars:
899                            pars["data"] = {}
900                        stay = True
901            self.training_parameters.update(pars["training"])
902            self.task.update_parameters(self._get_parameters_from_training())
903
904        if "data" in pars or "features" in pars:
905            for k, v in pars["data"].items():
906                if k not in self.data_parameters or v != self.data_parameters[k]:
907                    stay = True
908            for k, v in pars["features"].items():
909                if k not in self.feature_parameters or v != self.feature_parameters[k]:
910                    stay = True
911            if stay:
912                self.data_parameters.update(pars["data"])
913                self.feature_parameters.update(pars["features"])
914                dataset = self._construct_dataset()
915                (
916                    train_dataloader,
917                    test_dataloader,
918                    val_dataloader,
919                ) = self._partition_dataset(dataset)
920                self.task.set_dataloaders(
921                    train_dataloader, val_dataloader, test_dataloader
922                )
923                self.class_weights = train_dataloader.dataset.class_weights()
924                self.proportional_class_weights = (
925                    train_dataloader.dataset.class_weights(True)
926                )
927                if "losses" not in pars:
928                    pars["losses"] = {}
929
930        if "model" in pars:
931            self.model_parameters.update(pars["model"])
932
933        self._update_data_blanks()
934
935        if "augmentations" in pars:
936            self.aug_parameters.update(pars["augmentations"])
937            transformer = self._construct_transformer()
938            self.task.set_transformer(transformer)
939
940        if "losses" in pars:
941            for key in pars["losses"]:
942                if key in self.loss_parameters:
943                    self.loss_parameters[key].update(pars["losses"][key])
944                else:
945                    self.loss_parameters[key] = pars["losses"][key]
946            self.loss_parameters.update(pars["losses"])
947            loss = self._construct_loss()
948            self.task.set_loss(loss)
949
950        if "metrics" in pars:
951            for key in pars["metrics"]:
952                if key in self.metric_parameters:
953                    self.metric_parameters[key].update(pars["metrics"][key])
954                else:
955                    self.metric_parameters[key] = pars["metrics"][key]
956            metrics = self._construct_metrics()
957            self.task.set_metrics(metrics)
958
959        self.task.set_ssl_transformations(self.data_parameters["ssl_transformations"])
960        self._set_loss_weights(
961            pars.get("losses", {}).get(self.general_parameters["loss_function"], {})
962        )
963        model = self._construct_model()
964        predict_functions = self._construct_predict_functions()
965        self.task.set_predict_functions(*predict_functions)
966        self._update_model_blanks(model)
967        ssl_list = self._construct_ssl()
968        self._update_parameters_from_ssl(ssl_list)
969        model.set_ssl(ssl_constructors=ssl_list)
970        self.task.set_ssl_transformations([ssl.transformation for ssl in ssl_list])
971        self.task.set_model(model)
972        if "training" in pars and "checkpoint_path" in pars["training"]:
973            checkpoint_path = pars["training"]["checkpoint_path"]
974            only_model = pars["training"].get("only_load_model", False)
975            load_strict = pars["training"].get("load_strict", True)
976            self.task.load_from_checkpoint(checkpoint_path, only_model, load_strict)
977        if (
978            self.general_parameters["only_load_annotated"]
979            and self.general_parameters.get("ssl") is not None
980        ):
981            warnings.warn(
982                "Note that you are using SSL modules but only loading annotated files! Set "
983                "general/only_load_annotated to False to change that"
984            )
985        if self.task.dataset("train").annotation_class() != "none":
986            self._print_behaviors()

Update the dlc2action.task.universal_task.Task instance given the parameter updates

Parameters

parameters : dict the dictionary of parameter updates

def train( self, trial: optuna.trial._trial.Trial = None, optimized_metric: str = None, autostop_metric: str = None, autostop_interval: int = 10, autostop_threshold: float = 0.001, loading_bar: bool = False) -> Tuple:
 988    def train(
 989        self,
 990        trial: Trial = None,
 991        optimized_metric: str = None,
 992        autostop_metric: str = None,
 993        autostop_interval: int = 10,
 994        autostop_threshold: float = 0.001,
 995        loading_bar: bool = False,
 996    ) -> Tuple:
 997        """
 998        Train the task and return a log of epoch-average loss and metric
 999
1000        You can use the autostop parameters to finish training when the parameters are not improving. It will be
1001        stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than
1002        the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the
1003        current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared.
1004
1005        Parameters
1006        ----------
1007        trial : Trial
1008            an `optuna` trial (for hyperparameter searches)
1009        optimized_metric : str
1010            the name of the metric being optimized (for hyperparameter searches)
1011        to_ram : bool, default False
1012            if `True`, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes
1013            if the dataset is too large)
1014        autostop_interval : int, default 50
1015            the number of epochs to average the autostop metric over
1016        autostop_threshold : float, default 0.001
1017            the autostop difference threshold
1018        autostop_metric : str, optional
1019            the autostop metric (can be any one of the tracked metrics of `'loss'`)
1020        main_task_on : bool, default True
1021            if `False`, the main task (action segmentation) will not be used in training
1022        ssl_on : bool, default True
1023            if `False`, the SSL task will not be used in training
1024
1025        Returns
1026        -------
1027        loss_log: list
1028            a list of float loss function values for each epoch
1029        metrics_log: dict
1030            a dictionary of metric value logs (first-level keys are 'train' and 'val', second-level keys are metric
1031            names, values are lists of function values)
1032        """
1033
1034        to_ram = self.training_parameters.get("to_ram", False)
1035        logs = self.task.train(
1036            trial,
1037            optimized_metric,
1038            to_ram,
1039            autostop_metric=autostop_metric,
1040            autostop_interval=autostop_interval,
1041            autostop_threshold=autostop_threshold,
1042            main_task_on=self.training_parameters.get("main_task_on", True),
1043            ssl_on=self.training_parameters.get("ssl_on", True),
1044            temporal_subsampling_size=self.training_parameters.get(
1045                "temporal_subsampling_size"
1046            ),
1047            loading_bar=loading_bar,
1048        )
1049        return logs

Train the task and return a log of epoch-average loss and metric

You can use the autostop parameters to finish training when the parameters are not improving. It will be stopped if the average value of autostop_metric over the last autostop_interval epochs is smaller than the average over the previous autostop_interval epochs + autostop_threshold. For example, if the current epoch is 120 and autostop_interval is 50, the averages over epochs 70-120 and 20-70 will be compared.

Parameters

trial : Trial an optuna trial (for hyperparameter searches) optimized_metric : str the name of the metric being optimized (for hyperparameter searches) to_ram : bool, default False if True, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes if the dataset is too large) autostop_interval : int, default 50 the number of epochs to average the autostop metric over autostop_threshold : float, default 0.001 the autostop difference threshold autostop_metric : str, optional the autostop metric (can be any one of the tracked metrics of 'loss') main_task_on : bool, default True if False, the main task (action segmentation) will not be used in training ssl_on : bool, default True if False, the SSL task will not be used in training

Returns

loss_log: list a list of float loss function values for each epoch metrics_log: dict a dictionary of metric value logs (first-level keys are 'train' and 'val', second-level keys are metric names, values are lists of function values)

def save_model(self, save_path: str) -> None:
1051    def save_model(self, save_path: str) -> None:
1052        """
1053        Save the model of the `dlc2action.task.universal_task.Task` instance
1054
1055        Parameters
1056        ----------
1057        save_path : str
1058            the path to the saved file
1059        """
1060
1061        self.task.save_model(save_path)

Save the model of the dlc2action.task.universal_task.Task instance

Parameters

save_path : str the path to the saved file

def evaluate( self, data: Union[torch.utils.data.dataloader.DataLoader, dlc2action.data.dataset.BehaviorDataset, str] = None, augment_n: int = 0, verbose: bool = True) -> Tuple:
1063    def evaluate(
1064        self,
1065        data: Union[DataLoader, BehaviorDataset, str] = None,
1066        augment_n: int = 0,
1067        verbose: bool = True,
1068    ) -> Tuple:
1069        """
1070        Evaluate the Task model
1071
1072        Parameters
1073        ----------
1074        data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional
1075            the data to evaluate on (if not provided, evaluate on the Task validation dataset)
1076        augment_n : int, default 0
1077            the number of augmentations to average results over
1078        verbose : bool, default True
1079            if True, the process is reported to standard output
1080
1081        Returns
1082        -------
1083        loss : float
1084            the average value of the loss function
1085        ssl_loss : float
1086            the average value of the SSL loss function
1087        metric : dict
1088            a dictionary of average values of metric functions
1089        """
1090
1091        res = self.task.evaluate(
1092            data,
1093            augment_n,
1094            int(self.training_parameters.get("batch_size", 32)),
1095            verbose,
1096        )
1097        return res

Evaluate the Task model

Parameters

data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional the data to evaluate on (if not provided, evaluate on the Task validation dataset) augment_n : int, default 0 the number of augmentations to average results over verbose : bool, default True if True, the process is reported to standard output

Returns

loss : float the average value of the loss function ssl_loss : float the average value of the SSL loss function metric : dict a dictionary of average values of metric functions

def evaluate_prediction( self, prediction: torch.Tensor, data: Union[torch.utils.data.dataloader.DataLoader, dlc2action.data.dataset.BehaviorDataset, str] = None) -> Tuple:
1099    def evaluate_prediction(
1100        self,
1101        prediction: torch.Tensor,
1102        data: Union[DataLoader, BehaviorDataset, str] = None,
1103    ) -> Tuple:
1104        """
1105        Compute metrics for a prediction
1106
1107        Parameters
1108        ----------
1109        prediction : torch.Tensor
1110            the prediction
1111        data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional
1112            the data the prediction was made for (if not provided, take the validation dataset)
1113
1114        Returns
1115        -------
1116        loss : float
1117            the average value of the loss function
1118        metric : dict
1119            a dictionary of average values of metric functions
1120        """
1121
1122        return self.task.evaluate_prediction(
1123            prediction, data, int(self.training_parameters.get("batch_size", 32))
1124        )

Compute metrics for a prediction

Parameters

prediction : torch.Tensor the prediction data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional the data the prediction was made for (if not provided, take the validation dataset)

Returns

loss : float the average value of the loss function metric : dict a dictionary of average values of metric functions

def predict( self, data: Union[torch.utils.data.dataloader.DataLoader, dlc2action.data.dataset.BehaviorDataset, str], raw_output: bool = False, apply_primary_function: bool = True, augment_n: int = 0, embedding: bool = False) -> torch.Tensor:
1126    def predict(
1127        self,
1128        data: Union[DataLoader, BehaviorDataset, str],
1129        raw_output: bool = False,
1130        apply_primary_function: bool = True,
1131        augment_n: int = 0,
1132        embedding: bool = False,
1133    ) -> torch.Tensor:
1134        """
1135        Make a prediction with the Task model
1136
1137        Parameters
1138        ----------
1139        data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional
1140            the data to evaluate on (if not provided, evaluate on the Task validation dataset)
1141        raw_output : bool, default False
1142            if `True`, the raw predicted probabilities are returned
1143        apply_primary_function : bool, default True
1144            if `True`, the primary predict function is applied (to map the model output into a shape corresponding to
1145            the input)
1146        augment_n : int, default 0
1147            the number of augmentations to average results over
1148
1149        Returns
1150        -------
1151        prediction : torch.Tensor
1152            a prediction for the input data
1153        """
1154
1155        to_ram = self.training_parameters.get("to_ram", False)
1156        return self.task.predict(
1157            data,
1158            raw_output,
1159            apply_primary_function,
1160            augment_n,
1161            int(self.training_parameters.get("batch_size", 32)),
1162            to_ram,
1163            embedding=embedding,
1164        )

Make a prediction with the Task model

Parameters

data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional the data to evaluate on (if not provided, evaluate on the Task validation dataset) raw_output : bool, default False if True, the raw predicted probabilities are returned apply_primary_function : bool, default True if True, the primary predict function is applied (to map the model output into a shape corresponding to the input) augment_n : int, default 0 the number of augmentations to average results over

Returns

prediction : torch.Tensor a prediction for the input data

def dataset(self, mode: str = 'train') -> dlc2action.data.dataset.BehaviorDataset:
1166    def dataset(self, mode: str = "train") -> BehaviorDataset:
1167        """
1168        Get a dataset
1169
1170        Parameters
1171        ----------
1172        mode : {'train', 'val', 'test'}
1173            the dataset to get
1174
1175        Returns
1176        -------
1177        dataset : dlc2action.data.dataset.BehaviorDataset
1178            the dataset
1179        """
1180
1181        return self.task.dataset(mode)

Get a dataset

Parameters

mode : {'train', 'val', 'test'} the dataset to get

Returns

dataset : dlc2action.data.dataset.BehaviorDataset the dataset

def generate_full_length_prediction( self, dataset: Union[dlc2action.data.dataset.BehaviorDataset, str] = None, augment_n: int = 10) -> Dict:
1183    def generate_full_length_prediction(
1184        self,
1185        dataset: Union[BehaviorDataset, str] = None,
1186        augment_n: int = 10,
1187    ) -> Dict:
1188        """
1189        Compile a prediction for the original input sequences
1190
1191        Parameters
1192        ----------
1193        dataset : dlc2action.data.dataset.BehaviorDataset | str, optional
1194            the dataset to generate a prediction for (if `None`, generate for the `dlc2action.task.universal_task.Task`
1195            instance validation dataset)
1196        augment_n : int, default 10
1197            the number of augmentations to average results over
1198
1199        Returns
1200        -------
1201        prediction : dict
1202            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
1203            are prediction tensors
1204        """
1205
1206        return self.task.generate_full_length_prediction(
1207            dataset, int(self.training_parameters.get("batch_size", 32)), augment_n
1208        )

Compile a prediction for the original input sequences

Parameters

dataset : dlc2action.data.dataset.BehaviorDataset | str, optional the dataset to generate a prediction for (if None, generate for the dlc2action.task.universal_task.Task instance validation dataset) augment_n : int, default 10 the number of augmentations to average results over

Returns

prediction : dict a nested dictionary where first level keys are video ids, second level keys are clip ids and values are prediction tensors

def generate_submission( self, frame_number_map_file: str, dataset: Union[dlc2action.data.dataset.BehaviorDataset, str] = None, augment_n: int = 10) -> Dict:
1210    def generate_submission(
1211        self,
1212        frame_number_map_file: str,
1213        dataset: Union[BehaviorDataset, str] = None,
1214        augment_n: int = 10,
1215    ) -> Dict:
1216        """
1217        Generate a MABe-22 style submission dictionary
1218
1219        Parameters
1220        ----------
1221        frame_number_map_file : str
1222            path to the frame number map file
1223        dataset : BehaviorDataset, optional
1224            the dataset to generate a prediction for (if `None`, generate for the validation dataset)
1225        augment_n : int, default 10
1226            the number of augmentations to average results over
1227
1228        Returns
1229        -------
1230        submission : dict
1231            a dictionary with frame number mapping and embeddings
1232        """
1233
1234        return self.task.generate_submission(
1235            frame_number_map_file,
1236            dataset,
1237            int(self.training_parameters.get("batch_size", 32)),
1238            augment_n,
1239        )

Generate a MABe-22 style submission dictionary

Parameters

frame_number_map_file : str path to the frame number map file dataset : BehaviorDataset, optional the dataset to generate a prediction for (if None, generate for the validation dataset) augment_n : int, default 10 the number of augmentations to average results over

Returns

submission : dict a dictionary with frame number mapping and embeddings

def behaviors_dict(self)
1241    def behaviors_dict(self):
1242        """
1243        Get a behavior dictionary
1244
1245        Keys are label indices and values are label names.
1246
1247        Returns
1248        -------
1249        behaviors_dict : dict
1250            behavior dictionary
1251        """
1252
1253        return self.task.behaviors_dict()

Get a behavior dictionary

Keys are label indices and values are label names.

Returns

behaviors_dict : dict behavior dictionary

def count_classes(self, bouts: bool = False) -> Dict:
1255    def count_classes(self, bouts: bool = False) -> Dict:
1256        """
1257        Get a dictionary of class counts in different modes
1258
1259        Parameters
1260        ----------
1261        bouts : bool, default False
1262            if `True`, instead of frame counts segment counts are returned
1263
1264        Returns
1265        -------
1266        class_counts : dict
1267            a dictionary where first-level keys are "train", "val" and "test", second-level keys are
1268            class names and values are class counts (in frames)
1269        """
1270
1271        return self.task.count_classes(bouts)

Get a dictionary of class counts in different modes

Parameters

bouts : bool, default False if True, instead of frame counts segment counts are returned

Returns

class_counts : dict a dictionary where first-level keys are "train", "val" and "test", second-level keys are class names and values are class counts (in frames)

def visualize_results( self, save_path: str = None, add_legend: bool = True, ground_truth: bool = True, colormap: str = 'viridis', hide_axes: bool = False, min_classes: int = 1, width: float = 10, whole_video: bool = False, transparent: bool = False, dataset: Union[dlc2action.data.dataset.BehaviorDataset, torch.utils.data.dataloader.DataLoader, str, NoneType] = None, drop_classes: Set = None, search_classes: Set = None, smooth_interval_prediction: int = None) -> None:
1301    def visualize_results(
1302        self,
1303        save_path: str = None,
1304        add_legend: bool = True,
1305        ground_truth: bool = True,
1306        colormap: str = "viridis",
1307        hide_axes: bool = False,
1308        min_classes: int = 1,
1309        width: float = 10,
1310        whole_video: bool = False,
1311        transparent: bool = False,
1312        dataset: Union[BehaviorDataset, DataLoader, str, None] = None,
1313        drop_classes: Set = None,
1314        search_classes: Set = None,
1315        smooth_interval_prediction: int = None,
1316    ) -> None:
1317        """
1318        Visualize random predictions
1319
1320        Parameters
1321        ----------
1322        save_path : str, optional
1323            the path where the plot will be saved
1324        add_legend : bool, default True
1325            if True, legend will be added to the plot
1326        ground_truth : bool, default True
1327            if True, ground truth will be added to the plot
1328        colormap : str, default 'Accent'
1329            the `matplotlib` colormap to use
1330        hide_axes : bool, default True
1331            if `True`, the axes will be hidden on the plot
1332        min_classes : int, default 1
1333            the minimum number of classes in a displayed interval
1334        width : float, default 10
1335            the width of the plot
1336        whole_video : bool, default False
1337            if `True`, whole videos are plotted instead of segments
1338        transparent : bool, default False
1339            if `True`, the background on the plot is transparent
1340        dataset : BehaviorDataset | DataLoader | str | None, optional
1341            the dataset to make the prediction for (if not provided, the validation dataset is used)
1342        drop_classes : set, optional
1343            a set of class names to not be displayed
1344        search_classes : set, optional
1345            if given, only intervals where at least one of the classes is in ground truth will be shown
1346        """
1347
1348        return self.task.visualize_results(
1349            save_path,
1350            add_legend,
1351            ground_truth,
1352            colormap,
1353            hide_axes,
1354            min_classes,
1355            width,
1356            whole_video,
1357            transparent,
1358            dataset,
1359            drop_classes,
1360            search_classes,
1361            smooth_interval_prediction=smooth_interval_prediction,
1362        )

Visualize random predictions

Parameters

save_path : str, optional the path where the plot will be saved add_legend : bool, default True if True, legend will be added to the plot ground_truth : bool, default True if True, ground truth will be added to the plot colormap : str, default 'Accent' the matplotlib colormap to use hide_axes : bool, default True if True, the axes will be hidden on the plot min_classes : int, default 1 the minimum number of classes in a displayed interval width : float, default 10 the width of the plot whole_video : bool, default False if True, whole videos are plotted instead of segments transparent : bool, default False if True, the background on the plot is transparent dataset : BehaviorDataset | DataLoader | str | None, optional the dataset to make the prediction for (if not provided, the validation dataset is used) drop_classes : set, optional a set of class names to not be displayed search_classes : set, optional if given, only intervals where at least one of the classes is in ground truth will be shown

def generate_uncertainty_score( self, classes: List, augment_n: int = 0, method: str = 'least_confidence', predicted: torch.Tensor = None, behaviors_dict: Dict = None) -> Dict:
1364    def generate_uncertainty_score(
1365        self,
1366        classes: List,
1367        augment_n: int = 0,
1368        method: str = "least_confidence",
1369        predicted: torch.Tensor = None,
1370        behaviors_dict: Dict = None,
1371    ) -> Dict:
1372        """
1373        Generate frame-wise scores for active learning
1374
1375        Parameters
1376        ----------
1377        classes : list
1378            a list of class names or indices; their confidence scores will be computed separately and stacked
1379        augment_n : int, default 0
1380            the number of augmentations to average over
1381        method : {"least_confidence", "entropy"}
1382            the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if
1383            `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`)
1384
1385        Returns
1386        -------
1387        score_dicts : dict
1388            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
1389            are score tensors
1390        """
1391
1392        return self.task.generate_uncertainty_score(
1393            classes,
1394            augment_n,
1395            int(self.training_parameters.get("batch_size", 32)),
1396            method,
1397            predicted,
1398            behaviors_dict,
1399        )

Generate frame-wise scores for active learning

Parameters

classes : list a list of class names or indices; their confidence scores will be computed separately and stacked augment_n : int, default 0 the number of augmentations to average over method : {"least_confidence", "entropy"} the method used to calculate the scores from the probability predictions ("least_confidence": 1 - p_i if p_i > 0.5 or p_i if p_i < 0.5; "entropy": - p_i * log(p_i) - (1 - p_i) * log(1 - p_i))

Returns

score_dicts : dict a nested dictionary where first level keys are video ids, second level keys are clip ids and values are score tensors

def generate_bald_score( self, classes: List, augment_n: int = 0, num_models: int = 10, kernel_size: int = 11) -> Dict:
1401    def generate_bald_score(
1402        self,
1403        classes: List,
1404        augment_n: int = 0,
1405        num_models: int = 10,
1406        kernel_size: int = 11,
1407    ) -> Dict:
1408        """
1409        Generate frame-wise Bayesian Active Learning by Disagreement scores for active learning
1410
1411        Parameters
1412        ----------
1413        classes : list
1414            a list of class names or indices; their confidence scores will be computed separately and stacked
1415        augment_n : int, default 0
1416            the number of augmentations to average over
1417        num_models : int, default 10
1418            the number of dropout masks to apply
1419        kernel_size : int, default 11
1420            the size of the smoothing gaussian kernel
1421
1422        Returns
1423        -------
1424        score_dicts : dict
1425            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
1426            are score tensors
1427        """
1428
1429        return self.task.generate_bald_score(
1430            classes,
1431            augment_n,
1432            int(self.training_parameters.get("batch_size", 32)),
1433            num_models,
1434            kernel_size,
1435        )

Generate frame-wise Bayesian Active Learning by Disagreement scores for active learning

Parameters

classes : list a list of class names or indices; their confidence scores will be computed separately and stacked augment_n : int, default 0 the number of augmentations to average over num_models : int, default 10 the number of dropout masks to apply kernel_size : int, default 11 the size of the smoothing gaussian kernel

Returns

score_dicts : dict a nested dictionary where first level keys are video ids, second level keys are clip ids and values are score tensors

def get_normalization_stats(self) -> Dict:
1437    def get_normalization_stats(self) -> Dict:
1438        """
1439        Get the pre-computed normalization stats
1440
1441        Returns
1442        -------
1443        normalization_stats : dict
1444            a dictionary of means and stds
1445        """
1446
1447        return self.task.get_normalization_stats()

Get the pre-computed normalization stats

Returns

normalization_stats : dict a dictionary of means and stds

def exists(self, mode) -> bool:
1449    def exists(self, mode) -> bool:
1450        """
1451        Check whether the task has a train/test/validation subset
1452
1453        Parameters
1454        ----------
1455        mode : {"train", "val", "test"}
1456            the name of the subset to check for
1457
1458        Returns
1459        -------
1460        exists : bool
1461            `True` if the subset exists
1462        """
1463
1464        dl = self.task.dataloader(mode)
1465        if dl is None:
1466            return False
1467        else:
1468            return True

Check whether the task has a train/test/validation subset

Parameters

mode : {"train", "val", "test"} the name of the subset to check for

Returns

exists : bool True if the subset exists