PyTorch
ssl-aasist
custom_code
File size: 2,122 Bytes
b1b22fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# @package _group_

common:
  fp16: true
  log_format: json
  log_interval: 200
  tensorboard_logdir: tb
  min_loss_scale: 1e-6
  fp16_no_flatten_grads: true
  user_dir: ${env:PWD}/examples/data2vec

checkpoint:
  save_interval_updates: 50000
  keep_interval_updates: 1
  no_epoch_checkpoints: true

task:
  _name: masked_lm
  data: /fsx-wav2vec/abaevski/data/nlp/bookwiki_aml-full-mmap2-bin
  sample_break_mode: none
  tokens_per_sample: 512
  include_target_tokens: true
  random_token_prob: 0
  leave_unmasked_prob: 0
  include_index: True
  skip_masking: True
  d2v2_multi: True

dataset:
  batch_size: 2
  ignore_unused_valid_subsets: true
  skip_invalid_size_inputs_valid_test: true
  disable_validation: true

distributed_training:
  distributed_world_size: 32
  ddp_backend: c10d

criterion:
  _name: model
  log_keys:
    - ema_decay
    - target_var
    - pred_var
    - model_norm
    - ema_norm
    - masked_pct

optimization:
  max_update: 600000
  clip_norm: 1

optimizer:
  _name: composite
  dynamic_groups: true
  groups:
    default:
      lr_float: 0.0001
      optimizer:
        _name: adam
        adam_betas: [0.9,0.98]
        adam_eps: 1e-06
        weight_decay: 0.01
      lr_scheduler:
        _name: cosine
        warmup_updates: 4000

lr_scheduler: pass_through

model:
  _name: data2vec_multi

  loss_beta: 0
  loss_scale: 1

  depth: 24
  num_heads: 16
  embed_dim: 1024
  clone_batch: 8

  ema_decay: 0.9999
  ema_end_decay: 0.99999
  ema_anneal_end_step: 100000
  ema_encoder_only: true

  average_top_k_layers: 24
  layer_norm_target_layer: true
  instance_norm_target_layer: false
  batch_norm_target_layer: false
  instance_norm_targets: true
  layer_norm_targets: false

  layerdrop: 0
  norm_eps: 1e-5

  supported_modality: TEXT

  modalities:
    text:
      mask_prob: 0.5
      mask_length: 1
      mask_noise_std: 0.01
      prenet_depth: 0
      decoder:
        input_dropout: 0.1
        decoder_dim: 768
        decoder_groups: 1
        decoder_kernel: 9
        decoder_layers: 5
        decoder_residual: false
        projection_layers: 2
        projection_ratio: 2.0