dlc2action.options
Here all option dictionaries are stored
1# 2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL. 5# 6""" 7Here all option dictionaries are stored 8""" 9 10from torch.optim import Adam, SGD 11from dlc2action.data.input_store import * 12from dlc2action.data.annotation_store import * 13from dlc2action.feature_extraction import * 14from dlc2action.transformer import * 15from dlc2action.loss import MS_TCN_Loss 16from dlc2action.model.mlp import MLP 17from dlc2action.model.c2f_tcn import C2F_TCN 18from dlc2action.model.asformer import ASFormer 19from dlc2action.model.transformer import Transformer 20from dlc2action.model.edtcn import EDTCN 21from dlc2action.ssl.contrastive import * 22from dlc2action.ssl.masked import * 23from dlc2action.ssl.segment_order import * 24from dlc2action.ssl.tcc import TCCSSL 25from dlc2action.metric.metrics import * 26 27 28input_stores = { 29 "dlc_tracklet": DLCTrackletStore, 30 "dlc_track": DLCTrackStore, 31 "pku-mmd": PKUMMDInputStore, 32 "calms21": CalMS21InputStore, 33 "np_3d": Numpy3DInputStore, 34 "features": LoadedFeaturesInputStore, 35 "simba": SIMBAInputStore, 36} 37 38annotation_stores = { 39 "dlc": DLCAnnotationStore, 40 "pku-mmd": PKUMMDAnnotationStore, 41 "boris": BorisAnnotationStore, 42 "none": EmptyAnnotationStore, 43 "calms21": CalMS21AnnotationStore, 44 "csv": CSVAnnotationStore, 45 "simba": SIMBAAnnotationStore, 46} 47 48feature_extractors = {"kinematic": KinematicExtractor} 49 50ssl_constructors = { 51 "masked_features": MaskedFeaturesSSL_FC, 52 "masked_joints": MaskedKinematicSSL_FC, 53 "masked_frames": MaskedFramesSSL_FC, 54 "contrastive": ContrastiveSSL, 55 "pairwise": PairwiseSSL, 56 "contrastive_masked": ContrastiveMaskedSSL, 57 "pairwise_masked": PairwiseMaskedSSL, 58 "reverse": ReverseSSL, 59 "order": OrderSSL, 60 "contrastive_regression": ContrastiveRegressionSSL, 61 "tcc": TCCSSL, 62} 63 64transformers = {"kinematic": KinematicTransformer} 65 66losses = { 67 "ms_tcn": MS_TCN_Loss, 68} 69losses_multistage = [ 70 "ms_tcn", 71] # losses that expect predictions of shape (#stages, #batch, #classes, #frames) 72 73metrics = { 74 "accuracy": Accuracy, 75 "precision": Precision, 76 "f1": F1, 77 "recall": Recall, 78 "count": Count, 79 # "mAP": mAP, 80 "segmental_precision": SegmentalPrecision, 81 "segmental_recall": SegmentalRecall, 82 "segmental_f1": SegmentalF1, 83 "edit_distance": EditDistance, 84 "f_beta": Fbeta, 85 "segmental_f_beta": SegmentalFbeta, 86 "semisegmental_precision": SemiSegmentalPrecision, 87 "semisegmental_recall": SemiSegmentalRecall, 88 "semisegmental_f1": SemiSegmentalF1, 89 "pr-auc": PR_AUC, 90 "semisegmental_pr-auc": SemiSegmentalPR_AUC, 91 "mAP": PKU_mAP, 92} 93metrics_minimize = [ 94 "edit_distance" 95] # metrics that decrease when prediction quality increases 96metrics_no_direction = ["count"] # metrics that do not indicate prediction quality 97 98optimizers = {"Adam": Adam, "SGD": SGD} 99 100models = { 101 "asformer": ASFormer, 102 "mlp": MLP, 103 "c2f_tcn": C2F_TCN, 104 "edtcn": EDTCN, 105 "transformer": Transformer, 106} 107 108blanks = [ 109 "dataset_inverse_weights", 110 "dataset_proportional_weights", 111 "dataset_classes", 112 "dataset_features", 113 "dataset_len_segment", 114 "dataset_bodyparts", 115 "dataset_boundary_weight", 116 "model_features", 117] 118 119extractor_to_transformer = { 120 "kinematic": "kinematic", 121 "heatmap": "heatmap", 122} # keys are feature extractor names, values are transformer names 123 124partition_methods = { 125 "random": [ 126 "random", 127 "random:test-from-name", 128 "random:test-from-name:{name}", 129 "random:equalize:segments", 130 "random:equalize:videos", 131 ], 132 "fixed": [ 133 "val-from-name:{val_name}:test-from-name:{test_name}", 134 "time", 135 "time:start-from:{frac}", 136 "time:start-from:{frac}:strict", 137 "time:strict", 138 "file", 139 "folders", 140 ], 141} 142 143basic_parameters = { 144 "data": [ 145 "data_suffix", 146 "feature_suffix", 147 "annotation_suffix", 148 "canvas_shape", 149 "ignored_bodyparts", 150 "likelihood_threshold", 151 "behaviors", 152 "filter_annotated", 153 "filter_background", 154 "visibility_min_score", 155 "visibility_min_frac", 156 ], 157 "augmentations": { 158 "heatmap": ["augmentations", "rotation_degree_limits"], 159 "kinematic": [ 160 "augmentations", 161 "rotation_limits", 162 "mirror_dim", 163 "noise_std", 164 "zoom_limits", 165 "masking_probability", 166 ], 167 }, 168 "features": { 169 "heatmap": ["keys", "channel_policy", "heatmap_width", "sigma"], 170 "kinematic": [ 171 "keys", 172 "averaging_window", 173 "distance_pairs", 174 "angle_pairs", 175 "zone_vertices", 176 "zone_bools", 177 "zone_distances", 178 "area_vertices", 179 ], 180 }, 181 "model": { 182 "asformer": [ 183 "num_decoders", 184 "num_layers", 185 "r1", 186 "r2", 187 "num_f_maps", 188 "channel_masking_rate", 189 ], 190 "c2f_tcn": ["num_f_maps", "feature_dim"], 191 "edtcn": ["kernel_size", "mid_channels"], 192 "mlp": ["f_maps_list", "dropout_rates"], 193 "transformer": ["num_f_maps", "N", "heads", "num_pool"], 194 }, 195 "general": [ 196 "model_name", 197 "metric_functions", 198 "ignored_clips", 199 "len_segment", 200 "overlap", 201 "interactive", 202 ], 203 "losses": { 204 "ms_tcn": ["focal", "gamma", "alpha"], 205 "clip": ["focal", "gamma", "alpha", "fix_text"], 206 }, 207 "metrics": { 208 "f1": ["average", "ignored_classes", "threshold_value"], 209 "precision": ["average", "ignored_classes", "threshold_value"], 210 "recall": ["average", "ignored_classes", "threshold_value"], 211 "f_beta": ["average", "ignored_classes", "threshold_value", "beta"], 212 "count": ["classes"], 213 "segmental_precision": [ 214 "average", 215 "ignored_classes", 216 "threshold_value", 217 "iou_threshold", 218 ], 219 "segmental_recall": [ 220 "average", 221 "ignored_classes", 222 "threshold_value", 223 "iou_threshold", 224 ], 225 "segmental_f1": [ 226 "average", 227 "ignored_classes", 228 "threshold_value", 229 "iou_threshold", 230 ], 231 "segmental_f_beta": [ 232 "average", 233 "ignored_classes", 234 "threshold_value", 235 "iou_threshold", 236 ], 237 "pr-auc": ["average", "ignored_classes", "threshold_step"], 238 "mAP": ["average", "ignored_classes", "iou_threshold", "threshold_value"], 239 "semisegmental_precision": ["average", "ignored_classes", "iou_threshold"], 240 "semisegmental_recall": ["average", "ignored_classes", "iou_threshold"], 241 "semisegmental_f1": ["average", "ignored_classes", "iou_threshold"], 242 }, 243 "training": [ 244 "lr", 245 "device", 246 "num_epochs", 247 "to_ram", 248 "batch_size", 249 "normalize", 250 "temporal_subsampling_size", 251 "parallel", 252 "val_frac", 253 "test_frac", 254 "partition_method", 255 ], 256} 257 258model_hyperparameters = { 259 "asformer": { 260 "losses/ms_tcn/alpha": ("float_log", 1e-5, 1e-2), 261 "losses/ms_tcn/focal": ("categorical", [True, False]), 262 "training/temporal_subsampling_size": ("float", 0.75, 1), 263 "model/num_decoders": ("int", 1, 4), 264 "model/num_f_maps": ("categorical", [32, 64, 128]), 265 "model/num_layers": ("int", 5, 10), 266 "model/channel_masking_rate": ("float", 0.2, 0.4), 267 "general/len_segment": ("categorical", [64, 128]), 268 }, 269 "c2f_tcn": { 270 "losses/ms_tcn/alpha": ("float_log", 1e-5, 1e-2), 271 "losses/ms_tcn/focal": ("categorical", [True, False]), 272 "training/temporal_subsampling_size": ("float", 0.75, 1), 273 "model/num_f_maps": ("int_log", 32, 128), 274 "general/len_segment": ("categorical", [512, 1024]), 275 }, 276 "edtcn": { 277 "losses/ms_tcn/alpha": ("float_log", 1e-5, 1e-2), 278 "losses/ms_tcn/focal": ("categorical", [True, False]), 279 "training/temporal_subsampling_size": ("float", 0.75, 1), 280 "general/len_segment": ("categorical", [128, 256, 512]), 281 }, 282 "transformer": { 283 "losses/ms_tcn/alpha": ("float_log", 1e-5, 1e-2), 284 "losses/ms_tcn/focal": ("categorical", [True, False]), 285 "training/temporal_subsampling_size": ("float", 0.75, 1), 286 "model/N": ("int", 5, 12), 287 "model/heads": ("categorical", [1, 2, 4, 8]), 288 "model/num_pool": ("int", 0, 4), 289 "model/add_batchnorm": ("categorical", [True, False]), 290 "general/len_segment": ("categorical", [64, 128]), 291 }, 292 "mlp": { 293 "losses/ms_tcn/alpha": ("float_log", 1e-5, 1e-2), 294 "losses/ms_tcn/focal": ("categorical", [True, False]), 295 "training/temporal_subsampling_size": ("float", 0.75, 1), 296 "model/dropout_rates": ("float", 0.3, 0.6), 297 }, 298}