dlc2action.task.task_dispatcher
Class that provides an interface for dlc2action.task.universal_task.Task
.
1# 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. 5# A copy is included in dlc2action/LICENSE.AGPL. 6# 7"""Class that provides an interface for `dlc2action.task.universal_task.Task`.""" 8 9import inspect 10import warnings 11from collections.abc import Iterable, Mapping 12from copy import deepcopy 13from typing import Callable, Dict, List, Set, Tuple, Union 14 15import numpy as np 16import torch 17from dlc2action import options 18from dlc2action.data.dataset import BehaviorDataset 19from dlc2action.metric.base_metric import Metric 20from dlc2action.model.base_model import LoadedModel, Model 21from dlc2action.ssl.base_ssl import EmptySSL, SSLConstructor 22from dlc2action.task.universal_task import Task 23from dlc2action.transformer.base_transformer import Transformer 24from dlc2action.utils import PostProcessor 25from optuna.trial import Trial 26from torch.optim import Optimizer 27from torch.utils.data import DataLoader 28 29 30class TaskDispatcher: 31 """A class that manages the interactions between config dictionaries and a Task.""" 32 33 def __init__(self, parameters: Dict) -> None: 34 """Initialize the `TaskDispatcher`. 35 36 Parameters 37 ---------- 38 parameters : dict 39 a dictionary of task parameters 40 41 """ 42 pars = deepcopy(parameters) 43 self.class_weights = None 44 self.general_parameters = pars.get("general", {}) 45 self.data_parameters = pars.get("data", {}) 46 self.model_parameters = pars.get("model", {}) 47 self.training_parameters = pars.get("training", {}) 48 self.loss_parameters = pars.get("losses", {}) 49 self.metric_parameters = pars.get("metrics", {}) 50 self.ssl_parameters = pars.get("ssl", {}) 51 self.aug_parameters = pars.get("augmentations", {}) 52 self.feature_parameters = pars.get("features", {}) 53 self.blanks = {blank: [] for blank in options.blanks} 54 55 self.task = None 56 self._initialize_task() 57 self._print_behaviors() 58 59 @staticmethod 60 def complete_function_parameters(parameters, function, general_dicts: List) -> Dict: 61 """Complete a parameter dictionary with values from other dictionaries if required by a function. 62 63 Parameters 64 ---------- 65 parameters : dict 66 the function parameters dictionary 67 function : callable 68 the function to be inspected 69 general_dicts : list 70 a list of dictionaries where the missing values will be pulled from 71 72 Returns 73 ------- 74 parameters : dict 75 the updated parameter dictionary 76 77 """ 78 parameter_names = inspect.getfullargspec(function).args 79 for param in parameter_names: 80 for dic in general_dicts: 81 if param not in parameters and param in dic: 82 parameters[param] = dic[param] 83 return parameters 84 85 @staticmethod 86 def complete_dataset_parameters( 87 parameters: dict, 88 general_dict: dict, 89 data_type: str, 90 annotation_type: str, 91 ) -> Dict: 92 """Complete a parameter dictionary with values from other dictionaries if required by a dataset. 93 94 Parameters 95 ---------- 96 parameters : dict 97 the function parameters dictionary 98 general_dict : dict 99 the dictionary where the missing values will be pulled from 100 data_type : str 101 the input type of the dataset 102 annotation_type : str 103 the annotation type of the dataset 104 105 Returns 106 ------- 107 parameters : dict 108 the updated parameter dictionary 109 110 """ 111 params = deepcopy(parameters) 112 parameter_names = BehaviorDataset.get_parameters(data_type, annotation_type) 113 for param in parameter_names: 114 if param not in params and param in general_dict: 115 params[param] = general_dict[param] 116 return params 117 118 @staticmethod 119 def check(parameters: Dict, name: str) -> bool: 120 """Check whether there is a non-`None` value under the name key in the parameters dictionary. 121 122 Parameters 123 ---------- 124 parameters : dict 125 the dictionary to check 126 name : str 127 the key to check 128 129 Returns 130 ------- 131 result : bool 132 True if a non-`None` value exists 133 134 """ 135 if name in parameters and parameters[name] is not None: 136 return True 137 else: 138 return False 139 140 @staticmethod 141 def get(parameters: Dict, name: str, default): 142 """Get the value under the name key or the default if it is `None` or does not exist. 143 144 Parameters 145 ---------- 146 parameters : dict 147 the dictionary to check 148 name : str 149 the key to check 150 default 151 the default value to return 152 153 Returns 154 ------- 155 value 156 the resulting value 157 158 """ 159 if TaskDispatcher.check(parameters, name): 160 return parameters[name] 161 else: 162 return default 163 164 @staticmethod 165 def make_dataloader( 166 dataset: BehaviorDataset, batch_size: int = 32, shuffle: bool = False 167 ) -> DataLoader: 168 """Make a torch dataloader from a dataset. 169 170 Parameters 171 ---------- 172 dataset : dlc2action.data.dataset.BehaviorDataset 173 the dataset 174 batch_size : int 175 the batch size 176 shuffle : bool 177 whether to shuffle the dataset 178 179 Returns 180 ------- 181 dataloader : DataLoader 182 the dataloader (or `None` if the length of the dataset is 0) 183 184 """ 185 if dataset is None or len(dataset) == 0: 186 return None 187 else: 188 return DataLoader(dataset, batch_size=int(batch_size), shuffle=shuffle) 189 190 def _construct_ssl(self) -> List: 191 """Generate SSL constructors.""" 192 ssl_list = deepcopy(self.general_parameters.get("ssl", None)) 193 model_name = self.general_parameters.get("model_name", "") 194 # ssl_constructors = options.ssl_constructors if not "tcn" in model_name else options.ssl_constructors_tcn 195 ssl_constructors = options.ssl_constructors 196 if not isinstance(ssl_list, Iterable): 197 ssl_list = [ssl_list] 198 for i, ssl in enumerate(ssl_list): 199 if type(ssl) is str: 200 if ssl in ssl_constructors: 201 pars = self.get(self.ssl_parameters, ssl, default={}) 202 pars = self.complete_function_parameters( 203 parameters=pars, 204 function=ssl_constructors[ssl], 205 general_dicts=[ 206 self.model_parameters, 207 self.data_parameters, 208 self.general_parameters, 209 ], 210 ) 211 ssl_list[i] = ssl_constructors[ssl](**pars) 212 else: 213 raise ValueError( 214 f"The {ssl} SSL is not available, please choose from {list(ssl_constructors.keys())}" 215 ) 216 elif ssl is None: 217 ssl_list[i] = EmptySSL() 218 elif not isinstance(ssl, SSLConstructor): 219 raise TypeError( 220 f"The ssl parameter has to be a list of either strings, SSLConstructor instances or None, got {type(ssl)}" 221 ) 222 return ssl_list 223 224 def _construct_model(self) -> Model: 225 """Generate a model.""" 226 if self.check(self.general_parameters, "model"): 227 pars = self.complete_function_parameters( 228 function=LoadedModel, 229 parameters=self.model_parameters, 230 general_dicts=[self.general_parameters], 231 ) 232 model = LoadedModel(**pars) 233 elif self.check(self.general_parameters, "model_name"): 234 name = self.general_parameters["model_name"] 235 if name in options.models: 236 pars = self.complete_function_parameters( 237 function=options.models[name], 238 parameters=self.model_parameters, 239 general_dicts=[self.general_parameters], 240 ) 241 model = options.models[name](**pars) 242 else: 243 raise ValueError( 244 f"The {name} model is not available, please choose from {list(options.models.keys())}" 245 ) 246 else: 247 raise ValueError( 248 "You need to provide either a model or its name in the model_parameters!" 249 ) 250 251 if self.get(self.training_parameters, "freeze_features", False): 252 model.freeze_feature_extractor() 253 return model 254 255 def _construct_dataset(self) -> BehaviorDataset: 256 """ 257 Generate a dataset 258 """ 259 data_type = self.general_parameters.get("data_type", None) 260 if data_type is None: 261 raise ValueError( 262 "You need to provide the data_type parameter in the data parameters!" 263 ) 264 annotation_type = self.get(self.general_parameters, "annotation_type", "none") 265 feature_extraction = self.general_parameters.get("feature_extraction", "none") 266 if feature_extraction is None: 267 raise ValueError( 268 "You need to provide the feature_extraction parameter in the data parameters!" 269 ) 270 feature_extraction_pars = self.complete_function_parameters( 271 self.feature_parameters, 272 options.feature_extractors[feature_extraction], 273 [self.general_parameters, self.data_parameters], 274 ) 275 276 pars = self.complete_dataset_parameters( 277 self.data_parameters, 278 self.general_parameters, 279 data_type=data_type, 280 annotation_type=annotation_type, 281 ) 282 pars["feature_extraction_pars"] = feature_extraction_pars 283 dataset = BehaviorDataset(**pars) 284 285 if self.get(self.general_parameters, "save_dataset", default=False): 286 save_data_path = self.data_parameters.get("saved_data_path", None) 287 dataset.save(save_path=save_data_path) 288 289 return dataset 290 291 def _construct_transformer(self) -> Transformer: 292 """Generate a transformer.""" 293 features = self.general_parameters["feature_extraction"] 294 name = options.extractor_to_transformer[features] 295 if name in options.transformers: 296 transformer_class = options.transformers[name] 297 pars = self.complete_function_parameters( 298 function=transformer_class, 299 parameters=self.aug_parameters, 300 general_dicts=[self.general_parameters], 301 ) 302 transformer = transformer_class(**pars) 303 else: 304 raise ValueError(f"The {name} transformer is not available") 305 return transformer 306 307 def _construct_loss(self) -> torch.nn.Module: 308 """Generate a loss function.""" 309 if "loss_function" not in self.general_parameters: 310 raise ValueError( 311 'Please add a "loss_function" key to the parameters["general"] dictionary (either a name ' 312 f"from {list(options.losses.keys())} or a function)" 313 ) 314 else: 315 loss_function = self.general_parameters["loss_function"] 316 if type(loss_function) is str: 317 if loss_function in options.losses: 318 pars = self.get(self.loss_parameters, loss_function, default={}) 319 pars = self._set_loss_weights(pars) 320 pars = self.complete_function_parameters( 321 function=options.losses[loss_function], 322 parameters=pars, 323 general_dicts=[self.general_parameters], 324 ) 325 loss = options.losses[loss_function](**pars) 326 else: 327 raise ValueError( 328 f"The {loss_function} loss is not available, please choose from {list(options.losses.keys())}" 329 ) 330 else: 331 loss = loss_function 332 return loss 333 334 def _construct_metrics(self) -> List: 335 """Generate the metric.""" 336 metric_functions = self.get( 337 self.general_parameters, "metric_functions", default={} 338 ) 339 if isinstance(metric_functions, Iterable): 340 metrics = {} 341 for func in metric_functions: 342 if isinstance(func, str): 343 if func in options.metrics: 344 pars = self.get(self.metric_parameters, func, default={}) 345 pars = self.complete_function_parameters( 346 function=options.metrics[func], 347 parameters=pars, 348 general_dicts=[self.general_parameters], 349 ) 350 metrics[func] = options.metrics[func](**pars) 351 else: 352 raise ValueError( 353 f"The {func} metric is not available, please choose from {list(options.metrics.keys())}" 354 ) 355 elif isinstance(func, Metric): 356 name = "function_1" 357 i = 1 358 while name in metrics: 359 i += 1 360 name = f"function_{i}" 361 metrics[name] = func 362 else: 363 raise TypeError( 364 'The elements of parameters["general"]["metric_functions"] have to be either strings ' 365 f"from {list(options.metrics.keys())} or Metric instances; got {type(func)} instead" 366 ) 367 elif isinstance(metric_functions, dict): 368 metrics = metric_functions 369 else: 370 raise TypeError( 371 'The value at parameters["general"]["metric_functions"] can be either list, dictionary or None;' 372 f"got {type(metric_functions)} instead" 373 ) 374 return metrics 375 376 def _construct_optimizer(self) -> Optimizer: 377 """Generate an optimizer.""" 378 if "optimizer" in self.training_parameters: 379 name = self.training_parameters["optimizer"] 380 if name in options.optimizers: 381 optimizer = options.optimizers[name] 382 else: 383 raise ValueError( 384 f"The {name} optimizer is not available, please choose from {list(options.optimizers.keys())}" 385 ) 386 else: 387 optimizer = None 388 return optimizer 389 390 def _construct_predict_functions(self) -> Tuple[Callable, Callable]: 391 """Construct predict functions.""" 392 predict_function = self.training_parameters.get("predict_function", None) 393 primary_predict_function = self.training_parameters.get( 394 "primary_predict_function", None 395 ) 396 model_name = self.general_parameters.get("model_name", "") 397 threshold = self.training_parameters.get("hard_threshold", 0.5) 398 if not isinstance(predict_function, Callable): 399 if model_name in ["c2f_tcn", "c2f_transformer", "c2f_tcn_p"]: 400 if self.general_parameters["exclusive"]: 401 func = lambda x: torch.softmax(x, dim=1) 402 else: 403 func = lambda x: torch.sigmoid(x) 404 405 def primary_predict_function(x): 406 if len(x) == 1: 407 return func(x) 408 else: 409 if len(x.shape) != 4: 410 x = x.reshape((4, -1, x.shape[-2], x.shape[-1])) 411 weights = [1, 1, 1, 1] 412 ensemble_prob = func(x[0]) * weights[0] / sum(weights) 413 for i, outp_ele in enumerate(x[1:]): 414 ensemble_prob = ensemble_prob + func(outp_ele) * weights[ 415 i + 1 416 ] / sum(weights) 417 return ensemble_prob 418 419 else: 420 if model_name.startswith("ms_tcn") or model_name in [ 421 "asformer", 422 "transformer", 423 "c3d_ms", 424 "transformer_ms", 425 ]: 426 f = lambda x: x[-1] if len(x.shape) == 4 else x 427 elif model_name == "asrf": 428 429 def f(x): 430 x = x[-1] 431 # bounds = x[:, 0, :].unsqueeze(1) 432 cls = x[:, 1:, :] 433 # device = x.device 434 # x = PostProcessor("refinement_with_boundary")._refinement_with_boundary(cls.detach().cpu().numpy(), bounds.detach().cpu().numpy()) 435 # x = torch.tensor(x).to(device) 436 return cls 437 438 else: 439 f = lambda x: x 440 if self.general_parameters["exclusive"]: 441 primary_predict_function = lambda x: torch.softmax(f(x), dim=1) 442 else: 443 primary_predict_function = lambda x: torch.sigmoid(f(x)) 444 if self.general_parameters["exclusive"]: 445 predict_function = lambda x: torch.max(x.data, dim=1)[1] 446 else: 447 predict_function = lambda x: (x > threshold).int() 448 return primary_predict_function, predict_function 449 450 def _get_parameters_from_training(self) -> Dict: 451 """Get the training parameters that need to be passed to the Task.""" 452 task_training_par_names = [ 453 "lr", 454 "parallel", 455 "device", 456 "verbose", 457 "log_file", 458 "augment_train", 459 "augment_val", 460 "hard_threshold", 461 "ssl_losses", 462 "model_save_path", 463 "model_save_epochs", 464 "pseudolabel", 465 "pseudolabel_start", 466 "correction_interval", 467 "pseudolabel_alpha_f", 468 "alpha_growth_stop", 469 "num_epochs", 470 "validation_interval", 471 "ignore_tags", 472 "skip_metrics", 473 ] 474 task_training_pars = { 475 name: self.training_parameters[name] 476 for name in task_training_par_names 477 if self.check(self.training_parameters, name) 478 } 479 if self.check(self.general_parameters, "ssl"): 480 ssl_weights = [ 481 self.training_parameters["ssl_weights"][x] 482 for x in self.general_parameters["ssl"] 483 ] 484 task_training_pars["ssl_weights"] = ssl_weights 485 return task_training_pars 486 487 def _update_parameters_from_ssl(self, ssl_list: list) -> None: 488 """Update the necessary parameters given the list of SSL constructors.""" 489 if self.task is not None: 490 self.task.set_ssl_transformations([ssl.transformation for ssl in ssl_list]) 491 self.task.set_ssl_losses([ssl.loss for ssl in ssl_list]) 492 self.task.set_keep_target_none( 493 [ssl.type in ["contrastive"] for ssl in ssl_list] 494 ) 495 self.task.set_generate_ssl_input( 496 [ssl.type == "contrastive" for ssl in ssl_list] 497 ) 498 self.data_parameters["ssl_transformations"] = [ 499 ssl.transformation for ssl in ssl_list 500 ] 501 self.training_parameters["ssl_losses"] = [ssl.loss for ssl in ssl_list] 502 self.model_parameters["ssl_types"] = [ssl.type for ssl in ssl_list] 503 self.model_parameters["ssl_modules"] = [ 504 ssl.construct_module() for ssl in ssl_list 505 ] 506 self.aug_parameters["generate_ssl_input"] = [ 507 x.type == "contrastive" for x in ssl_list 508 ] 509 self.aug_parameters["keep_target_none"] = [ 510 x.type == "contrastive" for x in ssl_list 511 ] 512 513 def _set_loss_weights(self, parameters): 514 """Replace the `"dataset_inverse_weights"` blank in loss parameters with class weight values.""" 515 for k in list(parameters.keys()): 516 if parameters[k] in [ 517 "dataset_inverse_weights", 518 "dataset_proportional_weights", 519 ]: 520 if parameters[k] == "dataset_inverse_weights": 521 parameters[k] = self.class_weights 522 else: 523 parameters[k] = self.proportional_class_weights 524 print("Initializing class weights:") 525 string = " " 526 if isinstance(parameters[k], Mapping): 527 for key, val in parameters[k].items(): 528 string += ": ".join( 529 ( 530 " " + str(key), 531 ", ".join((map(lambda x: str(np.round(x, 3)), val))), 532 ) 533 ) 534 else: 535 string += ", ".join( 536 (map(lambda x: str(np.round(x, 3)), parameters[k])) 537 ) 538 print(string) 539 return parameters 540 541 def _partition_dataset( 542 self, dataset: BehaviorDataset 543 ) -> Tuple[BehaviorDataset, BehaviorDataset, BehaviorDataset]: 544 """Partition the dataset into train, validation and test subsamples.""" 545 use_test = self.get(self.training_parameters, "use_test", 0) 546 split_path = self.training_parameters.get("split_path", None) 547 partition_method = self.training_parameters.get("partition_method", "random") 548 val_frac = self.get(self.training_parameters, "val_frac", 0) 549 test_frac = self.get(self.training_parameters, "test_frac", 0) 550 save_split = self.get(self.training_parameters, "save_split", True) 551 normalize = self.get(self.training_parameters, "normalize", False) 552 skip_normalization_keys = self.training_parameters.get( 553 "skip_normalization_keys" 554 ) 555 stats = self.training_parameters.get("stats") 556 train_dataset, test_dataset, val_dataset = dataset.partition_train_test_val( 557 use_test, 558 split_path, 559 partition_method, 560 val_frac, 561 test_frac, 562 save_split, 563 normalize, 564 skip_normalization_keys, 565 stats, 566 ) 567 bs = int(self.training_parameters.get("batch_size", 32)) 568 train_dataloader, test_dataloader, val_dataloader = ( 569 self.make_dataloader(train_dataset, batch_size=bs, shuffle=True), 570 self.make_dataloader(test_dataset, batch_size=bs, shuffle=False), 571 self.make_dataloader(val_dataset, batch_size=bs, shuffle=False), 572 ) 573 return train_dataloader, test_dataloader, val_dataloader 574 575 def _initialize_task(self): 576 """Create a `dlc2action.task.universal_task.Task` instance.""" 577 dataset = self._construct_dataset() 578 self._update_data_blanks(dataset) 579 model = self._construct_model() 580 self._update_model_blanks(model) 581 ssl_list = self._construct_ssl() 582 self._update_parameters_from_ssl(ssl_list) 583 model.set_ssl(ssl_constructors=ssl_list) 584 dataset.set_ssl_transformations([ssl.transformation for ssl in ssl_list]) 585 transformer = self._construct_transformer() 586 metrics = self._construct_metrics() 587 optimizer = self._construct_optimizer() 588 primary_predict_function, predict_function = self._construct_predict_functions() 589 590 task_training_pars = self._get_parameters_from_training() 591 train_dataloader, test_dataloader, val_dataloader = self._partition_dataset( 592 dataset 593 ) 594 self.class_weights = train_dataloader.dataset.class_weights() 595 self._update_num_classes_parameter(dataset) 596 self.proportional_class_weights = train_dataloader.dataset.class_weights(True) 597 loss = self._construct_loss() 598 exclusive = self.general_parameters["exclusive"] 599 600 task_pars = { 601 "train_dataloader": train_dataloader, 602 "model": model, 603 "loss": loss, 604 "transformer": transformer, 605 "metrics": metrics, 606 "val_dataloader": val_dataloader, 607 "test_dataloader": test_dataloader, 608 "exclusive": exclusive, 609 "optimizer": optimizer, 610 "predict_function": predict_function, 611 "primary_predict_function": primary_predict_function, 612 } 613 task_pars.update(task_training_pars) 614 self.task = Task(**task_pars) 615 checkpoint_path = self.training_parameters.get("checkpoint_path", None) 616 if checkpoint_path is not None: 617 only_model = self.get(self.training_parameters, "only_load_model", False) 618 load_strict = self.get(self.training_parameters, "load_strict", True) 619 self.task.load_from_checkpoint(checkpoint_path, only_model, load_strict) 620 if ( 621 self.general_parameters["only_load_annotated"] 622 and self.general_parameters.get("ssl") is not None 623 ): 624 warnings.warn( 625 "Note that you are using SSL modules but only loading annotated files! Set " 626 "general/only_load_annotated to False to change that" 627 ) 628 629 def _update_data_blanks( 630 self, dataset: BehaviorDataset = None, remember: bool = False 631 ) -> None: 632 """Update all blanks from a dataset.""" 633 if dataset is None: 634 dataset = self.dataset() 635 self._update_dim_parameter(dataset, remember) 636 self._update_bodyparts_parameter(dataset, remember) 637 self._update_num_classes_parameter(dataset, remember) 638 self._update_len_segment_parameter(dataset, remember) 639 self._update_boundary_parameter(dataset, remember) 640 641 def _update_model_blanks(self, model: Model, remember: bool = False) -> None: 642 """Update blanks related to model parameters.""" 643 self._update_features_parameter(model, remember) 644 645 def _update_parameter(self, blank_name: str, value, remember: bool = False): 646 """Update a single blank parameter.""" 647 parameters = [ 648 self.model_parameters, 649 self.ssl_parameters, 650 self.general_parameters, 651 self.feature_parameters, 652 self.data_parameters, 653 self.training_parameters, 654 self.metric_parameters, 655 self.loss_parameters, 656 self.aug_parameters, 657 ] 658 par_names = [ 659 "model", 660 "ssl", 661 "general", 662 "feature", 663 "data", 664 "training", 665 "metrics", 666 "losses", 667 "augmentations", 668 ] 669 for names in self.blanks[blank_name]: 670 group = names[0] 671 key = names[1] 672 ind = par_names.index(group) 673 if len(names) == 3: 674 if names[2] in parameters[ind][key]: 675 parameters[ind][key][names[2]] = value 676 else: 677 if key in parameters[ind]: 678 parameters[ind][key] = value 679 for name, dic in zip(par_names, parameters): 680 for k, v in dic.items(): 681 if v == blank_name: 682 dic[k] = value 683 if [name, k] not in self.blanks[blank_name]: 684 self.blanks[blank_name].append([name, k]) 685 elif isinstance(v, Mapping): 686 for kk, vv in v.items(): 687 if vv == blank_name: 688 dic[k][kk] = value 689 if [name, k, kk] not in self.blanks[blank_name]: 690 self.blanks[blank_name].append([name, k, kk]) 691 692 def _update_features_parameter(self, model: Model, remember: bool = False) -> None: 693 """Fill the `"model_features"` blank.""" 694 value = model.features_shape() 695 self._update_parameter("model_features", value, remember) 696 697 def _update_bodyparts_parameter( 698 self, dataset: BehaviorDataset, remember: bool = False 699 ) -> None: 700 """Fill the `"dataset_bodyparts"` blank.""" 701 value = dataset.bodyparts_order() 702 self._update_parameter("dataset_bodyparts", value, remember) 703 704 def _update_dim_parameter( 705 self, dataset: BehaviorDataset, remember: bool = False 706 ) -> None: 707 """Fill the `"dataset_features"` blank.""" 708 value = dataset.features_shape() 709 self._update_parameter("dataset_features", value, remember) 710 711 def _update_boundary_parameter( 712 self, dataset: BehaviorDataset, remember: bool = False 713 ) -> None: 714 """Fill the `"dataset_features"` blank.""" 715 value = dataset._boundary_class_weight() 716 self._update_parameter("dataset_boundary_weight", value, remember) 717 718 def _update_num_classes_parameter( 719 self, dataset: BehaviorDataset, remember: bool = False 720 ) -> None: 721 """Fill in the `"dataset_classes"` blank.""" 722 value = dataset.num_classes() 723 self._update_parameter("dataset_classes", value, remember) 724 725 def _update_len_segment_parameter( 726 self, dataset: BehaviorDataset, remember: bool = False 727 ) -> None: 728 """Fill in the `"dataset_len_segment"` blank.""" 729 value = dataset.len_segment() 730 self._update_parameter("dataset_len_segment", value, remember) 731 732 def _print_behaviors(self): 733 behavior_set = self.behaviors_dict() 734 print(f"Behavior indices:") 735 for key, value in sorted(behavior_set.items()): 736 print(f" {key}: {value}") 737 738 def update_task(self, parameters: Dict) -> None: 739 """Update the `dlc2action.task.universal_task.Task` instance given the parameter updates. 740 741 Parameters 742 ---------- 743 parameters : dict 744 the dictionary of parameter updates 745 746 """ 747 pars = deepcopy(parameters) 748 # for blank_name in self.blanks: 749 # for names in self.blanks[blank_name]: 750 # group = names[0] 751 # key = names[1] 752 # if len(names) == 3: 753 # if ( 754 # group in pars 755 # and key in pars[group] 756 # and names[2] in pars[group][key] 757 # ): 758 # pars[group][key].pop(names[2]) 759 # else: 760 # if group in pars and key in pars[group]: 761 # pars[group].pop(key) 762 stay = False 763 if "ssl" in pars: 764 for key in pars["ssl"]: 765 if key in self.ssl_parameters: 766 self.ssl_parameters[key].update(pars["ssl"][key]) 767 else: 768 self.ssl_parameters[key] = pars["ssl"][key] 769 770 if "general" in pars: 771 if stay: 772 stay = False 773 if ( 774 "model_name" in pars["general"] 775 and pars["general"]["model_name"] 776 != self.general_parameters["model_name"] 777 ): 778 if "model" not in pars: 779 raise ValueError( 780 "When updating a task with a new model name you need to pass the parameters for the " 781 "new model" 782 ) 783 self.model_parameters = {} 784 self.general_parameters.update(pars["general"]) 785 data_related = [ 786 "num_classes", 787 "exclusive", 788 "data_type", 789 "annotation_type", 790 ] 791 ssl_related = ["ssl", "exclusive", "num_classes"] 792 loss_related = ["num_classes", "loss_function", "exclusive"] 793 augmentation_related = ["augmentation_type"] 794 metric_related = ["metric_functions"] 795 related_lists = [ 796 data_related, 797 ssl_related, 798 loss_related, 799 augmentation_related, 800 metric_related, 801 ] 802 names = ["data", "ssl", "losses", "augmentations", "metrics"] 803 for related_list, name in zip(related_lists, names): 804 if ( 805 any([x in pars["general"] for x in related_list]) 806 and name not in pars 807 ): 808 pars[name] = {} 809 810 if "training" in pars: 811 if "data" not in pars or not stay: 812 for x in [ 813 "to_ram", 814 "use_test", 815 "partition_method", 816 "val_frac", 817 "test_frac", 818 "save_split", 819 "batch_size", 820 "save_split", 821 ]: 822 if ( 823 x in pars["training"] 824 and pars["training"][x] != self.training_parameters[x] 825 ): 826 if "data" not in pars: 827 pars["data"] = {} 828 stay = True 829 self.training_parameters.update(pars["training"]) 830 self.task.update_parameters(self._get_parameters_from_training()) 831 832 if "data" in pars or "features" in pars: 833 for k, v in pars["data"].items(): 834 if k not in self.data_parameters or v != self.data_parameters[k]: 835 stay = True 836 for k, v in pars["features"].items(): 837 if k not in self.feature_parameters or v != self.feature_parameters[k]: 838 stay = True 839 if stay: 840 self.data_parameters.update(pars["data"]) 841 self.feature_parameters.update(pars["features"]) 842 dataset = self._construct_dataset() 843 ( 844 train_dataloader, 845 test_dataloader, 846 val_dataloader, 847 ) = self._partition_dataset(dataset) 848 self.task.set_dataloaders( 849 train_dataloader, val_dataloader, test_dataloader 850 ) 851 self.class_weights = train_dataloader.dataset.class_weights() 852 self.proportional_class_weights = ( 853 train_dataloader.dataset.class_weights(True) 854 ) 855 if "losses" not in pars: 856 pars["losses"] = {} 857 858 if "model" in pars: 859 self.model_parameters.update(pars["model"]) 860 861 self._update_data_blanks() 862 863 if "augmentations" in pars: 864 self.aug_parameters.update(pars["augmentations"]) 865 transformer = self._construct_transformer() 866 self.task.set_transformer(transformer) 867 868 if "losses" in pars: 869 for key in pars["losses"]: 870 if key in self.loss_parameters: 871 self.loss_parameters[key].update(pars["losses"][key]) 872 else: 873 self.loss_parameters[key] = pars["losses"][key] 874 self.loss_parameters.update(pars["losses"]) 875 loss = self._construct_loss() 876 self.task.set_loss(loss) 877 878 if "metrics" in pars: 879 for key in pars["metrics"]: 880 if key in self.metric_parameters: 881 self.metric_parameters[key].update(pars["metrics"][key]) 882 else: 883 self.metric_parameters[key] = pars["metrics"][key] 884 metrics = self._construct_metrics() 885 self.task.set_metrics(metrics) 886 887 self.task.set_ssl_transformations(self.data_parameters["ssl_transformations"]) 888 self._set_loss_weights( 889 pars.get("losses", {}).get(self.general_parameters["loss_function"], {}) 890 ) 891 model = self._construct_model() 892 predict_functions = self._construct_predict_functions() 893 self.task.set_predict_functions(*predict_functions) 894 self._update_model_blanks(model) 895 ssl_list = self._construct_ssl() 896 self._update_parameters_from_ssl(ssl_list) 897 model.set_ssl(ssl_constructors=ssl_list) 898 self.task.set_ssl_transformations([ssl.transformation for ssl in ssl_list]) 899 self.task.set_model(model) 900 if "training" in pars and "checkpoint_path" in pars["training"]: 901 checkpoint_path = pars["training"]["checkpoint_path"] 902 only_model = pars["training"].get("only_load_model", False) 903 load_strict = pars["training"].get("load_strict", True) 904 self.task.load_from_checkpoint(checkpoint_path, only_model, load_strict) 905 if ( 906 self.general_parameters["only_load_annotated"] 907 and self.general_parameters.get("ssl") is not None 908 ): 909 warnings.warn( 910 "Note that you are using SSL modules but only loading annotated files! Set " 911 "general/only_load_annotated to False to change that" 912 ) 913 if self.task.dataset("train").annotation_class() != "none": 914 self._print_behaviors() 915 916 def train( 917 self, 918 trial: Trial = None, 919 optimized_metric: str = None, 920 autostop_metric: str = None, 921 autostop_interval: int = 10, 922 autostop_threshold: float = 0.001, 923 loading_bar: bool = False, 924 ) -> Tuple: 925 """Train the task and return a log of epoch-average loss and metric. 926 927 You can use the autostop parameters to finish training when the parameters are not improving. It will be 928 stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than 929 the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the 930 current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared. 931 932 Parameters 933 ---------- 934 trial : Trial 935 an `optuna` trial (for hyperparameter searches) 936 optimized_metric : str 937 the name of the metric being optimized (for hyperparameter searches) 938 autostop_metric : str, optional 939 the autostop metric (can be any one of the tracked metrics of `'loss'`) 940 autostop_interval : int, default 50 941 the number of epochs to average the autostop metric over 942 autostop_threshold : float, default 0.001 943 the autostop difference threshold 944 loading_bar : bool, default False 945 whether to show a loading bar 946 947 Returns 948 ------- 949 loss_log: list 950 a list of float loss function values for each epoch 951 metrics_log: dict 952 a dictionary of metric value logs (first-level keys are 'train' and 'val', second-level keys are metric 953 names, values are lists of function values) 954 955 """ 956 to_ram = self.training_parameters.get("to_ram", False) 957 logs = self.task.train( 958 trial, 959 optimized_metric, 960 to_ram, 961 autostop_metric=autostop_metric, 962 autostop_interval=autostop_interval, 963 autostop_threshold=autostop_threshold, 964 main_task_on=self.training_parameters.get("main_task_on", True), 965 ssl_on=self.training_parameters.get("ssl_on", True), 966 temporal_subsampling_size=self.training_parameters.get( 967 "temporal_subsampling_size" 968 ), 969 loading_bar=loading_bar, 970 ) 971 return logs 972 973 def save_model(self, save_path: str) -> None: 974 """Save the model of the `dlc2action.task.universal_task.Task` instance. 975 976 Parameters 977 ---------- 978 save_path : str 979 the path to the saved file 980 981 """ 982 self.task.save_model(save_path) 983 984 def evaluate( 985 self, 986 data: Union[DataLoader, BehaviorDataset, str] = None, 987 augment_n: int = 0, 988 verbose: bool = True, 989 ) -> Tuple: 990 """Evaluate the Task model. 991 992 Parameters 993 ---------- 994 data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional 995 the data to evaluate on (if not provided, evaluate on the Task validation dataset) 996 augment_n : int, default 0 997 the number of augmentations to average results over 998 verbose : bool, default True 999 if True, the process is reported to standard output 1000 1001 Returns 1002 ------- 1003 loss : float 1004 the average value of the loss function 1005 ssl_loss : float 1006 the average value of the SSL loss function 1007 metric : dict 1008 a dictionary of average values of metric functions 1009 1010 """ 1011 res = self.task.evaluate( 1012 data, 1013 augment_n, 1014 int(self.training_parameters.get("batch_size", 32)), 1015 verbose, 1016 ) 1017 return res 1018 1019 def evaluate_prediction( 1020 self, 1021 prediction: torch.Tensor, 1022 data: Union[DataLoader, BehaviorDataset, str] = None, 1023 indices: Union[List[int], None] = None, 1024 ) -> Tuple: 1025 """Compute metrics for a prediction. 1026 1027 Parameters 1028 ---------- 1029 prediction : torch.Tensor 1030 the prediction 1031 data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional 1032 the data the prediction was made for (if not provided, take the validation dataset) 1033 1034 Returns 1035 ------- 1036 loss : float 1037 the average value of the loss function 1038 metric : dict 1039 a dictionary of average values of metric functions 1040 """ 1041 return self.task.evaluate_prediction( 1042 prediction, data, int(self.training_parameters.get("batch_size", 32)), indices 1043 ) 1044 1045 def predict( 1046 self, 1047 data: Union[DataLoader, BehaviorDataset, str], 1048 raw_output: bool = False, 1049 apply_primary_function: bool = True, 1050 augment_n: int = 0, 1051 embedding: bool = False, 1052 ) -> torch.Tensor: 1053 """Make a prediction with the Task model. 1054 1055 Parameters 1056 ---------- 1057 data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional 1058 the data to evaluate on (if not provided, evaluate on the Task validation dataset) 1059 raw_output : bool, default False 1060 if `True`, the raw predicted probabilities are returned 1061 apply_primary_function : bool, default True 1062 if `True`, the primary predict function is applied (to map the model output into a shape corresponding to 1063 the input) 1064 augment_n : int, default 0 1065 the number of augmentations to average results over 1066 embedding : bool, default False 1067 if `True`, the embedding is returned instead of the prediction 1068 1069 Returns 1070 ------- 1071 prediction : torch.Tensor 1072 a prediction for the input data 1073 1074 """ 1075 to_ram = self.training_parameters.get("to_ram", False) 1076 return self.task.predict( 1077 data, 1078 raw_output, 1079 apply_primary_function, 1080 augment_n, 1081 int(self.training_parameters.get("batch_size", 32)), 1082 to_ram, 1083 embedding=embedding, 1084 ) 1085 1086 def dataset(self, mode: str = "train") -> BehaviorDataset: 1087 """Get a dataset. 1088 1089 Parameters 1090 ---------- 1091 mode : {'train', 'val', 'test'} 1092 the dataset to get 1093 1094 Returns 1095 ------- 1096 dataset : dlc2action.data.dataset.BehaviorDataset 1097 the dataset 1098 1099 """ 1100 return self.task.dataset(mode) 1101 1102 def generate_full_length_prediction( 1103 self, 1104 dataset: Union[BehaviorDataset, str] = None, 1105 augment_n: int = 10, 1106 ) -> Dict: 1107 """Compile a prediction for the original input sequences. 1108 1109 Parameters 1110 ---------- 1111 dataset : dlc2action.data.dataset.BehaviorDataset | str, optional 1112 the dataset to generate a prediction for (if `None`, generate for the `dlc2action.task.universal_task.Task` 1113 instance validation dataset) 1114 augment_n : int, default 10 1115 the number of augmentations to average results over 1116 1117 Returns 1118 ------- 1119 prediction : dict 1120 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 1121 are prediction tensors 1122 1123 """ 1124 return self.task.generate_full_length_prediction( 1125 dataset, int(self.training_parameters.get("batch_size", 32)), augment_n 1126 ) 1127 1128 def generate_submission( 1129 self, 1130 frame_number_map_file: str, 1131 dataset: Union[BehaviorDataset, str] = None, 1132 augment_n: int = 10, 1133 ) -> Dict: 1134 """Generate a MABe-22 style submission dictionary. 1135 1136 Parameters 1137 ---------- 1138 frame_number_map_file : str 1139 path to the frame number map file 1140 dataset : BehaviorDataset, optional 1141 the dataset to generate a prediction for (if `None`, generate for the validation dataset) 1142 augment_n : int, default 10 1143 the number of augmentations to average results over 1144 1145 Returns 1146 ------- 1147 submission : dict 1148 a dictionary with frame number mapping and embeddings 1149 1150 """ 1151 return self.task.generate_submission( 1152 frame_number_map_file, 1153 dataset, 1154 int(self.training_parameters.get("batch_size", 32)), 1155 augment_n, 1156 ) 1157 1158 def behaviors_dict(self): 1159 """Get a behavior dictionary. 1160 1161 Keys are label indices and values are label names. 1162 1163 Returns 1164 ------- 1165 behaviors_dict : dict 1166 behavior dictionary 1167 1168 """ 1169 return self.task.behaviors_dict() 1170 1171 def count_classes(self, bouts: bool = False) -> Dict: 1172 """Get a dictionary of class counts in different modes. 1173 1174 Parameters 1175 ---------- 1176 bouts : bool, default False 1177 if `True`, instead of frame counts segment counts are returned 1178 1179 Returns 1180 ------- 1181 class_counts : dict 1182 a dictionary where first-level keys are "train", "val" and "test", second-level keys are 1183 class names and values are class counts (in frames) 1184 1185 """ 1186 return self.task.count_classes(bouts) 1187 1188 def _visualize_results_label( 1189 self, 1190 label: str, 1191 save_path: str = None, 1192 add_legend: bool = True, 1193 ground_truth: bool = True, 1194 hide_axes: bool = False, 1195 width: int = 10, 1196 whole_video: bool = False, 1197 transparent: bool = False, 1198 dataset: BehaviorDataset = None, 1199 smooth_interval: int = 0, 1200 title: str = None, 1201 ): 1202 return self.task._visualize_results_single( 1203 label, 1204 save_path, 1205 add_legend, 1206 ground_truth, 1207 hide_axes, 1208 width, 1209 whole_video, 1210 transparent, 1211 dataset, 1212 smooth_interval=smooth_interval, 1213 title=title, 1214 ) 1215 1216 def visualize_results( 1217 self, 1218 save_path: str = None, 1219 add_legend: bool = True, 1220 ground_truth: bool = True, 1221 colormap: str = "viridis", 1222 hide_axes: bool = False, 1223 min_classes: int = 1, 1224 width: float = 10, 1225 whole_video: bool = False, 1226 transparent: bool = False, 1227 dataset: Union[BehaviorDataset, DataLoader, str, None] = None, 1228 drop_classes: Set = None, 1229 search_classes: Set = None, 1230 smooth_interval_prediction: int = None, 1231 font_size: float = None, 1232 num_plots:int = 1, 1233 window_size:int=400, 1234 ) -> None: 1235 """Visualize random predictions. 1236 1237 Parameters 1238 ---------- 1239 save_path : str, optional 1240 the path where the plot will be saved 1241 add_legend : bool, default True 1242 if True, legend will be added to the plot 1243 ground_truth : bool, default True 1244 if True, ground truth will be added to the plot 1245 colormap : str, default 'Accent' 1246 the `matplotlib` colormap to use 1247 hide_axes : bool, default True 1248 if `True`, the axes will be hidden on the plot 1249 min_classes : int, default 1 1250 the minimum number of classes in a displayed interval 1251 width : float, default 10 1252 the width of the plot 1253 whole_video : bool, default False 1254 if `True`, whole videos are plotted instead of segments 1255 transparent : bool, default False 1256 if `True`, the background on the plot is transparent 1257 dataset : BehaviorDataset | DataLoader | str | None, optional 1258 the dataset to make the prediction for (if not provided, the validation dataset is used) 1259 drop_classes : set, optional 1260 a set of class names to not be displayed 1261 search_classes : set, optional 1262 if given, only intervals where at least one of the classes is in ground truth will be shown 1263 smooth_interval_prediction : int, optional 1264 if given, the prediction will be smoothed over the given number of frames 1265 1266 """ 1267 return self.task.visualize_results( 1268 save_path, 1269 add_legend, 1270 ground_truth, 1271 colormap, 1272 hide_axes, 1273 min_classes, 1274 width, 1275 whole_video, 1276 transparent, 1277 dataset, 1278 drop_classes, 1279 search_classes, 1280 font_size=font_size, 1281 smooth_interval_prediction=smooth_interval_prediction, 1282 num_plots=num_plots, 1283 window_size=window_size 1284 ) 1285 1286 def generate_uncertainty_score( 1287 self, 1288 classes: List, 1289 augment_n: int = 0, 1290 method: str = "least_confidence", 1291 predicted: torch.Tensor = None, 1292 behaviors_dict: Dict = None, 1293 ) -> Dict: 1294 """Generate frame-wise scores for active learning. 1295 1296 Parameters 1297 ---------- 1298 classes : list 1299 a list of class names or indices; their confidence scores will be computed separately and stacked 1300 augment_n : int, default 0 1301 the number of augmentations to average over 1302 method : {"least_confidence", "entropy"} 1303 the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if 1304 `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`) 1305 predicted : torch.Tensor, optional 1306 if given, the predictions will be used instead of the model's predictions 1307 behaviors_dict : dict, optional 1308 if given, the behaviors dictionary will be used instead of the model's behaviors dictionary 1309 1310 Returns 1311 ------- 1312 score_dicts : dict 1313 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 1314 are score tensors 1315 1316 """ 1317 return self.task.generate_uncertainty_score( 1318 classes, 1319 augment_n, 1320 int(self.training_parameters.get("batch_size", 32)), 1321 method, 1322 predicted, 1323 behaviors_dict, 1324 ) 1325 1326 def generate_bald_score( 1327 self, 1328 classes: List, 1329 augment_n: int = 0, 1330 num_models: int = 10, 1331 kernel_size: int = 11, 1332 ) -> Dict: 1333 """Generate frame-wise Bayesian Active Learning by Disagreement scores for active learning. 1334 1335 Parameters 1336 ---------- 1337 classes : list 1338 a list of class names or indices; their confidence scores will be computed separately and stacked 1339 augment_n : int, default 0 1340 the number of augmentations to average over 1341 num_models : int, default 10 1342 the number of dropout masks to apply 1343 kernel_size : int, default 11 1344 the size of the smoothing gaussian kernel 1345 1346 Returns 1347 ------- 1348 score_dicts : dict 1349 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 1350 are score tensors 1351 1352 """ 1353 return self.task.generate_bald_score( 1354 classes, 1355 augment_n, 1356 int(self.training_parameters.get("batch_size", 32)), 1357 num_models, 1358 kernel_size, 1359 ) 1360 1361 def exists(self, mode) -> bool: 1362 """Check whether the task has a train/test/validation subset. 1363 1364 Parameters 1365 ---------- 1366 mode : {"train", "val", "test"} 1367 the name of the subset to check for 1368 Returns 1369 ------- 1370 exists : bool 1371 `True` if the subset exists 1372 1373 """ 1374 dl = self.task.dataloader(mode) 1375 if dl is None: 1376 return False 1377 else: 1378 return True 1379 1380 def get_normalization_stats(self) -> Dict: 1381 """Get the pre-computed normalization stats. 1382 1383 Returns 1384 ------- 1385 normalization_stats : dict 1386 a dictionary of means and stds 1387 1388 """ 1389 return self.task.get_normalization_stats()
31class TaskDispatcher: 32 """A class that manages the interactions between config dictionaries and a Task.""" 33 34 def __init__(self, parameters: Dict) -> None: 35 """Initialize the `TaskDispatcher`. 36 37 Parameters 38 ---------- 39 parameters : dict 40 a dictionary of task parameters 41 42 """ 43 pars = deepcopy(parameters) 44 self.class_weights = None 45 self.general_parameters = pars.get("general", {}) 46 self.data_parameters = pars.get("data", {}) 47 self.model_parameters = pars.get("model", {}) 48 self.training_parameters = pars.get("training", {}) 49 self.loss_parameters = pars.get("losses", {}) 50 self.metric_parameters = pars.get("metrics", {}) 51 self.ssl_parameters = pars.get("ssl", {}) 52 self.aug_parameters = pars.get("augmentations", {}) 53 self.feature_parameters = pars.get("features", {}) 54 self.blanks = {blank: [] for blank in options.blanks} 55 56 self.task = None 57 self._initialize_task() 58 self._print_behaviors() 59 60 @staticmethod 61 def complete_function_parameters(parameters, function, general_dicts: List) -> Dict: 62 """Complete a parameter dictionary with values from other dictionaries if required by a function. 63 64 Parameters 65 ---------- 66 parameters : dict 67 the function parameters dictionary 68 function : callable 69 the function to be inspected 70 general_dicts : list 71 a list of dictionaries where the missing values will be pulled from 72 73 Returns 74 ------- 75 parameters : dict 76 the updated parameter dictionary 77 78 """ 79 parameter_names = inspect.getfullargspec(function).args 80 for param in parameter_names: 81 for dic in general_dicts: 82 if param not in parameters and param in dic: 83 parameters[param] = dic[param] 84 return parameters 85 86 @staticmethod 87 def complete_dataset_parameters( 88 parameters: dict, 89 general_dict: dict, 90 data_type: str, 91 annotation_type: str, 92 ) -> Dict: 93 """Complete a parameter dictionary with values from other dictionaries if required by a dataset. 94 95 Parameters 96 ---------- 97 parameters : dict 98 the function parameters dictionary 99 general_dict : dict 100 the dictionary where the missing values will be pulled from 101 data_type : str 102 the input type of the dataset 103 annotation_type : str 104 the annotation type of the dataset 105 106 Returns 107 ------- 108 parameters : dict 109 the updated parameter dictionary 110 111 """ 112 params = deepcopy(parameters) 113 parameter_names = BehaviorDataset.get_parameters(data_type, annotation_type) 114 for param in parameter_names: 115 if param not in params and param in general_dict: 116 params[param] = general_dict[param] 117 return params 118 119 @staticmethod 120 def check(parameters: Dict, name: str) -> bool: 121 """Check whether there is a non-`None` value under the name key in the parameters dictionary. 122 123 Parameters 124 ---------- 125 parameters : dict 126 the dictionary to check 127 name : str 128 the key to check 129 130 Returns 131 ------- 132 result : bool 133 True if a non-`None` value exists 134 135 """ 136 if name in parameters and parameters[name] is not None: 137 return True 138 else: 139 return False 140 141 @staticmethod 142 def get(parameters: Dict, name: str, default): 143 """Get the value under the name key or the default if it is `None` or does not exist. 144 145 Parameters 146 ---------- 147 parameters : dict 148 the dictionary to check 149 name : str 150 the key to check 151 default 152 the default value to return 153 154 Returns 155 ------- 156 value 157 the resulting value 158 159 """ 160 if TaskDispatcher.check(parameters, name): 161 return parameters[name] 162 else: 163 return default 164 165 @staticmethod 166 def make_dataloader( 167 dataset: BehaviorDataset, batch_size: int = 32, shuffle: bool = False 168 ) -> DataLoader: 169 """Make a torch dataloader from a dataset. 170 171 Parameters 172 ---------- 173 dataset : dlc2action.data.dataset.BehaviorDataset 174 the dataset 175 batch_size : int 176 the batch size 177 shuffle : bool 178 whether to shuffle the dataset 179 180 Returns 181 ------- 182 dataloader : DataLoader 183 the dataloader (or `None` if the length of the dataset is 0) 184 185 """ 186 if dataset is None or len(dataset) == 0: 187 return None 188 else: 189 return DataLoader(dataset, batch_size=int(batch_size), shuffle=shuffle) 190 191 def _construct_ssl(self) -> List: 192 """Generate SSL constructors.""" 193 ssl_list = deepcopy(self.general_parameters.get("ssl", None)) 194 model_name = self.general_parameters.get("model_name", "") 195 # ssl_constructors = options.ssl_constructors if not "tcn" in model_name else options.ssl_constructors_tcn 196 ssl_constructors = options.ssl_constructors 197 if not isinstance(ssl_list, Iterable): 198 ssl_list = [ssl_list] 199 for i, ssl in enumerate(ssl_list): 200 if type(ssl) is str: 201 if ssl in ssl_constructors: 202 pars = self.get(self.ssl_parameters, ssl, default={}) 203 pars = self.complete_function_parameters( 204 parameters=pars, 205 function=ssl_constructors[ssl], 206 general_dicts=[ 207 self.model_parameters, 208 self.data_parameters, 209 self.general_parameters, 210 ], 211 ) 212 ssl_list[i] = ssl_constructors[ssl](**pars) 213 else: 214 raise ValueError( 215 f"The {ssl} SSL is not available, please choose from {list(ssl_constructors.keys())}" 216 ) 217 elif ssl is None: 218 ssl_list[i] = EmptySSL() 219 elif not isinstance(ssl, SSLConstructor): 220 raise TypeError( 221 f"The ssl parameter has to be a list of either strings, SSLConstructor instances or None, got {type(ssl)}" 222 ) 223 return ssl_list 224 225 def _construct_model(self) -> Model: 226 """Generate a model.""" 227 if self.check(self.general_parameters, "model"): 228 pars = self.complete_function_parameters( 229 function=LoadedModel, 230 parameters=self.model_parameters, 231 general_dicts=[self.general_parameters], 232 ) 233 model = LoadedModel(**pars) 234 elif self.check(self.general_parameters, "model_name"): 235 name = self.general_parameters["model_name"] 236 if name in options.models: 237 pars = self.complete_function_parameters( 238 function=options.models[name], 239 parameters=self.model_parameters, 240 general_dicts=[self.general_parameters], 241 ) 242 model = options.models[name](**pars) 243 else: 244 raise ValueError( 245 f"The {name} model is not available, please choose from {list(options.models.keys())}" 246 ) 247 else: 248 raise ValueError( 249 "You need to provide either a model or its name in the model_parameters!" 250 ) 251 252 if self.get(self.training_parameters, "freeze_features", False): 253 model.freeze_feature_extractor() 254 return model 255 256 def _construct_dataset(self) -> BehaviorDataset: 257 """ 258 Generate a dataset 259 """ 260 data_type = self.general_parameters.get("data_type", None) 261 if data_type is None: 262 raise ValueError( 263 "You need to provide the data_type parameter in the data parameters!" 264 ) 265 annotation_type = self.get(self.general_parameters, "annotation_type", "none") 266 feature_extraction = self.general_parameters.get("feature_extraction", "none") 267 if feature_extraction is None: 268 raise ValueError( 269 "You need to provide the feature_extraction parameter in the data parameters!" 270 ) 271 feature_extraction_pars = self.complete_function_parameters( 272 self.feature_parameters, 273 options.feature_extractors[feature_extraction], 274 [self.general_parameters, self.data_parameters], 275 ) 276 277 pars = self.complete_dataset_parameters( 278 self.data_parameters, 279 self.general_parameters, 280 data_type=data_type, 281 annotation_type=annotation_type, 282 ) 283 pars["feature_extraction_pars"] = feature_extraction_pars 284 dataset = BehaviorDataset(**pars) 285 286 if self.get(self.general_parameters, "save_dataset", default=False): 287 save_data_path = self.data_parameters.get("saved_data_path", None) 288 dataset.save(save_path=save_data_path) 289 290 return dataset 291 292 def _construct_transformer(self) -> Transformer: 293 """Generate a transformer.""" 294 features = self.general_parameters["feature_extraction"] 295 name = options.extractor_to_transformer[features] 296 if name in options.transformers: 297 transformer_class = options.transformers[name] 298 pars = self.complete_function_parameters( 299 function=transformer_class, 300 parameters=self.aug_parameters, 301 general_dicts=[self.general_parameters], 302 ) 303 transformer = transformer_class(**pars) 304 else: 305 raise ValueError(f"The {name} transformer is not available") 306 return transformer 307 308 def _construct_loss(self) -> torch.nn.Module: 309 """Generate a loss function.""" 310 if "loss_function" not in self.general_parameters: 311 raise ValueError( 312 'Please add a "loss_function" key to the parameters["general"] dictionary (either a name ' 313 f"from {list(options.losses.keys())} or a function)" 314 ) 315 else: 316 loss_function = self.general_parameters["loss_function"] 317 if type(loss_function) is str: 318 if loss_function in options.losses: 319 pars = self.get(self.loss_parameters, loss_function, default={}) 320 pars = self._set_loss_weights(pars) 321 pars = self.complete_function_parameters( 322 function=options.losses[loss_function], 323 parameters=pars, 324 general_dicts=[self.general_parameters], 325 ) 326 loss = options.losses[loss_function](**pars) 327 else: 328 raise ValueError( 329 f"The {loss_function} loss is not available, please choose from {list(options.losses.keys())}" 330 ) 331 else: 332 loss = loss_function 333 return loss 334 335 def _construct_metrics(self) -> List: 336 """Generate the metric.""" 337 metric_functions = self.get( 338 self.general_parameters, "metric_functions", default={} 339 ) 340 if isinstance(metric_functions, Iterable): 341 metrics = {} 342 for func in metric_functions: 343 if isinstance(func, str): 344 if func in options.metrics: 345 pars = self.get(self.metric_parameters, func, default={}) 346 pars = self.complete_function_parameters( 347 function=options.metrics[func], 348 parameters=pars, 349 general_dicts=[self.general_parameters], 350 ) 351 metrics[func] = options.metrics[func](**pars) 352 else: 353 raise ValueError( 354 f"The {func} metric is not available, please choose from {list(options.metrics.keys())}" 355 ) 356 elif isinstance(func, Metric): 357 name = "function_1" 358 i = 1 359 while name in metrics: 360 i += 1 361 name = f"function_{i}" 362 metrics[name] = func 363 else: 364 raise TypeError( 365 'The elements of parameters["general"]["metric_functions"] have to be either strings ' 366 f"from {list(options.metrics.keys())} or Metric instances; got {type(func)} instead" 367 ) 368 elif isinstance(metric_functions, dict): 369 metrics = metric_functions 370 else: 371 raise TypeError( 372 'The value at parameters["general"]["metric_functions"] can be either list, dictionary or None;' 373 f"got {type(metric_functions)} instead" 374 ) 375 return metrics 376 377 def _construct_optimizer(self) -> Optimizer: 378 """Generate an optimizer.""" 379 if "optimizer" in self.training_parameters: 380 name = self.training_parameters["optimizer"] 381 if name in options.optimizers: 382 optimizer = options.optimizers[name] 383 else: 384 raise ValueError( 385 f"The {name} optimizer is not available, please choose from {list(options.optimizers.keys())}" 386 ) 387 else: 388 optimizer = None 389 return optimizer 390 391 def _construct_predict_functions(self) -> Tuple[Callable, Callable]: 392 """Construct predict functions.""" 393 predict_function = self.training_parameters.get("predict_function", None) 394 primary_predict_function = self.training_parameters.get( 395 "primary_predict_function", None 396 ) 397 model_name = self.general_parameters.get("model_name", "") 398 threshold = self.training_parameters.get("hard_threshold", 0.5) 399 if not isinstance(predict_function, Callable): 400 if model_name in ["c2f_tcn", "c2f_transformer", "c2f_tcn_p"]: 401 if self.general_parameters["exclusive"]: 402 func = lambda x: torch.softmax(x, dim=1) 403 else: 404 func = lambda x: torch.sigmoid(x) 405 406 def primary_predict_function(x): 407 if len(x) == 1: 408 return func(x) 409 else: 410 if len(x.shape) != 4: 411 x = x.reshape((4, -1, x.shape[-2], x.shape[-1])) 412 weights = [1, 1, 1, 1] 413 ensemble_prob = func(x[0]) * weights[0] / sum(weights) 414 for i, outp_ele in enumerate(x[1:]): 415 ensemble_prob = ensemble_prob + func(outp_ele) * weights[ 416 i + 1 417 ] / sum(weights) 418 return ensemble_prob 419 420 else: 421 if model_name.startswith("ms_tcn") or model_name in [ 422 "asformer", 423 "transformer", 424 "c3d_ms", 425 "transformer_ms", 426 ]: 427 f = lambda x: x[-1] if len(x.shape) == 4 else x 428 elif model_name == "asrf": 429 430 def f(x): 431 x = x[-1] 432 # bounds = x[:, 0, :].unsqueeze(1) 433 cls = x[:, 1:, :] 434 # device = x.device 435 # x = PostProcessor("refinement_with_boundary")._refinement_with_boundary(cls.detach().cpu().numpy(), bounds.detach().cpu().numpy()) 436 # x = torch.tensor(x).to(device) 437 return cls 438 439 else: 440 f = lambda x: x 441 if self.general_parameters["exclusive"]: 442 primary_predict_function = lambda x: torch.softmax(f(x), dim=1) 443 else: 444 primary_predict_function = lambda x: torch.sigmoid(f(x)) 445 if self.general_parameters["exclusive"]: 446 predict_function = lambda x: torch.max(x.data, dim=1)[1] 447 else: 448 predict_function = lambda x: (x > threshold).int() 449 return primary_predict_function, predict_function 450 451 def _get_parameters_from_training(self) -> Dict: 452 """Get the training parameters that need to be passed to the Task.""" 453 task_training_par_names = [ 454 "lr", 455 "parallel", 456 "device", 457 "verbose", 458 "log_file", 459 "augment_train", 460 "augment_val", 461 "hard_threshold", 462 "ssl_losses", 463 "model_save_path", 464 "model_save_epochs", 465 "pseudolabel", 466 "pseudolabel_start", 467 "correction_interval", 468 "pseudolabel_alpha_f", 469 "alpha_growth_stop", 470 "num_epochs", 471 "validation_interval", 472 "ignore_tags", 473 "skip_metrics", 474 ] 475 task_training_pars = { 476 name: self.training_parameters[name] 477 for name in task_training_par_names 478 if self.check(self.training_parameters, name) 479 } 480 if self.check(self.general_parameters, "ssl"): 481 ssl_weights = [ 482 self.training_parameters["ssl_weights"][x] 483 for x in self.general_parameters["ssl"] 484 ] 485 task_training_pars["ssl_weights"] = ssl_weights 486 return task_training_pars 487 488 def _update_parameters_from_ssl(self, ssl_list: list) -> None: 489 """Update the necessary parameters given the list of SSL constructors.""" 490 if self.task is not None: 491 self.task.set_ssl_transformations([ssl.transformation for ssl in ssl_list]) 492 self.task.set_ssl_losses([ssl.loss for ssl in ssl_list]) 493 self.task.set_keep_target_none( 494 [ssl.type in ["contrastive"] for ssl in ssl_list] 495 ) 496 self.task.set_generate_ssl_input( 497 [ssl.type == "contrastive" for ssl in ssl_list] 498 ) 499 self.data_parameters["ssl_transformations"] = [ 500 ssl.transformation for ssl in ssl_list 501 ] 502 self.training_parameters["ssl_losses"] = [ssl.loss for ssl in ssl_list] 503 self.model_parameters["ssl_types"] = [ssl.type for ssl in ssl_list] 504 self.model_parameters["ssl_modules"] = [ 505 ssl.construct_module() for ssl in ssl_list 506 ] 507 self.aug_parameters["generate_ssl_input"] = [ 508 x.type == "contrastive" for x in ssl_list 509 ] 510 self.aug_parameters["keep_target_none"] = [ 511 x.type == "contrastive" for x in ssl_list 512 ] 513 514 def _set_loss_weights(self, parameters): 515 """Replace the `"dataset_inverse_weights"` blank in loss parameters with class weight values.""" 516 for k in list(parameters.keys()): 517 if parameters[k] in [ 518 "dataset_inverse_weights", 519 "dataset_proportional_weights", 520 ]: 521 if parameters[k] == "dataset_inverse_weights": 522 parameters[k] = self.class_weights 523 else: 524 parameters[k] = self.proportional_class_weights 525 print("Initializing class weights:") 526 string = " " 527 if isinstance(parameters[k], Mapping): 528 for key, val in parameters[k].items(): 529 string += ": ".join( 530 ( 531 " " + str(key), 532 ", ".join((map(lambda x: str(np.round(x, 3)), val))), 533 ) 534 ) 535 else: 536 string += ", ".join( 537 (map(lambda x: str(np.round(x, 3)), parameters[k])) 538 ) 539 print(string) 540 return parameters 541 542 def _partition_dataset( 543 self, dataset: BehaviorDataset 544 ) -> Tuple[BehaviorDataset, BehaviorDataset, BehaviorDataset]: 545 """Partition the dataset into train, validation and test subsamples.""" 546 use_test = self.get(self.training_parameters, "use_test", 0) 547 split_path = self.training_parameters.get("split_path", None) 548 partition_method = self.training_parameters.get("partition_method", "random") 549 val_frac = self.get(self.training_parameters, "val_frac", 0) 550 test_frac = self.get(self.training_parameters, "test_frac", 0) 551 save_split = self.get(self.training_parameters, "save_split", True) 552 normalize = self.get(self.training_parameters, "normalize", False) 553 skip_normalization_keys = self.training_parameters.get( 554 "skip_normalization_keys" 555 ) 556 stats = self.training_parameters.get("stats") 557 train_dataset, test_dataset, val_dataset = dataset.partition_train_test_val( 558 use_test, 559 split_path, 560 partition_method, 561 val_frac, 562 test_frac, 563 save_split, 564 normalize, 565 skip_normalization_keys, 566 stats, 567 ) 568 bs = int(self.training_parameters.get("batch_size", 32)) 569 train_dataloader, test_dataloader, val_dataloader = ( 570 self.make_dataloader(train_dataset, batch_size=bs, shuffle=True), 571 self.make_dataloader(test_dataset, batch_size=bs, shuffle=False), 572 self.make_dataloader(val_dataset, batch_size=bs, shuffle=False), 573 ) 574 return train_dataloader, test_dataloader, val_dataloader 575 576 def _initialize_task(self): 577 """Create a `dlc2action.task.universal_task.Task` instance.""" 578 dataset = self._construct_dataset() 579 self._update_data_blanks(dataset) 580 model = self._construct_model() 581 self._update_model_blanks(model) 582 ssl_list = self._construct_ssl() 583 self._update_parameters_from_ssl(ssl_list) 584 model.set_ssl(ssl_constructors=ssl_list) 585 dataset.set_ssl_transformations([ssl.transformation for ssl in ssl_list]) 586 transformer = self._construct_transformer() 587 metrics = self._construct_metrics() 588 optimizer = self._construct_optimizer() 589 primary_predict_function, predict_function = self._construct_predict_functions() 590 591 task_training_pars = self._get_parameters_from_training() 592 train_dataloader, test_dataloader, val_dataloader = self._partition_dataset( 593 dataset 594 ) 595 self.class_weights = train_dataloader.dataset.class_weights() 596 self._update_num_classes_parameter(dataset) 597 self.proportional_class_weights = train_dataloader.dataset.class_weights(True) 598 loss = self._construct_loss() 599 exclusive = self.general_parameters["exclusive"] 600 601 task_pars = { 602 "train_dataloader": train_dataloader, 603 "model": model, 604 "loss": loss, 605 "transformer": transformer, 606 "metrics": metrics, 607 "val_dataloader": val_dataloader, 608 "test_dataloader": test_dataloader, 609 "exclusive": exclusive, 610 "optimizer": optimizer, 611 "predict_function": predict_function, 612 "primary_predict_function": primary_predict_function, 613 } 614 task_pars.update(task_training_pars) 615 self.task = Task(**task_pars) 616 checkpoint_path = self.training_parameters.get("checkpoint_path", None) 617 if checkpoint_path is not None: 618 only_model = self.get(self.training_parameters, "only_load_model", False) 619 load_strict = self.get(self.training_parameters, "load_strict", True) 620 self.task.load_from_checkpoint(checkpoint_path, only_model, load_strict) 621 if ( 622 self.general_parameters["only_load_annotated"] 623 and self.general_parameters.get("ssl") is not None 624 ): 625 warnings.warn( 626 "Note that you are using SSL modules but only loading annotated files! Set " 627 "general/only_load_annotated to False to change that" 628 ) 629 630 def _update_data_blanks( 631 self, dataset: BehaviorDataset = None, remember: bool = False 632 ) -> None: 633 """Update all blanks from a dataset.""" 634 if dataset is None: 635 dataset = self.dataset() 636 self._update_dim_parameter(dataset, remember) 637 self._update_bodyparts_parameter(dataset, remember) 638 self._update_num_classes_parameter(dataset, remember) 639 self._update_len_segment_parameter(dataset, remember) 640 self._update_boundary_parameter(dataset, remember) 641 642 def _update_model_blanks(self, model: Model, remember: bool = False) -> None: 643 """Update blanks related to model parameters.""" 644 self._update_features_parameter(model, remember) 645 646 def _update_parameter(self, blank_name: str, value, remember: bool = False): 647 """Update a single blank parameter.""" 648 parameters = [ 649 self.model_parameters, 650 self.ssl_parameters, 651 self.general_parameters, 652 self.feature_parameters, 653 self.data_parameters, 654 self.training_parameters, 655 self.metric_parameters, 656 self.loss_parameters, 657 self.aug_parameters, 658 ] 659 par_names = [ 660 "model", 661 "ssl", 662 "general", 663 "feature", 664 "data", 665 "training", 666 "metrics", 667 "losses", 668 "augmentations", 669 ] 670 for names in self.blanks[blank_name]: 671 group = names[0] 672 key = names[1] 673 ind = par_names.index(group) 674 if len(names) == 3: 675 if names[2] in parameters[ind][key]: 676 parameters[ind][key][names[2]] = value 677 else: 678 if key in parameters[ind]: 679 parameters[ind][key] = value 680 for name, dic in zip(par_names, parameters): 681 for k, v in dic.items(): 682 if v == blank_name: 683 dic[k] = value 684 if [name, k] not in self.blanks[blank_name]: 685 self.blanks[blank_name].append([name, k]) 686 elif isinstance(v, Mapping): 687 for kk, vv in v.items(): 688 if vv == blank_name: 689 dic[k][kk] = value 690 if [name, k, kk] not in self.blanks[blank_name]: 691 self.blanks[blank_name].append([name, k, kk]) 692 693 def _update_features_parameter(self, model: Model, remember: bool = False) -> None: 694 """Fill the `"model_features"` blank.""" 695 value = model.features_shape() 696 self._update_parameter("model_features", value, remember) 697 698 def _update_bodyparts_parameter( 699 self, dataset: BehaviorDataset, remember: bool = False 700 ) -> None: 701 """Fill the `"dataset_bodyparts"` blank.""" 702 value = dataset.bodyparts_order() 703 self._update_parameter("dataset_bodyparts", value, remember) 704 705 def _update_dim_parameter( 706 self, dataset: BehaviorDataset, remember: bool = False 707 ) -> None: 708 """Fill the `"dataset_features"` blank.""" 709 value = dataset.features_shape() 710 self._update_parameter("dataset_features", value, remember) 711 712 def _update_boundary_parameter( 713 self, dataset: BehaviorDataset, remember: bool = False 714 ) -> None: 715 """Fill the `"dataset_features"` blank.""" 716 value = dataset._boundary_class_weight() 717 self._update_parameter("dataset_boundary_weight", value, remember) 718 719 def _update_num_classes_parameter( 720 self, dataset: BehaviorDataset, remember: bool = False 721 ) -> None: 722 """Fill in the `"dataset_classes"` blank.""" 723 value = dataset.num_classes() 724 self._update_parameter("dataset_classes", value, remember) 725 726 def _update_len_segment_parameter( 727 self, dataset: BehaviorDataset, remember: bool = False 728 ) -> None: 729 """Fill in the `"dataset_len_segment"` blank.""" 730 value = dataset.len_segment() 731 self._update_parameter("dataset_len_segment", value, remember) 732 733 def _print_behaviors(self): 734 behavior_set = self.behaviors_dict() 735 print(f"Behavior indices:") 736 for key, value in sorted(behavior_set.items()): 737 print(f" {key}: {value}") 738 739 def update_task(self, parameters: Dict) -> None: 740 """Update the `dlc2action.task.universal_task.Task` instance given the parameter updates. 741 742 Parameters 743 ---------- 744 parameters : dict 745 the dictionary of parameter updates 746 747 """ 748 pars = deepcopy(parameters) 749 # for blank_name in self.blanks: 750 # for names in self.blanks[blank_name]: 751 # group = names[0] 752 # key = names[1] 753 # if len(names) == 3: 754 # if ( 755 # group in pars 756 # and key in pars[group] 757 # and names[2] in pars[group][key] 758 # ): 759 # pars[group][key].pop(names[2]) 760 # else: 761 # if group in pars and key in pars[group]: 762 # pars[group].pop(key) 763 stay = False 764 if "ssl" in pars: 765 for key in pars["ssl"]: 766 if key in self.ssl_parameters: 767 self.ssl_parameters[key].update(pars["ssl"][key]) 768 else: 769 self.ssl_parameters[key] = pars["ssl"][key] 770 771 if "general" in pars: 772 if stay: 773 stay = False 774 if ( 775 "model_name" in pars["general"] 776 and pars["general"]["model_name"] 777 != self.general_parameters["model_name"] 778 ): 779 if "model" not in pars: 780 raise ValueError( 781 "When updating a task with a new model name you need to pass the parameters for the " 782 "new model" 783 ) 784 self.model_parameters = {} 785 self.general_parameters.update(pars["general"]) 786 data_related = [ 787 "num_classes", 788 "exclusive", 789 "data_type", 790 "annotation_type", 791 ] 792 ssl_related = ["ssl", "exclusive", "num_classes"] 793 loss_related = ["num_classes", "loss_function", "exclusive"] 794 augmentation_related = ["augmentation_type"] 795 metric_related = ["metric_functions"] 796 related_lists = [ 797 data_related, 798 ssl_related, 799 loss_related, 800 augmentation_related, 801 metric_related, 802 ] 803 names = ["data", "ssl", "losses", "augmentations", "metrics"] 804 for related_list, name in zip(related_lists, names): 805 if ( 806 any([x in pars["general"] for x in related_list]) 807 and name not in pars 808 ): 809 pars[name] = {} 810 811 if "training" in pars: 812 if "data" not in pars or not stay: 813 for x in [ 814 "to_ram", 815 "use_test", 816 "partition_method", 817 "val_frac", 818 "test_frac", 819 "save_split", 820 "batch_size", 821 "save_split", 822 ]: 823 if ( 824 x in pars["training"] 825 and pars["training"][x] != self.training_parameters[x] 826 ): 827 if "data" not in pars: 828 pars["data"] = {} 829 stay = True 830 self.training_parameters.update(pars["training"]) 831 self.task.update_parameters(self._get_parameters_from_training()) 832 833 if "data" in pars or "features" in pars: 834 for k, v in pars["data"].items(): 835 if k not in self.data_parameters or v != self.data_parameters[k]: 836 stay = True 837 for k, v in pars["features"].items(): 838 if k not in self.feature_parameters or v != self.feature_parameters[k]: 839 stay = True 840 if stay: 841 self.data_parameters.update(pars["data"]) 842 self.feature_parameters.update(pars["features"]) 843 dataset = self._construct_dataset() 844 ( 845 train_dataloader, 846 test_dataloader, 847 val_dataloader, 848 ) = self._partition_dataset(dataset) 849 self.task.set_dataloaders( 850 train_dataloader, val_dataloader, test_dataloader 851 ) 852 self.class_weights = train_dataloader.dataset.class_weights() 853 self.proportional_class_weights = ( 854 train_dataloader.dataset.class_weights(True) 855 ) 856 if "losses" not in pars: 857 pars["losses"] = {} 858 859 if "model" in pars: 860 self.model_parameters.update(pars["model"]) 861 862 self._update_data_blanks() 863 864 if "augmentations" in pars: 865 self.aug_parameters.update(pars["augmentations"]) 866 transformer = self._construct_transformer() 867 self.task.set_transformer(transformer) 868 869 if "losses" in pars: 870 for key in pars["losses"]: 871 if key in self.loss_parameters: 872 self.loss_parameters[key].update(pars["losses"][key]) 873 else: 874 self.loss_parameters[key] = pars["losses"][key] 875 self.loss_parameters.update(pars["losses"]) 876 loss = self._construct_loss() 877 self.task.set_loss(loss) 878 879 if "metrics" in pars: 880 for key in pars["metrics"]: 881 if key in self.metric_parameters: 882 self.metric_parameters[key].update(pars["metrics"][key]) 883 else: 884 self.metric_parameters[key] = pars["metrics"][key] 885 metrics = self._construct_metrics() 886 self.task.set_metrics(metrics) 887 888 self.task.set_ssl_transformations(self.data_parameters["ssl_transformations"]) 889 self._set_loss_weights( 890 pars.get("losses", {}).get(self.general_parameters["loss_function"], {}) 891 ) 892 model = self._construct_model() 893 predict_functions = self._construct_predict_functions() 894 self.task.set_predict_functions(*predict_functions) 895 self._update_model_blanks(model) 896 ssl_list = self._construct_ssl() 897 self._update_parameters_from_ssl(ssl_list) 898 model.set_ssl(ssl_constructors=ssl_list) 899 self.task.set_ssl_transformations([ssl.transformation for ssl in ssl_list]) 900 self.task.set_model(model) 901 if "training" in pars and "checkpoint_path" in pars["training"]: 902 checkpoint_path = pars["training"]["checkpoint_path"] 903 only_model = pars["training"].get("only_load_model", False) 904 load_strict = pars["training"].get("load_strict", True) 905 self.task.load_from_checkpoint(checkpoint_path, only_model, load_strict) 906 if ( 907 self.general_parameters["only_load_annotated"] 908 and self.general_parameters.get("ssl") is not None 909 ): 910 warnings.warn( 911 "Note that you are using SSL modules but only loading annotated files! Set " 912 "general/only_load_annotated to False to change that" 913 ) 914 if self.task.dataset("train").annotation_class() != "none": 915 self._print_behaviors() 916 917 def train( 918 self, 919 trial: Trial = None, 920 optimized_metric: str = None, 921 autostop_metric: str = None, 922 autostop_interval: int = 10, 923 autostop_threshold: float = 0.001, 924 loading_bar: bool = False, 925 ) -> Tuple: 926 """Train the task and return a log of epoch-average loss and metric. 927 928 You can use the autostop parameters to finish training when the parameters are not improving. It will be 929 stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than 930 the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the 931 current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared. 932 933 Parameters 934 ---------- 935 trial : Trial 936 an `optuna` trial (for hyperparameter searches) 937 optimized_metric : str 938 the name of the metric being optimized (for hyperparameter searches) 939 autostop_metric : str, optional 940 the autostop metric (can be any one of the tracked metrics of `'loss'`) 941 autostop_interval : int, default 50 942 the number of epochs to average the autostop metric over 943 autostop_threshold : float, default 0.001 944 the autostop difference threshold 945 loading_bar : bool, default False 946 whether to show a loading bar 947 948 Returns 949 ------- 950 loss_log: list 951 a list of float loss function values for each epoch 952 metrics_log: dict 953 a dictionary of metric value logs (first-level keys are 'train' and 'val', second-level keys are metric 954 names, values are lists of function values) 955 956 """ 957 to_ram = self.training_parameters.get("to_ram", False) 958 logs = self.task.train( 959 trial, 960 optimized_metric, 961 to_ram, 962 autostop_metric=autostop_metric, 963 autostop_interval=autostop_interval, 964 autostop_threshold=autostop_threshold, 965 main_task_on=self.training_parameters.get("main_task_on", True), 966 ssl_on=self.training_parameters.get("ssl_on", True), 967 temporal_subsampling_size=self.training_parameters.get( 968 "temporal_subsampling_size" 969 ), 970 loading_bar=loading_bar, 971 ) 972 return logs 973 974 def save_model(self, save_path: str) -> None: 975 """Save the model of the `dlc2action.task.universal_task.Task` instance. 976 977 Parameters 978 ---------- 979 save_path : str 980 the path to the saved file 981 982 """ 983 self.task.save_model(save_path) 984 985 def evaluate( 986 self, 987 data: Union[DataLoader, BehaviorDataset, str] = None, 988 augment_n: int = 0, 989 verbose: bool = True, 990 ) -> Tuple: 991 """Evaluate the Task model. 992 993 Parameters 994 ---------- 995 data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional 996 the data to evaluate on (if not provided, evaluate on the Task validation dataset) 997 augment_n : int, default 0 998 the number of augmentations to average results over 999 verbose : bool, default True 1000 if True, the process is reported to standard output 1001 1002 Returns 1003 ------- 1004 loss : float 1005 the average value of the loss function 1006 ssl_loss : float 1007 the average value of the SSL loss function 1008 metric : dict 1009 a dictionary of average values of metric functions 1010 1011 """ 1012 res = self.task.evaluate( 1013 data, 1014 augment_n, 1015 int(self.training_parameters.get("batch_size", 32)), 1016 verbose, 1017 ) 1018 return res 1019 1020 def evaluate_prediction( 1021 self, 1022 prediction: torch.Tensor, 1023 data: Union[DataLoader, BehaviorDataset, str] = None, 1024 indices: Union[List[int], None] = None, 1025 ) -> Tuple: 1026 """Compute metrics for a prediction. 1027 1028 Parameters 1029 ---------- 1030 prediction : torch.Tensor 1031 the prediction 1032 data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional 1033 the data the prediction was made for (if not provided, take the validation dataset) 1034 1035 Returns 1036 ------- 1037 loss : float 1038 the average value of the loss function 1039 metric : dict 1040 a dictionary of average values of metric functions 1041 """ 1042 return self.task.evaluate_prediction( 1043 prediction, data, int(self.training_parameters.get("batch_size", 32)), indices 1044 ) 1045 1046 def predict( 1047 self, 1048 data: Union[DataLoader, BehaviorDataset, str], 1049 raw_output: bool = False, 1050 apply_primary_function: bool = True, 1051 augment_n: int = 0, 1052 embedding: bool = False, 1053 ) -> torch.Tensor: 1054 """Make a prediction with the Task model. 1055 1056 Parameters 1057 ---------- 1058 data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional 1059 the data to evaluate on (if not provided, evaluate on the Task validation dataset) 1060 raw_output : bool, default False 1061 if `True`, the raw predicted probabilities are returned 1062 apply_primary_function : bool, default True 1063 if `True`, the primary predict function is applied (to map the model output into a shape corresponding to 1064 the input) 1065 augment_n : int, default 0 1066 the number of augmentations to average results over 1067 embedding : bool, default False 1068 if `True`, the embedding is returned instead of the prediction 1069 1070 Returns 1071 ------- 1072 prediction : torch.Tensor 1073 a prediction for the input data 1074 1075 """ 1076 to_ram = self.training_parameters.get("to_ram", False) 1077 return self.task.predict( 1078 data, 1079 raw_output, 1080 apply_primary_function, 1081 augment_n, 1082 int(self.training_parameters.get("batch_size", 32)), 1083 to_ram, 1084 embedding=embedding, 1085 ) 1086 1087 def dataset(self, mode: str = "train") -> BehaviorDataset: 1088 """Get a dataset. 1089 1090 Parameters 1091 ---------- 1092 mode : {'train', 'val', 'test'} 1093 the dataset to get 1094 1095 Returns 1096 ------- 1097 dataset : dlc2action.data.dataset.BehaviorDataset 1098 the dataset 1099 1100 """ 1101 return self.task.dataset(mode) 1102 1103 def generate_full_length_prediction( 1104 self, 1105 dataset: Union[BehaviorDataset, str] = None, 1106 augment_n: int = 10, 1107 ) -> Dict: 1108 """Compile a prediction for the original input sequences. 1109 1110 Parameters 1111 ---------- 1112 dataset : dlc2action.data.dataset.BehaviorDataset | str, optional 1113 the dataset to generate a prediction for (if `None`, generate for the `dlc2action.task.universal_task.Task` 1114 instance validation dataset) 1115 augment_n : int, default 10 1116 the number of augmentations to average results over 1117 1118 Returns 1119 ------- 1120 prediction : dict 1121 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 1122 are prediction tensors 1123 1124 """ 1125 return self.task.generate_full_length_prediction( 1126 dataset, int(self.training_parameters.get("batch_size", 32)), augment_n 1127 ) 1128 1129 def generate_submission( 1130 self, 1131 frame_number_map_file: str, 1132 dataset: Union[BehaviorDataset, str] = None, 1133 augment_n: int = 10, 1134 ) -> Dict: 1135 """Generate a MABe-22 style submission dictionary. 1136 1137 Parameters 1138 ---------- 1139 frame_number_map_file : str 1140 path to the frame number map file 1141 dataset : BehaviorDataset, optional 1142 the dataset to generate a prediction for (if `None`, generate for the validation dataset) 1143 augment_n : int, default 10 1144 the number of augmentations to average results over 1145 1146 Returns 1147 ------- 1148 submission : dict 1149 a dictionary with frame number mapping and embeddings 1150 1151 """ 1152 return self.task.generate_submission( 1153 frame_number_map_file, 1154 dataset, 1155 int(self.training_parameters.get("batch_size", 32)), 1156 augment_n, 1157 ) 1158 1159 def behaviors_dict(self): 1160 """Get a behavior dictionary. 1161 1162 Keys are label indices and values are label names. 1163 1164 Returns 1165 ------- 1166 behaviors_dict : dict 1167 behavior dictionary 1168 1169 """ 1170 return self.task.behaviors_dict() 1171 1172 def count_classes(self, bouts: bool = False) -> Dict: 1173 """Get a dictionary of class counts in different modes. 1174 1175 Parameters 1176 ---------- 1177 bouts : bool, default False 1178 if `True`, instead of frame counts segment counts are returned 1179 1180 Returns 1181 ------- 1182 class_counts : dict 1183 a dictionary where first-level keys are "train", "val" and "test", second-level keys are 1184 class names and values are class counts (in frames) 1185 1186 """ 1187 return self.task.count_classes(bouts) 1188 1189 def _visualize_results_label( 1190 self, 1191 label: str, 1192 save_path: str = None, 1193 add_legend: bool = True, 1194 ground_truth: bool = True, 1195 hide_axes: bool = False, 1196 width: int = 10, 1197 whole_video: bool = False, 1198 transparent: bool = False, 1199 dataset: BehaviorDataset = None, 1200 smooth_interval: int = 0, 1201 title: str = None, 1202 ): 1203 return self.task._visualize_results_single( 1204 label, 1205 save_path, 1206 add_legend, 1207 ground_truth, 1208 hide_axes, 1209 width, 1210 whole_video, 1211 transparent, 1212 dataset, 1213 smooth_interval=smooth_interval, 1214 title=title, 1215 ) 1216 1217 def visualize_results( 1218 self, 1219 save_path: str = None, 1220 add_legend: bool = True, 1221 ground_truth: bool = True, 1222 colormap: str = "viridis", 1223 hide_axes: bool = False, 1224 min_classes: int = 1, 1225 width: float = 10, 1226 whole_video: bool = False, 1227 transparent: bool = False, 1228 dataset: Union[BehaviorDataset, DataLoader, str, None] = None, 1229 drop_classes: Set = None, 1230 search_classes: Set = None, 1231 smooth_interval_prediction: int = None, 1232 font_size: float = None, 1233 num_plots:int = 1, 1234 window_size:int=400, 1235 ) -> None: 1236 """Visualize random predictions. 1237 1238 Parameters 1239 ---------- 1240 save_path : str, optional 1241 the path where the plot will be saved 1242 add_legend : bool, default True 1243 if True, legend will be added to the plot 1244 ground_truth : bool, default True 1245 if True, ground truth will be added to the plot 1246 colormap : str, default 'Accent' 1247 the `matplotlib` colormap to use 1248 hide_axes : bool, default True 1249 if `True`, the axes will be hidden on the plot 1250 min_classes : int, default 1 1251 the minimum number of classes in a displayed interval 1252 width : float, default 10 1253 the width of the plot 1254 whole_video : bool, default False 1255 if `True`, whole videos are plotted instead of segments 1256 transparent : bool, default False 1257 if `True`, the background on the plot is transparent 1258 dataset : BehaviorDataset | DataLoader | str | None, optional 1259 the dataset to make the prediction for (if not provided, the validation dataset is used) 1260 drop_classes : set, optional 1261 a set of class names to not be displayed 1262 search_classes : set, optional 1263 if given, only intervals where at least one of the classes is in ground truth will be shown 1264 smooth_interval_prediction : int, optional 1265 if given, the prediction will be smoothed over the given number of frames 1266 1267 """ 1268 return self.task.visualize_results( 1269 save_path, 1270 add_legend, 1271 ground_truth, 1272 colormap, 1273 hide_axes, 1274 min_classes, 1275 width, 1276 whole_video, 1277 transparent, 1278 dataset, 1279 drop_classes, 1280 search_classes, 1281 font_size=font_size, 1282 smooth_interval_prediction=smooth_interval_prediction, 1283 num_plots=num_plots, 1284 window_size=window_size 1285 ) 1286 1287 def generate_uncertainty_score( 1288 self, 1289 classes: List, 1290 augment_n: int = 0, 1291 method: str = "least_confidence", 1292 predicted: torch.Tensor = None, 1293 behaviors_dict: Dict = None, 1294 ) -> Dict: 1295 """Generate frame-wise scores for active learning. 1296 1297 Parameters 1298 ---------- 1299 classes : list 1300 a list of class names or indices; their confidence scores will be computed separately and stacked 1301 augment_n : int, default 0 1302 the number of augmentations to average over 1303 method : {"least_confidence", "entropy"} 1304 the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if 1305 `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`) 1306 predicted : torch.Tensor, optional 1307 if given, the predictions will be used instead of the model's predictions 1308 behaviors_dict : dict, optional 1309 if given, the behaviors dictionary will be used instead of the model's behaviors dictionary 1310 1311 Returns 1312 ------- 1313 score_dicts : dict 1314 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 1315 are score tensors 1316 1317 """ 1318 return self.task.generate_uncertainty_score( 1319 classes, 1320 augment_n, 1321 int(self.training_parameters.get("batch_size", 32)), 1322 method, 1323 predicted, 1324 behaviors_dict, 1325 ) 1326 1327 def generate_bald_score( 1328 self, 1329 classes: List, 1330 augment_n: int = 0, 1331 num_models: int = 10, 1332 kernel_size: int = 11, 1333 ) -> Dict: 1334 """Generate frame-wise Bayesian Active Learning by Disagreement scores for active learning. 1335 1336 Parameters 1337 ---------- 1338 classes : list 1339 a list of class names or indices; their confidence scores will be computed separately and stacked 1340 augment_n : int, default 0 1341 the number of augmentations to average over 1342 num_models : int, default 10 1343 the number of dropout masks to apply 1344 kernel_size : int, default 11 1345 the size of the smoothing gaussian kernel 1346 1347 Returns 1348 ------- 1349 score_dicts : dict 1350 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 1351 are score tensors 1352 1353 """ 1354 return self.task.generate_bald_score( 1355 classes, 1356 augment_n, 1357 int(self.training_parameters.get("batch_size", 32)), 1358 num_models, 1359 kernel_size, 1360 ) 1361 1362 def exists(self, mode) -> bool: 1363 """Check whether the task has a train/test/validation subset. 1364 1365 Parameters 1366 ---------- 1367 mode : {"train", "val", "test"} 1368 the name of the subset to check for 1369 Returns 1370 ------- 1371 exists : bool 1372 `True` if the subset exists 1373 1374 """ 1375 dl = self.task.dataloader(mode) 1376 if dl is None: 1377 return False 1378 else: 1379 return True 1380 1381 def get_normalization_stats(self) -> Dict: 1382 """Get the pre-computed normalization stats. 1383 1384 Returns 1385 ------- 1386 normalization_stats : dict 1387 a dictionary of means and stds 1388 1389 """ 1390 return self.task.get_normalization_stats()
A class that manages the interactions between config dictionaries and a Task.
34 def __init__(self, parameters: Dict) -> None: 35 """Initialize the `TaskDispatcher`. 36 37 Parameters 38 ---------- 39 parameters : dict 40 a dictionary of task parameters 41 42 """ 43 pars = deepcopy(parameters) 44 self.class_weights = None 45 self.general_parameters = pars.get("general", {}) 46 self.data_parameters = pars.get("data", {}) 47 self.model_parameters = pars.get("model", {}) 48 self.training_parameters = pars.get("training", {}) 49 self.loss_parameters = pars.get("losses", {}) 50 self.metric_parameters = pars.get("metrics", {}) 51 self.ssl_parameters = pars.get("ssl", {}) 52 self.aug_parameters = pars.get("augmentations", {}) 53 self.feature_parameters = pars.get("features", {}) 54 self.blanks = {blank: [] for blank in options.blanks} 55 56 self.task = None 57 self._initialize_task() 58 self._print_behaviors()
60 @staticmethod 61 def complete_function_parameters(parameters, function, general_dicts: List) -> Dict: 62 """Complete a parameter dictionary with values from other dictionaries if required by a function. 63 64 Parameters 65 ---------- 66 parameters : dict 67 the function parameters dictionary 68 function : callable 69 the function to be inspected 70 general_dicts : list 71 a list of dictionaries where the missing values will be pulled from 72 73 Returns 74 ------- 75 parameters : dict 76 the updated parameter dictionary 77 78 """ 79 parameter_names = inspect.getfullargspec(function).args 80 for param in parameter_names: 81 for dic in general_dicts: 82 if param not in parameters and param in dic: 83 parameters[param] = dic[param] 84 return parameters
Complete a parameter dictionary with values from other dictionaries if required by a function.
Parameters
parameters : dict the function parameters dictionary function : callable the function to be inspected general_dicts : list a list of dictionaries where the missing values will be pulled from
Returns
parameters : dict the updated parameter dictionary
86 @staticmethod 87 def complete_dataset_parameters( 88 parameters: dict, 89 general_dict: dict, 90 data_type: str, 91 annotation_type: str, 92 ) -> Dict: 93 """Complete a parameter dictionary with values from other dictionaries if required by a dataset. 94 95 Parameters 96 ---------- 97 parameters : dict 98 the function parameters dictionary 99 general_dict : dict 100 the dictionary where the missing values will be pulled from 101 data_type : str 102 the input type of the dataset 103 annotation_type : str 104 the annotation type of the dataset 105 106 Returns 107 ------- 108 parameters : dict 109 the updated parameter dictionary 110 111 """ 112 params = deepcopy(parameters) 113 parameter_names = BehaviorDataset.get_parameters(data_type, annotation_type) 114 for param in parameter_names: 115 if param not in params and param in general_dict: 116 params[param] = general_dict[param] 117 return params
Complete a parameter dictionary with values from other dictionaries if required by a dataset.
Parameters
parameters : dict the function parameters dictionary general_dict : dict the dictionary where the missing values will be pulled from data_type : str the input type of the dataset annotation_type : str the annotation type of the dataset
Returns
parameters : dict the updated parameter dictionary
119 @staticmethod 120 def check(parameters: Dict, name: str) -> bool: 121 """Check whether there is a non-`None` value under the name key in the parameters dictionary. 122 123 Parameters 124 ---------- 125 parameters : dict 126 the dictionary to check 127 name : str 128 the key to check 129 130 Returns 131 ------- 132 result : bool 133 True if a non-`None` value exists 134 135 """ 136 if name in parameters and parameters[name] is not None: 137 return True 138 else: 139 return False
Check whether there is a non-None
value under the name key in the parameters dictionary.
Parameters
parameters : dict the dictionary to check name : str the key to check
Returns
result : bool
True if a non-None
value exists
141 @staticmethod 142 def get(parameters: Dict, name: str, default): 143 """Get the value under the name key or the default if it is `None` or does not exist. 144 145 Parameters 146 ---------- 147 parameters : dict 148 the dictionary to check 149 name : str 150 the key to check 151 default 152 the default value to return 153 154 Returns 155 ------- 156 value 157 the resulting value 158 159 """ 160 if TaskDispatcher.check(parameters, name): 161 return parameters[name] 162 else: 163 return default
Get the value under the name key or the default if it is None
or does not exist.
Parameters
parameters : dict the dictionary to check name : str the key to check default the default value to return
Returns
value the resulting value
165 @staticmethod 166 def make_dataloader( 167 dataset: BehaviorDataset, batch_size: int = 32, shuffle: bool = False 168 ) -> DataLoader: 169 """Make a torch dataloader from a dataset. 170 171 Parameters 172 ---------- 173 dataset : dlc2action.data.dataset.BehaviorDataset 174 the dataset 175 batch_size : int 176 the batch size 177 shuffle : bool 178 whether to shuffle the dataset 179 180 Returns 181 ------- 182 dataloader : DataLoader 183 the dataloader (or `None` if the length of the dataset is 0) 184 185 """ 186 if dataset is None or len(dataset) == 0: 187 return None 188 else: 189 return DataLoader(dataset, batch_size=int(batch_size), shuffle=shuffle)
Make a torch dataloader from a dataset.
Parameters
dataset : dlc2action.data.dataset.BehaviorDataset the dataset batch_size : int the batch size shuffle : bool whether to shuffle the dataset
Returns
dataloader : DataLoader
the dataloader (or None
if the length of the dataset is 0)
739 def update_task(self, parameters: Dict) -> None: 740 """Update the `dlc2action.task.universal_task.Task` instance given the parameter updates. 741 742 Parameters 743 ---------- 744 parameters : dict 745 the dictionary of parameter updates 746 747 """ 748 pars = deepcopy(parameters) 749 # for blank_name in self.blanks: 750 # for names in self.blanks[blank_name]: 751 # group = names[0] 752 # key = names[1] 753 # if len(names) == 3: 754 # if ( 755 # group in pars 756 # and key in pars[group] 757 # and names[2] in pars[group][key] 758 # ): 759 # pars[group][key].pop(names[2]) 760 # else: 761 # if group in pars and key in pars[group]: 762 # pars[group].pop(key) 763 stay = False 764 if "ssl" in pars: 765 for key in pars["ssl"]: 766 if key in self.ssl_parameters: 767 self.ssl_parameters[key].update(pars["ssl"][key]) 768 else: 769 self.ssl_parameters[key] = pars["ssl"][key] 770 771 if "general" in pars: 772 if stay: 773 stay = False 774 if ( 775 "model_name" in pars["general"] 776 and pars["general"]["model_name"] 777 != self.general_parameters["model_name"] 778 ): 779 if "model" not in pars: 780 raise ValueError( 781 "When updating a task with a new model name you need to pass the parameters for the " 782 "new model" 783 ) 784 self.model_parameters = {} 785 self.general_parameters.update(pars["general"]) 786 data_related = [ 787 "num_classes", 788 "exclusive", 789 "data_type", 790 "annotation_type", 791 ] 792 ssl_related = ["ssl", "exclusive", "num_classes"] 793 loss_related = ["num_classes", "loss_function", "exclusive"] 794 augmentation_related = ["augmentation_type"] 795 metric_related = ["metric_functions"] 796 related_lists = [ 797 data_related, 798 ssl_related, 799 loss_related, 800 augmentation_related, 801 metric_related, 802 ] 803 names = ["data", "ssl", "losses", "augmentations", "metrics"] 804 for related_list, name in zip(related_lists, names): 805 if ( 806 any([x in pars["general"] for x in related_list]) 807 and name not in pars 808 ): 809 pars[name] = {} 810 811 if "training" in pars: 812 if "data" not in pars or not stay: 813 for x in [ 814 "to_ram", 815 "use_test", 816 "partition_method", 817 "val_frac", 818 "test_frac", 819 "save_split", 820 "batch_size", 821 "save_split", 822 ]: 823 if ( 824 x in pars["training"] 825 and pars["training"][x] != self.training_parameters[x] 826 ): 827 if "data" not in pars: 828 pars["data"] = {} 829 stay = True 830 self.training_parameters.update(pars["training"]) 831 self.task.update_parameters(self._get_parameters_from_training()) 832 833 if "data" in pars or "features" in pars: 834 for k, v in pars["data"].items(): 835 if k not in self.data_parameters or v != self.data_parameters[k]: 836 stay = True 837 for k, v in pars["features"].items(): 838 if k not in self.feature_parameters or v != self.feature_parameters[k]: 839 stay = True 840 if stay: 841 self.data_parameters.update(pars["data"]) 842 self.feature_parameters.update(pars["features"]) 843 dataset = self._construct_dataset() 844 ( 845 train_dataloader, 846 test_dataloader, 847 val_dataloader, 848 ) = self._partition_dataset(dataset) 849 self.task.set_dataloaders( 850 train_dataloader, val_dataloader, test_dataloader 851 ) 852 self.class_weights = train_dataloader.dataset.class_weights() 853 self.proportional_class_weights = ( 854 train_dataloader.dataset.class_weights(True) 855 ) 856 if "losses" not in pars: 857 pars["losses"] = {} 858 859 if "model" in pars: 860 self.model_parameters.update(pars["model"]) 861 862 self._update_data_blanks() 863 864 if "augmentations" in pars: 865 self.aug_parameters.update(pars["augmentations"]) 866 transformer = self._construct_transformer() 867 self.task.set_transformer(transformer) 868 869 if "losses" in pars: 870 for key in pars["losses"]: 871 if key in self.loss_parameters: 872 self.loss_parameters[key].update(pars["losses"][key]) 873 else: 874 self.loss_parameters[key] = pars["losses"][key] 875 self.loss_parameters.update(pars["losses"]) 876 loss = self._construct_loss() 877 self.task.set_loss(loss) 878 879 if "metrics" in pars: 880 for key in pars["metrics"]: 881 if key in self.metric_parameters: 882 self.metric_parameters[key].update(pars["metrics"][key]) 883 else: 884 self.metric_parameters[key] = pars["metrics"][key] 885 metrics = self._construct_metrics() 886 self.task.set_metrics(metrics) 887 888 self.task.set_ssl_transformations(self.data_parameters["ssl_transformations"]) 889 self._set_loss_weights( 890 pars.get("losses", {}).get(self.general_parameters["loss_function"], {}) 891 ) 892 model = self._construct_model() 893 predict_functions = self._construct_predict_functions() 894 self.task.set_predict_functions(*predict_functions) 895 self._update_model_blanks(model) 896 ssl_list = self._construct_ssl() 897 self._update_parameters_from_ssl(ssl_list) 898 model.set_ssl(ssl_constructors=ssl_list) 899 self.task.set_ssl_transformations([ssl.transformation for ssl in ssl_list]) 900 self.task.set_model(model) 901 if "training" in pars and "checkpoint_path" in pars["training"]: 902 checkpoint_path = pars["training"]["checkpoint_path"] 903 only_model = pars["training"].get("only_load_model", False) 904 load_strict = pars["training"].get("load_strict", True) 905 self.task.load_from_checkpoint(checkpoint_path, only_model, load_strict) 906 if ( 907 self.general_parameters["only_load_annotated"] 908 and self.general_parameters.get("ssl") is not None 909 ): 910 warnings.warn( 911 "Note that you are using SSL modules but only loading annotated files! Set " 912 "general/only_load_annotated to False to change that" 913 ) 914 if self.task.dataset("train").annotation_class() != "none": 915 self._print_behaviors()
Update the dlc2action.task.universal_task.Task
instance given the parameter updates.
Parameters
parameters : dict the dictionary of parameter updates
917 def train( 918 self, 919 trial: Trial = None, 920 optimized_metric: str = None, 921 autostop_metric: str = None, 922 autostop_interval: int = 10, 923 autostop_threshold: float = 0.001, 924 loading_bar: bool = False, 925 ) -> Tuple: 926 """Train the task and return a log of epoch-average loss and metric. 927 928 You can use the autostop parameters to finish training when the parameters are not improving. It will be 929 stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than 930 the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the 931 current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared. 932 933 Parameters 934 ---------- 935 trial : Trial 936 an `optuna` trial (for hyperparameter searches) 937 optimized_metric : str 938 the name of the metric being optimized (for hyperparameter searches) 939 autostop_metric : str, optional 940 the autostop metric (can be any one of the tracked metrics of `'loss'`) 941 autostop_interval : int, default 50 942 the number of epochs to average the autostop metric over 943 autostop_threshold : float, default 0.001 944 the autostop difference threshold 945 loading_bar : bool, default False 946 whether to show a loading bar 947 948 Returns 949 ------- 950 loss_log: list 951 a list of float loss function values for each epoch 952 metrics_log: dict 953 a dictionary of metric value logs (first-level keys are 'train' and 'val', second-level keys are metric 954 names, values are lists of function values) 955 956 """ 957 to_ram = self.training_parameters.get("to_ram", False) 958 logs = self.task.train( 959 trial, 960 optimized_metric, 961 to_ram, 962 autostop_metric=autostop_metric, 963 autostop_interval=autostop_interval, 964 autostop_threshold=autostop_threshold, 965 main_task_on=self.training_parameters.get("main_task_on", True), 966 ssl_on=self.training_parameters.get("ssl_on", True), 967 temporal_subsampling_size=self.training_parameters.get( 968 "temporal_subsampling_size" 969 ), 970 loading_bar=loading_bar, 971 ) 972 return logs
Train the task and return a log of epoch-average loss and metric.
You can use the autostop parameters to finish training when the parameters are not improving. It will be
stopped if the average value of autostop_metric
over the last autostop_interval
epochs is smaller than
the average over the previous autostop_interval
epochs + autostop_threshold
. For example, if the
current epoch is 120 and autostop_interval
is 50, the averages over epochs 70-120 and 20-70 will be compared.
Parameters
trial : Trial
an optuna
trial (for hyperparameter searches)
optimized_metric : str
the name of the metric being optimized (for hyperparameter searches)
autostop_metric : str, optional
the autostop metric (can be any one of the tracked metrics of 'loss'
)
autostop_interval : int, default 50
the number of epochs to average the autostop metric over
autostop_threshold : float, default 0.001
the autostop difference threshold
loading_bar : bool, default False
whether to show a loading bar
Returns
loss_log: list a list of float loss function values for each epoch metrics_log: dict a dictionary of metric value logs (first-level keys are 'train' and 'val', second-level keys are metric names, values are lists of function values)
974 def save_model(self, save_path: str) -> None: 975 """Save the model of the `dlc2action.task.universal_task.Task` instance. 976 977 Parameters 978 ---------- 979 save_path : str 980 the path to the saved file 981 982 """ 983 self.task.save_model(save_path)
Save the model of the dlc2action.task.universal_task.Task
instance.
Parameters
save_path : str the path to the saved file
985 def evaluate( 986 self, 987 data: Union[DataLoader, BehaviorDataset, str] = None, 988 augment_n: int = 0, 989 verbose: bool = True, 990 ) -> Tuple: 991 """Evaluate the Task model. 992 993 Parameters 994 ---------- 995 data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional 996 the data to evaluate on (if not provided, evaluate on the Task validation dataset) 997 augment_n : int, default 0 998 the number of augmentations to average results over 999 verbose : bool, default True 1000 if True, the process is reported to standard output 1001 1002 Returns 1003 ------- 1004 loss : float 1005 the average value of the loss function 1006 ssl_loss : float 1007 the average value of the SSL loss function 1008 metric : dict 1009 a dictionary of average values of metric functions 1010 1011 """ 1012 res = self.task.evaluate( 1013 data, 1014 augment_n, 1015 int(self.training_parameters.get("batch_size", 32)), 1016 verbose, 1017 ) 1018 return res
Evaluate the Task model.
Parameters
data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional the data to evaluate on (if not provided, evaluate on the Task validation dataset) augment_n : int, default 0 the number of augmentations to average results over verbose : bool, default True if True, the process is reported to standard output
Returns
loss : float the average value of the loss function ssl_loss : float the average value of the SSL loss function metric : dict a dictionary of average values of metric functions
1020 def evaluate_prediction( 1021 self, 1022 prediction: torch.Tensor, 1023 data: Union[DataLoader, BehaviorDataset, str] = None, 1024 indices: Union[List[int], None] = None, 1025 ) -> Tuple: 1026 """Compute metrics for a prediction. 1027 1028 Parameters 1029 ---------- 1030 prediction : torch.Tensor 1031 the prediction 1032 data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional 1033 the data the prediction was made for (if not provided, take the validation dataset) 1034 1035 Returns 1036 ------- 1037 loss : float 1038 the average value of the loss function 1039 metric : dict 1040 a dictionary of average values of metric functions 1041 """ 1042 return self.task.evaluate_prediction( 1043 prediction, data, int(self.training_parameters.get("batch_size", 32)), indices 1044 )
Compute metrics for a prediction.
Parameters
prediction : torch.Tensor the prediction data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional the data the prediction was made for (if not provided, take the validation dataset)
Returns
loss : float the average value of the loss function metric : dict a dictionary of average values of metric functions
1046 def predict( 1047 self, 1048 data: Union[DataLoader, BehaviorDataset, str], 1049 raw_output: bool = False, 1050 apply_primary_function: bool = True, 1051 augment_n: int = 0, 1052 embedding: bool = False, 1053 ) -> torch.Tensor: 1054 """Make a prediction with the Task model. 1055 1056 Parameters 1057 ---------- 1058 data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional 1059 the data to evaluate on (if not provided, evaluate on the Task validation dataset) 1060 raw_output : bool, default False 1061 if `True`, the raw predicted probabilities are returned 1062 apply_primary_function : bool, default True 1063 if `True`, the primary predict function is applied (to map the model output into a shape corresponding to 1064 the input) 1065 augment_n : int, default 0 1066 the number of augmentations to average results over 1067 embedding : bool, default False 1068 if `True`, the embedding is returned instead of the prediction 1069 1070 Returns 1071 ------- 1072 prediction : torch.Tensor 1073 a prediction for the input data 1074 1075 """ 1076 to_ram = self.training_parameters.get("to_ram", False) 1077 return self.task.predict( 1078 data, 1079 raw_output, 1080 apply_primary_function, 1081 augment_n, 1082 int(self.training_parameters.get("batch_size", 32)), 1083 to_ram, 1084 embedding=embedding, 1085 )
Make a prediction with the Task model.
Parameters
data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional
the data to evaluate on (if not provided, evaluate on the Task validation dataset)
raw_output : bool, default False
if True
, the raw predicted probabilities are returned
apply_primary_function : bool, default True
if True
, the primary predict function is applied (to map the model output into a shape corresponding to
the input)
augment_n : int, default 0
the number of augmentations to average results over
embedding : bool, default False
if True
, the embedding is returned instead of the prediction
Returns
prediction : torch.Tensor a prediction for the input data
1087 def dataset(self, mode: str = "train") -> BehaviorDataset: 1088 """Get a dataset. 1089 1090 Parameters 1091 ---------- 1092 mode : {'train', 'val', 'test'} 1093 the dataset to get 1094 1095 Returns 1096 ------- 1097 dataset : dlc2action.data.dataset.BehaviorDataset 1098 the dataset 1099 1100 """ 1101 return self.task.dataset(mode)
Get a dataset.
Parameters
mode : {'train', 'val', 'test'} the dataset to get
Returns
dataset : dlc2action.data.dataset.BehaviorDataset the dataset
1103 def generate_full_length_prediction( 1104 self, 1105 dataset: Union[BehaviorDataset, str] = None, 1106 augment_n: int = 10, 1107 ) -> Dict: 1108 """Compile a prediction for the original input sequences. 1109 1110 Parameters 1111 ---------- 1112 dataset : dlc2action.data.dataset.BehaviorDataset | str, optional 1113 the dataset to generate a prediction for (if `None`, generate for the `dlc2action.task.universal_task.Task` 1114 instance validation dataset) 1115 augment_n : int, default 10 1116 the number of augmentations to average results over 1117 1118 Returns 1119 ------- 1120 prediction : dict 1121 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 1122 are prediction tensors 1123 1124 """ 1125 return self.task.generate_full_length_prediction( 1126 dataset, int(self.training_parameters.get("batch_size", 32)), augment_n 1127 )
Compile a prediction for the original input sequences.
Parameters
dataset : dlc2action.data.dataset.BehaviorDataset | str, optional
the dataset to generate a prediction for (if None
, generate for the dlc2action.task.universal_task.Task
instance validation dataset)
augment_n : int, default 10
the number of augmentations to average results over
Returns
prediction : dict a nested dictionary where first level keys are video ids, second level keys are clip ids and values are prediction tensors
1129 def generate_submission( 1130 self, 1131 frame_number_map_file: str, 1132 dataset: Union[BehaviorDataset, str] = None, 1133 augment_n: int = 10, 1134 ) -> Dict: 1135 """Generate a MABe-22 style submission dictionary. 1136 1137 Parameters 1138 ---------- 1139 frame_number_map_file : str 1140 path to the frame number map file 1141 dataset : BehaviorDataset, optional 1142 the dataset to generate a prediction for (if `None`, generate for the validation dataset) 1143 augment_n : int, default 10 1144 the number of augmentations to average results over 1145 1146 Returns 1147 ------- 1148 submission : dict 1149 a dictionary with frame number mapping and embeddings 1150 1151 """ 1152 return self.task.generate_submission( 1153 frame_number_map_file, 1154 dataset, 1155 int(self.training_parameters.get("batch_size", 32)), 1156 augment_n, 1157 )
Generate a MABe-22 style submission dictionary.
Parameters
frame_number_map_file : str
path to the frame number map file
dataset : BehaviorDataset, optional
the dataset to generate a prediction for (if None
, generate for the validation dataset)
augment_n : int, default 10
the number of augmentations to average results over
Returns
submission : dict a dictionary with frame number mapping and embeddings
1159 def behaviors_dict(self): 1160 """Get a behavior dictionary. 1161 1162 Keys are label indices and values are label names. 1163 1164 Returns 1165 ------- 1166 behaviors_dict : dict 1167 behavior dictionary 1168 1169 """ 1170 return self.task.behaviors_dict()
Get a behavior dictionary.
Keys are label indices and values are label names.
Returns
behaviors_dict : dict behavior dictionary
1172 def count_classes(self, bouts: bool = False) -> Dict: 1173 """Get a dictionary of class counts in different modes. 1174 1175 Parameters 1176 ---------- 1177 bouts : bool, default False 1178 if `True`, instead of frame counts segment counts are returned 1179 1180 Returns 1181 ------- 1182 class_counts : dict 1183 a dictionary where first-level keys are "train", "val" and "test", second-level keys are 1184 class names and values are class counts (in frames) 1185 1186 """ 1187 return self.task.count_classes(bouts)
Get a dictionary of class counts in different modes.
Parameters
bouts : bool, default False
if True
, instead of frame counts segment counts are returned
Returns
class_counts : dict a dictionary where first-level keys are "train", "val" and "test", second-level keys are class names and values are class counts (in frames)
1217 def visualize_results( 1218 self, 1219 save_path: str = None, 1220 add_legend: bool = True, 1221 ground_truth: bool = True, 1222 colormap: str = "viridis", 1223 hide_axes: bool = False, 1224 min_classes: int = 1, 1225 width: float = 10, 1226 whole_video: bool = False, 1227 transparent: bool = False, 1228 dataset: Union[BehaviorDataset, DataLoader, str, None] = None, 1229 drop_classes: Set = None, 1230 search_classes: Set = None, 1231 smooth_interval_prediction: int = None, 1232 font_size: float = None, 1233 num_plots:int = 1, 1234 window_size:int=400, 1235 ) -> None: 1236 """Visualize random predictions. 1237 1238 Parameters 1239 ---------- 1240 save_path : str, optional 1241 the path where the plot will be saved 1242 add_legend : bool, default True 1243 if True, legend will be added to the plot 1244 ground_truth : bool, default True 1245 if True, ground truth will be added to the plot 1246 colormap : str, default 'Accent' 1247 the `matplotlib` colormap to use 1248 hide_axes : bool, default True 1249 if `True`, the axes will be hidden on the plot 1250 min_classes : int, default 1 1251 the minimum number of classes in a displayed interval 1252 width : float, default 10 1253 the width of the plot 1254 whole_video : bool, default False 1255 if `True`, whole videos are plotted instead of segments 1256 transparent : bool, default False 1257 if `True`, the background on the plot is transparent 1258 dataset : BehaviorDataset | DataLoader | str | None, optional 1259 the dataset to make the prediction for (if not provided, the validation dataset is used) 1260 drop_classes : set, optional 1261 a set of class names to not be displayed 1262 search_classes : set, optional 1263 if given, only intervals where at least one of the classes is in ground truth will be shown 1264 smooth_interval_prediction : int, optional 1265 if given, the prediction will be smoothed over the given number of frames 1266 1267 """ 1268 return self.task.visualize_results( 1269 save_path, 1270 add_legend, 1271 ground_truth, 1272 colormap, 1273 hide_axes, 1274 min_classes, 1275 width, 1276 whole_video, 1277 transparent, 1278 dataset, 1279 drop_classes, 1280 search_classes, 1281 font_size=font_size, 1282 smooth_interval_prediction=smooth_interval_prediction, 1283 num_plots=num_plots, 1284 window_size=window_size 1285 )
Visualize random predictions.
Parameters
save_path : str, optional
the path where the plot will be saved
add_legend : bool, default True
if True, legend will be added to the plot
ground_truth : bool, default True
if True, ground truth will be added to the plot
colormap : str, default 'Accent'
the matplotlib
colormap to use
hide_axes : bool, default True
if True
, the axes will be hidden on the plot
min_classes : int, default 1
the minimum number of classes in a displayed interval
width : float, default 10
the width of the plot
whole_video : bool, default False
if True
, whole videos are plotted instead of segments
transparent : bool, default False
if True
, the background on the plot is transparent
dataset : BehaviorDataset | DataLoader | str | None, optional
the dataset to make the prediction for (if not provided, the validation dataset is used)
drop_classes : set, optional
a set of class names to not be displayed
search_classes : set, optional
if given, only intervals where at least one of the classes is in ground truth will be shown
smooth_interval_prediction : int, optional
if given, the prediction will be smoothed over the given number of frames
1287 def generate_uncertainty_score( 1288 self, 1289 classes: List, 1290 augment_n: int = 0, 1291 method: str = "least_confidence", 1292 predicted: torch.Tensor = None, 1293 behaviors_dict: Dict = None, 1294 ) -> Dict: 1295 """Generate frame-wise scores for active learning. 1296 1297 Parameters 1298 ---------- 1299 classes : list 1300 a list of class names or indices; their confidence scores will be computed separately and stacked 1301 augment_n : int, default 0 1302 the number of augmentations to average over 1303 method : {"least_confidence", "entropy"} 1304 the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if 1305 `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`) 1306 predicted : torch.Tensor, optional 1307 if given, the predictions will be used instead of the model's predictions 1308 behaviors_dict : dict, optional 1309 if given, the behaviors dictionary will be used instead of the model's behaviors dictionary 1310 1311 Returns 1312 ------- 1313 score_dicts : dict 1314 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 1315 are score tensors 1316 1317 """ 1318 return self.task.generate_uncertainty_score( 1319 classes, 1320 augment_n, 1321 int(self.training_parameters.get("batch_size", 32)), 1322 method, 1323 predicted, 1324 behaviors_dict, 1325 )
Generate frame-wise scores for active learning.
Parameters
classes : list
a list of class names or indices; their confidence scores will be computed separately and stacked
augment_n : int, default 0
the number of augmentations to average over
method : {"least_confidence", "entropy"}
the method used to calculate the scores from the probability predictions ("least_confidence"
: 1 - p_i
if
p_i > 0.5
or p_i
if p_i < 0.5
; "entropy"
: - p_i * log(p_i) - (1 - p_i) * log(1 - p_i)
)
predicted : torch.Tensor, optional
if given, the predictions will be used instead of the model's predictions
behaviors_dict : dict, optional
if given, the behaviors dictionary will be used instead of the model's behaviors dictionary
Returns
score_dicts : dict a nested dictionary where first level keys are video ids, second level keys are clip ids and values are score tensors
1327 def generate_bald_score( 1328 self, 1329 classes: List, 1330 augment_n: int = 0, 1331 num_models: int = 10, 1332 kernel_size: int = 11, 1333 ) -> Dict: 1334 """Generate frame-wise Bayesian Active Learning by Disagreement scores for active learning. 1335 1336 Parameters 1337 ---------- 1338 classes : list 1339 a list of class names or indices; their confidence scores will be computed separately and stacked 1340 augment_n : int, default 0 1341 the number of augmentations to average over 1342 num_models : int, default 10 1343 the number of dropout masks to apply 1344 kernel_size : int, default 11 1345 the size of the smoothing gaussian kernel 1346 1347 Returns 1348 ------- 1349 score_dicts : dict 1350 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 1351 are score tensors 1352 1353 """ 1354 return self.task.generate_bald_score( 1355 classes, 1356 augment_n, 1357 int(self.training_parameters.get("batch_size", 32)), 1358 num_models, 1359 kernel_size, 1360 )
Generate frame-wise Bayesian Active Learning by Disagreement scores for active learning.
Parameters
classes : list a list of class names or indices; their confidence scores will be computed separately and stacked augment_n : int, default 0 the number of augmentations to average over num_models : int, default 10 the number of dropout masks to apply kernel_size : int, default 11 the size of the smoothing gaussian kernel
Returns
score_dicts : dict a nested dictionary where first level keys are video ids, second level keys are clip ids and values are score tensors
1362 def exists(self, mode) -> bool: 1363 """Check whether the task has a train/test/validation subset. 1364 1365 Parameters 1366 ---------- 1367 mode : {"train", "val", "test"} 1368 the name of the subset to check for 1369 Returns 1370 ------- 1371 exists : bool 1372 `True` if the subset exists 1373 1374 """ 1375 dl = self.task.dataloader(mode) 1376 if dl is None: 1377 return False 1378 else: 1379 return True
Check whether the task has a train/test/validation subset.
Parameters
mode : {"train", "val", "test"} the name of the subset to check for
Returns
exists : bool
True
if the subset exists
1381 def get_normalization_stats(self) -> Dict: 1382 """Get the pre-computed normalization stats. 1383 1384 Returns 1385 ------- 1386 normalization_stats : dict 1387 a dictionary of means and stds 1388 1389 """ 1390 return self.task.get_normalization_stats()
Get the pre-computed normalization stats.
Returns
normalization_stats : dict a dictionary of means and stds