dlc2action.options

Here all option dictionaries are stored

  1#
  2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
  5#
  6"""
  7Here all option dictionaries are stored
  8"""
  9
 10from torch.optim import Adam, SGD
 11from dlc2action.data.input_store import *
 12from dlc2action.data.annotation_store import *
 13from dlc2action.feature_extraction import *
 14from dlc2action.transformer import *
 15from dlc2action.loss import MS_TCN_Loss
 16from dlc2action.model.mlp import MLP
 17from dlc2action.model.c2f_tcn import C2F_TCN
 18from dlc2action.model.asformer import ASFormer
 19from dlc2action.model.transformer import Transformer
 20from dlc2action.model.edtcn import EDTCN
 21from dlc2action.ssl.contrastive import *
 22from dlc2action.ssl.masked import *
 23from dlc2action.ssl.segment_order import *
 24from dlc2action.ssl.tcc import TCCSSL
 25from dlc2action.metric.metrics import *
 26
 27
 28input_stores = {
 29    "dlc_tracklet": DLCTrackletStore,
 30    "dlc_track": DLCTrackStore,
 31    "pku-mmd": PKUMMDInputStore,
 32    "calms21": CalMS21InputStore,
 33    "np_3d": Numpy3DInputStore,
 34    "features": LoadedFeaturesInputStore,
 35    "simba": SIMBAInputStore,
 36}
 37
 38annotation_stores = {
 39    "dlc": DLCAnnotationStore,
 40    "pku-mmd": PKUMMDAnnotationStore,
 41    "boris": BorisAnnotationStore,
 42    "none": EmptyAnnotationStore,
 43    "calms21": CalMS21AnnotationStore,
 44    "csv": CSVAnnotationStore,
 45    "simba": SIMBAAnnotationStore,
 46}
 47
 48feature_extractors = {"kinematic": KinematicExtractor}
 49
 50ssl_constructors = {
 51    "masked_features": MaskedFeaturesSSL_FC,
 52    "masked_joints": MaskedKinematicSSL_FC,
 53    "masked_frames": MaskedFramesSSL_FC,
 54    "contrastive": ContrastiveSSL,
 55    "pairwise": PairwiseSSL,
 56    "contrastive_masked": ContrastiveMaskedSSL,
 57    "pairwise_masked": PairwiseMaskedSSL,
 58    "reverse": ReverseSSL,
 59    "order": OrderSSL,
 60    "contrastive_regression": ContrastiveRegressionSSL,
 61    "tcc": TCCSSL,
 62}
 63
 64transformers = {"kinematic": KinematicTransformer}
 65
 66losses = {
 67    "ms_tcn": MS_TCN_Loss,
 68}
 69losses_multistage = [
 70    "ms_tcn",
 71]  # losses that expect predictions of shape (#stages, #batch, #classes, #frames)
 72
 73metrics = {
 74    "accuracy": Accuracy,
 75    "precision": Precision,
 76    "f1": F1,
 77    "recall": Recall,
 78    "count": Count,
 79    # "mAP": mAP,
 80    "segmental_precision": SegmentalPrecision,
 81    "segmental_recall": SegmentalRecall,
 82    "segmental_f1": SegmentalF1,
 83    "edit_distance": EditDistance,
 84    "f_beta": Fbeta,
 85    "segmental_f_beta": SegmentalFbeta,
 86    "semisegmental_precision": SemiSegmentalPrecision,
 87    "semisegmental_recall": SemiSegmentalRecall,
 88    "semisegmental_f1": SemiSegmentalF1,
 89    "pr-auc": PR_AUC,
 90    "semisegmental_pr-auc": SemiSegmentalPR_AUC,
 91    "mAP": PKU_mAP,
 92}
 93metrics_minimize = [
 94    "edit_distance"
 95]  # metrics that decrease when prediction quality increases
 96metrics_no_direction = ["count"]  # metrics that do not indicate prediction quality
 97
 98optimizers = {"Adam": Adam, "SGD": SGD}
 99
100models = {
101    "asformer": ASFormer,
102    "mlp": MLP,
103    "c2f_tcn": C2F_TCN,
104    "edtcn": EDTCN,
105    "transformer": Transformer,
106}
107
108blanks = [
109    "dataset_inverse_weights",
110    "dataset_proportional_weights",
111    "dataset_classes",
112    "dataset_features",
113    "dataset_len_segment",
114    "dataset_bodyparts",
115    "dataset_boundary_weight",
116    "model_features",
117]
118
119extractor_to_transformer = {
120    "kinematic": "kinematic",
121    "heatmap": "heatmap",
122}  # keys are feature extractor names, values are transformer names
123
124partition_methods = {
125    "random": [
126        "random",
127        "random:test-from-name",
128        "random:test-from-name:{name}",
129        "random:equalize:segments",
130        "random:equalize:videos",
131    ],
132    "fixed": [
133        "val-from-name:{val_name}:test-from-name:{test_name}",
134        "time",
135        "time:start-from:{frac}",
136        "time:start-from:{frac}:strict",
137        "time:strict",
138        "file",
139        "folders",
140    ],
141}
142
143basic_parameters = {
144    "data": [
145        "data_suffix",
146        "feature_suffix",
147        "annotation_suffix",
148        "canvas_shape",
149        "ignored_bodyparts",
150        "likelihood_threshold",
151        "behaviors",
152        "filter_annotated",
153        "filter_background",
154        "visibility_min_score",
155        "visibility_min_frac",
156    ],
157    "augmentations": {
158        "heatmap": ["augmentations", "rotation_degree_limits"],
159        "kinematic": [
160            "augmentations",
161            "rotation_limits",
162            "mirror_dim",
163            "noise_std",
164            "zoom_limits",
165            "masking_probability",
166        ],
167    },
168    "features": {
169        "heatmap": ["keys", "channel_policy", "heatmap_width", "sigma"],
170        "kinematic": [
171            "keys",
172            "averaging_window",
173            "distance_pairs",
174            "angle_pairs",
175            "zone_vertices",
176            "zone_bools",
177            "zone_distances",
178            "area_vertices",
179        ],
180    },
181    "model": {
182        "asformer": [
183            "num_decoders",
184            "num_layers",
185            "r1",
186            "r2",
187            "num_f_maps",
188            "channel_masking_rate",
189        ],
190        "c2f_tcn": ["num_f_maps", "feature_dim"],
191        "edtcn": ["kernel_size", "mid_channels"],
192        "mlp": ["f_maps_list", "dropout_rates"],
193        "transformer": ["num_f_maps", "N", "heads", "num_pool"],
194    },
195    "general": [
196        "model_name",
197        "metric_functions",
198        "ignored_clips",
199        "len_segment",
200        "overlap",
201        "interactive",
202    ],
203    "losses": {
204        "ms_tcn": ["focal", "gamma", "alpha"],
205        "clip": ["focal", "gamma", "alpha", "fix_text"],
206    },
207    "metrics": {
208        "f1": ["average", "ignored_classes", "threshold_value"],
209        "precision": ["average", "ignored_classes", "threshold_value"],
210        "recall": ["average", "ignored_classes", "threshold_value"],
211        "f_beta": ["average", "ignored_classes", "threshold_value", "beta"],
212        "count": ["classes"],
213        "segmental_precision": [
214            "average",
215            "ignored_classes",
216            "threshold_value",
217            "iou_threshold",
218        ],
219        "segmental_recall": [
220            "average",
221            "ignored_classes",
222            "threshold_value",
223            "iou_threshold",
224        ],
225        "segmental_f1": [
226            "average",
227            "ignored_classes",
228            "threshold_value",
229            "iou_threshold",
230        ],
231        "segmental_f_beta": [
232            "average",
233            "ignored_classes",
234            "threshold_value",
235            "iou_threshold",
236        ],
237        "pr-auc": ["average", "ignored_classes", "threshold_step"],
238        "mAP": ["average", "ignored_classes", "iou_threshold", "threshold_value"],
239        "semisegmental_precision": ["average", "ignored_classes", "iou_threshold"],
240        "semisegmental_recall": ["average", "ignored_classes", "iou_threshold"],
241        "semisegmental_f1": ["average", "ignored_classes", "iou_threshold"],
242    },
243    "training": [
244        "lr",
245        "device",
246        "num_epochs",
247        "to_ram",
248        "batch_size",
249        "normalize",
250        "temporal_subsampling_size",
251        "parallel",
252        "val_frac",
253        "test_frac",
254        "partition_method",
255    ],
256}
257
258model_hyperparameters = {
259    "asformer": {
260        "losses/ms_tcn/alpha": ("float_log", 1e-5, 1e-2),
261        "losses/ms_tcn/focal": ("categorical", [True, False]),
262        "training/temporal_subsampling_size": ("float", 0.75, 1),
263        "model/num_decoders": ("int", 1, 4),
264        "model/num_f_maps": ("categorical", [32, 64, 128]),
265        "model/num_layers": ("int", 5, 10),
266        "model/channel_masking_rate": ("float", 0.2, 0.4),
267        "general/len_segment": ("categorical", [64, 128]),
268    },
269    "c2f_tcn": {
270        "losses/ms_tcn/alpha": ("float_log", 1e-5, 1e-2),
271        "losses/ms_tcn/focal": ("categorical", [True, False]),
272        "training/temporal_subsampling_size": ("float", 0.75, 1),
273        "model/num_f_maps": ("int_log", 32, 128),
274        "general/len_segment": ("categorical", [512, 1024]),
275    },
276    "edtcn": {
277        "losses/ms_tcn/alpha": ("float_log", 1e-5, 1e-2),
278        "losses/ms_tcn/focal": ("categorical", [True, False]),
279        "training/temporal_subsampling_size": ("float", 0.75, 1),
280        "general/len_segment": ("categorical", [128, 256, 512]),
281    },
282    "transformer": {
283        "losses/ms_tcn/alpha": ("float_log", 1e-5, 1e-2),
284        "losses/ms_tcn/focal": ("categorical", [True, False]),
285        "training/temporal_subsampling_size": ("float", 0.75, 1),
286        "model/N": ("int", 5, 12),
287        "model/heads": ("categorical", [1, 2, 4, 8]),
288        "model/num_pool": ("int", 0, 4),
289        "model/add_batchnorm": ("categorical", [True, False]),
290        "general/len_segment": ("categorical", [64, 128]),
291    },
292    "mlp": {
293        "losses/ms_tcn/alpha": ("float_log", 1e-5, 1e-2),
294        "losses/ms_tcn/focal": ("categorical", [True, False]),
295        "training/temporal_subsampling_size": ("float", 0.75, 1),
296        "model/dropout_rates": ("float", 0.3, 0.6),
297    },
298}