dlc2action.task.task_dispatcher
Class that provides an interface for dlc2action.task.universal_task.Task
1# 2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL. 5# 6""" 7Class that provides an interface for `dlc2action.task.universal_task.Task` 8""" 9 10import inspect 11from typing import Dict, Union, Tuple, List, Callable, Set 12from torch.utils.data import DataLoader 13from torch.optim import Optimizer 14from collections.abc import Iterable, Mapping 15import torch 16from copy import deepcopy 17import numpy as np 18from optuna.trial import Trial 19import warnings 20 21from dlc2action.data.dataset import BehaviorDataset 22from dlc2action.task.universal_task import Task 23from dlc2action.transformer.base_transformer import Transformer 24from dlc2action.ssl.base_ssl import SSLConstructor 25from dlc2action.ssl.base_ssl import EmptySSL 26from dlc2action.model.base_model import LoadedModel, Model 27from dlc2action.metric.base_metric import Metric 28from dlc2action.utils import PostProcessor 29 30from dlc2action import options 31 32 33class TaskDispatcher: 34 """ 35 A class that manages the interactions between config dictionaries and a Task 36 """ 37 38 def __init__(self, parameters: Dict) -> None: 39 """ 40 Parameters 41 ---------- 42 parameters : dict 43 a dictionary of task parameters 44 """ 45 46 pars = deepcopy(parameters) 47 self.class_weights = None 48 self.general_parameters = pars.get("general", {}) 49 self.data_parameters = pars.get("data", {}) 50 self.model_parameters = pars.get("model", {}) 51 self.training_parameters = pars.get("training", {}) 52 self.loss_parameters = pars.get("losses", {}) 53 self.metric_parameters = pars.get("metrics", {}) 54 self.ssl_parameters = pars.get("ssl", {}) 55 self.aug_parameters = pars.get("augmentations", {}) 56 self.feature_parameters = pars.get("features", {}) 57 self.blanks = {blank: [] for blank in options.blanks} 58 59 self.task = None 60 self._initialize_task() 61 self._print_behaviors() 62 63 @staticmethod 64 def complete_function_parameters(parameters, function, general_dicts: List) -> Dict: 65 """ 66 Complete a parameter dictionary with values from other dictionaries if required by a function 67 68 Parameters 69 ---------- 70 parameters : dict 71 the function parameters dictionary 72 function : callable 73 the function to be inspected 74 general_dicts : list 75 a list of dictionaries where the missing values will be pulled from 76 """ 77 78 parameter_names = inspect.getfullargspec(function).args 79 for param in parameter_names: 80 for dic in general_dicts: 81 if param not in parameters and param in dic: 82 parameters[param] = dic[param] 83 return parameters 84 85 @staticmethod 86 def complete_dataset_parameters( 87 parameters: dict, 88 general_dict: dict, 89 data_type: str, 90 annotation_type: str, 91 ) -> Dict: 92 """ 93 Complete a parameter dictionary with values from other dictionaries if required by a dataset 94 95 Parameters 96 ---------- 97 parameters : dict 98 the function parameters dictionary 99 general_dict : dict 100 the dictionary where the missing values will be pulled from 101 data_type : str 102 the input type of the dataset 103 annotation_type : str 104 the annotation type of the dataset 105 106 Returns 107 ------- 108 parameters : dict 109 the updated parameter dictionary 110 """ 111 112 params = deepcopy(parameters) 113 parameter_names = BehaviorDataset.get_parameters(data_type, annotation_type) 114 for param in parameter_names: 115 if param not in params and param in general_dict: 116 params[param] = general_dict[param] 117 return params 118 119 @staticmethod 120 def check(parameters: Dict, name: str) -> bool: 121 """ 122 Check whether there is a non-`None` value under the name key in the parameters dictionary 123 124 Parameters 125 ---------- 126 parameters : dict 127 the dictionary to check 128 name : str 129 the key to check 130 131 Returns 132 ------- 133 result : bool 134 True if a non-`None` value exists 135 """ 136 137 if name in parameters and parameters[name] is not None: 138 return True 139 else: 140 return False 141 142 @staticmethod 143 def get(parameters: Dict, name: str, default): 144 """ 145 Get the value under the name key or the default if it is `None` or does not exist 146 147 Parameters 148 ---------- 149 parameters : dict 150 the dictionary to check 151 name : str 152 the key to check 153 default 154 the default value to return 155 156 Returns 157 ------- 158 value 159 the resulting value 160 """ 161 162 if TaskDispatcher.check(parameters, name): 163 return parameters[name] 164 else: 165 return default 166 167 @staticmethod 168 def make_dataloader( 169 dataset: BehaviorDataset, batch_size: int = 32, shuffle: bool = False 170 ) -> DataLoader: 171 """ 172 Make a torch dataloader from a dataset 173 174 Parameters 175 ---------- 176 dataset : dlc2action.data.dataset.BehaviorDataset 177 the dataset 178 batch_size : int 179 the batch size 180 181 Returns 182 ------- 183 dataloader : DataLoader 184 the dataloader (or `None` if the length of the dataset is 0) 185 """ 186 187 if dataset is None or len(dataset) == 0: 188 return None 189 else: 190 return DataLoader(dataset, batch_size=int(batch_size), shuffle=shuffle) 191 192 def _construct_ssl(self) -> List: 193 """ 194 Generate SSL constructors 195 """ 196 197 ssl_list = deepcopy(self.general_parameters.get("ssl", None)) 198 if not isinstance(ssl_list, Iterable): 199 ssl_list = [ssl_list] 200 for i, ssl in enumerate(ssl_list): 201 if type(ssl) is str: 202 if ssl in options.ssl_constructors: 203 pars = self.get(self.ssl_parameters, ssl, default={}) 204 pars = self.complete_function_parameters( 205 parameters=pars, 206 function=options.ssl_constructors[ssl], 207 general_dicts=[ 208 self.model_parameters, 209 self.data_parameters, 210 self.general_parameters, 211 ], 212 ) 213 ssl_list[i] = options.ssl_constructors[ssl](**pars) 214 else: 215 raise ValueError( 216 f"The {ssl} SSL is not available, please choose from {list(options.ssl_constructors.keys())}" 217 ) 218 elif ssl is None: 219 ssl_list[i] = EmptySSL() 220 elif not isinstance(ssl, SSLConstructor): 221 raise TypeError( 222 f"The ssl parameter has to be a list of either strings, SSLConstructor instances or None, got {type(ssl)}" 223 ) 224 return ssl_list 225 226 def _construct_model(self) -> Model: 227 """ 228 Generate a model 229 """ 230 231 if self.check(self.general_parameters, "model"): 232 pars = self.complete_function_parameters( 233 function=LoadedModel, 234 parameters=self.model_parameters, 235 general_dicts=[self.general_parameters], 236 ) 237 model = LoadedModel(**pars) 238 elif self.check(self.general_parameters, "model_name"): 239 name = self.general_parameters["model_name"] 240 if name in options.models: 241 pars = self.complete_function_parameters( 242 function=options.models[name], 243 parameters=self.model_parameters, 244 general_dicts=[self.general_parameters], 245 ) 246 model = options.models[name](**pars) 247 else: 248 raise ValueError( 249 f"The {name} model is not available, please choose from {list(options.models.keys())}" 250 ) 251 else: 252 raise ValueError( 253 "You need to provide either a model or its name in the model_parameters!" 254 ) 255 256 if self.get(self.training_parameters, "freeze_features", False): 257 model.freeze_feature_extractor() 258 return model 259 260 def _construct_dataset(self) -> BehaviorDataset: 261 """ 262 Generate a dataset 263 """ 264 265 data_type = self.general_parameters.get("data_type", None) 266 if data_type is None: 267 raise ValueError( 268 "You need to provide the data_type parameter in the data parameters!" 269 ) 270 annotation_type = self.get(self.general_parameters, "annotation_type", "none") 271 feature_extraction = self.general_parameters.get("feature_extraction", "none") 272 if feature_extraction is None: 273 raise ValueError( 274 "You need to provide the feature_extraction parameter in the data parameters!" 275 ) 276 feature_extraction_pars = self.complete_function_parameters( 277 self.feature_parameters, 278 options.feature_extractors[feature_extraction], 279 [self.general_parameters, self.data_parameters], 280 ) 281 282 pars = self.complete_dataset_parameters( 283 self.data_parameters, 284 self.general_parameters, 285 data_type=data_type, 286 annotation_type=annotation_type, 287 ) 288 pars["feature_extraction_pars"] = feature_extraction_pars 289 dataset = BehaviorDataset(**pars) 290 291 if self.get(self.general_parameters, "save_dataset", default=False): 292 save_data_path = self.data_parameters.get("saved_data_path", None) 293 dataset.save(save_path=save_data_path) 294 295 return dataset 296 297 def _construct_transformer(self) -> Transformer: 298 """ 299 Generate a transformer 300 """ 301 302 features = self.general_parameters["feature_extraction"] 303 name = options.extractor_to_transformer[features] 304 if name in options.transformers: 305 transformer_class = options.transformers[name] 306 pars = self.complete_function_parameters( 307 function=transformer_class, 308 parameters=self.aug_parameters, 309 general_dicts=[self.general_parameters], 310 ) 311 transformer = transformer_class(**pars) 312 else: 313 raise ValueError(f"The {name} transformer is not available") 314 return transformer 315 316 def _construct_loss(self) -> torch.nn.Module: 317 """ 318 Generate a loss function 319 """ 320 321 if "loss_function" not in self.general_parameters: 322 raise ValueError( 323 'Please add a "loss_function" key to the parameters["general"] dictionary (either a name ' 324 f"from {list(options.losses.keys())} or a function)" 325 ) 326 else: 327 loss_function = self.general_parameters["loss_function"] 328 if type(loss_function) is str: 329 if loss_function in options.losses: 330 pars = self.get(self.loss_parameters, loss_function, default={}) 331 pars = self._set_loss_weights(pars) 332 pars = self.complete_function_parameters( 333 function=options.losses[loss_function], 334 parameters=pars, 335 general_dicts=[self.general_parameters], 336 ) 337 loss = options.losses[loss_function](**pars) 338 else: 339 raise ValueError( 340 f"The {loss_function} loss is not available, please choose from {list(options.losses.keys())}" 341 ) 342 else: 343 loss = loss_function 344 return loss 345 346 def _construct_metrics(self) -> List: 347 """ 348 Generate the metric 349 """ 350 351 metric_functions = self.get( 352 self.general_parameters, "metric_functions", default={} 353 ) 354 if isinstance(metric_functions, Iterable): 355 metrics = {} 356 for func in metric_functions: 357 if isinstance(func, str): 358 if func in options.metrics: 359 pars = self.get(self.metric_parameters, func, default={}) 360 pars = self.complete_function_parameters( 361 function=options.metrics[func], 362 parameters=pars, 363 general_dicts=[self.general_parameters], 364 ) 365 metrics[func] = options.metrics[func](**pars) 366 else: 367 raise ValueError( 368 f"The {func} metric is not available, please choose from {list(options.metrics.keys())}" 369 ) 370 elif isinstance(func, Metric): 371 name = "function_1" 372 i = 1 373 while name in metrics: 374 i += 1 375 name = f"function_{i}" 376 metrics[name] = func 377 else: 378 raise TypeError( 379 'The elements of parameters["general"]["metric_functions"] have to be either strings ' 380 f"from {list(options.metrics.keys())} or Metric instances; got {type(func)} instead" 381 ) 382 elif isinstance(metric_functions, dict): 383 metrics = metric_functions 384 else: 385 raise TypeError( 386 'The value at parameters["general"]["metric_functions"] can be either list, dictionary or None;' 387 f"got {type(metric_functions)} instead" 388 ) 389 return metrics 390 391 def _construct_optimizer(self) -> Optimizer: 392 """ 393 Generate an optimizer 394 """ 395 396 if "optimizer" in self.training_parameters: 397 name = self.training_parameters["optimizer"] 398 if name in options.optimizers: 399 optimizer = options.optimizers[name] 400 else: 401 raise ValueError( 402 f"The {name} optimizer is not available, please choose from {list(options.optimizers.keys())}" 403 ) 404 else: 405 optimizer = None 406 return optimizer 407 408 def _construct_predict_functions(self) -> Tuple[Callable, Callable]: 409 """ 410 Construct predict functions 411 """ 412 413 predict_function = self.training_parameters.get("predict_function", None) 414 primary_predict_function = self.training_parameters.get( 415 "primary_predict_function", None 416 ) 417 model_name = self.general_parameters.get("model_name", "") 418 threshold = self.training_parameters.get("hard_threshold", 0.5) 419 if not isinstance(predict_function, Callable): 420 if model_name in ["c2f_tcn", "c2f_transformer", "c2f_tcn_p"]: 421 if self.general_parameters["exclusive"]: 422 func = lambda x: torch.softmax(x, dim=1) 423 else: 424 func = lambda x: torch.sigmoid(x) 425 426 def primary_predict_function(x): 427 if len(x.shape) != 4: 428 x = x.reshape((4, -1, x.shape[-2], x.shape[-1])) 429 weights = [1, 1, 1, 1] 430 ensemble_prob = func(x[0]) * weights[0] / sum(weights) 431 for i, outp_ele in enumerate(x[1:]): 432 ensemble_prob = ensemble_prob + func(outp_ele) * weights[ 433 i + 1 434 ] / sum(weights) 435 return ensemble_prob 436 437 else: 438 if model_name.startswith("ms_tcn") or model_name in [ 439 "asformer", 440 "transformer", 441 "c3d_ms", 442 "transformer_ms", 443 ]: 444 f = lambda x: x[-1] if len(x.shape) == 4 else x 445 elif model_name == "asrf": 446 447 def f(x): 448 x = x[-1] 449 # bounds = x[:, 0, :].unsqueeze(1) 450 cls = x[:, 1:, :] 451 # device = x.device 452 # x = PostProcessor("refinement_with_boundary")._refinement_with_boundary(cls.detach().cpu().numpy(), bounds.detach().cpu().numpy()) 453 # x = torch.tensor(x).to(device) 454 return cls 455 456 elif model_name == "actionclip": 457 458 def f(x): 459 video_embedding, text_embedding, logit_scale = ( 460 x["video"], 461 x["text"], 462 x["logit_scale"], 463 ) 464 B, Ff, T = video_embedding.shape 465 video_embedding = video_embedding.permute(0, 2, 1).reshape( 466 (B * T, -1) 467 ) 468 video_embedding /= video_embedding.norm(dim=-1, keepdim=True) 469 text_embedding /= text_embedding.norm(dim=-1, keepdim=True) 470 similarity = logit_scale * video_embedding @ text_embedding.T 471 similarity = similarity.reshape((B, T, -1)).permute(0, 2, 1) 472 return similarity 473 474 else: 475 f = lambda x: x 476 if self.general_parameters["exclusive"]: 477 primary_predict_function = lambda x: torch.softmax(f(x), dim=1) 478 else: 479 primary_predict_function = lambda x: torch.sigmoid(f(x)) 480 if self.general_parameters["exclusive"]: 481 predict_function = lambda x: torch.max(x.data, dim=1)[1] 482 else: 483 predict_function = lambda x: (x > threshold).int() 484 return primary_predict_function, predict_function 485 486 def _get_parameters_from_training(self) -> Dict: 487 """ 488 Get the training parameters that need to be passed to the Task 489 """ 490 491 task_training_par_names = [ 492 "lr", 493 "parallel", 494 "device", 495 "verbose", 496 "log_file", 497 "augment_train", 498 "augment_val", 499 "hard_threshold", 500 "ssl_losses", 501 "model_save_path", 502 "model_save_epochs", 503 "pseudolabel", 504 "pseudolabel_start", 505 "correction_interval", 506 "pseudolabel_alpha_f", 507 "alpha_growth_stop", 508 "num_epochs", 509 "validation_interval", 510 "ignore_tags", 511 "skip_metrics", 512 ] 513 task_training_pars = { 514 name: self.training_parameters[name] 515 for name in task_training_par_names 516 if self.check(self.training_parameters, name) 517 } 518 if self.check(self.general_parameters, "ssl"): 519 ssl_weights = [ 520 self.training_parameters["ssl_weights"][x] 521 for x in self.general_parameters["ssl"] 522 ] 523 task_training_pars["ssl_weights"] = ssl_weights 524 return task_training_pars 525 526 def _update_parameters_from_ssl(self, ssl_list: list) -> None: 527 """ 528 Update the necessary parameters given the list of SSL constructors 529 """ 530 531 if self.task is not None: 532 self.task.set_ssl_transformations([ssl.transformation for ssl in ssl_list]) 533 self.task.set_ssl_losses([ssl.loss for ssl in ssl_list]) 534 self.task.set_keep_target_none( 535 [ssl.type in ["contrastive"] for ssl in ssl_list] 536 ) 537 self.task.set_generate_ssl_input( 538 [ssl.type == "contrastive" for ssl in ssl_list] 539 ) 540 self.data_parameters["ssl_transformations"] = [ 541 ssl.transformation for ssl in ssl_list 542 ] 543 self.training_parameters["ssl_losses"] = [ssl.loss for ssl in ssl_list] 544 self.model_parameters["ssl_types"] = [ssl.type for ssl in ssl_list] 545 self.model_parameters["ssl_modules"] = [ 546 ssl.construct_module() for ssl in ssl_list 547 ] 548 self.aug_parameters["generate_ssl_input"] = [ 549 x.type == "contrastive" for x in ssl_list 550 ] 551 self.aug_parameters["keep_target_none"] = [ 552 x.type == "contrastive" for x in ssl_list 553 ] 554 555 def _set_loss_weights(self, parameters): 556 """ 557 Replace the `"dataset_inverse_weights"` blank in loss parameters with class weight values 558 """ 559 560 for k in list(parameters.keys()): 561 if parameters[k] in [ 562 "dataset_inverse_weights", 563 "dataset_proportional_weights", 564 ]: 565 if parameters[k] == "dataset_inverse_weights": 566 parameters[k] = self.class_weights 567 else: 568 parameters[k] = self.proportional_class_weights 569 print("Initializing class weights:") 570 string = " " 571 if isinstance(parameters[k], Mapping): 572 for key, val in parameters[k].items(): 573 string += ": ".join( 574 ( 575 " " + str(key), 576 ", ".join((map(lambda x: str(np.round(x, 3)), val))), 577 ) 578 ) 579 else: 580 string += ", ".join( 581 (map(lambda x: str(np.round(x, 3)), parameters[k])) 582 ) 583 print(string) 584 return parameters 585 586 def _partition_dataset( 587 self, dataset: BehaviorDataset 588 ) -> Tuple[BehaviorDataset, BehaviorDataset, BehaviorDataset]: 589 """ 590 Partition the dataset into train, validation and test subsamples 591 """ 592 593 use_test = self.get(self.training_parameters, "use_test", 0) 594 split_path = self.training_parameters.get("split_path", None) 595 partition_method = self.training_parameters.get("partition_method", "random") 596 val_frac = self.get(self.training_parameters, "val_frac", 0) 597 test_frac = self.get(self.training_parameters, "test_frac", 0) 598 save_split = self.get(self.training_parameters, "save_split", True) 599 normalize = self.get(self.training_parameters, "normalize", False) 600 skip_normalization_keys = self.training_parameters.get( 601 "skip_normalization_keys" 602 ) 603 stats = self.training_parameters.get("stats") 604 train_dataset, test_dataset, val_dataset = dataset.partition_train_test_val( 605 use_test, 606 split_path, 607 partition_method, 608 val_frac, 609 test_frac, 610 save_split, 611 normalize, 612 skip_normalization_keys, 613 stats, 614 ) 615 bs = int(self.training_parameters.get("batch_size", 32)) 616 train_dataloader, test_dataloader, val_dataloader = ( 617 self.make_dataloader(train_dataset, batch_size=bs, shuffle=True), 618 self.make_dataloader(test_dataset, batch_size=bs, shuffle=False), 619 self.make_dataloader(val_dataset, batch_size=bs, shuffle=False), 620 ) 621 return train_dataloader, test_dataloader, val_dataloader 622 623 def _initialize_task(self): 624 """ 625 Create a `dlc2action.task.universal_task.Task` instance 626 """ 627 628 dataset = self._construct_dataset() 629 self._update_data_blanks(dataset) 630 model = self._construct_model() 631 self._update_model_blanks(model) 632 ssl_list = self._construct_ssl() 633 self._update_parameters_from_ssl(ssl_list) 634 model.set_ssl(ssl_constructors=ssl_list) 635 dataset.set_ssl_transformations([ssl.transformation for ssl in ssl_list]) 636 transformer = self._construct_transformer() 637 metrics = self._construct_metrics() 638 optimizer = self._construct_optimizer() 639 primary_predict_function, predict_function = self._construct_predict_functions() 640 641 task_training_pars = self._get_parameters_from_training() 642 train_dataloader, test_dataloader, val_dataloader = self._partition_dataset( 643 dataset 644 ) 645 self.class_weights = train_dataloader.dataset.class_weights() 646 self.proportional_class_weights = train_dataloader.dataset.class_weights(True) 647 loss = self._construct_loss() 648 exclusive = self.general_parameters["exclusive"] 649 650 task_pars = { 651 "train_dataloader": train_dataloader, 652 "model": model, 653 "loss": loss, 654 "transformer": transformer, 655 "metrics": metrics, 656 "val_dataloader": val_dataloader, 657 "test_dataloader": test_dataloader, 658 "exclusive": exclusive, 659 "optimizer": optimizer, 660 "predict_function": predict_function, 661 "primary_predict_function": primary_predict_function, 662 } 663 task_pars.update(task_training_pars) 664 665 self.task = Task(**task_pars) 666 checkpoint_path = self.training_parameters.get("checkpoint_path", None) 667 if checkpoint_path is not None: 668 only_model = self.get(self.training_parameters, "only_load_model", False) 669 load_strict = self.get(self.training_parameters, "load_strict", True) 670 self.task.load_from_checkpoint(checkpoint_path, only_model, load_strict) 671 if ( 672 self.general_parameters["only_load_annotated"] 673 and self.general_parameters.get("ssl") is not None 674 ): 675 warnings.warn( 676 "Note that you are using SSL modules but only loading annotated files! Set " 677 "general/only_load_annotated to False to change that" 678 ) 679 680 def _update_data_blanks( 681 self, dataset: BehaviorDataset = None, remember: bool = False 682 ) -> None: 683 """ 684 Update all blanks from a dataset 685 """ 686 687 if dataset is None: 688 dataset = self.dataset() 689 self._update_dim_parameter(dataset, remember) 690 self._update_bodyparts_parameter(dataset, remember) 691 self._update_num_classes_parameter(dataset, remember) 692 self._update_len_segment_parameter(dataset, remember) 693 self._update_boundary_parameter(dataset, remember) 694 695 def _update_model_blanks(self, model: Model, remember: bool = False) -> None: 696 self._update_features_parameter(model, remember) 697 698 def _update_parameter(self, blank_name: str, value, remember: bool = False): 699 parameters = [ 700 self.model_parameters, 701 self.ssl_parameters, 702 self.general_parameters, 703 self.feature_parameters, 704 self.data_parameters, 705 self.training_parameters, 706 self.metric_parameters, 707 self.loss_parameters, 708 self.aug_parameters, 709 ] 710 par_names = [ 711 "model", 712 "ssl", 713 "general", 714 "feature", 715 "data", 716 "training", 717 "metrics", 718 "losses", 719 "augmentations", 720 ] 721 for names in self.blanks[blank_name]: 722 group = names[0] 723 key = names[1] 724 ind = par_names.index(group) 725 if len(names) == 3: 726 if names[2] in parameters[ind][key]: 727 parameters[ind][key][names[2]] = value 728 else: 729 if key in parameters[ind]: 730 parameters[ind][key] = value 731 for name, dic in zip(par_names, parameters): 732 for k, v in dic.items(): 733 if v == blank_name: 734 dic[k] = value 735 if [name, k] not in self.blanks[blank_name]: 736 self.blanks[blank_name].append([name, k]) 737 elif isinstance(v, Mapping): 738 for kk, vv in v.items(): 739 if vv == blank_name: 740 dic[k][kk] = value 741 if [name, k, kk] not in self.blanks[blank_name]: 742 self.blanks[blank_name].append([name, k, kk]) 743 744 def _update_features_parameter(self, model: Model, remember: bool = False) -> None: 745 """ 746 Fill the `"model_features"` blank 747 """ 748 749 value = model.features_shape() 750 self._update_parameter("model_features", value, remember) 751 752 def _update_bodyparts_parameter( 753 self, dataset: BehaviorDataset, remember: bool = False 754 ) -> None: 755 """ 756 Fill the `"dataset_bodyparts"` blank 757 """ 758 759 value = dataset.bodyparts_order() 760 self._update_parameter("dataset_bodyparts", value, remember) 761 762 def _update_dim_parameter( 763 self, dataset: BehaviorDataset, remember: bool = False 764 ) -> None: 765 """ 766 Fill the `"dataset_features"` blank 767 """ 768 769 value = dataset.features_shape() 770 self._update_parameter("dataset_features", value, remember) 771 772 def _update_boundary_parameter( 773 self, dataset: BehaviorDataset, remember: bool = False 774 ) -> None: 775 """ 776 Fill the `"dataset_features"` blank 777 """ 778 779 value = dataset.boundary_class_weight() 780 self._update_parameter("dataset_boundary_weight", value, remember) 781 782 def _update_num_classes_parameter( 783 self, dataset: BehaviorDataset, remember: bool = False 784 ) -> None: 785 """ 786 Fill in the `"dataset_classes"` blank 787 """ 788 789 value = dataset.num_classes() 790 self._update_parameter("dataset_classes", value, remember) 791 792 def _update_len_segment_parameter( 793 self, dataset: BehaviorDataset, remember: bool = False 794 ) -> None: 795 """ 796 Fill in the `"dataset_len_segment"` blank 797 """ 798 799 value = dataset.len_segment() 800 self._update_parameter("dataset_len_segment", value, remember) 801 802 def _print_behaviors(self): 803 behavior_set = self.behaviors_dict() 804 print(f"Behavior indices:") 805 for key, value in sorted(behavior_set.items()): 806 print(f" {key}: {value}") 807 808 def update_task(self, parameters: Dict) -> None: 809 """ 810 Update the `dlc2action.task.universal_task.Task` instance given the parameter updates 811 812 Parameters 813 ---------- 814 parameters : dict 815 the dictionary of parameter updates 816 """ 817 818 pars = deepcopy(parameters) 819 # for blank_name in self.blanks: 820 # for names in self.blanks[blank_name]: 821 # group = names[0] 822 # key = names[1] 823 # if len(names) == 3: 824 # if ( 825 # group in pars 826 # and key in pars[group] 827 # and names[2] in pars[group][key] 828 # ): 829 # pars[group][key].pop(names[2]) 830 # else: 831 # if group in pars and key in pars[group]: 832 # pars[group].pop(key) 833 stay = False 834 if "ssl" in pars: 835 for key in pars["ssl"]: 836 if key in self.ssl_parameters: 837 self.ssl_parameters[key].update(pars["ssl"][key]) 838 else: 839 self.ssl_parameters[key] = pars["ssl"][key] 840 841 if "general" in pars: 842 if stay: 843 stay = False 844 if ( 845 "model_name" in pars["general"] 846 and pars["general"]["model_name"] 847 != self.general_parameters["model_name"] 848 ): 849 if "model" not in pars: 850 raise ValueError( 851 "When updating a task with a new model name you need to pass the parameters for the " 852 "new model" 853 ) 854 self.model_parameters = {} 855 self.general_parameters.update(pars["general"]) 856 data_related = [ 857 "num_classes", 858 "exclusive", 859 "data_type", 860 "annotation_type", 861 ] 862 ssl_related = ["ssl", "exclusive", "num_classes"] 863 loss_related = ["num_classes", "loss_function", "exclusive"] 864 augmentation_related = ["augmentation_type"] 865 metric_related = ["metric_functions"] 866 related_lists = [ 867 data_related, 868 ssl_related, 869 loss_related, 870 augmentation_related, 871 metric_related, 872 ] 873 names = ["data", "ssl", "losses", "augmentations", "metrics"] 874 for related_list, name in zip(related_lists, names): 875 if ( 876 any([x in pars["general"] for x in related_list]) 877 and name not in pars 878 ): 879 pars[name] = {} 880 881 if "training" in pars: 882 if "data" not in pars or not stay: 883 for x in [ 884 "to_ram", 885 "use_test", 886 "partition_method", 887 "val_frac", 888 "test_frac", 889 "save_split", 890 "batch_size", 891 "save_split", 892 ]: 893 if ( 894 x in pars["training"] 895 and pars["training"][x] != self.training_parameters[x] 896 ): 897 if "data" not in pars: 898 pars["data"] = {} 899 stay = True 900 self.training_parameters.update(pars["training"]) 901 self.task.update_parameters(self._get_parameters_from_training()) 902 903 if "data" in pars or "features" in pars: 904 for k, v in pars["data"].items(): 905 if k not in self.data_parameters or v != self.data_parameters[k]: 906 stay = True 907 for k, v in pars["features"].items(): 908 if k not in self.feature_parameters or v != self.feature_parameters[k]: 909 stay = True 910 if stay: 911 self.data_parameters.update(pars["data"]) 912 self.feature_parameters.update(pars["features"]) 913 dataset = self._construct_dataset() 914 ( 915 train_dataloader, 916 test_dataloader, 917 val_dataloader, 918 ) = self._partition_dataset(dataset) 919 self.task.set_dataloaders( 920 train_dataloader, val_dataloader, test_dataloader 921 ) 922 self.class_weights = train_dataloader.dataset.class_weights() 923 self.proportional_class_weights = ( 924 train_dataloader.dataset.class_weights(True) 925 ) 926 if "losses" not in pars: 927 pars["losses"] = {} 928 929 if "model" in pars: 930 self.model_parameters.update(pars["model"]) 931 932 self._update_data_blanks() 933 934 if "augmentations" in pars: 935 self.aug_parameters.update(pars["augmentations"]) 936 transformer = self._construct_transformer() 937 self.task.set_transformer(transformer) 938 939 if "losses" in pars: 940 for key in pars["losses"]: 941 if key in self.loss_parameters: 942 self.loss_parameters[key].update(pars["losses"][key]) 943 else: 944 self.loss_parameters[key] = pars["losses"][key] 945 self.loss_parameters.update(pars["losses"]) 946 loss = self._construct_loss() 947 self.task.set_loss(loss) 948 949 if "metrics" in pars: 950 for key in pars["metrics"]: 951 if key in self.metric_parameters: 952 self.metric_parameters[key].update(pars["metrics"][key]) 953 else: 954 self.metric_parameters[key] = pars["metrics"][key] 955 metrics = self._construct_metrics() 956 self.task.set_metrics(metrics) 957 958 self.task.set_ssl_transformations(self.data_parameters["ssl_transformations"]) 959 self._set_loss_weights( 960 pars.get("losses", {}).get(self.general_parameters["loss_function"], {}) 961 ) 962 model = self._construct_model() 963 predict_functions = self._construct_predict_functions() 964 self.task.set_predict_functions(*predict_functions) 965 self._update_model_blanks(model) 966 ssl_list = self._construct_ssl() 967 self._update_parameters_from_ssl(ssl_list) 968 model.set_ssl(ssl_constructors=ssl_list) 969 self.task.set_ssl_transformations([ssl.transformation for ssl in ssl_list]) 970 self.task.set_model(model) 971 if "training" in pars and "checkpoint_path" in pars["training"]: 972 checkpoint_path = pars["training"]["checkpoint_path"] 973 only_model = pars["training"].get("only_load_model", False) 974 load_strict = pars["training"].get("load_strict", True) 975 self.task.load_from_checkpoint(checkpoint_path, only_model, load_strict) 976 if ( 977 self.general_parameters["only_load_annotated"] 978 and self.general_parameters.get("ssl") is not None 979 ): 980 warnings.warn( 981 "Note that you are using SSL modules but only loading annotated files! Set " 982 "general/only_load_annotated to False to change that" 983 ) 984 if self.task.dataset("train").annotation_class() != "none": 985 self._print_behaviors() 986 987 def train( 988 self, 989 trial: Trial = None, 990 optimized_metric: str = None, 991 autostop_metric: str = None, 992 autostop_interval: int = 10, 993 autostop_threshold: float = 0.001, 994 loading_bar: bool = False, 995 ) -> Tuple: 996 """ 997 Train the task and return a log of epoch-average loss and metric 998 999 You can use the autostop parameters to finish training when the parameters are not improving. It will be 1000 stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than 1001 the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the 1002 current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared. 1003 1004 Parameters 1005 ---------- 1006 trial : Trial 1007 an `optuna` trial (for hyperparameter searches) 1008 optimized_metric : str 1009 the name of the metric being optimized (for hyperparameter searches) 1010 to_ram : bool, default False 1011 if `True`, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes 1012 if the dataset is too large) 1013 autostop_interval : int, default 50 1014 the number of epochs to average the autostop metric over 1015 autostop_threshold : float, default 0.001 1016 the autostop difference threshold 1017 autostop_metric : str, optional 1018 the autostop metric (can be any one of the tracked metrics of `'loss'`) 1019 main_task_on : bool, default True 1020 if `False`, the main task (action segmentation) will not be used in training 1021 ssl_on : bool, default True 1022 if `False`, the SSL task will not be used in training 1023 1024 Returns 1025 ------- 1026 loss_log: list 1027 a list of float loss function values for each epoch 1028 metrics_log: dict 1029 a dictionary of metric value logs (first-level keys are 'train' and 'val', second-level keys are metric 1030 names, values are lists of function values) 1031 """ 1032 1033 to_ram = self.training_parameters.get("to_ram", False) 1034 logs = self.task.train( 1035 trial, 1036 optimized_metric, 1037 to_ram, 1038 autostop_metric=autostop_metric, 1039 autostop_interval=autostop_interval, 1040 autostop_threshold=autostop_threshold, 1041 main_task_on=self.training_parameters.get("main_task_on", True), 1042 ssl_on=self.training_parameters.get("ssl_on", True), 1043 temporal_subsampling_size=self.training_parameters.get( 1044 "temporal_subsampling_size" 1045 ), 1046 loading_bar=loading_bar, 1047 ) 1048 return logs 1049 1050 def save_model(self, save_path: str) -> None: 1051 """ 1052 Save the model of the `dlc2action.task.universal_task.Task` instance 1053 1054 Parameters 1055 ---------- 1056 save_path : str 1057 the path to the saved file 1058 """ 1059 1060 self.task.save_model(save_path) 1061 1062 def evaluate( 1063 self, 1064 data: Union[DataLoader, BehaviorDataset, str] = None, 1065 augment_n: int = 0, 1066 verbose: bool = True, 1067 ) -> Tuple: 1068 """ 1069 Evaluate the Task model 1070 1071 Parameters 1072 ---------- 1073 data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional 1074 the data to evaluate on (if not provided, evaluate on the Task validation dataset) 1075 augment_n : int, default 0 1076 the number of augmentations to average results over 1077 verbose : bool, default True 1078 if True, the process is reported to standard output 1079 1080 Returns 1081 ------- 1082 loss : float 1083 the average value of the loss function 1084 ssl_loss : float 1085 the average value of the SSL loss function 1086 metric : dict 1087 a dictionary of average values of metric functions 1088 """ 1089 1090 res = self.task.evaluate( 1091 data, 1092 augment_n, 1093 int(self.training_parameters.get("batch_size", 32)), 1094 verbose, 1095 ) 1096 return res 1097 1098 def evaluate_prediction( 1099 self, 1100 prediction: torch.Tensor, 1101 data: Union[DataLoader, BehaviorDataset, str] = None, 1102 ) -> Tuple: 1103 """ 1104 Compute metrics for a prediction 1105 1106 Parameters 1107 ---------- 1108 prediction : torch.Tensor 1109 the prediction 1110 data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional 1111 the data the prediction was made for (if not provided, take the validation dataset) 1112 1113 Returns 1114 ------- 1115 loss : float 1116 the average value of the loss function 1117 metric : dict 1118 a dictionary of average values of metric functions 1119 """ 1120 1121 return self.task.evaluate_prediction( 1122 prediction, data, int(self.training_parameters.get("batch_size", 32)) 1123 ) 1124 1125 def predict( 1126 self, 1127 data: Union[DataLoader, BehaviorDataset, str], 1128 raw_output: bool = False, 1129 apply_primary_function: bool = True, 1130 augment_n: int = 0, 1131 embedding: bool = False, 1132 ) -> torch.Tensor: 1133 """ 1134 Make a prediction with the Task model 1135 1136 Parameters 1137 ---------- 1138 data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional 1139 the data to evaluate on (if not provided, evaluate on the Task validation dataset) 1140 raw_output : bool, default False 1141 if `True`, the raw predicted probabilities are returned 1142 apply_primary_function : bool, default True 1143 if `True`, the primary predict function is applied (to map the model output into a shape corresponding to 1144 the input) 1145 augment_n : int, default 0 1146 the number of augmentations to average results over 1147 1148 Returns 1149 ------- 1150 prediction : torch.Tensor 1151 a prediction for the input data 1152 """ 1153 1154 to_ram = self.training_parameters.get("to_ram", False) 1155 return self.task.predict( 1156 data, 1157 raw_output, 1158 apply_primary_function, 1159 augment_n, 1160 int(self.training_parameters.get("batch_size", 32)), 1161 to_ram, 1162 embedding=embedding, 1163 ) 1164 1165 def dataset(self, mode: str = "train") -> BehaviorDataset: 1166 """ 1167 Get a dataset 1168 1169 Parameters 1170 ---------- 1171 mode : {'train', 'val', 'test'} 1172 the dataset to get 1173 1174 Returns 1175 ------- 1176 dataset : dlc2action.data.dataset.BehaviorDataset 1177 the dataset 1178 """ 1179 1180 return self.task.dataset(mode) 1181 1182 def generate_full_length_prediction( 1183 self, 1184 dataset: Union[BehaviorDataset, str] = None, 1185 augment_n: int = 10, 1186 ) -> Dict: 1187 """ 1188 Compile a prediction for the original input sequences 1189 1190 Parameters 1191 ---------- 1192 dataset : dlc2action.data.dataset.BehaviorDataset | str, optional 1193 the dataset to generate a prediction for (if `None`, generate for the `dlc2action.task.universal_task.Task` 1194 instance validation dataset) 1195 augment_n : int, default 10 1196 the number of augmentations to average results over 1197 1198 Returns 1199 ------- 1200 prediction : dict 1201 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 1202 are prediction tensors 1203 """ 1204 1205 return self.task.generate_full_length_prediction( 1206 dataset, int(self.training_parameters.get("batch_size", 32)), augment_n 1207 ) 1208 1209 def generate_submission( 1210 self, 1211 frame_number_map_file: str, 1212 dataset: Union[BehaviorDataset, str] = None, 1213 augment_n: int = 10, 1214 ) -> Dict: 1215 """ 1216 Generate a MABe-22 style submission dictionary 1217 1218 Parameters 1219 ---------- 1220 frame_number_map_file : str 1221 path to the frame number map file 1222 dataset : BehaviorDataset, optional 1223 the dataset to generate a prediction for (if `None`, generate for the validation dataset) 1224 augment_n : int, default 10 1225 the number of augmentations to average results over 1226 1227 Returns 1228 ------- 1229 submission : dict 1230 a dictionary with frame number mapping and embeddings 1231 """ 1232 1233 return self.task.generate_submission( 1234 frame_number_map_file, 1235 dataset, 1236 int(self.training_parameters.get("batch_size", 32)), 1237 augment_n, 1238 ) 1239 1240 def behaviors_dict(self): 1241 """ 1242 Get a behavior dictionary 1243 1244 Keys are label indices and values are label names. 1245 1246 Returns 1247 ------- 1248 behaviors_dict : dict 1249 behavior dictionary 1250 """ 1251 1252 return self.task.behaviors_dict() 1253 1254 def count_classes(self, bouts: bool = False) -> Dict: 1255 """ 1256 Get a dictionary of class counts in different modes 1257 1258 Parameters 1259 ---------- 1260 bouts : bool, default False 1261 if `True`, instead of frame counts segment counts are returned 1262 1263 Returns 1264 ------- 1265 class_counts : dict 1266 a dictionary where first-level keys are "train", "val" and "test", second-level keys are 1267 class names and values are class counts (in frames) 1268 """ 1269 1270 return self.task.count_classes(bouts) 1271 1272 def _visualize_results_label( 1273 self, 1274 label: str, 1275 save_path: str = None, 1276 add_legend: bool = True, 1277 ground_truth: bool = True, 1278 hide_axes: bool = False, 1279 width: int = 10, 1280 whole_video: bool = False, 1281 transparent: bool = False, 1282 dataset: BehaviorDataset = None, 1283 smooth_interval: int = 0, 1284 title: str = None, 1285 ): 1286 return self.task._visualize_results_label( 1287 label, 1288 save_path, 1289 add_legend, 1290 ground_truth, 1291 hide_axes, 1292 width, 1293 whole_video, 1294 transparent, 1295 dataset, 1296 smooth_interval=smooth_interval, 1297 title=title, 1298 ) 1299 1300 def visualize_results( 1301 self, 1302 save_path: str = None, 1303 add_legend: bool = True, 1304 ground_truth: bool = True, 1305 colormap: str = "viridis", 1306 hide_axes: bool = False, 1307 min_classes: int = 1, 1308 width: float = 10, 1309 whole_video: bool = False, 1310 transparent: bool = False, 1311 dataset: Union[BehaviorDataset, DataLoader, str, None] = None, 1312 drop_classes: Set = None, 1313 search_classes: Set = None, 1314 smooth_interval_prediction: int = None, 1315 ) -> None: 1316 """ 1317 Visualize random predictions 1318 1319 Parameters 1320 ---------- 1321 save_path : str, optional 1322 the path where the plot will be saved 1323 add_legend : bool, default True 1324 if True, legend will be added to the plot 1325 ground_truth : bool, default True 1326 if True, ground truth will be added to the plot 1327 colormap : str, default 'Accent' 1328 the `matplotlib` colormap to use 1329 hide_axes : bool, default True 1330 if `True`, the axes will be hidden on the plot 1331 min_classes : int, default 1 1332 the minimum number of classes in a displayed interval 1333 width : float, default 10 1334 the width of the plot 1335 whole_video : bool, default False 1336 if `True`, whole videos are plotted instead of segments 1337 transparent : bool, default False 1338 if `True`, the background on the plot is transparent 1339 dataset : BehaviorDataset | DataLoader | str | None, optional 1340 the dataset to make the prediction for (if not provided, the validation dataset is used) 1341 drop_classes : set, optional 1342 a set of class names to not be displayed 1343 search_classes : set, optional 1344 if given, only intervals where at least one of the classes is in ground truth will be shown 1345 """ 1346 1347 return self.task.visualize_results( 1348 save_path, 1349 add_legend, 1350 ground_truth, 1351 colormap, 1352 hide_axes, 1353 min_classes, 1354 width, 1355 whole_video, 1356 transparent, 1357 dataset, 1358 drop_classes, 1359 search_classes, 1360 smooth_interval_prediction=smooth_interval_prediction, 1361 ) 1362 1363 def generate_uncertainty_score( 1364 self, 1365 classes: List, 1366 augment_n: int = 0, 1367 method: str = "least_confidence", 1368 predicted: torch.Tensor = None, 1369 behaviors_dict: Dict = None, 1370 ) -> Dict: 1371 """ 1372 Generate frame-wise scores for active learning 1373 1374 Parameters 1375 ---------- 1376 classes : list 1377 a list of class names or indices; their confidence scores will be computed separately and stacked 1378 augment_n : int, default 0 1379 the number of augmentations to average over 1380 method : {"least_confidence", "entropy"} 1381 the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if 1382 `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`) 1383 1384 Returns 1385 ------- 1386 score_dicts : dict 1387 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 1388 are score tensors 1389 """ 1390 1391 return self.task.generate_uncertainty_score( 1392 classes, 1393 augment_n, 1394 int(self.training_parameters.get("batch_size", 32)), 1395 method, 1396 predicted, 1397 behaviors_dict, 1398 ) 1399 1400 def generate_bald_score( 1401 self, 1402 classes: List, 1403 augment_n: int = 0, 1404 num_models: int = 10, 1405 kernel_size: int = 11, 1406 ) -> Dict: 1407 """ 1408 Generate frame-wise Bayesian Active Learning by Disagreement scores for active learning 1409 1410 Parameters 1411 ---------- 1412 classes : list 1413 a list of class names or indices; their confidence scores will be computed separately and stacked 1414 augment_n : int, default 0 1415 the number of augmentations to average over 1416 num_models : int, default 10 1417 the number of dropout masks to apply 1418 kernel_size : int, default 11 1419 the size of the smoothing gaussian kernel 1420 1421 Returns 1422 ------- 1423 score_dicts : dict 1424 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 1425 are score tensors 1426 """ 1427 1428 return self.task.generate_bald_score( 1429 classes, 1430 augment_n, 1431 int(self.training_parameters.get("batch_size", 32)), 1432 num_models, 1433 kernel_size, 1434 ) 1435 1436 def get_normalization_stats(self) -> Dict: 1437 """ 1438 Get the pre-computed normalization stats 1439 1440 Returns 1441 ------- 1442 normalization_stats : dict 1443 a dictionary of means and stds 1444 """ 1445 1446 return self.task.get_normalization_stats() 1447 1448 def exists(self, mode) -> bool: 1449 """ 1450 Check whether the task has a train/test/validation subset 1451 1452 Parameters 1453 ---------- 1454 mode : {"train", "val", "test"} 1455 the name of the subset to check for 1456 1457 Returns 1458 ------- 1459 exists : bool 1460 `True` if the subset exists 1461 """ 1462 1463 dl = self.task.dataloader(mode) 1464 if dl is None: 1465 return False 1466 else: 1467 return True
34class TaskDispatcher: 35 """ 36 A class that manages the interactions between config dictionaries and a Task 37 """ 38 39 def __init__(self, parameters: Dict) -> None: 40 """ 41 Parameters 42 ---------- 43 parameters : dict 44 a dictionary of task parameters 45 """ 46 47 pars = deepcopy(parameters) 48 self.class_weights = None 49 self.general_parameters = pars.get("general", {}) 50 self.data_parameters = pars.get("data", {}) 51 self.model_parameters = pars.get("model", {}) 52 self.training_parameters = pars.get("training", {}) 53 self.loss_parameters = pars.get("losses", {}) 54 self.metric_parameters = pars.get("metrics", {}) 55 self.ssl_parameters = pars.get("ssl", {}) 56 self.aug_parameters = pars.get("augmentations", {}) 57 self.feature_parameters = pars.get("features", {}) 58 self.blanks = {blank: [] for blank in options.blanks} 59 60 self.task = None 61 self._initialize_task() 62 self._print_behaviors() 63 64 @staticmethod 65 def complete_function_parameters(parameters, function, general_dicts: List) -> Dict: 66 """ 67 Complete a parameter dictionary with values from other dictionaries if required by a function 68 69 Parameters 70 ---------- 71 parameters : dict 72 the function parameters dictionary 73 function : callable 74 the function to be inspected 75 general_dicts : list 76 a list of dictionaries where the missing values will be pulled from 77 """ 78 79 parameter_names = inspect.getfullargspec(function).args 80 for param in parameter_names: 81 for dic in general_dicts: 82 if param not in parameters and param in dic: 83 parameters[param] = dic[param] 84 return parameters 85 86 @staticmethod 87 def complete_dataset_parameters( 88 parameters: dict, 89 general_dict: dict, 90 data_type: str, 91 annotation_type: str, 92 ) -> Dict: 93 """ 94 Complete a parameter dictionary with values from other dictionaries if required by a dataset 95 96 Parameters 97 ---------- 98 parameters : dict 99 the function parameters dictionary 100 general_dict : dict 101 the dictionary where the missing values will be pulled from 102 data_type : str 103 the input type of the dataset 104 annotation_type : str 105 the annotation type of the dataset 106 107 Returns 108 ------- 109 parameters : dict 110 the updated parameter dictionary 111 """ 112 113 params = deepcopy(parameters) 114 parameter_names = BehaviorDataset.get_parameters(data_type, annotation_type) 115 for param in parameter_names: 116 if param not in params and param in general_dict: 117 params[param] = general_dict[param] 118 return params 119 120 @staticmethod 121 def check(parameters: Dict, name: str) -> bool: 122 """ 123 Check whether there is a non-`None` value under the name key in the parameters dictionary 124 125 Parameters 126 ---------- 127 parameters : dict 128 the dictionary to check 129 name : str 130 the key to check 131 132 Returns 133 ------- 134 result : bool 135 True if a non-`None` value exists 136 """ 137 138 if name in parameters and parameters[name] is not None: 139 return True 140 else: 141 return False 142 143 @staticmethod 144 def get(parameters: Dict, name: str, default): 145 """ 146 Get the value under the name key or the default if it is `None` or does not exist 147 148 Parameters 149 ---------- 150 parameters : dict 151 the dictionary to check 152 name : str 153 the key to check 154 default 155 the default value to return 156 157 Returns 158 ------- 159 value 160 the resulting value 161 """ 162 163 if TaskDispatcher.check(parameters, name): 164 return parameters[name] 165 else: 166 return default 167 168 @staticmethod 169 def make_dataloader( 170 dataset: BehaviorDataset, batch_size: int = 32, shuffle: bool = False 171 ) -> DataLoader: 172 """ 173 Make a torch dataloader from a dataset 174 175 Parameters 176 ---------- 177 dataset : dlc2action.data.dataset.BehaviorDataset 178 the dataset 179 batch_size : int 180 the batch size 181 182 Returns 183 ------- 184 dataloader : DataLoader 185 the dataloader (or `None` if the length of the dataset is 0) 186 """ 187 188 if dataset is None or len(dataset) == 0: 189 return None 190 else: 191 return DataLoader(dataset, batch_size=int(batch_size), shuffle=shuffle) 192 193 def _construct_ssl(self) -> List: 194 """ 195 Generate SSL constructors 196 """ 197 198 ssl_list = deepcopy(self.general_parameters.get("ssl", None)) 199 if not isinstance(ssl_list, Iterable): 200 ssl_list = [ssl_list] 201 for i, ssl in enumerate(ssl_list): 202 if type(ssl) is str: 203 if ssl in options.ssl_constructors: 204 pars = self.get(self.ssl_parameters, ssl, default={}) 205 pars = self.complete_function_parameters( 206 parameters=pars, 207 function=options.ssl_constructors[ssl], 208 general_dicts=[ 209 self.model_parameters, 210 self.data_parameters, 211 self.general_parameters, 212 ], 213 ) 214 ssl_list[i] = options.ssl_constructors[ssl](**pars) 215 else: 216 raise ValueError( 217 f"The {ssl} SSL is not available, please choose from {list(options.ssl_constructors.keys())}" 218 ) 219 elif ssl is None: 220 ssl_list[i] = EmptySSL() 221 elif not isinstance(ssl, SSLConstructor): 222 raise TypeError( 223 f"The ssl parameter has to be a list of either strings, SSLConstructor instances or None, got {type(ssl)}" 224 ) 225 return ssl_list 226 227 def _construct_model(self) -> Model: 228 """ 229 Generate a model 230 """ 231 232 if self.check(self.general_parameters, "model"): 233 pars = self.complete_function_parameters( 234 function=LoadedModel, 235 parameters=self.model_parameters, 236 general_dicts=[self.general_parameters], 237 ) 238 model = LoadedModel(**pars) 239 elif self.check(self.general_parameters, "model_name"): 240 name = self.general_parameters["model_name"] 241 if name in options.models: 242 pars = self.complete_function_parameters( 243 function=options.models[name], 244 parameters=self.model_parameters, 245 general_dicts=[self.general_parameters], 246 ) 247 model = options.models[name](**pars) 248 else: 249 raise ValueError( 250 f"The {name} model is not available, please choose from {list(options.models.keys())}" 251 ) 252 else: 253 raise ValueError( 254 "You need to provide either a model or its name in the model_parameters!" 255 ) 256 257 if self.get(self.training_parameters, "freeze_features", False): 258 model.freeze_feature_extractor() 259 return model 260 261 def _construct_dataset(self) -> BehaviorDataset: 262 """ 263 Generate a dataset 264 """ 265 266 data_type = self.general_parameters.get("data_type", None) 267 if data_type is None: 268 raise ValueError( 269 "You need to provide the data_type parameter in the data parameters!" 270 ) 271 annotation_type = self.get(self.general_parameters, "annotation_type", "none") 272 feature_extraction = self.general_parameters.get("feature_extraction", "none") 273 if feature_extraction is None: 274 raise ValueError( 275 "You need to provide the feature_extraction parameter in the data parameters!" 276 ) 277 feature_extraction_pars = self.complete_function_parameters( 278 self.feature_parameters, 279 options.feature_extractors[feature_extraction], 280 [self.general_parameters, self.data_parameters], 281 ) 282 283 pars = self.complete_dataset_parameters( 284 self.data_parameters, 285 self.general_parameters, 286 data_type=data_type, 287 annotation_type=annotation_type, 288 ) 289 pars["feature_extraction_pars"] = feature_extraction_pars 290 dataset = BehaviorDataset(**pars) 291 292 if self.get(self.general_parameters, "save_dataset", default=False): 293 save_data_path = self.data_parameters.get("saved_data_path", None) 294 dataset.save(save_path=save_data_path) 295 296 return dataset 297 298 def _construct_transformer(self) -> Transformer: 299 """ 300 Generate a transformer 301 """ 302 303 features = self.general_parameters["feature_extraction"] 304 name = options.extractor_to_transformer[features] 305 if name in options.transformers: 306 transformer_class = options.transformers[name] 307 pars = self.complete_function_parameters( 308 function=transformer_class, 309 parameters=self.aug_parameters, 310 general_dicts=[self.general_parameters], 311 ) 312 transformer = transformer_class(**pars) 313 else: 314 raise ValueError(f"The {name} transformer is not available") 315 return transformer 316 317 def _construct_loss(self) -> torch.nn.Module: 318 """ 319 Generate a loss function 320 """ 321 322 if "loss_function" not in self.general_parameters: 323 raise ValueError( 324 'Please add a "loss_function" key to the parameters["general"] dictionary (either a name ' 325 f"from {list(options.losses.keys())} or a function)" 326 ) 327 else: 328 loss_function = self.general_parameters["loss_function"] 329 if type(loss_function) is str: 330 if loss_function in options.losses: 331 pars = self.get(self.loss_parameters, loss_function, default={}) 332 pars = self._set_loss_weights(pars) 333 pars = self.complete_function_parameters( 334 function=options.losses[loss_function], 335 parameters=pars, 336 general_dicts=[self.general_parameters], 337 ) 338 loss = options.losses[loss_function](**pars) 339 else: 340 raise ValueError( 341 f"The {loss_function} loss is not available, please choose from {list(options.losses.keys())}" 342 ) 343 else: 344 loss = loss_function 345 return loss 346 347 def _construct_metrics(self) -> List: 348 """ 349 Generate the metric 350 """ 351 352 metric_functions = self.get( 353 self.general_parameters, "metric_functions", default={} 354 ) 355 if isinstance(metric_functions, Iterable): 356 metrics = {} 357 for func in metric_functions: 358 if isinstance(func, str): 359 if func in options.metrics: 360 pars = self.get(self.metric_parameters, func, default={}) 361 pars = self.complete_function_parameters( 362 function=options.metrics[func], 363 parameters=pars, 364 general_dicts=[self.general_parameters], 365 ) 366 metrics[func] = options.metrics[func](**pars) 367 else: 368 raise ValueError( 369 f"The {func} metric is not available, please choose from {list(options.metrics.keys())}" 370 ) 371 elif isinstance(func, Metric): 372 name = "function_1" 373 i = 1 374 while name in metrics: 375 i += 1 376 name = f"function_{i}" 377 metrics[name] = func 378 else: 379 raise TypeError( 380 'The elements of parameters["general"]["metric_functions"] have to be either strings ' 381 f"from {list(options.metrics.keys())} or Metric instances; got {type(func)} instead" 382 ) 383 elif isinstance(metric_functions, dict): 384 metrics = metric_functions 385 else: 386 raise TypeError( 387 'The value at parameters["general"]["metric_functions"] can be either list, dictionary or None;' 388 f"got {type(metric_functions)} instead" 389 ) 390 return metrics 391 392 def _construct_optimizer(self) -> Optimizer: 393 """ 394 Generate an optimizer 395 """ 396 397 if "optimizer" in self.training_parameters: 398 name = self.training_parameters["optimizer"] 399 if name in options.optimizers: 400 optimizer = options.optimizers[name] 401 else: 402 raise ValueError( 403 f"The {name} optimizer is not available, please choose from {list(options.optimizers.keys())}" 404 ) 405 else: 406 optimizer = None 407 return optimizer 408 409 def _construct_predict_functions(self) -> Tuple[Callable, Callable]: 410 """ 411 Construct predict functions 412 """ 413 414 predict_function = self.training_parameters.get("predict_function", None) 415 primary_predict_function = self.training_parameters.get( 416 "primary_predict_function", None 417 ) 418 model_name = self.general_parameters.get("model_name", "") 419 threshold = self.training_parameters.get("hard_threshold", 0.5) 420 if not isinstance(predict_function, Callable): 421 if model_name in ["c2f_tcn", "c2f_transformer", "c2f_tcn_p"]: 422 if self.general_parameters["exclusive"]: 423 func = lambda x: torch.softmax(x, dim=1) 424 else: 425 func = lambda x: torch.sigmoid(x) 426 427 def primary_predict_function(x): 428 if len(x.shape) != 4: 429 x = x.reshape((4, -1, x.shape[-2], x.shape[-1])) 430 weights = [1, 1, 1, 1] 431 ensemble_prob = func(x[0]) * weights[0] / sum(weights) 432 for i, outp_ele in enumerate(x[1:]): 433 ensemble_prob = ensemble_prob + func(outp_ele) * weights[ 434 i + 1 435 ] / sum(weights) 436 return ensemble_prob 437 438 else: 439 if model_name.startswith("ms_tcn") or model_name in [ 440 "asformer", 441 "transformer", 442 "c3d_ms", 443 "transformer_ms", 444 ]: 445 f = lambda x: x[-1] if len(x.shape) == 4 else x 446 elif model_name == "asrf": 447 448 def f(x): 449 x = x[-1] 450 # bounds = x[:, 0, :].unsqueeze(1) 451 cls = x[:, 1:, :] 452 # device = x.device 453 # x = PostProcessor("refinement_with_boundary")._refinement_with_boundary(cls.detach().cpu().numpy(), bounds.detach().cpu().numpy()) 454 # x = torch.tensor(x).to(device) 455 return cls 456 457 elif model_name == "actionclip": 458 459 def f(x): 460 video_embedding, text_embedding, logit_scale = ( 461 x["video"], 462 x["text"], 463 x["logit_scale"], 464 ) 465 B, Ff, T = video_embedding.shape 466 video_embedding = video_embedding.permute(0, 2, 1).reshape( 467 (B * T, -1) 468 ) 469 video_embedding /= video_embedding.norm(dim=-1, keepdim=True) 470 text_embedding /= text_embedding.norm(dim=-1, keepdim=True) 471 similarity = logit_scale * video_embedding @ text_embedding.T 472 similarity = similarity.reshape((B, T, -1)).permute(0, 2, 1) 473 return similarity 474 475 else: 476 f = lambda x: x 477 if self.general_parameters["exclusive"]: 478 primary_predict_function = lambda x: torch.softmax(f(x), dim=1) 479 else: 480 primary_predict_function = lambda x: torch.sigmoid(f(x)) 481 if self.general_parameters["exclusive"]: 482 predict_function = lambda x: torch.max(x.data, dim=1)[1] 483 else: 484 predict_function = lambda x: (x > threshold).int() 485 return primary_predict_function, predict_function 486 487 def _get_parameters_from_training(self) -> Dict: 488 """ 489 Get the training parameters that need to be passed to the Task 490 """ 491 492 task_training_par_names = [ 493 "lr", 494 "parallel", 495 "device", 496 "verbose", 497 "log_file", 498 "augment_train", 499 "augment_val", 500 "hard_threshold", 501 "ssl_losses", 502 "model_save_path", 503 "model_save_epochs", 504 "pseudolabel", 505 "pseudolabel_start", 506 "correction_interval", 507 "pseudolabel_alpha_f", 508 "alpha_growth_stop", 509 "num_epochs", 510 "validation_interval", 511 "ignore_tags", 512 "skip_metrics", 513 ] 514 task_training_pars = { 515 name: self.training_parameters[name] 516 for name in task_training_par_names 517 if self.check(self.training_parameters, name) 518 } 519 if self.check(self.general_parameters, "ssl"): 520 ssl_weights = [ 521 self.training_parameters["ssl_weights"][x] 522 for x in self.general_parameters["ssl"] 523 ] 524 task_training_pars["ssl_weights"] = ssl_weights 525 return task_training_pars 526 527 def _update_parameters_from_ssl(self, ssl_list: list) -> None: 528 """ 529 Update the necessary parameters given the list of SSL constructors 530 """ 531 532 if self.task is not None: 533 self.task.set_ssl_transformations([ssl.transformation for ssl in ssl_list]) 534 self.task.set_ssl_losses([ssl.loss for ssl in ssl_list]) 535 self.task.set_keep_target_none( 536 [ssl.type in ["contrastive"] for ssl in ssl_list] 537 ) 538 self.task.set_generate_ssl_input( 539 [ssl.type == "contrastive" for ssl in ssl_list] 540 ) 541 self.data_parameters["ssl_transformations"] = [ 542 ssl.transformation for ssl in ssl_list 543 ] 544 self.training_parameters["ssl_losses"] = [ssl.loss for ssl in ssl_list] 545 self.model_parameters["ssl_types"] = [ssl.type for ssl in ssl_list] 546 self.model_parameters["ssl_modules"] = [ 547 ssl.construct_module() for ssl in ssl_list 548 ] 549 self.aug_parameters["generate_ssl_input"] = [ 550 x.type == "contrastive" for x in ssl_list 551 ] 552 self.aug_parameters["keep_target_none"] = [ 553 x.type == "contrastive" for x in ssl_list 554 ] 555 556 def _set_loss_weights(self, parameters): 557 """ 558 Replace the `"dataset_inverse_weights"` blank in loss parameters with class weight values 559 """ 560 561 for k in list(parameters.keys()): 562 if parameters[k] in [ 563 "dataset_inverse_weights", 564 "dataset_proportional_weights", 565 ]: 566 if parameters[k] == "dataset_inverse_weights": 567 parameters[k] = self.class_weights 568 else: 569 parameters[k] = self.proportional_class_weights 570 print("Initializing class weights:") 571 string = " " 572 if isinstance(parameters[k], Mapping): 573 for key, val in parameters[k].items(): 574 string += ": ".join( 575 ( 576 " " + str(key), 577 ", ".join((map(lambda x: str(np.round(x, 3)), val))), 578 ) 579 ) 580 else: 581 string += ", ".join( 582 (map(lambda x: str(np.round(x, 3)), parameters[k])) 583 ) 584 print(string) 585 return parameters 586 587 def _partition_dataset( 588 self, dataset: BehaviorDataset 589 ) -> Tuple[BehaviorDataset, BehaviorDataset, BehaviorDataset]: 590 """ 591 Partition the dataset into train, validation and test subsamples 592 """ 593 594 use_test = self.get(self.training_parameters, "use_test", 0) 595 split_path = self.training_parameters.get("split_path", None) 596 partition_method = self.training_parameters.get("partition_method", "random") 597 val_frac = self.get(self.training_parameters, "val_frac", 0) 598 test_frac = self.get(self.training_parameters, "test_frac", 0) 599 save_split = self.get(self.training_parameters, "save_split", True) 600 normalize = self.get(self.training_parameters, "normalize", False) 601 skip_normalization_keys = self.training_parameters.get( 602 "skip_normalization_keys" 603 ) 604 stats = self.training_parameters.get("stats") 605 train_dataset, test_dataset, val_dataset = dataset.partition_train_test_val( 606 use_test, 607 split_path, 608 partition_method, 609 val_frac, 610 test_frac, 611 save_split, 612 normalize, 613 skip_normalization_keys, 614 stats, 615 ) 616 bs = int(self.training_parameters.get("batch_size", 32)) 617 train_dataloader, test_dataloader, val_dataloader = ( 618 self.make_dataloader(train_dataset, batch_size=bs, shuffle=True), 619 self.make_dataloader(test_dataset, batch_size=bs, shuffle=False), 620 self.make_dataloader(val_dataset, batch_size=bs, shuffle=False), 621 ) 622 return train_dataloader, test_dataloader, val_dataloader 623 624 def _initialize_task(self): 625 """ 626 Create a `dlc2action.task.universal_task.Task` instance 627 """ 628 629 dataset = self._construct_dataset() 630 self._update_data_blanks(dataset) 631 model = self._construct_model() 632 self._update_model_blanks(model) 633 ssl_list = self._construct_ssl() 634 self._update_parameters_from_ssl(ssl_list) 635 model.set_ssl(ssl_constructors=ssl_list) 636 dataset.set_ssl_transformations([ssl.transformation for ssl in ssl_list]) 637 transformer = self._construct_transformer() 638 metrics = self._construct_metrics() 639 optimizer = self._construct_optimizer() 640 primary_predict_function, predict_function = self._construct_predict_functions() 641 642 task_training_pars = self._get_parameters_from_training() 643 train_dataloader, test_dataloader, val_dataloader = self._partition_dataset( 644 dataset 645 ) 646 self.class_weights = train_dataloader.dataset.class_weights() 647 self.proportional_class_weights = train_dataloader.dataset.class_weights(True) 648 loss = self._construct_loss() 649 exclusive = self.general_parameters["exclusive"] 650 651 task_pars = { 652 "train_dataloader": train_dataloader, 653 "model": model, 654 "loss": loss, 655 "transformer": transformer, 656 "metrics": metrics, 657 "val_dataloader": val_dataloader, 658 "test_dataloader": test_dataloader, 659 "exclusive": exclusive, 660 "optimizer": optimizer, 661 "predict_function": predict_function, 662 "primary_predict_function": primary_predict_function, 663 } 664 task_pars.update(task_training_pars) 665 666 self.task = Task(**task_pars) 667 checkpoint_path = self.training_parameters.get("checkpoint_path", None) 668 if checkpoint_path is not None: 669 only_model = self.get(self.training_parameters, "only_load_model", False) 670 load_strict = self.get(self.training_parameters, "load_strict", True) 671 self.task.load_from_checkpoint(checkpoint_path, only_model, load_strict) 672 if ( 673 self.general_parameters["only_load_annotated"] 674 and self.general_parameters.get("ssl") is not None 675 ): 676 warnings.warn( 677 "Note that you are using SSL modules but only loading annotated files! Set " 678 "general/only_load_annotated to False to change that" 679 ) 680 681 def _update_data_blanks( 682 self, dataset: BehaviorDataset = None, remember: bool = False 683 ) -> None: 684 """ 685 Update all blanks from a dataset 686 """ 687 688 if dataset is None: 689 dataset = self.dataset() 690 self._update_dim_parameter(dataset, remember) 691 self._update_bodyparts_parameter(dataset, remember) 692 self._update_num_classes_parameter(dataset, remember) 693 self._update_len_segment_parameter(dataset, remember) 694 self._update_boundary_parameter(dataset, remember) 695 696 def _update_model_blanks(self, model: Model, remember: bool = False) -> None: 697 self._update_features_parameter(model, remember) 698 699 def _update_parameter(self, blank_name: str, value, remember: bool = False): 700 parameters = [ 701 self.model_parameters, 702 self.ssl_parameters, 703 self.general_parameters, 704 self.feature_parameters, 705 self.data_parameters, 706 self.training_parameters, 707 self.metric_parameters, 708 self.loss_parameters, 709 self.aug_parameters, 710 ] 711 par_names = [ 712 "model", 713 "ssl", 714 "general", 715 "feature", 716 "data", 717 "training", 718 "metrics", 719 "losses", 720 "augmentations", 721 ] 722 for names in self.blanks[blank_name]: 723 group = names[0] 724 key = names[1] 725 ind = par_names.index(group) 726 if len(names) == 3: 727 if names[2] in parameters[ind][key]: 728 parameters[ind][key][names[2]] = value 729 else: 730 if key in parameters[ind]: 731 parameters[ind][key] = value 732 for name, dic in zip(par_names, parameters): 733 for k, v in dic.items(): 734 if v == blank_name: 735 dic[k] = value 736 if [name, k] not in self.blanks[blank_name]: 737 self.blanks[blank_name].append([name, k]) 738 elif isinstance(v, Mapping): 739 for kk, vv in v.items(): 740 if vv == blank_name: 741 dic[k][kk] = value 742 if [name, k, kk] not in self.blanks[blank_name]: 743 self.blanks[blank_name].append([name, k, kk]) 744 745 def _update_features_parameter(self, model: Model, remember: bool = False) -> None: 746 """ 747 Fill the `"model_features"` blank 748 """ 749 750 value = model.features_shape() 751 self._update_parameter("model_features", value, remember) 752 753 def _update_bodyparts_parameter( 754 self, dataset: BehaviorDataset, remember: bool = False 755 ) -> None: 756 """ 757 Fill the `"dataset_bodyparts"` blank 758 """ 759 760 value = dataset.bodyparts_order() 761 self._update_parameter("dataset_bodyparts", value, remember) 762 763 def _update_dim_parameter( 764 self, dataset: BehaviorDataset, remember: bool = False 765 ) -> None: 766 """ 767 Fill the `"dataset_features"` blank 768 """ 769 770 value = dataset.features_shape() 771 self._update_parameter("dataset_features", value, remember) 772 773 def _update_boundary_parameter( 774 self, dataset: BehaviorDataset, remember: bool = False 775 ) -> None: 776 """ 777 Fill the `"dataset_features"` blank 778 """ 779 780 value = dataset.boundary_class_weight() 781 self._update_parameter("dataset_boundary_weight", value, remember) 782 783 def _update_num_classes_parameter( 784 self, dataset: BehaviorDataset, remember: bool = False 785 ) -> None: 786 """ 787 Fill in the `"dataset_classes"` blank 788 """ 789 790 value = dataset.num_classes() 791 self._update_parameter("dataset_classes", value, remember) 792 793 def _update_len_segment_parameter( 794 self, dataset: BehaviorDataset, remember: bool = False 795 ) -> None: 796 """ 797 Fill in the `"dataset_len_segment"` blank 798 """ 799 800 value = dataset.len_segment() 801 self._update_parameter("dataset_len_segment", value, remember) 802 803 def _print_behaviors(self): 804 behavior_set = self.behaviors_dict() 805 print(f"Behavior indices:") 806 for key, value in sorted(behavior_set.items()): 807 print(f" {key}: {value}") 808 809 def update_task(self, parameters: Dict) -> None: 810 """ 811 Update the `dlc2action.task.universal_task.Task` instance given the parameter updates 812 813 Parameters 814 ---------- 815 parameters : dict 816 the dictionary of parameter updates 817 """ 818 819 pars = deepcopy(parameters) 820 # for blank_name in self.blanks: 821 # for names in self.blanks[blank_name]: 822 # group = names[0] 823 # key = names[1] 824 # if len(names) == 3: 825 # if ( 826 # group in pars 827 # and key in pars[group] 828 # and names[2] in pars[group][key] 829 # ): 830 # pars[group][key].pop(names[2]) 831 # else: 832 # if group in pars and key in pars[group]: 833 # pars[group].pop(key) 834 stay = False 835 if "ssl" in pars: 836 for key in pars["ssl"]: 837 if key in self.ssl_parameters: 838 self.ssl_parameters[key].update(pars["ssl"][key]) 839 else: 840 self.ssl_parameters[key] = pars["ssl"][key] 841 842 if "general" in pars: 843 if stay: 844 stay = False 845 if ( 846 "model_name" in pars["general"] 847 and pars["general"]["model_name"] 848 != self.general_parameters["model_name"] 849 ): 850 if "model" not in pars: 851 raise ValueError( 852 "When updating a task with a new model name you need to pass the parameters for the " 853 "new model" 854 ) 855 self.model_parameters = {} 856 self.general_parameters.update(pars["general"]) 857 data_related = [ 858 "num_classes", 859 "exclusive", 860 "data_type", 861 "annotation_type", 862 ] 863 ssl_related = ["ssl", "exclusive", "num_classes"] 864 loss_related = ["num_classes", "loss_function", "exclusive"] 865 augmentation_related = ["augmentation_type"] 866 metric_related = ["metric_functions"] 867 related_lists = [ 868 data_related, 869 ssl_related, 870 loss_related, 871 augmentation_related, 872 metric_related, 873 ] 874 names = ["data", "ssl", "losses", "augmentations", "metrics"] 875 for related_list, name in zip(related_lists, names): 876 if ( 877 any([x in pars["general"] for x in related_list]) 878 and name not in pars 879 ): 880 pars[name] = {} 881 882 if "training" in pars: 883 if "data" not in pars or not stay: 884 for x in [ 885 "to_ram", 886 "use_test", 887 "partition_method", 888 "val_frac", 889 "test_frac", 890 "save_split", 891 "batch_size", 892 "save_split", 893 ]: 894 if ( 895 x in pars["training"] 896 and pars["training"][x] != self.training_parameters[x] 897 ): 898 if "data" not in pars: 899 pars["data"] = {} 900 stay = True 901 self.training_parameters.update(pars["training"]) 902 self.task.update_parameters(self._get_parameters_from_training()) 903 904 if "data" in pars or "features" in pars: 905 for k, v in pars["data"].items(): 906 if k not in self.data_parameters or v != self.data_parameters[k]: 907 stay = True 908 for k, v in pars["features"].items(): 909 if k not in self.feature_parameters or v != self.feature_parameters[k]: 910 stay = True 911 if stay: 912 self.data_parameters.update(pars["data"]) 913 self.feature_parameters.update(pars["features"]) 914 dataset = self._construct_dataset() 915 ( 916 train_dataloader, 917 test_dataloader, 918 val_dataloader, 919 ) = self._partition_dataset(dataset) 920 self.task.set_dataloaders( 921 train_dataloader, val_dataloader, test_dataloader 922 ) 923 self.class_weights = train_dataloader.dataset.class_weights() 924 self.proportional_class_weights = ( 925 train_dataloader.dataset.class_weights(True) 926 ) 927 if "losses" not in pars: 928 pars["losses"] = {} 929 930 if "model" in pars: 931 self.model_parameters.update(pars["model"]) 932 933 self._update_data_blanks() 934 935 if "augmentations" in pars: 936 self.aug_parameters.update(pars["augmentations"]) 937 transformer = self._construct_transformer() 938 self.task.set_transformer(transformer) 939 940 if "losses" in pars: 941 for key in pars["losses"]: 942 if key in self.loss_parameters: 943 self.loss_parameters[key].update(pars["losses"][key]) 944 else: 945 self.loss_parameters[key] = pars["losses"][key] 946 self.loss_parameters.update(pars["losses"]) 947 loss = self._construct_loss() 948 self.task.set_loss(loss) 949 950 if "metrics" in pars: 951 for key in pars["metrics"]: 952 if key in self.metric_parameters: 953 self.metric_parameters[key].update(pars["metrics"][key]) 954 else: 955 self.metric_parameters[key] = pars["metrics"][key] 956 metrics = self._construct_metrics() 957 self.task.set_metrics(metrics) 958 959 self.task.set_ssl_transformations(self.data_parameters["ssl_transformations"]) 960 self._set_loss_weights( 961 pars.get("losses", {}).get(self.general_parameters["loss_function"], {}) 962 ) 963 model = self._construct_model() 964 predict_functions = self._construct_predict_functions() 965 self.task.set_predict_functions(*predict_functions) 966 self._update_model_blanks(model) 967 ssl_list = self._construct_ssl() 968 self._update_parameters_from_ssl(ssl_list) 969 model.set_ssl(ssl_constructors=ssl_list) 970 self.task.set_ssl_transformations([ssl.transformation for ssl in ssl_list]) 971 self.task.set_model(model) 972 if "training" in pars and "checkpoint_path" in pars["training"]: 973 checkpoint_path = pars["training"]["checkpoint_path"] 974 only_model = pars["training"].get("only_load_model", False) 975 load_strict = pars["training"].get("load_strict", True) 976 self.task.load_from_checkpoint(checkpoint_path, only_model, load_strict) 977 if ( 978 self.general_parameters["only_load_annotated"] 979 and self.general_parameters.get("ssl") is not None 980 ): 981 warnings.warn( 982 "Note that you are using SSL modules but only loading annotated files! Set " 983 "general/only_load_annotated to False to change that" 984 ) 985 if self.task.dataset("train").annotation_class() != "none": 986 self._print_behaviors() 987 988 def train( 989 self, 990 trial: Trial = None, 991 optimized_metric: str = None, 992 autostop_metric: str = None, 993 autostop_interval: int = 10, 994 autostop_threshold: float = 0.001, 995 loading_bar: bool = False, 996 ) -> Tuple: 997 """ 998 Train the task and return a log of epoch-average loss and metric 999 1000 You can use the autostop parameters to finish training when the parameters are not improving. It will be 1001 stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than 1002 the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the 1003 current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared. 1004 1005 Parameters 1006 ---------- 1007 trial : Trial 1008 an `optuna` trial (for hyperparameter searches) 1009 optimized_metric : str 1010 the name of the metric being optimized (for hyperparameter searches) 1011 to_ram : bool, default False 1012 if `True`, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes 1013 if the dataset is too large) 1014 autostop_interval : int, default 50 1015 the number of epochs to average the autostop metric over 1016 autostop_threshold : float, default 0.001 1017 the autostop difference threshold 1018 autostop_metric : str, optional 1019 the autostop metric (can be any one of the tracked metrics of `'loss'`) 1020 main_task_on : bool, default True 1021 if `False`, the main task (action segmentation) will not be used in training 1022 ssl_on : bool, default True 1023 if `False`, the SSL task will not be used in training 1024 1025 Returns 1026 ------- 1027 loss_log: list 1028 a list of float loss function values for each epoch 1029 metrics_log: dict 1030 a dictionary of metric value logs (first-level keys are 'train' and 'val', second-level keys are metric 1031 names, values are lists of function values) 1032 """ 1033 1034 to_ram = self.training_parameters.get("to_ram", False) 1035 logs = self.task.train( 1036 trial, 1037 optimized_metric, 1038 to_ram, 1039 autostop_metric=autostop_metric, 1040 autostop_interval=autostop_interval, 1041 autostop_threshold=autostop_threshold, 1042 main_task_on=self.training_parameters.get("main_task_on", True), 1043 ssl_on=self.training_parameters.get("ssl_on", True), 1044 temporal_subsampling_size=self.training_parameters.get( 1045 "temporal_subsampling_size" 1046 ), 1047 loading_bar=loading_bar, 1048 ) 1049 return logs 1050 1051 def save_model(self, save_path: str) -> None: 1052 """ 1053 Save the model of the `dlc2action.task.universal_task.Task` instance 1054 1055 Parameters 1056 ---------- 1057 save_path : str 1058 the path to the saved file 1059 """ 1060 1061 self.task.save_model(save_path) 1062 1063 def evaluate( 1064 self, 1065 data: Union[DataLoader, BehaviorDataset, str] = None, 1066 augment_n: int = 0, 1067 verbose: bool = True, 1068 ) -> Tuple: 1069 """ 1070 Evaluate the Task model 1071 1072 Parameters 1073 ---------- 1074 data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional 1075 the data to evaluate on (if not provided, evaluate on the Task validation dataset) 1076 augment_n : int, default 0 1077 the number of augmentations to average results over 1078 verbose : bool, default True 1079 if True, the process is reported to standard output 1080 1081 Returns 1082 ------- 1083 loss : float 1084 the average value of the loss function 1085 ssl_loss : float 1086 the average value of the SSL loss function 1087 metric : dict 1088 a dictionary of average values of metric functions 1089 """ 1090 1091 res = self.task.evaluate( 1092 data, 1093 augment_n, 1094 int(self.training_parameters.get("batch_size", 32)), 1095 verbose, 1096 ) 1097 return res 1098 1099 def evaluate_prediction( 1100 self, 1101 prediction: torch.Tensor, 1102 data: Union[DataLoader, BehaviorDataset, str] = None, 1103 ) -> Tuple: 1104 """ 1105 Compute metrics for a prediction 1106 1107 Parameters 1108 ---------- 1109 prediction : torch.Tensor 1110 the prediction 1111 data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional 1112 the data the prediction was made for (if not provided, take the validation dataset) 1113 1114 Returns 1115 ------- 1116 loss : float 1117 the average value of the loss function 1118 metric : dict 1119 a dictionary of average values of metric functions 1120 """ 1121 1122 return self.task.evaluate_prediction( 1123 prediction, data, int(self.training_parameters.get("batch_size", 32)) 1124 ) 1125 1126 def predict( 1127 self, 1128 data: Union[DataLoader, BehaviorDataset, str], 1129 raw_output: bool = False, 1130 apply_primary_function: bool = True, 1131 augment_n: int = 0, 1132 embedding: bool = False, 1133 ) -> torch.Tensor: 1134 """ 1135 Make a prediction with the Task model 1136 1137 Parameters 1138 ---------- 1139 data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional 1140 the data to evaluate on (if not provided, evaluate on the Task validation dataset) 1141 raw_output : bool, default False 1142 if `True`, the raw predicted probabilities are returned 1143 apply_primary_function : bool, default True 1144 if `True`, the primary predict function is applied (to map the model output into a shape corresponding to 1145 the input) 1146 augment_n : int, default 0 1147 the number of augmentations to average results over 1148 1149 Returns 1150 ------- 1151 prediction : torch.Tensor 1152 a prediction for the input data 1153 """ 1154 1155 to_ram = self.training_parameters.get("to_ram", False) 1156 return self.task.predict( 1157 data, 1158 raw_output, 1159 apply_primary_function, 1160 augment_n, 1161 int(self.training_parameters.get("batch_size", 32)), 1162 to_ram, 1163 embedding=embedding, 1164 ) 1165 1166 def dataset(self, mode: str = "train") -> BehaviorDataset: 1167 """ 1168 Get a dataset 1169 1170 Parameters 1171 ---------- 1172 mode : {'train', 'val', 'test'} 1173 the dataset to get 1174 1175 Returns 1176 ------- 1177 dataset : dlc2action.data.dataset.BehaviorDataset 1178 the dataset 1179 """ 1180 1181 return self.task.dataset(mode) 1182 1183 def generate_full_length_prediction( 1184 self, 1185 dataset: Union[BehaviorDataset, str] = None, 1186 augment_n: int = 10, 1187 ) -> Dict: 1188 """ 1189 Compile a prediction for the original input sequences 1190 1191 Parameters 1192 ---------- 1193 dataset : dlc2action.data.dataset.BehaviorDataset | str, optional 1194 the dataset to generate a prediction for (if `None`, generate for the `dlc2action.task.universal_task.Task` 1195 instance validation dataset) 1196 augment_n : int, default 10 1197 the number of augmentations to average results over 1198 1199 Returns 1200 ------- 1201 prediction : dict 1202 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 1203 are prediction tensors 1204 """ 1205 1206 return self.task.generate_full_length_prediction( 1207 dataset, int(self.training_parameters.get("batch_size", 32)), augment_n 1208 ) 1209 1210 def generate_submission( 1211 self, 1212 frame_number_map_file: str, 1213 dataset: Union[BehaviorDataset, str] = None, 1214 augment_n: int = 10, 1215 ) -> Dict: 1216 """ 1217 Generate a MABe-22 style submission dictionary 1218 1219 Parameters 1220 ---------- 1221 frame_number_map_file : str 1222 path to the frame number map file 1223 dataset : BehaviorDataset, optional 1224 the dataset to generate a prediction for (if `None`, generate for the validation dataset) 1225 augment_n : int, default 10 1226 the number of augmentations to average results over 1227 1228 Returns 1229 ------- 1230 submission : dict 1231 a dictionary with frame number mapping and embeddings 1232 """ 1233 1234 return self.task.generate_submission( 1235 frame_number_map_file, 1236 dataset, 1237 int(self.training_parameters.get("batch_size", 32)), 1238 augment_n, 1239 ) 1240 1241 def behaviors_dict(self): 1242 """ 1243 Get a behavior dictionary 1244 1245 Keys are label indices and values are label names. 1246 1247 Returns 1248 ------- 1249 behaviors_dict : dict 1250 behavior dictionary 1251 """ 1252 1253 return self.task.behaviors_dict() 1254 1255 def count_classes(self, bouts: bool = False) -> Dict: 1256 """ 1257 Get a dictionary of class counts in different modes 1258 1259 Parameters 1260 ---------- 1261 bouts : bool, default False 1262 if `True`, instead of frame counts segment counts are returned 1263 1264 Returns 1265 ------- 1266 class_counts : dict 1267 a dictionary where first-level keys are "train", "val" and "test", second-level keys are 1268 class names and values are class counts (in frames) 1269 """ 1270 1271 return self.task.count_classes(bouts) 1272 1273 def _visualize_results_label( 1274 self, 1275 label: str, 1276 save_path: str = None, 1277 add_legend: bool = True, 1278 ground_truth: bool = True, 1279 hide_axes: bool = False, 1280 width: int = 10, 1281 whole_video: bool = False, 1282 transparent: bool = False, 1283 dataset: BehaviorDataset = None, 1284 smooth_interval: int = 0, 1285 title: str = None, 1286 ): 1287 return self.task._visualize_results_label( 1288 label, 1289 save_path, 1290 add_legend, 1291 ground_truth, 1292 hide_axes, 1293 width, 1294 whole_video, 1295 transparent, 1296 dataset, 1297 smooth_interval=smooth_interval, 1298 title=title, 1299 ) 1300 1301 def visualize_results( 1302 self, 1303 save_path: str = None, 1304 add_legend: bool = True, 1305 ground_truth: bool = True, 1306 colormap: str = "viridis", 1307 hide_axes: bool = False, 1308 min_classes: int = 1, 1309 width: float = 10, 1310 whole_video: bool = False, 1311 transparent: bool = False, 1312 dataset: Union[BehaviorDataset, DataLoader, str, None] = None, 1313 drop_classes: Set = None, 1314 search_classes: Set = None, 1315 smooth_interval_prediction: int = None, 1316 ) -> None: 1317 """ 1318 Visualize random predictions 1319 1320 Parameters 1321 ---------- 1322 save_path : str, optional 1323 the path where the plot will be saved 1324 add_legend : bool, default True 1325 if True, legend will be added to the plot 1326 ground_truth : bool, default True 1327 if True, ground truth will be added to the plot 1328 colormap : str, default 'Accent' 1329 the `matplotlib` colormap to use 1330 hide_axes : bool, default True 1331 if `True`, the axes will be hidden on the plot 1332 min_classes : int, default 1 1333 the minimum number of classes in a displayed interval 1334 width : float, default 10 1335 the width of the plot 1336 whole_video : bool, default False 1337 if `True`, whole videos are plotted instead of segments 1338 transparent : bool, default False 1339 if `True`, the background on the plot is transparent 1340 dataset : BehaviorDataset | DataLoader | str | None, optional 1341 the dataset to make the prediction for (if not provided, the validation dataset is used) 1342 drop_classes : set, optional 1343 a set of class names to not be displayed 1344 search_classes : set, optional 1345 if given, only intervals where at least one of the classes is in ground truth will be shown 1346 """ 1347 1348 return self.task.visualize_results( 1349 save_path, 1350 add_legend, 1351 ground_truth, 1352 colormap, 1353 hide_axes, 1354 min_classes, 1355 width, 1356 whole_video, 1357 transparent, 1358 dataset, 1359 drop_classes, 1360 search_classes, 1361 smooth_interval_prediction=smooth_interval_prediction, 1362 ) 1363 1364 def generate_uncertainty_score( 1365 self, 1366 classes: List, 1367 augment_n: int = 0, 1368 method: str = "least_confidence", 1369 predicted: torch.Tensor = None, 1370 behaviors_dict: Dict = None, 1371 ) -> Dict: 1372 """ 1373 Generate frame-wise scores for active learning 1374 1375 Parameters 1376 ---------- 1377 classes : list 1378 a list of class names or indices; their confidence scores will be computed separately and stacked 1379 augment_n : int, default 0 1380 the number of augmentations to average over 1381 method : {"least_confidence", "entropy"} 1382 the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if 1383 `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`) 1384 1385 Returns 1386 ------- 1387 score_dicts : dict 1388 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 1389 are score tensors 1390 """ 1391 1392 return self.task.generate_uncertainty_score( 1393 classes, 1394 augment_n, 1395 int(self.training_parameters.get("batch_size", 32)), 1396 method, 1397 predicted, 1398 behaviors_dict, 1399 ) 1400 1401 def generate_bald_score( 1402 self, 1403 classes: List, 1404 augment_n: int = 0, 1405 num_models: int = 10, 1406 kernel_size: int = 11, 1407 ) -> Dict: 1408 """ 1409 Generate frame-wise Bayesian Active Learning by Disagreement scores for active learning 1410 1411 Parameters 1412 ---------- 1413 classes : list 1414 a list of class names or indices; their confidence scores will be computed separately and stacked 1415 augment_n : int, default 0 1416 the number of augmentations to average over 1417 num_models : int, default 10 1418 the number of dropout masks to apply 1419 kernel_size : int, default 11 1420 the size of the smoothing gaussian kernel 1421 1422 Returns 1423 ------- 1424 score_dicts : dict 1425 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 1426 are score tensors 1427 """ 1428 1429 return self.task.generate_bald_score( 1430 classes, 1431 augment_n, 1432 int(self.training_parameters.get("batch_size", 32)), 1433 num_models, 1434 kernel_size, 1435 ) 1436 1437 def get_normalization_stats(self) -> Dict: 1438 """ 1439 Get the pre-computed normalization stats 1440 1441 Returns 1442 ------- 1443 normalization_stats : dict 1444 a dictionary of means and stds 1445 """ 1446 1447 return self.task.get_normalization_stats() 1448 1449 def exists(self, mode) -> bool: 1450 """ 1451 Check whether the task has a train/test/validation subset 1452 1453 Parameters 1454 ---------- 1455 mode : {"train", "val", "test"} 1456 the name of the subset to check for 1457 1458 Returns 1459 ------- 1460 exists : bool 1461 `True` if the subset exists 1462 """ 1463 1464 dl = self.task.dataloader(mode) 1465 if dl is None: 1466 return False 1467 else: 1468 return True
A class that manages the interactions between config dictionaries and a Task
39 def __init__(self, parameters: Dict) -> None: 40 """ 41 Parameters 42 ---------- 43 parameters : dict 44 a dictionary of task parameters 45 """ 46 47 pars = deepcopy(parameters) 48 self.class_weights = None 49 self.general_parameters = pars.get("general", {}) 50 self.data_parameters = pars.get("data", {}) 51 self.model_parameters = pars.get("model", {}) 52 self.training_parameters = pars.get("training", {}) 53 self.loss_parameters = pars.get("losses", {}) 54 self.metric_parameters = pars.get("metrics", {}) 55 self.ssl_parameters = pars.get("ssl", {}) 56 self.aug_parameters = pars.get("augmentations", {}) 57 self.feature_parameters = pars.get("features", {}) 58 self.blanks = {blank: [] for blank in options.blanks} 59 60 self.task = None 61 self._initialize_task() 62 self._print_behaviors()
Parameters
parameters : dict a dictionary of task parameters
64 @staticmethod 65 def complete_function_parameters(parameters, function, general_dicts: List) -> Dict: 66 """ 67 Complete a parameter dictionary with values from other dictionaries if required by a function 68 69 Parameters 70 ---------- 71 parameters : dict 72 the function parameters dictionary 73 function : callable 74 the function to be inspected 75 general_dicts : list 76 a list of dictionaries where the missing values will be pulled from 77 """ 78 79 parameter_names = inspect.getfullargspec(function).args 80 for param in parameter_names: 81 for dic in general_dicts: 82 if param not in parameters and param in dic: 83 parameters[param] = dic[param] 84 return parameters
Complete a parameter dictionary with values from other dictionaries if required by a function
Parameters
parameters : dict the function parameters dictionary function : callable the function to be inspected general_dicts : list a list of dictionaries where the missing values will be pulled from
86 @staticmethod 87 def complete_dataset_parameters( 88 parameters: dict, 89 general_dict: dict, 90 data_type: str, 91 annotation_type: str, 92 ) -> Dict: 93 """ 94 Complete a parameter dictionary with values from other dictionaries if required by a dataset 95 96 Parameters 97 ---------- 98 parameters : dict 99 the function parameters dictionary 100 general_dict : dict 101 the dictionary where the missing values will be pulled from 102 data_type : str 103 the input type of the dataset 104 annotation_type : str 105 the annotation type of the dataset 106 107 Returns 108 ------- 109 parameters : dict 110 the updated parameter dictionary 111 """ 112 113 params = deepcopy(parameters) 114 parameter_names = BehaviorDataset.get_parameters(data_type, annotation_type) 115 for param in parameter_names: 116 if param not in params and param in general_dict: 117 params[param] = general_dict[param] 118 return params
Complete a parameter dictionary with values from other dictionaries if required by a dataset
Parameters
parameters : dict the function parameters dictionary general_dict : dict the dictionary where the missing values will be pulled from data_type : str the input type of the dataset annotation_type : str the annotation type of the dataset
Returns
parameters : dict the updated parameter dictionary
120 @staticmethod 121 def check(parameters: Dict, name: str) -> bool: 122 """ 123 Check whether there is a non-`None` value under the name key in the parameters dictionary 124 125 Parameters 126 ---------- 127 parameters : dict 128 the dictionary to check 129 name : str 130 the key to check 131 132 Returns 133 ------- 134 result : bool 135 True if a non-`None` value exists 136 """ 137 138 if name in parameters and parameters[name] is not None: 139 return True 140 else: 141 return False
Check whether there is a non-None
value under the name key in the parameters dictionary
Parameters
parameters : dict the dictionary to check name : str the key to check
Returns
result : bool
True if a non-None
value exists
143 @staticmethod 144 def get(parameters: Dict, name: str, default): 145 """ 146 Get the value under the name key or the default if it is `None` or does not exist 147 148 Parameters 149 ---------- 150 parameters : dict 151 the dictionary to check 152 name : str 153 the key to check 154 default 155 the default value to return 156 157 Returns 158 ------- 159 value 160 the resulting value 161 """ 162 163 if TaskDispatcher.check(parameters, name): 164 return parameters[name] 165 else: 166 return default
Get the value under the name key or the default if it is None
or does not exist
Parameters
parameters : dict the dictionary to check name : str the key to check default the default value to return
Returns
value the resulting value
168 @staticmethod 169 def make_dataloader( 170 dataset: BehaviorDataset, batch_size: int = 32, shuffle: bool = False 171 ) -> DataLoader: 172 """ 173 Make a torch dataloader from a dataset 174 175 Parameters 176 ---------- 177 dataset : dlc2action.data.dataset.BehaviorDataset 178 the dataset 179 batch_size : int 180 the batch size 181 182 Returns 183 ------- 184 dataloader : DataLoader 185 the dataloader (or `None` if the length of the dataset is 0) 186 """ 187 188 if dataset is None or len(dataset) == 0: 189 return None 190 else: 191 return DataLoader(dataset, batch_size=int(batch_size), shuffle=shuffle)
Make a torch dataloader from a dataset
Parameters
dataset : dlc2action.data.dataset.BehaviorDataset the dataset batch_size : int the batch size
Returns
dataloader : DataLoader
the dataloader (or None
if the length of the dataset is 0)
809 def update_task(self, parameters: Dict) -> None: 810 """ 811 Update the `dlc2action.task.universal_task.Task` instance given the parameter updates 812 813 Parameters 814 ---------- 815 parameters : dict 816 the dictionary of parameter updates 817 """ 818 819 pars = deepcopy(parameters) 820 # for blank_name in self.blanks: 821 # for names in self.blanks[blank_name]: 822 # group = names[0] 823 # key = names[1] 824 # if len(names) == 3: 825 # if ( 826 # group in pars 827 # and key in pars[group] 828 # and names[2] in pars[group][key] 829 # ): 830 # pars[group][key].pop(names[2]) 831 # else: 832 # if group in pars and key in pars[group]: 833 # pars[group].pop(key) 834 stay = False 835 if "ssl" in pars: 836 for key in pars["ssl"]: 837 if key in self.ssl_parameters: 838 self.ssl_parameters[key].update(pars["ssl"][key]) 839 else: 840 self.ssl_parameters[key] = pars["ssl"][key] 841 842 if "general" in pars: 843 if stay: 844 stay = False 845 if ( 846 "model_name" in pars["general"] 847 and pars["general"]["model_name"] 848 != self.general_parameters["model_name"] 849 ): 850 if "model" not in pars: 851 raise ValueError( 852 "When updating a task with a new model name you need to pass the parameters for the " 853 "new model" 854 ) 855 self.model_parameters = {} 856 self.general_parameters.update(pars["general"]) 857 data_related = [ 858 "num_classes", 859 "exclusive", 860 "data_type", 861 "annotation_type", 862 ] 863 ssl_related = ["ssl", "exclusive", "num_classes"] 864 loss_related = ["num_classes", "loss_function", "exclusive"] 865 augmentation_related = ["augmentation_type"] 866 metric_related = ["metric_functions"] 867 related_lists = [ 868 data_related, 869 ssl_related, 870 loss_related, 871 augmentation_related, 872 metric_related, 873 ] 874 names = ["data", "ssl", "losses", "augmentations", "metrics"] 875 for related_list, name in zip(related_lists, names): 876 if ( 877 any([x in pars["general"] for x in related_list]) 878 and name not in pars 879 ): 880 pars[name] = {} 881 882 if "training" in pars: 883 if "data" not in pars or not stay: 884 for x in [ 885 "to_ram", 886 "use_test", 887 "partition_method", 888 "val_frac", 889 "test_frac", 890 "save_split", 891 "batch_size", 892 "save_split", 893 ]: 894 if ( 895 x in pars["training"] 896 and pars["training"][x] != self.training_parameters[x] 897 ): 898 if "data" not in pars: 899 pars["data"] = {} 900 stay = True 901 self.training_parameters.update(pars["training"]) 902 self.task.update_parameters(self._get_parameters_from_training()) 903 904 if "data" in pars or "features" in pars: 905 for k, v in pars["data"].items(): 906 if k not in self.data_parameters or v != self.data_parameters[k]: 907 stay = True 908 for k, v in pars["features"].items(): 909 if k not in self.feature_parameters or v != self.feature_parameters[k]: 910 stay = True 911 if stay: 912 self.data_parameters.update(pars["data"]) 913 self.feature_parameters.update(pars["features"]) 914 dataset = self._construct_dataset() 915 ( 916 train_dataloader, 917 test_dataloader, 918 val_dataloader, 919 ) = self._partition_dataset(dataset) 920 self.task.set_dataloaders( 921 train_dataloader, val_dataloader, test_dataloader 922 ) 923 self.class_weights = train_dataloader.dataset.class_weights() 924 self.proportional_class_weights = ( 925 train_dataloader.dataset.class_weights(True) 926 ) 927 if "losses" not in pars: 928 pars["losses"] = {} 929 930 if "model" in pars: 931 self.model_parameters.update(pars["model"]) 932 933 self._update_data_blanks() 934 935 if "augmentations" in pars: 936 self.aug_parameters.update(pars["augmentations"]) 937 transformer = self._construct_transformer() 938 self.task.set_transformer(transformer) 939 940 if "losses" in pars: 941 for key in pars["losses"]: 942 if key in self.loss_parameters: 943 self.loss_parameters[key].update(pars["losses"][key]) 944 else: 945 self.loss_parameters[key] = pars["losses"][key] 946 self.loss_parameters.update(pars["losses"]) 947 loss = self._construct_loss() 948 self.task.set_loss(loss) 949 950 if "metrics" in pars: 951 for key in pars["metrics"]: 952 if key in self.metric_parameters: 953 self.metric_parameters[key].update(pars["metrics"][key]) 954 else: 955 self.metric_parameters[key] = pars["metrics"][key] 956 metrics = self._construct_metrics() 957 self.task.set_metrics(metrics) 958 959 self.task.set_ssl_transformations(self.data_parameters["ssl_transformations"]) 960 self._set_loss_weights( 961 pars.get("losses", {}).get(self.general_parameters["loss_function"], {}) 962 ) 963 model = self._construct_model() 964 predict_functions = self._construct_predict_functions() 965 self.task.set_predict_functions(*predict_functions) 966 self._update_model_blanks(model) 967 ssl_list = self._construct_ssl() 968 self._update_parameters_from_ssl(ssl_list) 969 model.set_ssl(ssl_constructors=ssl_list) 970 self.task.set_ssl_transformations([ssl.transformation for ssl in ssl_list]) 971 self.task.set_model(model) 972 if "training" in pars and "checkpoint_path" in pars["training"]: 973 checkpoint_path = pars["training"]["checkpoint_path"] 974 only_model = pars["training"].get("only_load_model", False) 975 load_strict = pars["training"].get("load_strict", True) 976 self.task.load_from_checkpoint(checkpoint_path, only_model, load_strict) 977 if ( 978 self.general_parameters["only_load_annotated"] 979 and self.general_parameters.get("ssl") is not None 980 ): 981 warnings.warn( 982 "Note that you are using SSL modules but only loading annotated files! Set " 983 "general/only_load_annotated to False to change that" 984 ) 985 if self.task.dataset("train").annotation_class() != "none": 986 self._print_behaviors()
Update the dlc2action.task.universal_task.Task
instance given the parameter updates
Parameters
parameters : dict the dictionary of parameter updates
988 def train( 989 self, 990 trial: Trial = None, 991 optimized_metric: str = None, 992 autostop_metric: str = None, 993 autostop_interval: int = 10, 994 autostop_threshold: float = 0.001, 995 loading_bar: bool = False, 996 ) -> Tuple: 997 """ 998 Train the task and return a log of epoch-average loss and metric 999 1000 You can use the autostop parameters to finish training when the parameters are not improving. It will be 1001 stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than 1002 the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the 1003 current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared. 1004 1005 Parameters 1006 ---------- 1007 trial : Trial 1008 an `optuna` trial (for hyperparameter searches) 1009 optimized_metric : str 1010 the name of the metric being optimized (for hyperparameter searches) 1011 to_ram : bool, default False 1012 if `True`, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes 1013 if the dataset is too large) 1014 autostop_interval : int, default 50 1015 the number of epochs to average the autostop metric over 1016 autostop_threshold : float, default 0.001 1017 the autostop difference threshold 1018 autostop_metric : str, optional 1019 the autostop metric (can be any one of the tracked metrics of `'loss'`) 1020 main_task_on : bool, default True 1021 if `False`, the main task (action segmentation) will not be used in training 1022 ssl_on : bool, default True 1023 if `False`, the SSL task will not be used in training 1024 1025 Returns 1026 ------- 1027 loss_log: list 1028 a list of float loss function values for each epoch 1029 metrics_log: dict 1030 a dictionary of metric value logs (first-level keys are 'train' and 'val', second-level keys are metric 1031 names, values are lists of function values) 1032 """ 1033 1034 to_ram = self.training_parameters.get("to_ram", False) 1035 logs = self.task.train( 1036 trial, 1037 optimized_metric, 1038 to_ram, 1039 autostop_metric=autostop_metric, 1040 autostop_interval=autostop_interval, 1041 autostop_threshold=autostop_threshold, 1042 main_task_on=self.training_parameters.get("main_task_on", True), 1043 ssl_on=self.training_parameters.get("ssl_on", True), 1044 temporal_subsampling_size=self.training_parameters.get( 1045 "temporal_subsampling_size" 1046 ), 1047 loading_bar=loading_bar, 1048 ) 1049 return logs
Train the task and return a log of epoch-average loss and metric
You can use the autostop parameters to finish training when the parameters are not improving. It will be
stopped if the average value of autostop_metric
over the last autostop_interval
epochs is smaller than
the average over the previous autostop_interval
epochs + autostop_threshold
. For example, if the
current epoch is 120 and autostop_interval
is 50, the averages over epochs 70-120 and 20-70 will be compared.
Parameters
trial : Trial
an optuna
trial (for hyperparameter searches)
optimized_metric : str
the name of the metric being optimized (for hyperparameter searches)
to_ram : bool, default False
if True
, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes
if the dataset is too large)
autostop_interval : int, default 50
the number of epochs to average the autostop metric over
autostop_threshold : float, default 0.001
the autostop difference threshold
autostop_metric : str, optional
the autostop metric (can be any one of the tracked metrics of 'loss'
)
main_task_on : bool, default True
if False
, the main task (action segmentation) will not be used in training
ssl_on : bool, default True
if False
, the SSL task will not be used in training
Returns
loss_log: list a list of float loss function values for each epoch metrics_log: dict a dictionary of metric value logs (first-level keys are 'train' and 'val', second-level keys are metric names, values are lists of function values)
1051 def save_model(self, save_path: str) -> None: 1052 """ 1053 Save the model of the `dlc2action.task.universal_task.Task` instance 1054 1055 Parameters 1056 ---------- 1057 save_path : str 1058 the path to the saved file 1059 """ 1060 1061 self.task.save_model(save_path)
Save the model of the dlc2action.task.universal_task.Task
instance
Parameters
save_path : str the path to the saved file
1063 def evaluate( 1064 self, 1065 data: Union[DataLoader, BehaviorDataset, str] = None, 1066 augment_n: int = 0, 1067 verbose: bool = True, 1068 ) -> Tuple: 1069 """ 1070 Evaluate the Task model 1071 1072 Parameters 1073 ---------- 1074 data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional 1075 the data to evaluate on (if not provided, evaluate on the Task validation dataset) 1076 augment_n : int, default 0 1077 the number of augmentations to average results over 1078 verbose : bool, default True 1079 if True, the process is reported to standard output 1080 1081 Returns 1082 ------- 1083 loss : float 1084 the average value of the loss function 1085 ssl_loss : float 1086 the average value of the SSL loss function 1087 metric : dict 1088 a dictionary of average values of metric functions 1089 """ 1090 1091 res = self.task.evaluate( 1092 data, 1093 augment_n, 1094 int(self.training_parameters.get("batch_size", 32)), 1095 verbose, 1096 ) 1097 return res
Evaluate the Task model
Parameters
data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional the data to evaluate on (if not provided, evaluate on the Task validation dataset) augment_n : int, default 0 the number of augmentations to average results over verbose : bool, default True if True, the process is reported to standard output
Returns
loss : float the average value of the loss function ssl_loss : float the average value of the SSL loss function metric : dict a dictionary of average values of metric functions
1099 def evaluate_prediction( 1100 self, 1101 prediction: torch.Tensor, 1102 data: Union[DataLoader, BehaviorDataset, str] = None, 1103 ) -> Tuple: 1104 """ 1105 Compute metrics for a prediction 1106 1107 Parameters 1108 ---------- 1109 prediction : torch.Tensor 1110 the prediction 1111 data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional 1112 the data the prediction was made for (if not provided, take the validation dataset) 1113 1114 Returns 1115 ------- 1116 loss : float 1117 the average value of the loss function 1118 metric : dict 1119 a dictionary of average values of metric functions 1120 """ 1121 1122 return self.task.evaluate_prediction( 1123 prediction, data, int(self.training_parameters.get("batch_size", 32)) 1124 )
Compute metrics for a prediction
Parameters
prediction : torch.Tensor the prediction data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional the data the prediction was made for (if not provided, take the validation dataset)
Returns
loss : float the average value of the loss function metric : dict a dictionary of average values of metric functions
1126 def predict( 1127 self, 1128 data: Union[DataLoader, BehaviorDataset, str], 1129 raw_output: bool = False, 1130 apply_primary_function: bool = True, 1131 augment_n: int = 0, 1132 embedding: bool = False, 1133 ) -> torch.Tensor: 1134 """ 1135 Make a prediction with the Task model 1136 1137 Parameters 1138 ---------- 1139 data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional 1140 the data to evaluate on (if not provided, evaluate on the Task validation dataset) 1141 raw_output : bool, default False 1142 if `True`, the raw predicted probabilities are returned 1143 apply_primary_function : bool, default True 1144 if `True`, the primary predict function is applied (to map the model output into a shape corresponding to 1145 the input) 1146 augment_n : int, default 0 1147 the number of augmentations to average results over 1148 1149 Returns 1150 ------- 1151 prediction : torch.Tensor 1152 a prediction for the input data 1153 """ 1154 1155 to_ram = self.training_parameters.get("to_ram", False) 1156 return self.task.predict( 1157 data, 1158 raw_output, 1159 apply_primary_function, 1160 augment_n, 1161 int(self.training_parameters.get("batch_size", 32)), 1162 to_ram, 1163 embedding=embedding, 1164 )
Make a prediction with the Task model
Parameters
data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional
the data to evaluate on (if not provided, evaluate on the Task validation dataset)
raw_output : bool, default False
if True
, the raw predicted probabilities are returned
apply_primary_function : bool, default True
if True
, the primary predict function is applied (to map the model output into a shape corresponding to
the input)
augment_n : int, default 0
the number of augmentations to average results over
Returns
prediction : torch.Tensor a prediction for the input data
1166 def dataset(self, mode: str = "train") -> BehaviorDataset: 1167 """ 1168 Get a dataset 1169 1170 Parameters 1171 ---------- 1172 mode : {'train', 'val', 'test'} 1173 the dataset to get 1174 1175 Returns 1176 ------- 1177 dataset : dlc2action.data.dataset.BehaviorDataset 1178 the dataset 1179 """ 1180 1181 return self.task.dataset(mode)
Get a dataset
Parameters
mode : {'train', 'val', 'test'} the dataset to get
Returns
dataset : dlc2action.data.dataset.BehaviorDataset the dataset
1183 def generate_full_length_prediction( 1184 self, 1185 dataset: Union[BehaviorDataset, str] = None, 1186 augment_n: int = 10, 1187 ) -> Dict: 1188 """ 1189 Compile a prediction for the original input sequences 1190 1191 Parameters 1192 ---------- 1193 dataset : dlc2action.data.dataset.BehaviorDataset | str, optional 1194 the dataset to generate a prediction for (if `None`, generate for the `dlc2action.task.universal_task.Task` 1195 instance validation dataset) 1196 augment_n : int, default 10 1197 the number of augmentations to average results over 1198 1199 Returns 1200 ------- 1201 prediction : dict 1202 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 1203 are prediction tensors 1204 """ 1205 1206 return self.task.generate_full_length_prediction( 1207 dataset, int(self.training_parameters.get("batch_size", 32)), augment_n 1208 )
Compile a prediction for the original input sequences
Parameters
dataset : dlc2action.data.dataset.BehaviorDataset | str, optional
the dataset to generate a prediction for (if None
, generate for the dlc2action.task.universal_task.Task
instance validation dataset)
augment_n : int, default 10
the number of augmentations to average results over
Returns
prediction : dict a nested dictionary where first level keys are video ids, second level keys are clip ids and values are prediction tensors
1210 def generate_submission( 1211 self, 1212 frame_number_map_file: str, 1213 dataset: Union[BehaviorDataset, str] = None, 1214 augment_n: int = 10, 1215 ) -> Dict: 1216 """ 1217 Generate a MABe-22 style submission dictionary 1218 1219 Parameters 1220 ---------- 1221 frame_number_map_file : str 1222 path to the frame number map file 1223 dataset : BehaviorDataset, optional 1224 the dataset to generate a prediction for (if `None`, generate for the validation dataset) 1225 augment_n : int, default 10 1226 the number of augmentations to average results over 1227 1228 Returns 1229 ------- 1230 submission : dict 1231 a dictionary with frame number mapping and embeddings 1232 """ 1233 1234 return self.task.generate_submission( 1235 frame_number_map_file, 1236 dataset, 1237 int(self.training_parameters.get("batch_size", 32)), 1238 augment_n, 1239 )
Generate a MABe-22 style submission dictionary
Parameters
frame_number_map_file : str
path to the frame number map file
dataset : BehaviorDataset, optional
the dataset to generate a prediction for (if None
, generate for the validation dataset)
augment_n : int, default 10
the number of augmentations to average results over
Returns
submission : dict a dictionary with frame number mapping and embeddings
1241 def behaviors_dict(self): 1242 """ 1243 Get a behavior dictionary 1244 1245 Keys are label indices and values are label names. 1246 1247 Returns 1248 ------- 1249 behaviors_dict : dict 1250 behavior dictionary 1251 """ 1252 1253 return self.task.behaviors_dict()
Get a behavior dictionary
Keys are label indices and values are label names.
Returns
behaviors_dict : dict behavior dictionary
1255 def count_classes(self, bouts: bool = False) -> Dict: 1256 """ 1257 Get a dictionary of class counts in different modes 1258 1259 Parameters 1260 ---------- 1261 bouts : bool, default False 1262 if `True`, instead of frame counts segment counts are returned 1263 1264 Returns 1265 ------- 1266 class_counts : dict 1267 a dictionary where first-level keys are "train", "val" and "test", second-level keys are 1268 class names and values are class counts (in frames) 1269 """ 1270 1271 return self.task.count_classes(bouts)
Get a dictionary of class counts in different modes
Parameters
bouts : bool, default False
if True
, instead of frame counts segment counts are returned
Returns
class_counts : dict a dictionary where first-level keys are "train", "val" and "test", second-level keys are class names and values are class counts (in frames)
1301 def visualize_results( 1302 self, 1303 save_path: str = None, 1304 add_legend: bool = True, 1305 ground_truth: bool = True, 1306 colormap: str = "viridis", 1307 hide_axes: bool = False, 1308 min_classes: int = 1, 1309 width: float = 10, 1310 whole_video: bool = False, 1311 transparent: bool = False, 1312 dataset: Union[BehaviorDataset, DataLoader, str, None] = None, 1313 drop_classes: Set = None, 1314 search_classes: Set = None, 1315 smooth_interval_prediction: int = None, 1316 ) -> None: 1317 """ 1318 Visualize random predictions 1319 1320 Parameters 1321 ---------- 1322 save_path : str, optional 1323 the path where the plot will be saved 1324 add_legend : bool, default True 1325 if True, legend will be added to the plot 1326 ground_truth : bool, default True 1327 if True, ground truth will be added to the plot 1328 colormap : str, default 'Accent' 1329 the `matplotlib` colormap to use 1330 hide_axes : bool, default True 1331 if `True`, the axes will be hidden on the plot 1332 min_classes : int, default 1 1333 the minimum number of classes in a displayed interval 1334 width : float, default 10 1335 the width of the plot 1336 whole_video : bool, default False 1337 if `True`, whole videos are plotted instead of segments 1338 transparent : bool, default False 1339 if `True`, the background on the plot is transparent 1340 dataset : BehaviorDataset | DataLoader | str | None, optional 1341 the dataset to make the prediction for (if not provided, the validation dataset is used) 1342 drop_classes : set, optional 1343 a set of class names to not be displayed 1344 search_classes : set, optional 1345 if given, only intervals where at least one of the classes is in ground truth will be shown 1346 """ 1347 1348 return self.task.visualize_results( 1349 save_path, 1350 add_legend, 1351 ground_truth, 1352 colormap, 1353 hide_axes, 1354 min_classes, 1355 width, 1356 whole_video, 1357 transparent, 1358 dataset, 1359 drop_classes, 1360 search_classes, 1361 smooth_interval_prediction=smooth_interval_prediction, 1362 )
Visualize random predictions
Parameters
save_path : str, optional
the path where the plot will be saved
add_legend : bool, default True
if True, legend will be added to the plot
ground_truth : bool, default True
if True, ground truth will be added to the plot
colormap : str, default 'Accent'
the matplotlib
colormap to use
hide_axes : bool, default True
if True
, the axes will be hidden on the plot
min_classes : int, default 1
the minimum number of classes in a displayed interval
width : float, default 10
the width of the plot
whole_video : bool, default False
if True
, whole videos are plotted instead of segments
transparent : bool, default False
if True
, the background on the plot is transparent
dataset : BehaviorDataset | DataLoader | str | None, optional
the dataset to make the prediction for (if not provided, the validation dataset is used)
drop_classes : set, optional
a set of class names to not be displayed
search_classes : set, optional
if given, only intervals where at least one of the classes is in ground truth will be shown
1364 def generate_uncertainty_score( 1365 self, 1366 classes: List, 1367 augment_n: int = 0, 1368 method: str = "least_confidence", 1369 predicted: torch.Tensor = None, 1370 behaviors_dict: Dict = None, 1371 ) -> Dict: 1372 """ 1373 Generate frame-wise scores for active learning 1374 1375 Parameters 1376 ---------- 1377 classes : list 1378 a list of class names or indices; their confidence scores will be computed separately and stacked 1379 augment_n : int, default 0 1380 the number of augmentations to average over 1381 method : {"least_confidence", "entropy"} 1382 the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if 1383 `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`) 1384 1385 Returns 1386 ------- 1387 score_dicts : dict 1388 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 1389 are score tensors 1390 """ 1391 1392 return self.task.generate_uncertainty_score( 1393 classes, 1394 augment_n, 1395 int(self.training_parameters.get("batch_size", 32)), 1396 method, 1397 predicted, 1398 behaviors_dict, 1399 )
Generate frame-wise scores for active learning
Parameters
classes : list
a list of class names or indices; their confidence scores will be computed separately and stacked
augment_n : int, default 0
the number of augmentations to average over
method : {"least_confidence", "entropy"}
the method used to calculate the scores from the probability predictions ("least_confidence"
: 1 - p_i
if
p_i > 0.5
or p_i
if p_i < 0.5
; "entropy"
: - p_i * log(p_i) - (1 - p_i) * log(1 - p_i)
)
Returns
score_dicts : dict a nested dictionary where first level keys are video ids, second level keys are clip ids and values are score tensors
1401 def generate_bald_score( 1402 self, 1403 classes: List, 1404 augment_n: int = 0, 1405 num_models: int = 10, 1406 kernel_size: int = 11, 1407 ) -> Dict: 1408 """ 1409 Generate frame-wise Bayesian Active Learning by Disagreement scores for active learning 1410 1411 Parameters 1412 ---------- 1413 classes : list 1414 a list of class names or indices; their confidence scores will be computed separately and stacked 1415 augment_n : int, default 0 1416 the number of augmentations to average over 1417 num_models : int, default 10 1418 the number of dropout masks to apply 1419 kernel_size : int, default 11 1420 the size of the smoothing gaussian kernel 1421 1422 Returns 1423 ------- 1424 score_dicts : dict 1425 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 1426 are score tensors 1427 """ 1428 1429 return self.task.generate_bald_score( 1430 classes, 1431 augment_n, 1432 int(self.training_parameters.get("batch_size", 32)), 1433 num_models, 1434 kernel_size, 1435 )
Generate frame-wise Bayesian Active Learning by Disagreement scores for active learning
Parameters
classes : list a list of class names or indices; their confidence scores will be computed separately and stacked augment_n : int, default 0 the number of augmentations to average over num_models : int, default 10 the number of dropout masks to apply kernel_size : int, default 11 the size of the smoothing gaussian kernel
Returns
score_dicts : dict a nested dictionary where first level keys are video ids, second level keys are clip ids and values are score tensors
1437 def get_normalization_stats(self) -> Dict: 1438 """ 1439 Get the pre-computed normalization stats 1440 1441 Returns 1442 ------- 1443 normalization_stats : dict 1444 a dictionary of means and stds 1445 """ 1446 1447 return self.task.get_normalization_stats()
Get the pre-computed normalization stats
Returns
normalization_stats : dict a dictionary of means and stds
1449 def exists(self, mode) -> bool: 1450 """ 1451 Check whether the task has a train/test/validation subset 1452 1453 Parameters 1454 ---------- 1455 mode : {"train", "val", "test"} 1456 the name of the subset to check for 1457 1458 Returns 1459 ------- 1460 exists : bool 1461 `True` if the subset exists 1462 """ 1463 1464 dl = self.task.dataloader(mode) 1465 if dl is None: 1466 return False 1467 else: 1468 return True
Check whether the task has a train/test/validation subset
Parameters
mode : {"train", "val", "test"} the name of the subset to check for
Returns
exists : bool
True
if the subset exists