arulpm commited on
Commit
ea60981
·
verified ·
1 Parent(s): 0b84586

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
1_Pooling/config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "word_embedding_dimension": 768,
3
+ "pooling_mode_cls_token": true,
4
+ "pooling_mode_mean_tokens": false,
5
+ "pooling_mode_max_tokens": false,
6
+ "pooling_mode_mean_sqrt_len_tokens": false,
7
+ "pooling_mode_weightedmean_tokens": false,
8
+ "pooling_mode_lasttoken": false,
9
+ "include_prompt": true
10
+ }
README.md CHANGED
@@ -1,3 +1,570 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  <br>nyeri tenggorokan, ada amandel 2 hari yang lalu
5
  <br>puisng </code> | <code>Obat: Eflagen 50 mg<br>Deskripsi Obat: Obat antiinflamasi nonsteroid (NSAID) untuk mengurangi nyeri</code> |
 
 
6
  <br>RPO : Acetylcystein, OBH</code> | <code>Obat: Cefixime 200 mg<br>Deskripsi Obat: Antibiotik golongan sefalosporin untuk mengobati infeksi bakteri</code> |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  <br>rpo myalanta
8
  <br>alergi obat dsiangkal </code> | <code>Obat: Lambucid<br>Deskripsi Obat: Antasida untuk meredakan gejala asam lambung berlebih</code> |
 
9
  <br>rpo tablet tambah darah. riw anemia + dikatakan saat SMA 3 tahun yll. tidak ingat hb berapa. </code> | <code>Obat: Caviplex<br>Deskripsi Obat: Suplemen multivitamin untuk memenuhi kebutuhan vitamin dan mineral</code> |
 
10
  <br>keluhan hari ini kepala pusing terasa berat, batuk, nyeri badan. <br>Anamnesa Pemeriksaan Dokter: rencana rujukan rsMD spP tgl 23-04-24 keluhan hari ini kepala pusing terasa berat, batuk, nyeri badan.</code> | <code>Obat: Profat sirup<br>Deskripsi Obat: Suplemen penambah darah sirup untuk anak</code> |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - multilingual
4
+ license: apache-2.0
5
+ tags:
6
+ - sentence-transformers
7
+ - sentence-similarity
8
+ - feature-extraction
9
+ - generated_from_trainer
10
+ - dataset_size:217680
11
+ - loss:MultipleNegativesRankingLoss
12
+ base_model: Alibaba-NLP/gte-multilingual-base
13
+ widget:
14
+ - source_sentence: "Keluhan: bintik2 merah di badan. 3 hari\r\ngatal +\nAnamnesa Pemeriksaan\
15
+ \ Dokter: bintik2 merah di badan. 3 hari\r\ngatal +"
16
+ sentences:
17
+ - 'Obat: Devit 5000iu
18
+
19
+ Deskripsi Obat: Vitamin D3 dosis tinggi untuk defisiensi'
20
+ - 'Obat: Betaver 6 mg
21
+
22
+ Deskripsi Obat: Obat untuk mengobati vertigo dan gangguan keseimbangan'
23
+ - 'Obat: Captopril 12,5 mg
24
+
25
+ Deskripsi Obat: Obat antihipertensi golongan ACE inhibitor dosis rendah untuk
26
+ menurunkan tekanan darah'
27
+ - source_sentence: "Keluhan: 2 minggu sampai sekarang BAK berdarah, sebelumnya sempat\
28
+ \ berdarah, saat BAK kadang sakit \nAnamnesa Pemeriksaan Dokter: bak disertai\
29
+ \ darah sejak 2 minggu, kadang disertai nyeri, riwayat sebelumnya bab disertai\
30
+ \ darah "
31
+ sentences:
32
+ - 'Obat: Hufarizine 10 mg
33
+
34
+ Deskripsi Obat: Antihistamin untuk mengatasi alergi'
35
+ - 'Obat: Cefixime 200 mg
36
+
37
+ Deskripsi Obat: Antibiotik golongan sefalosporin untuk mengobati infeksi bakteri'
38
+ - 'Obat: Omeprazole
39
+
40
+ Deskripsi Obat: Obat penghambat pompa proton untuk mengatasi tukak lambung'
41
+ - source_sentence: 'Keluhan: Batuk batuk. Sedang minum obat pasca operasi gigi
42
+
43
+ Anamnesa Pemeriksaan Dokter: batuk2'
44
+ sentences:
45
+ - 'Obat: Pyfaton tablet
46
+
47
+ Deskripsi Obat: Obat untuk mengobati gangguan pencernaan'
48
+ - 'Obat: Cefadroxil 500 mg
49
+
50
+ Deskripsi Obat: Antibiotik golongan sefalosporin untuk mengobati infeksi bakteri
51
+ saluran pernapasan dan kulit'
52
+ - 'Obat: Acyclovir 400 mg
53
+
54
+ Deskripsi Obat: Obat antiviral untuk mengobati infeksi virus herpes simpleks dan
55
+ varicella zoster'
56
+ - source_sentence: 'Keluhan: nan
57
+
58
+ Anamnesa Pemeriksaan Dokter: demam hari ke 5, mual muntah, ada kemerahan di betis
59
+ kaki kanan sejak 3 hari, susah masuk makan karena mual'
60
+ sentences:
61
+ - 'Obat: Caviplex
62
+
63
+ Deskripsi Obat: Suplemen multivitamin untuk memenuhi kebutuhan vitamin dan mineral'
64
+ - 'Obat: Ciprofloxacin 500 mg
65
+
66
+ Deskripsi Obat: Antibiotik golongan fluorokuinolon untuk mengobati infeksi bakteri'
67
+ - 'Obat: OBH
68
+
69
+ Deskripsi Obat: Obat batuk untuk meredakan batuk kering'
70
+ - source_sentence: 'Keluhan: nan
71
+
72
+ Anamnesa Pemeriksaan Dokter: mual pusing'
73
+ sentences:
74
+ - 'Obat: Buscopan
75
+
76
+ Deskripsi Obat: Obat untuk mengurangi kejang otot polos saluran pencernaan'
77
+ - 'Obat: Sangobion
78
+
79
+ Deskripsi Obat: Suplemen penambah darah untuk mengobati anemia defisiensi besi'
80
+ - 'Obat: Blocand 16 mg
81
+
82
+ Deskripsi Obat: Obat antihipertensi golongan ARB dosis tinggi untuk menurunkan
83
+ tekanan darah'
84
+ pipeline_tag: sentence-similarity
85
+ library_name: sentence-transformers
86
+ ---
87
+
88
+ # GTE Multilingual fine-tuned on clinical-to-drug mapping
89
+
90
+ This is a [sentence-transformers](https://www.SBERT.net) model finetuned from [Alibaba-NLP/gte-multilingual-base](https://huggingface.co/Alibaba-NLP/gte-multilingual-base). It maps sentences & paragraphs to a 768-dimensional dense vector space and can be used for semantic textual similarity, semantic search, paraphrase mining, text classification, clustering, and more.
91
+
92
+ ## Model Details
93
+
94
+ ### Model Description
95
+ - **Model Type:** Sentence Transformer
96
+ - **Base model:** [Alibaba-NLP/gte-multilingual-base](https://huggingface.co/Alibaba-NLP/gte-multilingual-base) <!-- at revision 9bbca17d9273fd0d03d5725c7a4b0f6b45142062 -->
97
+ - **Maximum Sequence Length:** 8192 tokens
98
+ - **Output Dimensionality:** 768 dimensions
99
+ - **Similarity Function:** Cosine Similarity
100
+ <!-- - **Training Dataset:** Unknown -->
101
+ - **Language:** multilingual
102
+ - **License:** apache-2.0
103
+
104
+ ### Model Sources
105
+
106
+ - **Documentation:** [Sentence Transformers Documentation](https://sbert.net)
107
+ - **Repository:** [Sentence Transformers on GitHub](https://github.com/UKPLab/sentence-transformers)
108
+ - **Hugging Face:** [Sentence Transformers on Hugging Face](https://huggingface.co/models?library=sentence-transformers)
109
+
110
+ ### Full Model Architecture
111
+
112
+ ```
113
+ SentenceTransformer(
114
+ (0): Transformer({'max_seq_length': 8192, 'do_lower_case': False}) with Transformer model: NewModel
115
+ (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
116
+ (2): Normalize()
117
+ )
118
+ ```
119
+
120
+ ## Usage
121
+
122
+ ### Direct Usage (Sentence Transformers)
123
+
124
+ First install the Sentence Transformers library:
125
+
126
+ ```bash
127
+ pip install -U sentence-transformers
128
+ ```
129
+
130
+ Then you can load this model and run inference.
131
+ ```python
132
+ from sentence_transformers import SentenceTransformer
133
+
134
+ # Download from the 🤗 Hub
135
+ model = SentenceTransformer("sentence_transformers_model_id")
136
+ # Run inference
137
+ sentences = [
138
+ 'Keluhan: nan\nAnamnesa Pemeriksaan Dokter: mual pusing',
139
+ 'Obat: Buscopan\nDeskripsi Obat: Obat untuk mengurangi kejang otot polos saluran pencernaan',
140
+ 'Obat: Blocand 16 mg\nDeskripsi Obat: Obat antihipertensi golongan ARB dosis tinggi untuk menurunkan tekanan darah',
141
+ ]
142
+ embeddings = model.encode(sentences)
143
+ print(embeddings.shape)
144
+ # [3, 768]
145
+
146
+ # Get the similarity scores for the embeddings
147
+ similarities = model.similarity(embeddings, embeddings)
148
+ print(similarities.shape)
149
+ # [3, 3]
150
+ ```
151
+
152
+ <!--
153
+ ### Direct Usage (Transformers)
154
+
155
+ <details><summary>Click to see the direct usage in Transformers</summary>
156
+
157
+ </details>
158
+ -->
159
+
160
+ <!--
161
+ ### Downstream Usage (Sentence Transformers)
162
+
163
+ You can finetune this model on your own dataset.
164
+
165
+ <details><summary>Click to expand</summary>
166
+
167
+ </details>
168
+ -->
169
+
170
+ <!--
171
+ ### Out-of-Scope Use
172
+
173
+ *List how the model may foreseeably be misused and address what users ought not to do with the model.*
174
+ -->
175
+
176
+ <!--
177
+ ## Bias, Risks and Limitations
178
+
179
+ *What are the known or foreseeable issues stemming from this model? You could also flag here known failure cases or weaknesses of the model.*
180
+ -->
181
+
182
+ <!--
183
+ ### Recommendations
184
+
185
+ *What are recommendations with respect to the foreseeable issues? For example, filtering explicit content.*
186
+ -->
187
+
188
+ ## Training Details
189
+
190
+ ### Training Dataset
191
+
192
+ #### Unnamed Dataset
193
+
194
+ * Size: 217,680 training samples
195
+ * Columns: <code>anchor</code> and <code>positive</code>
196
+ * Approximate statistics based on the first 1000 samples:
197
+ | | anchor | positive |
198
+ |:--------|:------------------------------------------------------------------------------------|:-----------------------------------------------------------------------------------|
199
+ | type | string | string |
200
+ | details | <ul><li>min: 17 tokens</li><li>mean: 42.44 tokens</li><li>max: 177 tokens</li></ul> | <ul><li>min: 15 tokens</li><li>mean: 22.55 tokens</li><li>max: 38 tokens</li></ul> |
201
+ * Samples:
202
+ | anchor | positive |
203
+ |:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-----------------------------------------------------------------------------------------------------------------------|
204
+ | <code>Keluhan: nan<br>Anamnesa Pemeriksaan Dokter: gusi kiri bengkak
205
  <br>nyeri tenggorokan, ada amandel 2 hari yang lalu
206
  <br>puisng </code> | <code>Obat: Eflagen 50 mg<br>Deskripsi Obat: Obat antiinflamasi nonsteroid (NSAID) untuk mengurangi nyeri</code> |
207
+ | <code>Keluhan: pasien mengatakan kuku kaki jempol kiri merah dan nyeri sejak 1 bulan yll hilang timbul<br>Anamnesa Pemeriksaan Dokter: pasien mengatakan kuku kaki jempol kiri merah dan nyeri sejak 1 bulan yll hilang timbul. merah + bengkak + </code> | <code>Obat: Mefinal 500 mg<br>Deskripsi Obat: Obat antiinflamasi nonsteroid (NSAID) untuk mengurangi nyeri</code> |
208
+ | <code>Keluhan: batuk dahak sudah 1 minggu<br>Anamnesa Pemeriksaan Dokter: batuk dahak sejak 1 minggu
209
  <br>RPO : Acetylcystein, OBH</code> | <code>Obat: Cefixime 200 mg<br>Deskripsi Obat: Antibiotik golongan sefalosporin untuk mengobati infeksi bakteri</code> |
210
+ * Loss: [<code>MultipleNegativesRankingLoss</code>](https://sbert.net/docs/package_reference/sentence_transformer/losses.html#multiplenegativesrankingloss) with these parameters:
211
+ ```json
212
+ {
213
+ "scale": 20.0,
214
+ "similarity_fct": "cos_sim"
215
+ }
216
+ ```
217
+
218
+ ### Evaluation Dataset
219
+
220
+ #### Unnamed Dataset
221
+
222
+ * Size: 24,187 evaluation samples
223
+ * Columns: <code>anchor</code> and <code>positive</code>
224
+ * Approximate statistics based on the first 1000 samples:
225
+ | | anchor | positive |
226
+ |:--------|:------------------------------------------------------------------------------------|:-----------------------------------------------------------------------------------|
227
+ | type | string | string |
228
+ | details | <ul><li>min: 16 tokens</li><li>mean: 42.79 tokens</li><li>max: 112 tokens</li></ul> | <ul><li>min: 15 tokens</li><li>mean: 22.79 tokens</li><li>max: 38 tokens</li></ul> |
229
+ * Samples:
230
+ | anchor | positive |
231
+ |:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------------------------------|
232
+ | <code>Keluhan: NYERI PERUT, MUAL, MUNTAH 1KALI HARI INI,<br>Anamnesa Pemeriksaan Dokter: NYERI PERUT, MUAL, MUNTAH 1KALI HARI INI, tenggorokan gatla
233
  <br>rpo myalanta
234
  <br>alergi obat dsiangkal </code> | <code>Obat: Lambucid<br>Deskripsi Obat: Antasida untuk meredakan gejala asam lambung berlebih</code> |
235
+ | <code>Keluhan: badan lemas tadi pagi, mual- muntah- batuk pilek- <br>Anamnesa Pemeriksaan Dokter: lemas sejak pagi ini. demam - pingsan - sempat terjatuh karena lemas. pola makan tidak teratur.
236
  <br>rpo tablet tambah darah. riw anemia + dikatakan saat SMA 3 tahun yll. tidak ingat hb berapa. </code> | <code>Obat: Caviplex<br>Deskripsi Obat: Suplemen multivitamin untuk memenuhi kebutuhan vitamin dan mineral</code> |
237
+ | <code>Keluhan: rencana rujukan rsMD spP tgl 23-04-24
238
  <br>keluhan hari ini kepala pusing terasa berat, batuk, nyeri badan. <br>Anamnesa Pemeriksaan Dokter: rencana rujukan rsMD spP tgl 23-04-24 keluhan hari ini kepala pusing terasa berat, batuk, nyeri badan.</code> | <code>Obat: Profat sirup<br>Deskripsi Obat: Suplemen penambah darah sirup untuk anak</code> |
239
+ * Loss: [<code>MultipleNegativesRankingLoss</code>](https://sbert.net/docs/package_reference/sentence_transformer/losses.html#multiplenegativesrankingloss) with these parameters:
240
+ ```json
241
+ {
242
+ "scale": 20.0,
243
+ "similarity_fct": "cos_sim"
244
+ }
245
+ ```
246
+
247
+ ### Training Hyperparameters
248
+ #### Non-Default Hyperparameters
249
+
250
+ - `eval_strategy`: steps
251
+ - `per_device_train_batch_size`: 32
252
+ - `per_device_eval_batch_size`: 32
253
+ - `num_train_epochs`: 1
254
+ - `warmup_ratio`: 0.1
255
+ - `fp16`: True
256
+ - `batch_sampler`: no_duplicates
257
+
258
+ #### All Hyperparameters
259
+ <details><summary>Click to expand</summary>
260
+
261
+ - `overwrite_output_dir`: False
262
+ - `do_predict`: False
263
+ - `eval_strategy`: steps
264
+ - `prediction_loss_only`: True
265
+ - `per_device_train_batch_size`: 32
266
+ - `per_device_eval_batch_size`: 32
267
+ - `per_gpu_train_batch_size`: None
268
+ - `per_gpu_eval_batch_size`: None
269
+ - `gradient_accumulation_steps`: 1
270
+ - `eval_accumulation_steps`: None
271
+ - `torch_empty_cache_steps`: None
272
+ - `learning_rate`: 5e-05
273
+ - `weight_decay`: 0.0
274
+ - `adam_beta1`: 0.9
275
+ - `adam_beta2`: 0.999
276
+ - `adam_epsilon`: 1e-08
277
+ - `max_grad_norm`: 1.0
278
+ - `num_train_epochs`: 1
279
+ - `max_steps`: -1
280
+ - `lr_scheduler_type`: linear
281
+ - `lr_scheduler_kwargs`: {}
282
+ - `warmup_ratio`: 0.1
283
+ - `warmup_steps`: 0
284
+ - `log_level`: passive
285
+ - `log_level_replica`: warning
286
+ - `log_on_each_node`: True
287
+ - `logging_nan_inf_filter`: True
288
+ - `save_safetensors`: True
289
+ - `save_on_each_node`: False
290
+ - `save_only_model`: False
291
+ - `restore_callback_states_from_checkpoint`: False
292
+ - `no_cuda`: False
293
+ - `use_cpu`: False
294
+ - `use_mps_device`: False
295
+ - `seed`: 42
296
+ - `data_seed`: None
297
+ - `jit_mode_eval`: False
298
+ - `use_ipex`: False
299
+ - `bf16`: False
300
+ - `fp16`: True
301
+ - `fp16_opt_level`: O1
302
+ - `half_precision_backend`: auto
303
+ - `bf16_full_eval`: False
304
+ - `fp16_full_eval`: False
305
+ - `tf32`: None
306
+ - `local_rank`: 0
307
+ - `ddp_backend`: None
308
+ - `tpu_num_cores`: None
309
+ - `tpu_metrics_debug`: False
310
+ - `debug`: []
311
+ - `dataloader_drop_last`: False
312
+ - `dataloader_num_workers`: 0
313
+ - `dataloader_prefetch_factor`: None
314
+ - `past_index`: -1
315
+ - `disable_tqdm`: False
316
+ - `remove_unused_columns`: True
317
+ - `label_names`: None
318
+ - `load_best_model_at_end`: False
319
+ - `ignore_data_skip`: False
320
+ - `fsdp`: []
321
+ - `fsdp_min_num_params`: 0
322
+ - `fsdp_config`: {'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}
323
+ - `fsdp_transformer_layer_cls_to_wrap`: None
324
+ - `accelerator_config`: {'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None}
325
+ - `deepspeed`: None
326
+ - `label_smoothing_factor`: 0.0
327
+ - `optim`: adamw_torch
328
+ - `optim_args`: None
329
+ - `adafactor`: False
330
+ - `group_by_length`: False
331
+ - `length_column_name`: length
332
+ - `ddp_find_unused_parameters`: None
333
+ - `ddp_bucket_cap_mb`: None
334
+ - `ddp_broadcast_buffers`: False
335
+ - `dataloader_pin_memory`: True
336
+ - `dataloader_persistent_workers`: False
337
+ - `skip_memory_metrics`: True
338
+ - `use_legacy_prediction_loop`: False
339
+ - `push_to_hub`: False
340
+ - `resume_from_checkpoint`: None
341
+ - `hub_model_id`: None
342
+ - `hub_strategy`: every_save
343
+ - `hub_private_repo`: None
344
+ - `hub_always_push`: False
345
+ - `hub_revision`: None
346
+ - `gradient_checkpointing`: False
347
+ - `gradient_checkpointing_kwargs`: None
348
+ - `include_inputs_for_metrics`: False
349
+ - `include_for_metrics`: []
350
+ - `eval_do_concat_batches`: True
351
+ - `fp16_backend`: auto
352
+ - `push_to_hub_model_id`: None
353
+ - `push_to_hub_organization`: None
354
+ - `mp_parameters`:
355
+ - `auto_find_batch_size`: False
356
+ - `full_determinism`: False
357
+ - `torchdynamo`: None
358
+ - `ray_scope`: last
359
+ - `ddp_timeout`: 1800
360
+ - `torch_compile`: False
361
+ - `torch_compile_backend`: None
362
+ - `torch_compile_mode`: None
363
+ - `include_tokens_per_second`: False
364
+ - `include_num_input_tokens_seen`: False
365
+ - `neftune_noise_alpha`: None
366
+ - `optim_target_modules`: None
367
+ - `batch_eval_metrics`: False
368
+ - `eval_on_start`: False
369
+ - `use_liger_kernel`: False
370
+ - `liger_kernel_config`: None
371
+ - `eval_use_gather_object`: False
372
+ - `average_tokens_across_devices`: False
373
+ - `prompts`: None
374
+ - `batch_sampler`: no_duplicates
375
+ - `multi_dataset_batch_sampler`: proportional
376
+
377
+ </details>
378
+
379
+ ### Training Logs
380
+ <details><summary>Click to expand</summary>
381
+
382
+ | Epoch | Step | Training Loss | Validation Loss |
383
+ |:------:|:----:|:-------------:|:---------------:|
384
+ | 0.0073 | 50 | 3.3431 | - |
385
+ | 0.0147 | 100 | 3.1904 | - |
386
+ | 0.0220 | 150 | 3.0541 | - |
387
+ | 0.0294 | 200 | 2.972 | - |
388
+ | 0.0367 | 250 | 2.8877 | - |
389
+ | 0.0441 | 300 | 2.8234 | - |
390
+ | 0.0514 | 350 | 2.749 | - |
391
+ | 0.0588 | 400 | 2.7435 | - |
392
+ | 0.0661 | 450 | 2.7368 | - |
393
+ | 0.0735 | 500 | 2.6943 | - |
394
+ | 0.0808 | 550 | 2.7168 | - |
395
+ | 0.0882 | 600 | 2.7194 | - |
396
+ | 0.0955 | 650 | 2.6096 | - |
397
+ | 0.1029 | 700 | 2.7118 | - |
398
+ | 0.1102 | 750 | 2.7036 | - |
399
+ | 0.1176 | 800 | 2.6625 | - |
400
+ | 0.1249 | 850 | 2.6362 | - |
401
+ | 0.1323 | 900 | 2.599 | - |
402
+ | 0.1396 | 950 | 2.572 | - |
403
+ | 0.1470 | 1000 | 2.6124 | 2.0072 |
404
+ | 0.1543 | 1050 | 2.5467 | - |
405
+ | 0.1617 | 1100 | 2.5713 | - |
406
+ | 0.1690 | 1150 | 2.5741 | - |
407
+ | 0.1764 | 1200 | 2.5794 | - |
408
+ | 0.1837 | 1250 | 2.5231 | - |
409
+ | 0.1911 | 1300 | 2.5312 | - |
410
+ | 0.1984 | 1350 | 2.4483 | - |
411
+ | 0.2058 | 1400 | 2.5178 | - |
412
+ | 0.2131 | 1450 | 2.4795 | - |
413
+ | 0.2205 | 1500 | 2.5426 | - |
414
+ | 0.2278 | 1550 | 2.502 | - |
415
+ | 0.2352 | 1600 | 2.5378 | - |
416
+ | 0.2425 | 1650 | 2.4746 | - |
417
+ | 0.2499 | 1700 | 2.4356 | - |
418
+ | 0.2572 | 1750 | 2.5303 | - |
419
+ | 0.2646 | 1800 | 2.514 | - |
420
+ | 0.2719 | 1850 | 2.5207 | - |
421
+ | 0.2793 | 1900 | 2.4671 | - |
422
+ | 0.2866 | 1950 | 2.4367 | - |
423
+ | 0.2940 | 2000 | 2.4873 | 1.9339 |
424
+ | 0.3013 | 2050 | 2.4513 | - |
425
+ | 0.3087 | 2100 | 2.4695 | - |
426
+ | 0.3160 | 2150 | 2.4309 | - |
427
+ | 0.3234 | 2200 | 2.4439 | - |
428
+ | 0.3307 | 2250 | 2.4242 | - |
429
+ | 0.3381 | 2300 | 2.4569 | - |
430
+ | 0.3454 | 2350 | 2.4157 | - |
431
+ | 0.3528 | 2400 | 2.4709 | - |
432
+ | 0.3601 | 2450 | 2.4202 | - |
433
+ | 0.3675 | 2500 | 2.4401 | - |
434
+ | 0.3748 | 2550 | 2.4096 | - |
435
+ | 0.3822 | 2600 | 2.3878 | - |
436
+ | 0.3895 | 2650 | 2.4766 | - |
437
+ | 0.3969 | 2700 | 2.4149 | - |
438
+ | 0.4042 | 2750 | 2.4197 | - |
439
+ | 0.4116 | 2800 | 2.3656 | - |
440
+ | 0.4189 | 2850 | 2.4679 | - |
441
+ | 0.4263 | 2900 | 2.3749 | - |
442
+ | 0.4336 | 2950 | 2.4146 | - |
443
+ | 0.4410 | 3000 | 2.3942 | 1.8871 |
444
+ | 0.4483 | 3050 | 2.418 | - |
445
+ | 0.4557 | 3100 | 2.4504 | - |
446
+ | 0.4630 | 3150 | 2.3759 | - |
447
+ | 0.4704 | 3200 | 2.3671 | - |
448
+ | 0.4777 | 3250 | 2.4433 | - |
449
+ | 0.4851 | 3300 | 2.4036 | - |
450
+ | 0.4924 | 3350 | 2.3539 | - |
451
+ | 0.4998 | 3400 | 2.3806 | - |
452
+ | 0.5071 | 3450 | 2.3737 | - |
453
+ | 0.5145 | 3500 | 2.4127 | - |
454
+ | 0.5218 | 3550 | 2.4243 | - |
455
+ | 0.5292 | 3600 | 2.3528 | - |
456
+ | 0.5365 | 3650 | 2.3788 | - |
457
+ | 0.5439 | 3700 | 2.3968 | - |
458
+ | 0.5512 | 3750 | 2.3896 | - |
459
+ | 0.5586 | 3800 | 2.3966 | - |
460
+ | 0.5659 | 3850 | 2.3571 | - |
461
+ | 0.5733 | 3900 | 2.3437 | - |
462
+ | 0.5806 | 3950 | 2.3353 | - |
463
+ | 0.5880 | 4000 | 2.3335 | 1.8599 |
464
+ | 0.5953 | 4050 | 2.3778 | - |
465
+ | 0.6027 | 4100 | 2.3929 | - |
466
+ | 0.6100 | 4150 | 2.3818 | - |
467
+ | 0.6174 | 4200 | 2.3874 | - |
468
+ | 0.6247 | 4250 | 2.3224 | - |
469
+ | 0.6321 | 4300 | 2.3317 | - |
470
+ | 0.6394 | 4350 | 2.3761 | - |
471
+ | 0.6468 | 4400 | 2.4066 | - |
472
+ | 0.6541 | 4450 | 2.3406 | - |
473
+ | 0.6615 | 4500 | 2.3844 | - |
474
+ | 0.6688 | 4550 | 2.2993 | - |
475
+ | 0.6762 | 4600 | 2.337 | - |
476
+ | 0.6835 | 4650 | 2.37 | - |
477
+ | 0.6909 | 4700 | 2.3126 | - |
478
+ | 0.6982 | 4750 | 2.3818 | - |
479
+ | 0.7056 | 4800 | 2.3849 | - |
480
+ | 0.7129 | 4850 | 2.3379 | - |
481
+ | 0.7203 | 4900 | 2.3518 | - |
482
+ | 0.7276 | 4950 | 2.3354 | - |
483
+ | 0.7350 | 5000 | 2.3443 | 1.8349 |
484
+ | 0.7423 | 5050 | 2.3396 | - |
485
+ | 0.7497 | 5100 | 2.3086 | - |
486
+ | 0.7570 | 5150 | 2.3392 | - |
487
+ | 0.7644 | 5200 | 2.3316 | - |
488
+ | 0.7717 | 5250 | 2.3092 | - |
489
+ | 0.7791 | 5300 | 2.3794 | - |
490
+ | 0.7864 | 5350 | 2.331 | - |
491
+ | 0.7938 | 5400 | 2.2554 | - |
492
+ | 0.8011 | 5450 | 2.3266 | - |
493
+ | 0.8085 | 5500 | 2.3314 | - |
494
+ | 0.8158 | 5550 | 2.3357 | - |
495
+ | 0.8232 | 5600 | 2.3523 | - |
496
+ | 0.8305 | 5650 | 2.3253 | - |
497
+ | 0.8379 | 5700 | 2.3021 | - |
498
+ | 0.8452 | 5750 | 2.3342 | - |
499
+ | 0.8526 | 5800 | 2.2839 | - |
500
+ | 0.8599 | 5850 | 2.3136 | - |
501
+ | 0.8673 | 5900 | 2.3562 | - |
502
+ | 0.8746 | 5950 | 2.2878 | - |
503
+ | 0.8820 | 6000 | 2.3219 | 1.8173 |
504
+ | 0.8893 | 6050 | 2.2941 | - |
505
+ | 0.8967 | 6100 | 2.3245 | - |
506
+ | 0.9040 | 6150 | 2.2561 | - |
507
+ | 0.9114 | 6200 | 2.3327 | - |
508
+ | 0.9187 | 6250 | 2.3047 | - |
509
+ | 0.9261 | 6300 | 2.2916 | - |
510
+ | 0.9334 | 6350 | 2.3495 | - |
511
+ | 0.9408 | 6400 | 1.9273 | - |
512
+ | 0.9481 | 6450 | 1.3917 | - |
513
+ | 0.9555 | 6500 | 1.4726 | - |
514
+ | 0.9628 | 6550 | 1.3922 | - |
515
+ | 0.9702 | 6600 | 1.4664 | - |
516
+ | 0.9775 | 6650 | 1.4329 | - |
517
+ | 0.9849 | 6700 | 1.4046 | - |
518
+ | 0.9922 | 6750 | 1.3891 | - |
519
+ | 0.9996 | 6800 | 1.4731 | - |
520
+
521
+ </details>
522
+
523
+ ### Framework Versions
524
+ - Python: 3.11.13
525
+ - Sentence Transformers: 4.1.0
526
+ - Transformers: 4.53.2
527
+ - PyTorch: 2.6.0+cu124
528
+ - Accelerate: 1.8.1
529
+ - Datasets: 2.14.4
530
+ - Tokenizers: 0.21.2
531
+
532
+ ## Citation
533
+
534
+ ### BibTeX
535
+
536
+ #### Sentence Transformers
537
+ ```bibtex
538
+ @inproceedings{reimers-2019-sentence-bert,
539
+ title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
540
+ author = "Reimers, Nils and Gurevych, Iryna",
541
+ booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
542
+ month = "11",
543
+ year = "2019",
544
+ publisher = "Association for Computational Linguistics",
545
+ url = "https://arxiv.org/abs/1908.10084",
546
+ }
547
+ ```
548
+
549
+ #### MultipleNegativesRankingLoss
550
+ ```bibtex
551
+ @misc{henderson2017efficient,
552
+ title={Efficient Natural Language Response Suggestion for Smart Reply},
553
+ author={Matthew Henderson and Rami Al-Rfou and Brian Strope and Yun-hsuan Sung and Laszlo Lukacs and Ruiqi Guo and Sanjiv Kumar and Balint Miklos and Ray Kurzweil},
554
+ year={2017},
555
+ eprint={1705.00652},
556
+ archivePrefix={arXiv},
557
+ primaryClass={cs.CL}
558
+ }
559
+ ```
560
+
561
+ <!--
562
+ ## Glossary
563
+
564
+ *Clearly define terms in order to be accessible across audiences.*
565
+ -->
566
+
567
+ <!--
568
+ ## Model Card Authors
569
+
570
+ *Lists the people who create the model card, providing recognition and accountability for the detailed work that goes into its construction.*
571
+ -->
572
+
573
+ <!--
574
+ ## Model Card Contact
575
+
576
+ *Provides a way for people who have updates to the Model Card, suggestions, or questions, to contact the Model Card authors.*
577
+ -->
config.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "NewModel"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.0,
6
+ "auto_map": {
7
+ "AutoConfig": "configuration.NewConfig",
8
+ "AutoModel": "modeling.NewModel",
9
+ "AutoModelForMaskedLM": "Alibaba-NLP/new-impl--modeling.NewForMaskedLM",
10
+ "AutoModelForMultipleChoice": "Alibaba-NLP/new-impl--modeling.NewForMultipleChoice",
11
+ "AutoModelForQuestionAnswering": "Alibaba-NLP/new-impl--modeling.NewForQuestionAnswering",
12
+ "AutoModelForSequenceClassification": "Alibaba-NLP/new-impl--modeling.NewForSequenceClassification",
13
+ "AutoModelForTokenClassification": "Alibaba-NLP/new-impl--modeling.NewForTokenClassification"
14
+ },
15
+ "classifier_dropout": 0.0,
16
+ "hidden_act": "gelu",
17
+ "hidden_dropout_prob": 0.1,
18
+ "hidden_size": 768,
19
+ "id2label": {
20
+ "0": "LABEL_0"
21
+ },
22
+ "initializer_range": 0.02,
23
+ "intermediate_size": 3072,
24
+ "label2id": {
25
+ "LABEL_0": 0
26
+ },
27
+ "layer_norm_eps": 1e-12,
28
+ "layer_norm_type": "layer_norm",
29
+ "logn_attention_clip1": false,
30
+ "logn_attention_scale": false,
31
+ "max_position_embeddings": 8192,
32
+ "model_type": "new",
33
+ "num_attention_heads": 12,
34
+ "num_hidden_layers": 12,
35
+ "pack_qkv": true,
36
+ "pad_token_id": 1,
37
+ "position_embedding_type": "rope",
38
+ "rope_scaling": {
39
+ "factor": 8.0,
40
+ "type": "ntk"
41
+ },
42
+ "rope_theta": 20000,
43
+ "torch_dtype": "float32",
44
+ "transformers_version": "4.53.2",
45
+ "type_vocab_size": 1,
46
+ "unpad_inputs": false,
47
+ "use_memory_efficient_attention": false,
48
+ "vocab_size": 250048
49
+ }
config_sentence_transformers.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "__version__": {
3
+ "sentence_transformers": "4.1.0",
4
+ "transformers": "4.53.2",
5
+ "pytorch": "2.6.0+cu124"
6
+ },
7
+ "prompts": {},
8
+ "default_prompt_name": null,
9
+ "similarity_fn_name": "cosine"
10
+ }
configuration.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The GTE Team Authors and Alibaba Group.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ NEW model configuration"""
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.utils import logging
19
+
20
+ logger = logging.get_logger(__name__)
21
+
22
+
23
+ class NewConfig(PretrainedConfig):
24
+ r"""
25
+ This is the configuration class to store the configuration of a [`NewModel`] or a [`TFNewModel`]. It is used to
26
+ instantiate a NEW model according to the specified arguments, defining the model architecture. Instantiating a
27
+ configuration with the defaults will yield a similar configuration to that of the NEW
28
+ [izhx/new-base-en](https://huggingface.co/izhx/new-base-en) architecture.
29
+
30
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
31
+ documentation from [`PretrainedConfig`] for more information.
32
+
33
+
34
+ Args:
35
+ vocab_size (`int`, *optional*, defaults to 30522):
36
+ Vocabulary size of the NEW model. Defines the number of different tokens that can be represented by the
37
+ `inputs_ids` passed when calling [`NewModel`] or [`TFNewModel`].
38
+ hidden_size (`int`, *optional*, defaults to 768):
39
+ Dimensionality of the encoder layers and the pooler layer.
40
+ num_hidden_layers (`int`, *optional*, defaults to 12):
41
+ Number of hidden layers in the Transformer encoder.
42
+ num_attention_heads (`int`, *optional*, defaults to 12):
43
+ Number of attention heads for each attention layer in the Transformer encoder.
44
+ intermediate_size (`int`, *optional*, defaults to 3072):
45
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
46
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
47
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
48
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
49
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
50
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
51
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
52
+ The dropout ratio for the attention probabilities.
53
+ max_position_embeddings (`int`, *optional*, defaults to 512):
54
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
55
+ just in case (e.g., 512 or 1024 or 2048).
56
+ type_vocab_size (`int`, *optional*, defaults to 2):
57
+ The vocabulary size of the `token_type_ids` passed when calling [`NewModel`] or [`TFNewModel`].
58
+ initializer_range (`float`, *optional*, defaults to 0.02):
59
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
60
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
61
+ The epsilon used by the layer normalization layers.
62
+ position_embedding_type (`str`, *optional*, defaults to `"rope"`):
63
+ Type of position embedding. Choose one of `"absolute"`, `"rope"`.
64
+ rope_theta (`float`, *optional*, defaults to 10000.0):
65
+ The base period of the RoPE embeddings.
66
+ rope_scaling (`Dict`, *optional*):
67
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
68
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
69
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
70
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
71
+ these scaling strategies behave:
72
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
73
+ experimental feature, subject to breaking API changes in future versions.
74
+ classifier_dropout (`float`, *optional*):
75
+ The dropout ratio for the classification head.
76
+
77
+ Examples:
78
+
79
+ ```python
80
+ >>> from transformers import NewConfig, NewModel
81
+
82
+ >>> # Initializing a NEW izhx/new-base-en style configuration
83
+ >>> configuration = NewConfig()
84
+
85
+ >>> # Initializing a model (with random weights) from the izhx/new-base-en style configuration
86
+ >>> model = NewModel(configuration)
87
+
88
+ >>> # Accessing the model configuration
89
+ >>> configuration = model.config
90
+ ```"""
91
+
92
+ model_type = "new"
93
+
94
+ def __init__(
95
+ self,
96
+ vocab_size=30528,
97
+ hidden_size=768,
98
+ num_hidden_layers=12,
99
+ num_attention_heads=12,
100
+ intermediate_size=3072,
101
+ hidden_act="gelu",
102
+ hidden_dropout_prob=0.1,
103
+ attention_probs_dropout_prob=0.0,
104
+ max_position_embeddings=2048,
105
+ type_vocab_size=1,
106
+ initializer_range=0.02,
107
+ layer_norm_type='layer_norm',
108
+ layer_norm_eps=1e-12,
109
+ # pad_token_id=0,
110
+ position_embedding_type="rope",
111
+ rope_theta=10000.0,
112
+ rope_scaling=None,
113
+ classifier_dropout=None,
114
+ pack_qkv=True,
115
+ unpad_inputs=False,
116
+ use_memory_efficient_attention=False,
117
+ logn_attention_scale=False,
118
+ logn_attention_clip1=False,
119
+ **kwargs,
120
+ ):
121
+ super().__init__(**kwargs)
122
+
123
+ self.vocab_size = vocab_size
124
+ self.hidden_size = hidden_size
125
+ self.num_hidden_layers = num_hidden_layers
126
+ self.num_attention_heads = num_attention_heads
127
+ self.hidden_act = hidden_act
128
+ self.intermediate_size = intermediate_size
129
+ self.hidden_dropout_prob = hidden_dropout_prob
130
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
131
+ self.max_position_embeddings = max_position_embeddings
132
+ self.type_vocab_size = type_vocab_size
133
+ self.initializer_range = initializer_range
134
+ self.layer_norm_type = layer_norm_type
135
+ self.layer_norm_eps = layer_norm_eps
136
+ self.position_embedding_type = position_embedding_type
137
+ self.rope_theta = rope_theta
138
+ self.rope_scaling = rope_scaling
139
+ self.classifier_dropout = classifier_dropout
140
+
141
+ self.pack_qkv = pack_qkv
142
+ self.unpad_inputs = unpad_inputs
143
+ self.use_memory_efficient_attention = use_memory_efficient_attention
144
+ self.logn_attention_scale = logn_attention_scale
145
+ self.logn_attention_clip1 = logn_attention_clip1
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2aeef5d178ff139843493689b25190a5da682613f294102a88051ed03888ea60
3
+ size 1221487872
modeling.py ADDED
@@ -0,0 +1,1418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The GTE Team Authors and Alibaba Group.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch NEW model."""
17
+
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+
26
+ from transformers.activations import ACT2FN
27
+ from transformers.modeling_outputs import (
28
+ BaseModelOutput,
29
+ BaseModelOutputWithPooling,
30
+ MaskedLMOutput,
31
+ MultipleChoiceModelOutput,
32
+ QuestionAnsweringModelOutput,
33
+ SequenceClassifierOutput,
34
+ ModelOutput,
35
+ )
36
+ from transformers.modeling_utils import PreTrainedModel
37
+ from transformers.utils import logging
38
+
39
+ try:
40
+ import xformers.ops as xops
41
+ except ImportError as e:
42
+ xops = None
43
+
44
+ from .configuration import NewConfig
45
+
46
+
47
+ logger = logging.get_logger(__name__)
48
+
49
+
50
+ # Adapted from https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
51
+ # Which was adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
52
+ class IndexFirstAxis(torch.autograd.Function):
53
+ @staticmethod
54
+ def forward(ctx, input, indices):
55
+ ctx.save_for_backward(indices)
56
+ assert input.ndim >= 2
57
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
58
+ second_dim = other_shape.numel()
59
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
60
+ # return input[indices]
61
+ # return torch.gather(
62
+ # rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
63
+ # ).reshape(-1, *other_shape)
64
+ return torch.gather(
65
+ input.view(ctx.first_axis_dim, second_dim),
66
+ 0,
67
+ indices.unsqueeze(-1).expand(indices.size(0), second_dim)
68
+ ).reshape(-1, *other_shape)
69
+
70
+ @staticmethod
71
+ def backward(ctx, grad_output):
72
+ (indices,) = ctx.saved_tensors
73
+ assert grad_output.ndim >= 2
74
+ other_shape = grad_output.shape[1:]
75
+ # grad_output = rearrange(grad_output, "b ... -> b (...)")
76
+ grad_output = grad_output.view(grad_output.size(0), other_shape.numel())
77
+ grad_input = torch.zeros(
78
+ [ctx.first_axis_dim, grad_output.shape[1]],
79
+ device=grad_output.device,
80
+ dtype=grad_output.dtype,
81
+ )
82
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
83
+ # grad_input[indices] = grad_output
84
+ # grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
85
+ grad_input.scatter_(
86
+ 0, indices.unsqueeze(-1).expand(indices.size(0), grad_output.size(1)), grad_output
87
+ )
88
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
89
+
90
+
91
+ index_first_axis = IndexFirstAxis.apply
92
+
93
+
94
+ def unpad_input(hidden_states, attention_mask=None, indices=None):
95
+ """
96
+ Arguments:
97
+ hidden_states: (batch, seqlen, ...)
98
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
99
+ indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
100
+ Return:
101
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
102
+ """
103
+ if indices is None:
104
+ assert attention_mask is not None
105
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
106
+
107
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
108
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
109
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
110
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
111
+ # so we write custom forward and backward to make it a bit faster.
112
+ hidden_states = hidden_states.view(-1, *hidden_states.shape[2:])
113
+ return index_first_axis(hidden_states, indices)
114
+
115
+
116
+ class IndexPutFirstAxis(torch.autograd.Function):
117
+ @staticmethod
118
+ def forward(
119
+ ctx,
120
+ values: torch.Tensor,
121
+ indices: torch.Tensor,
122
+ first_axis_dim
123
+ ) -> torch.Tensor:
124
+ ctx.save_for_backward(indices)
125
+ assert indices.ndim == 1
126
+ assert values.ndim >= 2
127
+ output = torch.zeros(
128
+ first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
129
+ )
130
+ output[indices] = values
131
+ return output
132
+
133
+ @staticmethod
134
+ def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]:
135
+ indices, = ctx.saved_tensors
136
+ grad_values = grad_output[indices]
137
+ return grad_values, None, None
138
+
139
+
140
+ index_put_first_axis = IndexPutFirstAxis.apply
141
+
142
+
143
+ def pad_input(inputs: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int) -> torch.Tensor:
144
+ """Add padding to sequences.
145
+
146
+ Arguments:
147
+ inputs: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
148
+ indices: (total_nnz), `indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()`
149
+ batch: int batch_size
150
+ seqlen: int max sequence length
151
+
152
+ Returns:
153
+ inputs: (batch, seqlen, ...)
154
+ """
155
+ output = index_put_first_axis(inputs, indices, batch * seqlen)
156
+ return output.view(batch, seqlen, *inputs.shape[1:])
157
+
158
+
159
+ def rotate_half(x):
160
+ """Rotates half the hidden dims of the input."""
161
+ x1 = x[..., : x.shape[-1] // 2]
162
+ x2 = x[..., x.shape[-1] // 2 :]
163
+ return torch.cat((-x2, x1), dim=-1)
164
+
165
+
166
+ def apply_rotary_pos_emb(q, k, cos, sin):
167
+ """Applies Rotary Position Embedding to the query and key tensors.
168
+
169
+ Args:
170
+ q (`torch.Tensor`): The query tensor.
171
+ k (`torch.Tensor`): The key tensor.
172
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
173
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
174
+ Returns:
175
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
176
+ """
177
+ cos, sin = cos.to(q.dtype), sin.to(q.dtype)
178
+ q_embed = (q * cos) + (rotate_half(q) * sin)
179
+ k_embed = (k * cos) + (rotate_half(k) * sin)
180
+ return q_embed, k_embed
181
+
182
+
183
+ class RotaryEmbedding(torch.nn.Module):
184
+ def __init__(self, dim, max_position_embeddings=512, base=10000.0, device=None):
185
+ super().__init__()
186
+
187
+ self.dim = dim
188
+ self.max_position_embeddings = max_position_embeddings
189
+ self.base = base
190
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
191
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
192
+
193
+ # Build here to make `torch.jit.trace` work.
194
+ self._set_cos_sin_cache(
195
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
196
+ )
197
+
198
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
199
+ self.max_seq_len_cached = seq_len
200
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
201
+
202
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
203
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
204
+ emb = torch.cat((freqs, freqs), dim=-1)
205
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
206
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
207
+
208
+ def forward(self, x, seq_len=None):
209
+ # x: [bs, num_attention_heads, seq_len, head_size]
210
+ if seq_len > self.max_seq_len_cached:
211
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
212
+
213
+ return (
214
+ self.cos_cached[:seq_len, ...].to(dtype=x.dtype),
215
+ self.sin_cached[:seq_len, ...].to(dtype=x.dtype),
216
+ )
217
+
218
+
219
+ class NTKScalingRotaryEmbedding(RotaryEmbedding):
220
+ """RotaryEmbedding extended with fixed and mixed NTK scaling. https://kexue.fm/archives/9706 """
221
+
222
+ def __init__(self, dim, max_position_embeddings=512, base=10000, device=None, scaling_factor=1.0, mixed_b=None):
223
+ self.scaling_factor = scaling_factor
224
+ self.mixed_b = mixed_b
225
+ super().__init__(dim, max_position_embeddings, base, device)
226
+ max_position_embeddings = max_position_embeddings * self.scaling_factor
227
+ self._set_cos_sin_cache(max_position_embeddings, self.inv_freq.device, torch.get_default_dtype())
228
+
229
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
230
+ self.max_seq_len_cached = seq_len
231
+
232
+ if seq_len > self.max_position_embeddings:
233
+ base = self.base * (self.scaling_factor if self.mixed_b is None else 1)
234
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
235
+
236
+ if self.mixed_b is None:
237
+ inv_freq = inv_freq / self.scaling_factor ** (2 / self.dim) # (6)
238
+ else:
239
+ a = torch.tensor(self.scaling_factor).log() / (self.dim / 2) ** self.mixed_b # (13)
240
+ lambda_1_m = (a * torch.arange(1, self.dim // 2 + 1).float().to(device) ** self.mixed_b).exp() # (12)
241
+ inv_freq = inv_freq / lambda_1_m # (10)
242
+
243
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
244
+
245
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
246
+
247
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
248
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
249
+ emb = torch.cat((freqs, freqs), dim=-1)
250
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
251
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
252
+
253
+
254
+ class RMSNorm(nn.Module):
255
+ def __init__(self, hidden_size, eps=1e-6):
256
+ """
257
+ RMSNorm is equivalent to T5LayerNorm
258
+ """
259
+ super().__init__()
260
+ self.weight = nn.Parameter(torch.ones(hidden_size))
261
+ self.variance_epsilon = eps
262
+
263
+ def forward(self, hidden_states):
264
+ input_dtype = hidden_states.dtype
265
+ hidden_states = hidden_states.to(torch.float32)
266
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
267
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
268
+ return self.weight * hidden_states.to(input_dtype)
269
+
270
+
271
+ LAYER_NORM = {
272
+ 'layer_norm': nn.LayerNorm,
273
+ 'rms_norm': RMSNorm
274
+ }
275
+
276
+
277
+ class NewEmbeddings(nn.Module):
278
+ """
279
+ Embedding and Unpadding.
280
+ """
281
+
282
+ def __init__(self, config: NewConfig):
283
+ super().__init__()
284
+ self.padding_idx = config.pad_token_id
285
+ self.word_embeddings = nn.Embedding(
286
+ config.vocab_size, config.hidden_size, padding_idx=self.padding_idx
287
+ )
288
+
289
+ self.position_embedding_type = config.position_embedding_type
290
+ if self.position_embedding_type == 'absolute':
291
+ self.position_embeddings = nn.Embedding(
292
+ config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
293
+ )
294
+ elif self.position_embedding_type == 'rope':
295
+ self._init_rope(config)
296
+ else:
297
+ raise ValueError
298
+
299
+ self.type_vocab_size = config.type_vocab_size
300
+ if self.type_vocab_size > 0:
301
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
302
+
303
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
304
+ # any TensorFlow checkpoint file
305
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
306
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
307
+ # position_ids is contiguous in memory and excluded when serialized
308
+ self.register_buffer(
309
+ "position_ids", torch.arange(config.max_position_embeddings), persistent=False
310
+ )
311
+
312
+ def _init_rope(self, config):
313
+ kwargs = dict(
314
+ dim=int(config.hidden_size / config.num_attention_heads),
315
+ max_position_embeddings=config.max_position_embeddings,
316
+ base=config.rope_theta
317
+ )
318
+ if config.rope_scaling is None:
319
+ self.rotary_emb = RotaryEmbedding(**kwargs)
320
+ else:
321
+ kwargs.update(scaling_factor=config.rope_scaling["factor"])
322
+ scaling_type = config.rope_scaling["type"]
323
+ if scaling_type == 'ntk':
324
+ kwargs.update(mixed_b=config.rope_scaling.get('mixed_b', None))
325
+ self.rotary_emb = NTKScalingRotaryEmbedding(**kwargs)
326
+ # elif scaling_type == "linear":
327
+ # self.rotary_emb = LinearScalingRotaryEmbedding(**kwargs)
328
+ # elif scaling_type == "dynamic":
329
+ # self.rotary_emb = DynamicNTKScalingRotaryEmbedding(**kwargs)
330
+ else:
331
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
332
+
333
+ def forward(
334
+ self,
335
+ unpad_inputs: bool,
336
+ input_ids: Optional[torch.Tensor] = None,
337
+ attention_mask: Optional[torch.Tensor] = None,
338
+ length: Optional[List[int]] = None,
339
+ token_type_ids: Optional[torch.Tensor] = None,
340
+ position_ids: Optional[torch.Tensor] = None,
341
+ inputs_embeds: Optional[torch.Tensor] = None,
342
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[Tuple], Optional[List[int]]]:
343
+ """
344
+ """
345
+ if inputs_embeds is None:
346
+ device, input_shape = input_ids.device, input_ids.shape
347
+ else:
348
+ device, input_shape = inputs_embeds.device, inputs_embeds.shape[:2]
349
+ batch_size, seq_length = input_shape
350
+
351
+ # Set attention_mask if it's None
352
+ if attention_mask is None:
353
+ attention_mask = torch.ones(input_shape, device=device)
354
+ if length is not None:
355
+ for i, l in enumerate(length):
356
+ attention_mask[i, l:] = 0
357
+
358
+ # Set attention_mask_bool for unpadding
359
+ if unpad_inputs:
360
+ attention_mask_bool = attention_mask.bool()
361
+ if length is None:
362
+ length = attention_mask.sum(-1).tolist()
363
+
364
+ # Get word embeddings
365
+ if inputs_embeds is None:
366
+ if unpad_inputs:
367
+ input_ids = input_ids[attention_mask_bool].unsqueeze(0)
368
+ inputs_embeds = self.word_embeddings(input_ids)
369
+ else:
370
+ if unpad_inputs:
371
+ inputs_embeds = inputs_embeds[attention_mask_bool].unsqueeze(0)
372
+ embeddings = inputs_embeds
373
+
374
+ # Set and unpad position_ids
375
+ if position_ids is None:
376
+ if seq_length > self.position_ids.size(0):
377
+ self.register_buffer(
378
+ "position_ids", torch.arange(seq_length, device=embeddings.device), persistent=False
379
+ )
380
+ if unpad_inputs:
381
+ # [1, cumsum_seq_len]
382
+ position_ids = torch.cat([self.position_ids[:l] for l in length]).unsqueeze(0)
383
+ else:
384
+ # [bs, seq_len]
385
+ position_ids = self.position_ids[:seq_length].expand(batch_size, -1)
386
+ elif unpad_inputs:
387
+ position_ids = position_ids[attention_mask_bool].unsqueeze(0) # [1, cumsum_seq_len]
388
+
389
+ # Compute rotary embedding
390
+ if self.position_embedding_type == 'rope':
391
+ rope_cos, rope_sin = self.rotary_emb(inputs_embeds, seq_len=seq_length)
392
+ rope_cos = rope_cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
393
+ rope_sin = rope_sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
394
+ rope_embeds = rope_cos, rope_sin
395
+ else:
396
+ rope_embeds = None
397
+
398
+ if self.type_vocab_size > 0:
399
+ if token_type_ids is None:
400
+ token_type_ids = position_ids.mul(0)
401
+ else:
402
+ if self.type_vocab_size < 2:
403
+ token_type_ids.mul_(0)
404
+ if unpad_inputs:
405
+ token_type_ids = token_type_ids[attention_mask_bool].unsqueeze(0)
406
+
407
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
408
+ embeddings = embeddings + token_type_embeddings
409
+
410
+ # BERT position
411
+ if self.position_embedding_type == "absolute":
412
+ position_embeddings = self.position_embeddings(position_ids)
413
+ embeddings = embeddings + position_embeddings
414
+
415
+ embeddings = self.LayerNorm(embeddings)
416
+ embeddings = self.dropout(embeddings)
417
+
418
+ return embeddings, attention_mask, rope_embeds, length
419
+
420
+
421
+ class NewAttention(nn.Module):
422
+ def __init__(self, config: NewConfig, pack_qkv=None, use_memory_efficient_attention=None):
423
+ super().__init__()
424
+ self.config = config
425
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
426
+ raise ValueError(
427
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
428
+ f"heads ({config.num_attention_heads})"
429
+ )
430
+
431
+ self.hidden_size = config.hidden_size
432
+ self.num_attention_heads = config.num_attention_heads
433
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
434
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
435
+
436
+ if pack_qkv is None:
437
+ pack_qkv = config.pack_qkv
438
+ self.pack_qkv = pack_qkv
439
+
440
+ if self.pack_qkv:
441
+ self.qkv_proj = nn.Linear(config.hidden_size, self.all_head_size * 3, bias=True)
442
+ else:
443
+ self.q_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
444
+ self.k_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
445
+ self.v_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
446
+
447
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
448
+ self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=True)
449
+
450
+ if use_memory_efficient_attention is None:
451
+ use_memory_efficient_attention = self.config.use_memory_efficient_attention
452
+ self.use_memory_efficient_attention = use_memory_efficient_attention
453
+ self.memory_efficient_attention = None if xops is None else xops.memory_efficient_attention
454
+ if self.use_memory_efficient_attention:
455
+ assert self.memory_efficient_attention is not None, 'please install xformers'
456
+
457
+ def forward(
458
+ self,
459
+ hidden_states: torch.Tensor,
460
+ attention_bias: torch.FloatTensor,
461
+ rope_embeds: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None,
462
+ padding_inputs: Optional[Tuple] = None, # indices, batch, seqlen
463
+ attention_scale: Optional[torch.FloatTensor] = None,
464
+ head_mask: Optional[torch.FloatTensor] = None,
465
+ output_attentions: Optional[bool] = False,
466
+ qkv_inputs: Optional[Tuple] = None, # For RetroMAE
467
+ ) -> Tuple[torch.Tensor, ...]:
468
+ shape_hd = (self.num_attention_heads, self.attention_head_size)
469
+ # qkv
470
+ if self.pack_qkv and qkv_inputs is None:
471
+ qkv_pack = self.qkv_proj(hidden_states).split(self.all_head_size, dim=-1)
472
+ else:
473
+ if qkv_inputs is None:
474
+ qkv_inputs = (hidden_states, hidden_states, hidden_states)
475
+ qkv_pack = [
476
+ getattr(self, n + '_proj')(s) for s, n in zip(qkv_inputs, 'qkv')
477
+ ]
478
+ query_states, key_states, value_states = [t.view(t.shape[:-1] + shape_hd) for t in qkv_pack]
479
+
480
+ if self.config.position_embedding_type == 'rope':
481
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, *rope_embeds)
482
+
483
+ dtype = query_states.dtype
484
+
485
+ if self.config.logn_attention_scale and attention_scale is not None:
486
+ # https://kexue.fm/archives/8823
487
+ query_states = query_states * attention_scale.to(dtype)
488
+
489
+ if padding_inputs is not None:
490
+ query_states = pad_input(query_states.squeeze(), *padding_inputs)
491
+ key_states = pad_input(key_states.squeeze(), *padding_inputs)
492
+ value_states = pad_input(value_states.squeeze(), *padding_inputs)
493
+
494
+ if self.use_memory_efficient_attention:
495
+ assert self.memory_efficient_attention is not None, "xformers is not loaded"
496
+ assert output_attentions is False, "memory_efficient_attention do not output attentions"
497
+ assert head_mask is None, "Not support yet"
498
+ attention_probs = None
499
+ if torch.is_tensor(attention_bias):
500
+ attention_bias = attention_bias.to(dtype)
501
+ context_layer = self.memory_efficient_attention(
502
+ query_states,
503
+ key_states,
504
+ value_states,
505
+ attn_bias=attention_bias,
506
+ p=self.dropout.p
507
+ )
508
+ else:
509
+ if output_attentions and isinstance(self, NewSdpaAttention):
510
+ raise RuntimeError("SDPA do not output attentions")
511
+ context_layer, attention_probs = self._attention(
512
+ query_states, key_states, value_states, attention_bias, head_mask
513
+ )
514
+
515
+ if padding_inputs is not None:
516
+ context_layer = unpad_input(context_layer, indices=padding_inputs[0])
517
+
518
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
519
+ context_layer = context_layer.view(new_context_layer_shape)
520
+
521
+ # output proj
522
+ attn_output = self.o_proj(context_layer)
523
+
524
+ # add attentions if we output them
525
+ outputs = (attn_output, attention_probs) if output_attentions else (attn_output,)
526
+ return outputs
527
+
528
+ def _attention(self, query_states, key_states, value_states, attention_bias, head_mask):
529
+ """
530
+ Args:
531
+ q/k/v: (B, L, n_head, head_dim),
532
+ Returns:
533
+ attn_output: (B L, n_head, head_dim)
534
+ """
535
+ query_states = query_states.transpose(1, 2)
536
+ key_states = key_states.transpose(1, 2)
537
+ value_states = value_states.transpose(1, 2)
538
+ # Take the dot product between "query" and "key" to get the raw attention scores.
539
+ attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
540
+
541
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
542
+ if attention_bias is not None:
543
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
544
+ attention_scores = attention_scores + attention_bias
545
+
546
+ # Normalize the attention scores to probabilities.
547
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
548
+
549
+ # This is actually dropping out entire tokens to attend to, which might
550
+ # seem a bit unusual, but is taken from the original Transformer paper.
551
+ if self.dropout.p > 0:
552
+ attention_probs = self.dropout(attention_probs)
553
+
554
+ # Mask heads if we want to
555
+ if head_mask is not None:
556
+ attention_probs = attention_probs * head_mask
557
+
558
+ context_layer = torch.matmul(attention_probs, value_states)
559
+
560
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
561
+ return context_layer, attention_probs
562
+
563
+
564
+ class NewSdpaAttention(NewAttention):
565
+ """
566
+ New attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
567
+ `NewAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
568
+ SDPA API.
569
+ """
570
+ def __init__(self, config: NewConfig, **kwargs):
571
+ super().__init__(config, **kwargs)
572
+ # torch.backends.cuda.enable_mem_efficient_sdp(False)
573
+ # logger.warning(
574
+ # "Disable memory efficient attention kernel for `NewSdpaAttention`, you can set "
575
+ # "`use_memory_efficient_attention=True` if it expected to use."
576
+ # )
577
+
578
+ def _attention(self, query_states, key_states, value_states, attention_bias, head_mask):
579
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
580
+ query_states.transpose(1, 2),
581
+ key_states.transpose(1, 2),
582
+ value_states.transpose(1, 2),
583
+ attn_mask=attention_bias,
584
+ dropout_p=self.dropout.p if self.training else 0.0,
585
+ )
586
+ attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
587
+ return attn_output, None
588
+
589
+
590
+ NEW_ATTENTION_CLASSES = {
591
+ "eager": NewAttention,
592
+ # "flash_attention_2": , # TODO
593
+ "sdpa": NewSdpaAttention,
594
+ }
595
+
596
+
597
+ class NewGatedMLP(nn.Module):
598
+ """
599
+ GLU Variants Improve Transformer.
600
+ """
601
+
602
+ def __init__(self, config: NewConfig):
603
+ super().__init__()
604
+ self.intermediate_size = config.intermediate_size
605
+ self.up_gate_proj = nn.Linear(config.hidden_size, self.intermediate_size * 2, bias=False)
606
+ self.down_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=True)
607
+ self.act_fn = ACT2FN[config.hidden_act]
608
+ if config.hidden_dropout_prob > 0:
609
+ self.hidden_dropout = nn.Dropout(config.hidden_dropout_prob)
610
+ else:
611
+ self.hidden_dropout = None
612
+
613
+ def forward(self, hidden_states):
614
+ up_gate = self.up_gate_proj(hidden_states)
615
+ up_states, gate = torch.split(up_gate, self.intermediate_size, dim=-1)
616
+ gate = self.act_fn(gate)
617
+ gated_states = gate * up_states
618
+ if self.hidden_dropout is not None:
619
+ gated_states = self.hidden_dropout(gated_states)
620
+ down_states = self.down_proj(gated_states)
621
+ return down_states
622
+
623
+
624
+ class NewLayer(nn.Module):
625
+ def __init__(
626
+ self,
627
+ config: NewConfig,
628
+ pack_qkv=None,
629
+ use_memory_efficient_attention=None,
630
+ attn_implementation=None
631
+ ):
632
+ super().__init__()
633
+ if attn_implementation is None:
634
+ attn_implementation = config._attn_implementation
635
+ if use_memory_efficient_attention is None:
636
+ use_memory_efficient_attention = config.use_memory_efficient_attention
637
+ if use_memory_efficient_attention:
638
+ if attn_implementation != 'eager':
639
+ logger.warning_once(f"Override {attn_implementation=} to 'eager' as {use_memory_efficient_attention=}")
640
+ attn_implementation = 'eager' # Since it will be SDPA by default for torch>=2.1.1
641
+ self.attention = NEW_ATTENTION_CLASSES[attn_implementation](
642
+ config, pack_qkv=pack_qkv, use_memory_efficient_attention=use_memory_efficient_attention
643
+ )
644
+ self.mlp = NewGatedMLP(config)
645
+
646
+ ln_class = LAYER_NORM[config.layer_norm_type]
647
+ self.attn_ln = ln_class(config.hidden_size, eps=config.layer_norm_eps)
648
+ self.mlp_ln = ln_class(config.hidden_size, eps=config.layer_norm_eps)
649
+
650
+ if config.hidden_dropout_prob > 0:
651
+ self.hidden_dropout = nn.Dropout(config.hidden_dropout_prob)
652
+ else:
653
+ self.hidden_dropout = None
654
+
655
+ def forward(
656
+ self,
657
+ hidden_states: torch.Tensor,
658
+ attention_bias: torch.FloatTensor,
659
+ rope_embeds: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None,
660
+ padding_inputs: Optional[Tuple] = None, # indices, batch, seqlen
661
+ attention_scale: Optional[torch.FloatTensor] = None,
662
+ subset_indices: Optional[torch.LongTensor] = None,
663
+ head_mask: Optional[torch.FloatTensor] = None,
664
+ output_attentions: Optional[bool] = False,
665
+ qkv_inputs: Optional[Tuple] = None, # For RetroMAE
666
+ ) -> Tuple[torch.Tensor, ...]:
667
+ # Multi head self attention
668
+ residual = hidden_states if qkv_inputs is None else qkv_inputs[0]
669
+ attention_outputs = self.attention(
670
+ hidden_states,
671
+ attention_bias,
672
+ rope_embeds,
673
+ padding_inputs,
674
+ attention_scale,
675
+ head_mask,
676
+ output_attentions=output_attentions,
677
+ qkv_inputs=qkv_inputs,
678
+ )
679
+ hidden_states = attention_outputs[0]
680
+ if self.hidden_dropout is not None:
681
+ hidden_states = self.hidden_dropout(hidden_states)
682
+ hidden_states = residual + hidden_states
683
+
684
+ # In pretraining, after the attention of last layer, we only need the masked tokens.
685
+ if subset_indices is not None:
686
+ hidden_states = hidden_states[subset_indices]
687
+
688
+ hidden_states = self.attn_ln(hidden_states)
689
+
690
+ # Fully Connected
691
+ residual = hidden_states
692
+ hidden_states = self.mlp(hidden_states)
693
+ if self.hidden_dropout is not None:
694
+ hidden_states = self.hidden_dropout(hidden_states)
695
+ hidden_states = residual + hidden_states
696
+ hidden_states = self.mlp_ln(hidden_states)
697
+
698
+ # add self attentions if we output attention weights
699
+ outputs = (hidden_states,) + attention_outputs[1:]
700
+ return outputs
701
+
702
+
703
+ class NewEncoder(nn.Module):
704
+ def __init__(self, config):
705
+ super().__init__()
706
+ self.config = config
707
+ self.layer = nn.ModuleList([NewLayer(config) for _ in range(config.num_hidden_layers)])
708
+ self.gradient_checkpointing = False
709
+
710
+ def forward(
711
+ self,
712
+ hidden_states: torch.Tensor,
713
+ attention_bias: Optional[torch.FloatTensor] = None,
714
+ rope_embeds: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None,
715
+ padding_inputs: Optional[Tuple] = None, # indices, batch, seqlen
716
+ attention_scale: Optional[torch.FloatTensor] = None,
717
+ subset_indices: Optional[torch.LongTensor] = None,
718
+ head_mask: Optional[torch.FloatTensor] = None,
719
+ output_attentions: Optional[bool] = False,
720
+ output_hidden_states: Optional[bool] = False,
721
+ return_dict: Optional[bool] = True,
722
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
723
+ all_hidden_states = () if output_hidden_states else None
724
+ all_self_attentions = () if output_attentions else None
725
+
726
+ for i, layer_module in enumerate(self.layer):
727
+ if output_hidden_states:
728
+ all_hidden_states = all_hidden_states + (hidden_states,)
729
+
730
+ if i >= len(self.layer) - 1:
731
+ layer_subset_indices = subset_indices
732
+ else:
733
+ layer_subset_indices = None
734
+
735
+ layer_head_mask = head_mask[i] if head_mask is not None else None
736
+
737
+ if self.gradient_checkpointing and self.training:
738
+ layer_outputs = self._gradient_checkpointing_func(
739
+ layer_module.__call__,
740
+ hidden_states,
741
+ attention_bias,
742
+ rope_embeds,
743
+ padding_inputs,
744
+ attention_scale,
745
+ layer_subset_indices,
746
+ layer_head_mask,
747
+ )
748
+ else:
749
+ layer_outputs = layer_module(
750
+ hidden_states,
751
+ attention_bias,
752
+ rope_embeds,
753
+ padding_inputs,
754
+ attention_scale,
755
+ layer_subset_indices,
756
+ layer_head_mask,
757
+ output_attentions,
758
+ )
759
+
760
+ hidden_states = layer_outputs[0]
761
+ if output_attentions:
762
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
763
+
764
+ if output_hidden_states:
765
+ all_hidden_states = all_hidden_states + (hidden_states,)
766
+
767
+ if not return_dict:
768
+ return tuple(
769
+ v
770
+ for v in [
771
+ hidden_states,
772
+ all_hidden_states,
773
+ all_self_attentions,
774
+ ]
775
+ if v is not None
776
+ )
777
+ return BaseModelOutput(
778
+ last_hidden_state=hidden_states,
779
+ hidden_states=all_hidden_states,
780
+ attentions=all_self_attentions,
781
+ )
782
+
783
+
784
+ # Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->New
785
+ class NewPooler(nn.Module):
786
+ def __init__(self, config):
787
+ super().__init__()
788
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
789
+ self.activation = nn.Tanh()
790
+
791
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
792
+ # We "pool" the model by simply taking the hidden state corresponding
793
+ # to the first token.
794
+ first_token_tensor = hidden_states[:, 0]
795
+ pooled_output = self.dense(first_token_tensor)
796
+ pooled_output = self.activation(pooled_output)
797
+ return pooled_output
798
+
799
+
800
+ class NewPreTrainedModel(PreTrainedModel):
801
+ """
802
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
803
+ models.
804
+ """
805
+
806
+ config_class = NewConfig
807
+ base_model_prefix = "new"
808
+ supports_gradient_checkpointing = True
809
+ _supports_sdpa = True
810
+
811
+ def _init_weights(self, module):
812
+ """Initialize the weights"""
813
+ if isinstance(module, nn.Linear):
814
+ # Slightly different from the TF version which uses truncated_normal for initialization
815
+ # cf https://github.com/pytorch/pytorch/pull/5617
816
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
817
+ if module.bias is not None:
818
+ module.bias.data.zero_()
819
+ elif isinstance(module, nn.Embedding):
820
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
821
+ if module.padding_idx is not None:
822
+ module.weight.data[module.padding_idx].zero_()
823
+ elif isinstance(module, nn.LayerNorm):
824
+ module.bias.data.zero_()
825
+ module.weight.data.fill_(1.0)
826
+
827
+
828
+ class NewModel(NewPreTrainedModel):
829
+ """
830
+ The bare New Model transformer outputting raw hidden-states without any specific head on top.
831
+ """
832
+
833
+ def __init__(self, config: NewConfig, add_pooling_layer=False):
834
+ super().__init__(config)
835
+ self.config = config
836
+
837
+ self.embeddings = NewEmbeddings(config)
838
+ self.encoder = NewEncoder(config)
839
+
840
+ self.pooler = NewPooler(config) if add_pooling_layer else None
841
+
842
+ # Initialize weights and apply final processing
843
+ self.post_init()
844
+
845
+ def get_input_embeddings(self):
846
+ return self.embeddings.word_embeddings
847
+
848
+ def set_input_embeddings(self, value):
849
+ self.embeddings.word_embeddings = value
850
+
851
+ def forward(
852
+ self,
853
+ input_ids: Optional[torch.Tensor] = None,
854
+ attention_mask: Optional[torch.Tensor] = None,
855
+ length: Optional[List[int]] = None,
856
+ subset_indices: Optional[torch.LongTensor] = None,
857
+ token_type_ids: Optional[torch.Tensor] = None,
858
+ position_ids: Optional[torch.Tensor] = None,
859
+ head_mask: Optional[torch.Tensor] = None,
860
+ inputs_embeds: Optional[torch.Tensor] = None,
861
+ output_attentions: Optional[bool] = None,
862
+ output_hidden_states: Optional[bool] = None,
863
+ return_dict: Optional[bool] = None,
864
+ unpad_inputs: Optional[bool] = None,
865
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPooling]:
866
+ r"""
867
+ length (`list` of length `batch_size`, *optional*):
868
+ If is `None`, return padded `last_hidden_state`.
869
+ subset_indices ():
870
+ pass
871
+ unpad_inputs (`bool`, *optional*):
872
+ pass
873
+ """
874
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
875
+ output_hidden_states = (
876
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
877
+ )
878
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
879
+ unpad_inputs = unpad_inputs if unpad_inputs is not None else self.config.unpad_inputs
880
+ output_padded = length is None
881
+
882
+ if input_ids is not None and inputs_embeds is not None:
883
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
884
+ elif input_ids is not None:
885
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
886
+ input_shape = input_ids.size()
887
+ elif inputs_embeds is not None:
888
+ input_shape = inputs_embeds.size()[:-1]
889
+ else:
890
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
891
+
892
+ # TODO: not used
893
+ # # Prepare head mask if needed
894
+ # # 1.0 in head_mask indicate we keep the head
895
+ # # attention_probs has shape bsz x n_heads x N x N
896
+ # # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
897
+ # # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
898
+ # head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
899
+
900
+ # Get embeddings, may unpad them
901
+ (embedding_output, attention_mask, rope_embeds, length) = self.embeddings(
902
+ unpad_inputs,
903
+ input_ids=input_ids,
904
+ attention_mask=attention_mask,
905
+ length=length,
906
+ token_type_ids=token_type_ids,
907
+ position_ids=position_ids,
908
+ inputs_embeds=inputs_embeds
909
+ )
910
+
911
+ batch_size, seq_length = input_shape
912
+ if unpad_inputs and self.config.use_memory_efficient_attention:
913
+ attention_bias = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(length)
914
+ else:
915
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
916
+ # ourselves in which case we just need to make it broadcastable to all heads.
917
+ attention_bias = self.get_extended_attention_mask(attention_mask, input_shape)
918
+ if self.config.use_memory_efficient_attention:
919
+ # Invalid shape for attention bias: torch.Size([48, 1, 1, 512]) (expected (48, 12, 512, 512))
920
+ attention_bias = attention_bias.expand(-1, self.config.num_attention_heads, seq_length, -1)
921
+
922
+ padding_inputs = None
923
+ if unpad_inputs and (output_padded or not self.config.use_memory_efficient_attention):
924
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
925
+ if not self.config.use_memory_efficient_attention:
926
+ padding_inputs = (indices, *input_shape)
927
+
928
+ attention_scale = None
929
+ if self.config.logn_attention_scale:
930
+ logger.warning_once("TODO: logn_attention_scale")
931
+ # # attention scale log_512(input_len)
932
+ # attention_scale = attention_mask.sum(1).log() / torch.tensor(self.config.max_position_embeddings).log()
933
+ # # inference-time logn scale need clip 1
934
+ # if self.config.logn_attention_clip1:
935
+ # attention_scale.clip_(1)
936
+ # attention_scale = attention_scale[:, None, None, None]
937
+ # else:
938
+ # attention_scale = None
939
+
940
+ encoder_outputs = self.encoder(
941
+ embedding_output,
942
+ attention_bias=attention_bias,
943
+ rope_embeds=rope_embeds,
944
+ padding_inputs=padding_inputs,
945
+ attention_scale=attention_scale,
946
+ subset_indices=subset_indices,
947
+ head_mask=head_mask,
948
+ output_attentions=output_attentions,
949
+ output_hidden_states=output_hidden_states,
950
+ return_dict=return_dict,
951
+ )
952
+ sequence_output = encoder_outputs[0]
953
+ if unpad_inputs and output_padded:
954
+ sequence_output = pad_input(
955
+ sequence_output.squeeze(), indices, batch_size, seq_length
956
+ )
957
+
958
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
959
+
960
+ if not return_dict:
961
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
962
+
963
+ return BaseModelOutputWithPooling(
964
+ last_hidden_state=sequence_output,
965
+ pooler_output=pooled_output,
966
+ hidden_states=encoder_outputs.hidden_states,
967
+ attentions=encoder_outputs.attentions,
968
+ )
969
+
970
+
971
+ class NewLMPredictionHead(nn.Module):
972
+ def __init__(self, config):
973
+ super().__init__()
974
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
975
+ self.transform_act_fn = ACT2FN[config.hidden_act]
976
+ self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
977
+
978
+ # The output weights are the same as the input embeddings, but there is
979
+ # an output-only bias for each token.
980
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
981
+
982
+ def forward(self, hidden_states):
983
+ hidden_states = self.dense(hidden_states)
984
+ hidden_states = self.transform_act_fn(hidden_states)
985
+ hidden_states = self.norm(hidden_states)
986
+ hidden_states = self.decoder(hidden_states)
987
+ return hidden_states
988
+
989
+
990
+ class NewForMaskedLM(NewPreTrainedModel):
991
+ _tied_weights_keys = ["lm_head.decoder.bias", "lm_head.decoder.weight"]
992
+
993
+ def __init__(self, config: NewConfig):
994
+ super().__init__(config)
995
+ self.new = NewModel(config, add_pooling_layer=False)
996
+ self.lm_head = NewLMPredictionHead(config)
997
+ self.loss_fct = nn.CrossEntropyLoss()
998
+
999
+ # Initialize weights and apply final processing
1000
+ self.post_init()
1001
+
1002
+ def get_output_embeddings(self):
1003
+ return self.lm_head.decoder
1004
+
1005
+ def set_output_embeddings(self, new_embeddings):
1006
+ self.lm_head.decoder = new_embeddings
1007
+
1008
+ def forward(
1009
+ self,
1010
+ input_ids: Optional[torch.Tensor] = None,
1011
+ attention_mask: Optional[torch.Tensor] = None,
1012
+ token_type_ids: Optional[torch.Tensor] = None,
1013
+ position_ids: Optional[torch.Tensor] = None,
1014
+ head_mask: Optional[torch.Tensor] = None,
1015
+ inputs_embeds: Optional[torch.Tensor] = None,
1016
+ labels: Optional[torch.Tensor] = None,
1017
+ output_attentions: Optional[bool] = None,
1018
+ output_hidden_states: Optional[bool] = None,
1019
+ return_dict: Optional[bool] = None,
1020
+ unpad_inputs: Optional[bool] = None,
1021
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
1022
+ r"""
1023
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1024
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1025
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1026
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1027
+ """
1028
+
1029
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1030
+
1031
+ if labels is None or not self.new.config.unpad_inputs:
1032
+ length = None
1033
+ subset_indices = None
1034
+ else:
1035
+ length = attention_mask.sum(-1).tolist()
1036
+ labels = labels[attention_mask.bool()].unsqueeze(0)
1037
+ subset_indices = labels > -100
1038
+
1039
+ outputs = self.new(
1040
+ input_ids,
1041
+ attention_mask=attention_mask,
1042
+ length=length,
1043
+ subset_indices=subset_indices,
1044
+ token_type_ids=token_type_ids,
1045
+ position_ids=position_ids,
1046
+ head_mask=head_mask,
1047
+ inputs_embeds=inputs_embeds,
1048
+ output_attentions=output_attentions,
1049
+ output_hidden_states=output_hidden_states,
1050
+ return_dict=return_dict,
1051
+ unpad_inputs=unpad_inputs,
1052
+ )
1053
+
1054
+ sequence_output = outputs[0]
1055
+ prediction_scores = self.lm_head(sequence_output)
1056
+
1057
+ masked_lm_loss = None
1058
+ if labels is not None:
1059
+ if subset_indices is None:
1060
+ mask = attention_mask.bool()
1061
+ prediction_scores = prediction_scores[mask]
1062
+ labels = labels[mask]
1063
+ else:
1064
+ labels = labels[subset_indices]
1065
+ masked_lm_loss = self.loss_fct(prediction_scores, labels)
1066
+
1067
+ if not return_dict:
1068
+ output = (prediction_scores,) + outputs[2:]
1069
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1070
+
1071
+ return MaskedLMOutput(
1072
+ loss=masked_lm_loss,
1073
+ logits=prediction_scores,
1074
+ hidden_states=outputs.hidden_states,
1075
+ attentions=outputs.attentions,
1076
+ )
1077
+
1078
+
1079
+ class NewForSequenceClassification(NewPreTrainedModel):
1080
+ def __init__(self, config):
1081
+ super().__init__(config)
1082
+ self.num_labels = config.num_labels
1083
+ self.config = config
1084
+
1085
+ self.new = NewModel(config, add_pooling_layer=True)
1086
+ classifier_dropout = (
1087
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1088
+ )
1089
+ self.dropout = nn.Dropout(classifier_dropout)
1090
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1091
+
1092
+ # Initialize weights and apply final processing
1093
+ self.post_init()
1094
+
1095
+ def forward(
1096
+ self,
1097
+ input_ids: Optional[torch.Tensor] = None,
1098
+ attention_mask: Optional[torch.Tensor] = None,
1099
+ token_type_ids: Optional[torch.Tensor] = None,
1100
+ position_ids: Optional[torch.Tensor] = None,
1101
+ head_mask: Optional[torch.Tensor] = None,
1102
+ inputs_embeds: Optional[torch.Tensor] = None,
1103
+ labels: Optional[torch.Tensor] = None,
1104
+ output_attentions: Optional[bool] = None,
1105
+ output_hidden_states: Optional[bool] = None,
1106
+ return_dict: Optional[bool] = None,
1107
+ unpad_inputs: Optional[bool] = None,
1108
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1109
+ r"""
1110
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1111
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1112
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1113
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1114
+ """
1115
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1116
+
1117
+ outputs = self.new(
1118
+ input_ids,
1119
+ attention_mask=attention_mask,
1120
+ token_type_ids=token_type_ids,
1121
+ position_ids=position_ids,
1122
+ head_mask=head_mask,
1123
+ inputs_embeds=inputs_embeds,
1124
+ output_attentions=output_attentions,
1125
+ output_hidden_states=output_hidden_states,
1126
+ return_dict=return_dict,
1127
+ unpad_inputs=unpad_inputs,
1128
+ )
1129
+
1130
+ pooled_output = outputs[1]
1131
+
1132
+ pooled_output = self.dropout(pooled_output)
1133
+ logits = self.classifier(pooled_output)
1134
+
1135
+ loss = None
1136
+ if labels is not None:
1137
+ if self.config.problem_type is None:
1138
+ if self.num_labels == 1:
1139
+ self.config.problem_type = "regression"
1140
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1141
+ self.config.problem_type = "single_label_classification"
1142
+ else:
1143
+ self.config.problem_type = "multi_label_classification"
1144
+
1145
+ if self.config.problem_type == "regression":
1146
+ loss_fct = nn.MSELoss()
1147
+ if self.num_labels == 1:
1148
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1149
+ else:
1150
+ loss = loss_fct(logits, labels)
1151
+ elif self.config.problem_type == "single_label_classification":
1152
+ loss_fct = nn.CrossEntropyLoss()
1153
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1154
+ elif self.config.problem_type == "multi_label_classification":
1155
+ loss_fct = nn.BCEWithLogitsLoss()
1156
+ loss = loss_fct(logits, labels)
1157
+
1158
+ if not return_dict:
1159
+ output = (logits,) + outputs[2:]
1160
+ return ((loss,) + output) if loss is not None else output
1161
+
1162
+ return SequenceClassifierOutput(
1163
+ loss=loss,
1164
+ logits=logits,
1165
+ hidden_states=outputs.hidden_states,
1166
+ attentions=outputs.attentions,
1167
+ )
1168
+
1169
+
1170
+ class NewForMultipleChoice(NewPreTrainedModel):
1171
+ def __init__(self, config):
1172
+ super().__init__(config)
1173
+
1174
+ self.new = NewModel(config, add_pooling_layer=True)
1175
+ classifier_dropout = (
1176
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1177
+ )
1178
+ self.dropout = nn.Dropout(classifier_dropout)
1179
+ self.classifier = nn.Linear(config.hidden_size, 1)
1180
+
1181
+ # Initialize weights and apply final processing
1182
+ self.post_init()
1183
+
1184
+ def forward(
1185
+ self,
1186
+ input_ids: Optional[torch.Tensor] = None,
1187
+ attention_mask: Optional[torch.Tensor] = None,
1188
+ token_type_ids: Optional[torch.Tensor] = None,
1189
+ position_ids: Optional[torch.Tensor] = None,
1190
+ head_mask: Optional[torch.Tensor] = None,
1191
+ inputs_embeds: Optional[torch.Tensor] = None,
1192
+ labels: Optional[torch.Tensor] = None,
1193
+ output_attentions: Optional[bool] = None,
1194
+ output_hidden_states: Optional[bool] = None,
1195
+ return_dict: Optional[bool] = None,
1196
+ unpad_inputs: Optional[bool] = None,
1197
+ ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
1198
+ r"""
1199
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1200
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
1201
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
1202
+ `input_ids` above)
1203
+ """
1204
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1205
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1206
+
1207
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1208
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1209
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1210
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1211
+ inputs_embeds = (
1212
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1213
+ if inputs_embeds is not None
1214
+ else None
1215
+ )
1216
+
1217
+ outputs = self.new(
1218
+ input_ids,
1219
+ attention_mask=attention_mask,
1220
+ token_type_ids=token_type_ids,
1221
+ position_ids=position_ids,
1222
+ head_mask=head_mask,
1223
+ inputs_embeds=inputs_embeds,
1224
+ output_attentions=output_attentions,
1225
+ output_hidden_states=output_hidden_states,
1226
+ return_dict=return_dict,
1227
+ unpad_inputs=unpad_inputs,
1228
+ )
1229
+
1230
+ pooled_output = outputs[1]
1231
+
1232
+ pooled_output = self.dropout(pooled_output)
1233
+ logits = self.classifier(pooled_output)
1234
+ reshaped_logits = logits.view(-1, num_choices)
1235
+
1236
+ loss = None
1237
+ if labels is not None:
1238
+ loss_fct = nn.CrossEntropyLoss()
1239
+ loss = loss_fct(reshaped_logits, labels)
1240
+
1241
+ if not return_dict:
1242
+ output = (reshaped_logits,) + outputs[2:]
1243
+ return ((loss,) + output) if loss is not None else output
1244
+
1245
+ return MultipleChoiceModelOutput(
1246
+ loss=loss,
1247
+ logits=reshaped_logits,
1248
+ hidden_states=outputs.hidden_states,
1249
+ attentions=outputs.attentions,
1250
+ )
1251
+
1252
+
1253
+ @dataclass
1254
+ class NewTokenClassifierOutput(ModelOutput):
1255
+ loss: Optional[torch.FloatTensor] = None
1256
+ logits: torch.FloatTensor = None
1257
+ last_hidden_state: torch.FloatTensor = None
1258
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
1259
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
1260
+
1261
+
1262
+ class NewForTokenClassification(NewPreTrainedModel):
1263
+ def __init__(self, config):
1264
+ super().__init__(config)
1265
+ self.num_labels = config.num_labels
1266
+
1267
+ self.new = NewModel(config, add_pooling_layer=False)
1268
+ classifier_dropout = (
1269
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1270
+ )
1271
+ self.dropout = nn.Dropout(classifier_dropout)
1272
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1273
+
1274
+ # Initialize weights and apply final processing
1275
+ self.post_init()
1276
+
1277
+ def forward(
1278
+ self,
1279
+ input_ids: Optional[torch.Tensor] = None,
1280
+ attention_mask: Optional[torch.Tensor] = None,
1281
+ token_type_ids: Optional[torch.Tensor] = None,
1282
+ position_ids: Optional[torch.Tensor] = None,
1283
+ head_mask: Optional[torch.Tensor] = None,
1284
+ inputs_embeds: Optional[torch.Tensor] = None,
1285
+ labels: Optional[torch.Tensor] = None,
1286
+ output_attentions: Optional[bool] = None,
1287
+ output_hidden_states: Optional[bool] = None,
1288
+ return_dict: Optional[bool] = None,
1289
+ unpad_inputs: Optional[bool] = None,
1290
+ ) -> Union[Tuple[torch.Tensor], NewTokenClassifierOutput]:
1291
+ r"""
1292
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1293
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1294
+ """
1295
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1296
+
1297
+ outputs = self.new(
1298
+ input_ids,
1299
+ attention_mask=attention_mask,
1300
+ token_type_ids=token_type_ids,
1301
+ position_ids=position_ids,
1302
+ head_mask=head_mask,
1303
+ inputs_embeds=inputs_embeds,
1304
+ output_attentions=output_attentions,
1305
+ output_hidden_states=output_hidden_states,
1306
+ return_dict=return_dict,
1307
+ unpad_inputs=unpad_inputs,
1308
+ )
1309
+
1310
+ sequence_output = outputs[0]
1311
+
1312
+ sequence_output = self.dropout(sequence_output)
1313
+ logits = self.classifier(sequence_output)
1314
+
1315
+ loss = None
1316
+ if labels is not None:
1317
+ loss_fct = nn.CrossEntropyLoss()
1318
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1319
+
1320
+ if not return_dict:
1321
+ output = (logits,) + outputs[2:]
1322
+ return ((loss,) + output) if loss is not None else output
1323
+
1324
+ return NewTokenClassifierOutput(
1325
+ loss=loss,
1326
+ logits=logits,
1327
+ last_hidden_state=sequence_output,
1328
+ hidden_states=outputs.hidden_states,
1329
+ attentions=outputs.attentions,
1330
+ )
1331
+
1332
+
1333
+ class NewForQuestionAnswering(NewPreTrainedModel):
1334
+ def __init__(self, config):
1335
+ super().__init__(config)
1336
+ self.num_labels = config.num_labels
1337
+
1338
+ self.new = NewModel(config, add_pooling_layer=False)
1339
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1340
+
1341
+ # Initialize weights and apply final processing
1342
+ self.post_init()
1343
+
1344
+ def forward(
1345
+ self,
1346
+ input_ids: Optional[torch.Tensor] = None,
1347
+ attention_mask: Optional[torch.Tensor] = None,
1348
+ token_type_ids: Optional[torch.Tensor] = None,
1349
+ position_ids: Optional[torch.Tensor] = None,
1350
+ head_mask: Optional[torch.Tensor] = None,
1351
+ inputs_embeds: Optional[torch.Tensor] = None,
1352
+ start_positions: Optional[torch.Tensor] = None,
1353
+ end_positions: Optional[torch.Tensor] = None,
1354
+ output_attentions: Optional[bool] = None,
1355
+ output_hidden_states: Optional[bool] = None,
1356
+ return_dict: Optional[bool] = None,
1357
+ unpad_inputs: Optional[bool] = None,
1358
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
1359
+ r"""
1360
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1361
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1362
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1363
+ are not taken into account for computing the loss.
1364
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1365
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1366
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1367
+ are not taken into account for computing the loss.
1368
+ """
1369
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1370
+
1371
+ outputs = self.new(
1372
+ input_ids,
1373
+ attention_mask=attention_mask,
1374
+ token_type_ids=token_type_ids,
1375
+ position_ids=position_ids,
1376
+ head_mask=head_mask,
1377
+ inputs_embeds=inputs_embeds,
1378
+ output_attentions=output_attentions,
1379
+ output_hidden_states=output_hidden_states,
1380
+ return_dict=return_dict,
1381
+ unpad_inputs=unpad_inputs,
1382
+ )
1383
+
1384
+ sequence_output = outputs[0]
1385
+
1386
+ logits = self.qa_outputs(sequence_output)
1387
+ start_logits, end_logits = logits.split(1, dim=-1)
1388
+ start_logits = start_logits.squeeze(-1).contiguous()
1389
+ end_logits = end_logits.squeeze(-1).contiguous()
1390
+
1391
+ total_loss = None
1392
+ if start_positions is not None and end_positions is not None:
1393
+ # If we are on multi-GPU, split add a dimension
1394
+ if len(start_positions.size()) > 1:
1395
+ start_positions = start_positions.squeeze(-1)
1396
+ if len(end_positions.size()) > 1:
1397
+ end_positions = end_positions.squeeze(-1)
1398
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1399
+ ignored_index = start_logits.size(1)
1400
+ start_positions = start_positions.clamp(0, ignored_index)
1401
+ end_positions = end_positions.clamp(0, ignored_index)
1402
+
1403
+ loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
1404
+ start_loss = loss_fct(start_logits, start_positions)
1405
+ end_loss = loss_fct(end_logits, end_positions)
1406
+ total_loss = (start_loss + end_loss) / 2
1407
+
1408
+ if not return_dict:
1409
+ output = (start_logits, end_logits) + outputs[2:]
1410
+ return ((total_loss,) + output) if total_loss is not None else output
1411
+
1412
+ return QuestionAnsweringModelOutput(
1413
+ loss=total_loss,
1414
+ start_logits=start_logits,
1415
+ end_logits=end_logits,
1416
+ hidden_states=outputs.hidden_states,
1417
+ attentions=outputs.attentions,
1418
+ )
modules.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "sentence_transformers.models.Transformer"
7
+ },
8
+ {
9
+ "idx": 1,
10
+ "name": "1",
11
+ "path": "1_Pooling",
12
+ "type": "sentence_transformers.models.Pooling"
13
+ },
14
+ {
15
+ "idx": 2,
16
+ "name": "2",
17
+ "path": "2_Normalize",
18
+ "type": "sentence_transformers.models.Normalize"
19
+ }
20
+ ]
optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31fea2f2e35733cfe3b09094949fb7a2832d8cddcdd9b41f5490ed8a5c9e1d0c
3
+ size 2443060474
rng_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:19ceae518e6cfb53055ed5e11269eae77fa624a98a082815654db1c9ff34138d
3
+ size 14244
scaler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:77a0fc31716765e66cba03a3306e9c1711d520ccfadc681d20f84aedbf882573
3
+ size 988
scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a094d9d30b33158ba23e4fd3d4d66833a78100feef4aada49cb7f17ad27d06e5
3
+ size 1064
sentence_bert_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "max_seq_length": 8192,
3
+ "do_lower_case": false
4
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "cls_token": {
10
+ "content": "<s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "eos_token": {
17
+ "content": "</s>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "mask_token": {
24
+ "content": "<mask>",
25
+ "lstrip": true,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "pad_token": {
31
+ "content": "<pad>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "sep_token": {
38
+ "content": "</s>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ },
44
+ "unk_token": {
45
+ "content": "<unk>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false
50
+ }
51
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa7a6ad87a7ce8fe196787355f6af7d03aee94d19c54a5eb1392ed18c8ef451a
3
+ size 17082988
tokenizer_config.json ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<s>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<pad>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "</s>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "<unk>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "250001": {
36
+ "content": "<mask>",
37
+ "lstrip": true,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "bos_token": "<s>",
45
+ "clean_up_tokenization_spaces": true,
46
+ "cls_token": "<s>",
47
+ "eos_token": "</s>",
48
+ "extra_special_tokens": {},
49
+ "mask_token": "<mask>",
50
+ "model_max_length": 8192,
51
+ "pad_token": "<pad>",
52
+ "sep_token": "</s>",
53
+ "tokenizer_class": "XLMRobertaTokenizerFast",
54
+ "unk_token": "<unk>"
55
+ }
trainer_state.json ADDED
@@ -0,0 +1,1034 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_global_step": null,
3
+ "best_metric": null,
4
+ "best_model_checkpoint": null,
5
+ "epoch": 1.0,
6
+ "eval_steps": 1000,
7
+ "global_step": 6803,
8
+ "is_hyper_param_search": false,
9
+ "is_local_process_zero": true,
10
+ "is_world_process_zero": true,
11
+ "log_history": [
12
+ {
13
+ "epoch": 0.007349698662354844,
14
+ "grad_norm": 7.311169624328613,
15
+ "learning_rate": 3.597650513950074e-06,
16
+ "loss": 3.3431,
17
+ "step": 50
18
+ },
19
+ {
20
+ "epoch": 0.014699397324709687,
21
+ "grad_norm": 7.183924198150635,
22
+ "learning_rate": 7.268722466960353e-06,
23
+ "loss": 3.1904,
24
+ "step": 100
25
+ },
26
+ {
27
+ "epoch": 0.022049095987064532,
28
+ "grad_norm": 8.947137832641602,
29
+ "learning_rate": 1.0939794419970633e-05,
30
+ "loss": 3.0541,
31
+ "step": 150
32
+ },
33
+ {
34
+ "epoch": 0.029398794649419375,
35
+ "grad_norm": 7.888134956359863,
36
+ "learning_rate": 1.461086637298091e-05,
37
+ "loss": 2.972,
38
+ "step": 200
39
+ },
40
+ {
41
+ "epoch": 0.03674849331177422,
42
+ "grad_norm": 7.55863618850708,
43
+ "learning_rate": 1.828193832599119e-05,
44
+ "loss": 2.8877,
45
+ "step": 250
46
+ },
47
+ {
48
+ "epoch": 0.044098191974129064,
49
+ "grad_norm": 9.108429908752441,
50
+ "learning_rate": 2.195301027900147e-05,
51
+ "loss": 2.8234,
52
+ "step": 300
53
+ },
54
+ {
55
+ "epoch": 0.0514478906364839,
56
+ "grad_norm": 8.362081527709961,
57
+ "learning_rate": 2.5624082232011748e-05,
58
+ "loss": 2.749,
59
+ "step": 350
60
+ },
61
+ {
62
+ "epoch": 0.05879758929883875,
63
+ "grad_norm": 6.888792991638184,
64
+ "learning_rate": 2.929515418502203e-05,
65
+ "loss": 2.7435,
66
+ "step": 400
67
+ },
68
+ {
69
+ "epoch": 0.0661472879611936,
70
+ "grad_norm": 7.607847213745117,
71
+ "learning_rate": 3.296622613803231e-05,
72
+ "loss": 2.7368,
73
+ "step": 450
74
+ },
75
+ {
76
+ "epoch": 0.07349698662354844,
77
+ "grad_norm": 5.798498153686523,
78
+ "learning_rate": 3.663729809104259e-05,
79
+ "loss": 2.6943,
80
+ "step": 500
81
+ },
82
+ {
83
+ "epoch": 0.08084668528590327,
84
+ "grad_norm": 7.030722618103027,
85
+ "learning_rate": 4.030837004405287e-05,
86
+ "loss": 2.7168,
87
+ "step": 550
88
+ },
89
+ {
90
+ "epoch": 0.08819638394825813,
91
+ "grad_norm": 5.483092308044434,
92
+ "learning_rate": 4.397944199706314e-05,
93
+ "loss": 2.7194,
94
+ "step": 600
95
+ },
96
+ {
97
+ "epoch": 0.09554608261061297,
98
+ "grad_norm": 7.004276275634766,
99
+ "learning_rate": 4.7650513950073424e-05,
100
+ "loss": 2.6096,
101
+ "step": 650
102
+ },
103
+ {
104
+ "epoch": 0.1028957812729678,
105
+ "grad_norm": 5.524320125579834,
106
+ "learning_rate": 4.985298921920941e-05,
107
+ "loss": 2.7118,
108
+ "step": 700
109
+ },
110
+ {
111
+ "epoch": 0.11024547993532265,
112
+ "grad_norm": 6.163999557495117,
113
+ "learning_rate": 4.9444625939235546e-05,
114
+ "loss": 2.7036,
115
+ "step": 750
116
+ },
117
+ {
118
+ "epoch": 0.1175951785976775,
119
+ "grad_norm": 6.143119812011719,
120
+ "learning_rate": 4.903626265926169e-05,
121
+ "loss": 2.6625,
122
+ "step": 800
123
+ },
124
+ {
125
+ "epoch": 0.12494487726003234,
126
+ "grad_norm": 5.09917688369751,
127
+ "learning_rate": 4.8627899379287814e-05,
128
+ "loss": 2.6362,
129
+ "step": 850
130
+ },
131
+ {
132
+ "epoch": 0.1322945759223872,
133
+ "grad_norm": 5.649909019470215,
134
+ "learning_rate": 4.821953609931395e-05,
135
+ "loss": 2.599,
136
+ "step": 900
137
+ },
138
+ {
139
+ "epoch": 0.13964427458474202,
140
+ "grad_norm": 4.9785966873168945,
141
+ "learning_rate": 4.781117281934009e-05,
142
+ "loss": 2.572,
143
+ "step": 950
144
+ },
145
+ {
146
+ "epoch": 0.14699397324709687,
147
+ "grad_norm": 5.37718391418457,
148
+ "learning_rate": 4.7402809539366224e-05,
149
+ "loss": 2.6124,
150
+ "step": 1000
151
+ },
152
+ {
153
+ "epoch": 0.14699397324709687,
154
+ "eval_loss": 2.007162094116211,
155
+ "eval_runtime": 104.7294,
156
+ "eval_samples_per_second": 230.948,
157
+ "eval_steps_per_second": 7.219,
158
+ "step": 1000
159
+ },
160
+ {
161
+ "epoch": 0.15434367190945172,
162
+ "grad_norm": 5.040450572967529,
163
+ "learning_rate": 4.699444625939235e-05,
164
+ "loss": 2.5467,
165
+ "step": 1050
166
+ },
167
+ {
168
+ "epoch": 0.16169337057180655,
169
+ "grad_norm": 4.658956527709961,
170
+ "learning_rate": 4.658608297941849e-05,
171
+ "loss": 2.5713,
172
+ "step": 1100
173
+ },
174
+ {
175
+ "epoch": 0.1690430692341614,
176
+ "grad_norm": 5.415022373199463,
177
+ "learning_rate": 4.617771969944463e-05,
178
+ "loss": 2.5741,
179
+ "step": 1150
180
+ },
181
+ {
182
+ "epoch": 0.17639276789651626,
183
+ "grad_norm": 5.289612770080566,
184
+ "learning_rate": 4.576935641947077e-05,
185
+ "loss": 2.5794,
186
+ "step": 1200
187
+ },
188
+ {
189
+ "epoch": 0.18374246655887108,
190
+ "grad_norm": 5.042024612426758,
191
+ "learning_rate": 4.5360993139496896e-05,
192
+ "loss": 2.5231,
193
+ "step": 1250
194
+ },
195
+ {
196
+ "epoch": 0.19109216522122593,
197
+ "grad_norm": 5.330864429473877,
198
+ "learning_rate": 4.495262985952303e-05,
199
+ "loss": 2.5312,
200
+ "step": 1300
201
+ },
202
+ {
203
+ "epoch": 0.19844186388358076,
204
+ "grad_norm": 5.096045017242432,
205
+ "learning_rate": 4.454426657954917e-05,
206
+ "loss": 2.4483,
207
+ "step": 1350
208
+ },
209
+ {
210
+ "epoch": 0.2057915625459356,
211
+ "grad_norm": 5.265764236450195,
212
+ "learning_rate": 4.4135903299575306e-05,
213
+ "loss": 2.5178,
214
+ "step": 1400
215
+ },
216
+ {
217
+ "epoch": 0.21314126120829047,
218
+ "grad_norm": 4.734705924987793,
219
+ "learning_rate": 4.372754001960144e-05,
220
+ "loss": 2.4795,
221
+ "step": 1450
222
+ },
223
+ {
224
+ "epoch": 0.2204909598706453,
225
+ "grad_norm": 5.286141395568848,
226
+ "learning_rate": 4.3319176739627575e-05,
227
+ "loss": 2.5426,
228
+ "step": 1500
229
+ },
230
+ {
231
+ "epoch": 0.22784065853300015,
232
+ "grad_norm": 4.941354751586914,
233
+ "learning_rate": 4.291081345965371e-05,
234
+ "loss": 2.502,
235
+ "step": 1550
236
+ },
237
+ {
238
+ "epoch": 0.235190357195355,
239
+ "grad_norm": 4.677373886108398,
240
+ "learning_rate": 4.250245017967985e-05,
241
+ "loss": 2.5378,
242
+ "step": 1600
243
+ },
244
+ {
245
+ "epoch": 0.24254005585770982,
246
+ "grad_norm": 5.507444858551025,
247
+ "learning_rate": 4.209408689970598e-05,
248
+ "loss": 2.4746,
249
+ "step": 1650
250
+ },
251
+ {
252
+ "epoch": 0.24988975452006468,
253
+ "grad_norm": 5.301682949066162,
254
+ "learning_rate": 4.168572361973211e-05,
255
+ "loss": 2.4356,
256
+ "step": 1700
257
+ },
258
+ {
259
+ "epoch": 0.25723945318241953,
260
+ "grad_norm": 5.333432197570801,
261
+ "learning_rate": 4.1277360339758254e-05,
262
+ "loss": 2.5303,
263
+ "step": 1750
264
+ },
265
+ {
266
+ "epoch": 0.2645891518447744,
267
+ "grad_norm": 4.362396717071533,
268
+ "learning_rate": 4.086899705978439e-05,
269
+ "loss": 2.514,
270
+ "step": 1800
271
+ },
272
+ {
273
+ "epoch": 0.2719388505071292,
274
+ "grad_norm": 4.468241214752197,
275
+ "learning_rate": 4.046063377981052e-05,
276
+ "loss": 2.5207,
277
+ "step": 1850
278
+ },
279
+ {
280
+ "epoch": 0.27928854916948403,
281
+ "grad_norm": 5.397381782531738,
282
+ "learning_rate": 4.005227049983666e-05,
283
+ "loss": 2.4671,
284
+ "step": 1900
285
+ },
286
+ {
287
+ "epoch": 0.2866382478318389,
288
+ "grad_norm": 5.816707611083984,
289
+ "learning_rate": 3.965207448546227e-05,
290
+ "loss": 2.4367,
291
+ "step": 1950
292
+ },
293
+ {
294
+ "epoch": 0.29398794649419374,
295
+ "grad_norm": 5.598290920257568,
296
+ "learning_rate": 3.9243711205488406e-05,
297
+ "loss": 2.4873,
298
+ "step": 2000
299
+ },
300
+ {
301
+ "epoch": 0.29398794649419374,
302
+ "eval_loss": 1.933936595916748,
303
+ "eval_runtime": 104.3275,
304
+ "eval_samples_per_second": 231.837,
305
+ "eval_steps_per_second": 7.246,
306
+ "step": 2000
307
+ },
308
+ {
309
+ "epoch": 0.3013376451565486,
310
+ "grad_norm": 5.163125514984131,
311
+ "learning_rate": 3.883534792551454e-05,
312
+ "loss": 2.4513,
313
+ "step": 2050
314
+ },
315
+ {
316
+ "epoch": 0.30868734381890345,
317
+ "grad_norm": 4.974937915802002,
318
+ "learning_rate": 3.8426984645540675e-05,
319
+ "loss": 2.4695,
320
+ "step": 2100
321
+ },
322
+ {
323
+ "epoch": 0.31603704248125825,
324
+ "grad_norm": 5.368651866912842,
325
+ "learning_rate": 3.801862136556681e-05,
326
+ "loss": 2.4309,
327
+ "step": 2150
328
+ },
329
+ {
330
+ "epoch": 0.3233867411436131,
331
+ "grad_norm": 4.559326171875,
332
+ "learning_rate": 3.7618425351192424e-05,
333
+ "loss": 2.4439,
334
+ "step": 2200
335
+ },
336
+ {
337
+ "epoch": 0.33073643980596795,
338
+ "grad_norm": 5.623340606689453,
339
+ "learning_rate": 3.721006207121856e-05,
340
+ "loss": 2.4242,
341
+ "step": 2250
342
+ },
343
+ {
344
+ "epoch": 0.3380861384683228,
345
+ "grad_norm": 5.139235973358154,
346
+ "learning_rate": 3.680169879124469e-05,
347
+ "loss": 2.4569,
348
+ "step": 2300
349
+ },
350
+ {
351
+ "epoch": 0.34543583713067766,
352
+ "grad_norm": 5.401805877685547,
353
+ "learning_rate": 3.639333551127083e-05,
354
+ "loss": 2.4157,
355
+ "step": 2350
356
+ },
357
+ {
358
+ "epoch": 0.3527855357930325,
359
+ "grad_norm": 4.80020809173584,
360
+ "learning_rate": 3.598497223129697e-05,
361
+ "loss": 2.4709,
362
+ "step": 2400
363
+ },
364
+ {
365
+ "epoch": 0.3601352344553873,
366
+ "grad_norm": 4.676275730133057,
367
+ "learning_rate": 3.5576608951323096e-05,
368
+ "loss": 2.4202,
369
+ "step": 2450
370
+ },
371
+ {
372
+ "epoch": 0.36748493311774216,
373
+ "grad_norm": 4.569772243499756,
374
+ "learning_rate": 3.516824567134923e-05,
375
+ "loss": 2.4401,
376
+ "step": 2500
377
+ },
378
+ {
379
+ "epoch": 0.374834631780097,
380
+ "grad_norm": 5.404001712799072,
381
+ "learning_rate": 3.475988239137537e-05,
382
+ "loss": 2.4096,
383
+ "step": 2550
384
+ },
385
+ {
386
+ "epoch": 0.38218433044245187,
387
+ "grad_norm": 5.0445475578308105,
388
+ "learning_rate": 3.4351519111401506e-05,
389
+ "loss": 2.3878,
390
+ "step": 2600
391
+ },
392
+ {
393
+ "epoch": 0.3895340291048067,
394
+ "grad_norm": 6.409811019897461,
395
+ "learning_rate": 3.394315583142764e-05,
396
+ "loss": 2.4766,
397
+ "step": 2650
398
+ },
399
+ {
400
+ "epoch": 0.3968837277671615,
401
+ "grad_norm": 4.700074672698975,
402
+ "learning_rate": 3.3534792551453774e-05,
403
+ "loss": 2.4149,
404
+ "step": 2700
405
+ },
406
+ {
407
+ "epoch": 0.4042334264295164,
408
+ "grad_norm": 5.682339668273926,
409
+ "learning_rate": 3.312642927147991e-05,
410
+ "loss": 2.4197,
411
+ "step": 2750
412
+ },
413
+ {
414
+ "epoch": 0.4115831250918712,
415
+ "grad_norm": 4.715968608856201,
416
+ "learning_rate": 3.272623325710552e-05,
417
+ "loss": 2.3656,
418
+ "step": 2800
419
+ },
420
+ {
421
+ "epoch": 0.4189328237542261,
422
+ "grad_norm": 4.35789155960083,
423
+ "learning_rate": 3.231786997713166e-05,
424
+ "loss": 2.4679,
425
+ "step": 2850
426
+ },
427
+ {
428
+ "epoch": 0.42628252241658093,
429
+ "grad_norm": 4.864765167236328,
430
+ "learning_rate": 3.190950669715779e-05,
431
+ "loss": 2.3749,
432
+ "step": 2900
433
+ },
434
+ {
435
+ "epoch": 0.4336322210789358,
436
+ "grad_norm": 4.375899314880371,
437
+ "learning_rate": 3.1501143417183927e-05,
438
+ "loss": 2.4146,
439
+ "step": 2950
440
+ },
441
+ {
442
+ "epoch": 0.4409819197412906,
443
+ "grad_norm": 4.743485450744629,
444
+ "learning_rate": 3.109278013721007e-05,
445
+ "loss": 2.3942,
446
+ "step": 3000
447
+ },
448
+ {
449
+ "epoch": 0.4409819197412906,
450
+ "eval_loss": 1.8871009349822998,
451
+ "eval_runtime": 102.8121,
452
+ "eval_samples_per_second": 235.254,
453
+ "eval_steps_per_second": 7.353,
454
+ "step": 3000
455
+ },
456
+ {
457
+ "epoch": 0.44833161840364544,
458
+ "grad_norm": 5.192530155181885,
459
+ "learning_rate": 3.0684416857236195e-05,
460
+ "loss": 2.418,
461
+ "step": 3050
462
+ },
463
+ {
464
+ "epoch": 0.4556813170660003,
465
+ "grad_norm": 5.368825912475586,
466
+ "learning_rate": 3.0276053577262336e-05,
467
+ "loss": 2.4504,
468
+ "step": 3100
469
+ },
470
+ {
471
+ "epoch": 0.46303101572835514,
472
+ "grad_norm": 6.064142227172852,
473
+ "learning_rate": 2.986769029728847e-05,
474
+ "loss": 2.3759,
475
+ "step": 3150
476
+ },
477
+ {
478
+ "epoch": 0.47038071439071,
479
+ "grad_norm": 5.2487101554870605,
480
+ "learning_rate": 2.9459327017314602e-05,
481
+ "loss": 2.3671,
482
+ "step": 3200
483
+ },
484
+ {
485
+ "epoch": 0.47773041305306485,
486
+ "grad_norm": 4.877628326416016,
487
+ "learning_rate": 2.9050963737340743e-05,
488
+ "loss": 2.4433,
489
+ "step": 3250
490
+ },
491
+ {
492
+ "epoch": 0.48508011171541965,
493
+ "grad_norm": 5.295327663421631,
494
+ "learning_rate": 2.8642600457366874e-05,
495
+ "loss": 2.4036,
496
+ "step": 3300
497
+ },
498
+ {
499
+ "epoch": 0.4924298103777745,
500
+ "grad_norm": 5.654299259185791,
501
+ "learning_rate": 2.823423717739301e-05,
502
+ "loss": 2.3539,
503
+ "step": 3350
504
+ },
505
+ {
506
+ "epoch": 0.49977950904012935,
507
+ "grad_norm": 5.245526313781738,
508
+ "learning_rate": 2.7825873897419146e-05,
509
+ "loss": 2.3806,
510
+ "step": 3400
511
+ },
512
+ {
513
+ "epoch": 0.5071292077024842,
514
+ "grad_norm": 4.889758110046387,
515
+ "learning_rate": 2.741751061744528e-05,
516
+ "loss": 2.3737,
517
+ "step": 3450
518
+ },
519
+ {
520
+ "epoch": 0.5144789063648391,
521
+ "grad_norm": 4.38112735748291,
522
+ "learning_rate": 2.7009147337471415e-05,
523
+ "loss": 2.4127,
524
+ "step": 3500
525
+ },
526
+ {
527
+ "epoch": 0.5218286050271939,
528
+ "grad_norm": 4.222320556640625,
529
+ "learning_rate": 2.6600784057497553e-05,
530
+ "loss": 2.4243,
531
+ "step": 3550
532
+ },
533
+ {
534
+ "epoch": 0.5291783036895488,
535
+ "grad_norm": 4.443734645843506,
536
+ "learning_rate": 2.6192420777523684e-05,
537
+ "loss": 2.3528,
538
+ "step": 3600
539
+ },
540
+ {
541
+ "epoch": 0.5365280023519036,
542
+ "grad_norm": 4.743895053863525,
543
+ "learning_rate": 2.5784057497549825e-05,
544
+ "loss": 2.3788,
545
+ "step": 3650
546
+ },
547
+ {
548
+ "epoch": 0.5438777010142584,
549
+ "grad_norm": 5.054786205291748,
550
+ "learning_rate": 2.5375694217575956e-05,
551
+ "loss": 2.3968,
552
+ "step": 3700
553
+ },
554
+ {
555
+ "epoch": 0.5512273996766133,
556
+ "grad_norm": 5.545046806335449,
557
+ "learning_rate": 2.4967330937602094e-05,
558
+ "loss": 2.3896,
559
+ "step": 3750
560
+ },
561
+ {
562
+ "epoch": 0.5585770983389681,
563
+ "grad_norm": 5.790585994720459,
564
+ "learning_rate": 2.4558967657628228e-05,
565
+ "loss": 2.3966,
566
+ "step": 3800
567
+ },
568
+ {
569
+ "epoch": 0.565926797001323,
570
+ "grad_norm": 4.9621500968933105,
571
+ "learning_rate": 2.4150604377654362e-05,
572
+ "loss": 2.3571,
573
+ "step": 3850
574
+ },
575
+ {
576
+ "epoch": 0.5732764956636778,
577
+ "grad_norm": 5.133751392364502,
578
+ "learning_rate": 2.3742241097680497e-05,
579
+ "loss": 2.3437,
580
+ "step": 3900
581
+ },
582
+ {
583
+ "epoch": 0.5806261943260327,
584
+ "grad_norm": 4.68387508392334,
585
+ "learning_rate": 2.3333877817706634e-05,
586
+ "loss": 2.3353,
587
+ "step": 3950
588
+ },
589
+ {
590
+ "epoch": 0.5879758929883875,
591
+ "grad_norm": 4.301448345184326,
592
+ "learning_rate": 2.292551453773277e-05,
593
+ "loss": 2.3335,
594
+ "step": 4000
595
+ },
596
+ {
597
+ "epoch": 0.5879758929883875,
598
+ "eval_loss": 1.8598763942718506,
599
+ "eval_runtime": 102.4241,
600
+ "eval_samples_per_second": 236.146,
601
+ "eval_steps_per_second": 7.381,
602
+ "step": 4000
603
+ },
604
+ {
605
+ "epoch": 0.5953255916507423,
606
+ "grad_norm": 4.841070175170898,
607
+ "learning_rate": 2.2517151257758903e-05,
608
+ "loss": 2.3778,
609
+ "step": 4050
610
+ },
611
+ {
612
+ "epoch": 0.6026752903130972,
613
+ "grad_norm": 4.520423412322998,
614
+ "learning_rate": 2.2108787977785038e-05,
615
+ "loss": 2.3929,
616
+ "step": 4100
617
+ },
618
+ {
619
+ "epoch": 0.610024988975452,
620
+ "grad_norm": 4.455601215362549,
621
+ "learning_rate": 2.1700424697811172e-05,
622
+ "loss": 2.3818,
623
+ "step": 4150
624
+ },
625
+ {
626
+ "epoch": 0.6173746876378069,
627
+ "grad_norm": 4.496808052062988,
628
+ "learning_rate": 2.129206141783731e-05,
629
+ "loss": 2.3874,
630
+ "step": 4200
631
+ },
632
+ {
633
+ "epoch": 0.6247243863001617,
634
+ "grad_norm": 4.8610429763793945,
635
+ "learning_rate": 2.0883698137863444e-05,
636
+ "loss": 2.3224,
637
+ "step": 4250
638
+ },
639
+ {
640
+ "epoch": 0.6320740849625165,
641
+ "grad_norm": 5.088446617126465,
642
+ "learning_rate": 2.0475334857889582e-05,
643
+ "loss": 2.3317,
644
+ "step": 4300
645
+ },
646
+ {
647
+ "epoch": 0.6394237836248714,
648
+ "grad_norm": 5.25501012802124,
649
+ "learning_rate": 2.0066971577915713e-05,
650
+ "loss": 2.3761,
651
+ "step": 4350
652
+ },
653
+ {
654
+ "epoch": 0.6467734822872262,
655
+ "grad_norm": 4.978085517883301,
656
+ "learning_rate": 1.965860829794185e-05,
657
+ "loss": 2.4066,
658
+ "step": 4400
659
+ },
660
+ {
661
+ "epoch": 0.6541231809495811,
662
+ "grad_norm": 4.8442230224609375,
663
+ "learning_rate": 1.9250245017967985e-05,
664
+ "loss": 2.3406,
665
+ "step": 4450
666
+ },
667
+ {
668
+ "epoch": 0.6614728796119359,
669
+ "grad_norm": 5.3364105224609375,
670
+ "learning_rate": 1.8841881737994123e-05,
671
+ "loss": 2.3844,
672
+ "step": 4500
673
+ },
674
+ {
675
+ "epoch": 0.6688225782742907,
676
+ "grad_norm": 4.952212810516357,
677
+ "learning_rate": 1.8433518458020254e-05,
678
+ "loss": 2.2993,
679
+ "step": 4550
680
+ },
681
+ {
682
+ "epoch": 0.6761722769366456,
683
+ "grad_norm": 4.741338729858398,
684
+ "learning_rate": 1.802515517804639e-05,
685
+ "loss": 2.337,
686
+ "step": 4600
687
+ },
688
+ {
689
+ "epoch": 0.6835219755990004,
690
+ "grad_norm": 5.139461994171143,
691
+ "learning_rate": 1.7616791898072526e-05,
692
+ "loss": 2.37,
693
+ "step": 4650
694
+ },
695
+ {
696
+ "epoch": 0.6908716742613553,
697
+ "grad_norm": 5.813575744628906,
698
+ "learning_rate": 1.7208428618098664e-05,
699
+ "loss": 2.3126,
700
+ "step": 4700
701
+ },
702
+ {
703
+ "epoch": 0.6982213729237101,
704
+ "grad_norm": 5.255462169647217,
705
+ "learning_rate": 1.6800065338124795e-05,
706
+ "loss": 2.3818,
707
+ "step": 4750
708
+ },
709
+ {
710
+ "epoch": 0.705571071586065,
711
+ "grad_norm": 4.611236095428467,
712
+ "learning_rate": 1.6391702058150933e-05,
713
+ "loss": 2.3849,
714
+ "step": 4800
715
+ },
716
+ {
717
+ "epoch": 0.7129207702484198,
718
+ "grad_norm": 4.705516338348389,
719
+ "learning_rate": 1.5983338778177067e-05,
720
+ "loss": 2.3379,
721
+ "step": 4850
722
+ },
723
+ {
724
+ "epoch": 0.7202704689107746,
725
+ "grad_norm": 5.51972770690918,
726
+ "learning_rate": 1.5574975498203205e-05,
727
+ "loss": 2.3518,
728
+ "step": 4900
729
+ },
730
+ {
731
+ "epoch": 0.7276201675731295,
732
+ "grad_norm": 5.037983417510986,
733
+ "learning_rate": 1.5166612218229337e-05,
734
+ "loss": 2.3354,
735
+ "step": 4950
736
+ },
737
+ {
738
+ "epoch": 0.7349698662354843,
739
+ "grad_norm": 5.397042274475098,
740
+ "learning_rate": 1.4758248938255473e-05,
741
+ "loss": 2.3443,
742
+ "step": 5000
743
+ },
744
+ {
745
+ "epoch": 0.7349698662354843,
746
+ "eval_loss": 1.8348900079727173,
747
+ "eval_runtime": 101.89,
748
+ "eval_samples_per_second": 237.384,
749
+ "eval_steps_per_second": 7.42,
750
+ "step": 5000
751
+ },
752
+ {
753
+ "epoch": 0.7423195648978392,
754
+ "grad_norm": 4.795029163360596,
755
+ "learning_rate": 1.434988565828161e-05,
756
+ "loss": 2.3396,
757
+ "step": 5050
758
+ },
759
+ {
760
+ "epoch": 0.749669263560194,
761
+ "grad_norm": 5.308424949645996,
762
+ "learning_rate": 1.3941522378307742e-05,
763
+ "loss": 2.3086,
764
+ "step": 5100
765
+ },
766
+ {
767
+ "epoch": 0.7570189622225488,
768
+ "grad_norm": 5.047533988952637,
769
+ "learning_rate": 1.3533159098333878e-05,
770
+ "loss": 2.3392,
771
+ "step": 5150
772
+ },
773
+ {
774
+ "epoch": 0.7643686608849037,
775
+ "grad_norm": 4.623427391052246,
776
+ "learning_rate": 1.3124795818360014e-05,
777
+ "loss": 2.3316,
778
+ "step": 5200
779
+ },
780
+ {
781
+ "epoch": 0.7717183595472585,
782
+ "grad_norm": 5.1898956298828125,
783
+ "learning_rate": 1.271643253838615e-05,
784
+ "loss": 2.3092,
785
+ "step": 5250
786
+ },
787
+ {
788
+ "epoch": 0.7790680582096134,
789
+ "grad_norm": 4.365305423736572,
790
+ "learning_rate": 1.2308069258412285e-05,
791
+ "loss": 2.3794,
792
+ "step": 5300
793
+ },
794
+ {
795
+ "epoch": 0.7864177568719682,
796
+ "grad_norm": 4.959578990936279,
797
+ "learning_rate": 1.189970597843842e-05,
798
+ "loss": 2.331,
799
+ "step": 5350
800
+ },
801
+ {
802
+ "epoch": 0.793767455534323,
803
+ "grad_norm": 5.128349781036377,
804
+ "learning_rate": 1.1491342698464555e-05,
805
+ "loss": 2.2554,
806
+ "step": 5400
807
+ },
808
+ {
809
+ "epoch": 0.801117154196678,
810
+ "grad_norm": 5.27305269241333,
811
+ "learning_rate": 1.108297941849069e-05,
812
+ "loss": 2.3266,
813
+ "step": 5450
814
+ },
815
+ {
816
+ "epoch": 0.8084668528590327,
817
+ "grad_norm": 4.942054271697998,
818
+ "learning_rate": 1.0674616138516826e-05,
819
+ "loss": 2.3314,
820
+ "step": 5500
821
+ },
822
+ {
823
+ "epoch": 0.8158165515213877,
824
+ "grad_norm": 6.351059913635254,
825
+ "learning_rate": 1.026625285854296e-05,
826
+ "loss": 2.3357,
827
+ "step": 5550
828
+ },
829
+ {
830
+ "epoch": 0.8231662501837425,
831
+ "grad_norm": 4.822713851928711,
832
+ "learning_rate": 9.857889578569096e-06,
833
+ "loss": 2.3523,
834
+ "step": 5600
835
+ },
836
+ {
837
+ "epoch": 0.8305159488460974,
838
+ "grad_norm": 5.248088359832764,
839
+ "learning_rate": 9.44952629859523e-06,
840
+ "loss": 2.3253,
841
+ "step": 5650
842
+ },
843
+ {
844
+ "epoch": 0.8378656475084522,
845
+ "grad_norm": 4.384091377258301,
846
+ "learning_rate": 9.041163018621367e-06,
847
+ "loss": 2.3021,
848
+ "step": 5700
849
+ },
850
+ {
851
+ "epoch": 0.845215346170807,
852
+ "grad_norm": 5.19210958480835,
853
+ "learning_rate": 8.632799738647501e-06,
854
+ "loss": 2.3342,
855
+ "step": 5750
856
+ },
857
+ {
858
+ "epoch": 0.8525650448331619,
859
+ "grad_norm": 4.561094760894775,
860
+ "learning_rate": 8.224436458673635e-06,
861
+ "loss": 2.2839,
862
+ "step": 5800
863
+ },
864
+ {
865
+ "epoch": 0.8599147434955167,
866
+ "grad_norm": 6.146600723266602,
867
+ "learning_rate": 7.816073178699772e-06,
868
+ "loss": 2.3136,
869
+ "step": 5850
870
+ },
871
+ {
872
+ "epoch": 0.8672644421578716,
873
+ "grad_norm": 4.4296746253967285,
874
+ "learning_rate": 7.407709898725907e-06,
875
+ "loss": 2.3562,
876
+ "step": 5900
877
+ },
878
+ {
879
+ "epoch": 0.8746141408202264,
880
+ "grad_norm": 5.189635276794434,
881
+ "learning_rate": 6.999346618752042e-06,
882
+ "loss": 2.2878,
883
+ "step": 5950
884
+ },
885
+ {
886
+ "epoch": 0.8819638394825812,
887
+ "grad_norm": 5.818453788757324,
888
+ "learning_rate": 6.590983338778177e-06,
889
+ "loss": 2.3219,
890
+ "step": 6000
891
+ },
892
+ {
893
+ "epoch": 0.8819638394825812,
894
+ "eval_loss": 1.8173038959503174,
895
+ "eval_runtime": 103.0814,
896
+ "eval_samples_per_second": 234.64,
897
+ "eval_steps_per_second": 7.334,
898
+ "step": 6000
899
+ },
900
+ {
901
+ "epoch": 0.8893135381449361,
902
+ "grad_norm": 5.1032633781433105,
903
+ "learning_rate": 6.182620058804313e-06,
904
+ "loss": 2.2941,
905
+ "step": 6050
906
+ },
907
+ {
908
+ "epoch": 0.8966632368072909,
909
+ "grad_norm": 5.225054740905762,
910
+ "learning_rate": 5.7742567788304485e-06,
911
+ "loss": 2.3245,
912
+ "step": 6100
913
+ },
914
+ {
915
+ "epoch": 0.9040129354696458,
916
+ "grad_norm": 4.597662448883057,
917
+ "learning_rate": 5.365893498856583e-06,
918
+ "loss": 2.2561,
919
+ "step": 6150
920
+ },
921
+ {
922
+ "epoch": 0.9113626341320006,
923
+ "grad_norm": 5.72780704498291,
924
+ "learning_rate": 4.957530218882718e-06,
925
+ "loss": 2.3327,
926
+ "step": 6200
927
+ },
928
+ {
929
+ "epoch": 0.9187123327943554,
930
+ "grad_norm": 4.458365440368652,
931
+ "learning_rate": 4.549166938908853e-06,
932
+ "loss": 2.3047,
933
+ "step": 6250
934
+ },
935
+ {
936
+ "epoch": 0.9260620314567103,
937
+ "grad_norm": 4.625770092010498,
938
+ "learning_rate": 4.140803658934989e-06,
939
+ "loss": 2.2916,
940
+ "step": 6300
941
+ },
942
+ {
943
+ "epoch": 0.9334117301190651,
944
+ "grad_norm": 5.372076511383057,
945
+ "learning_rate": 3.732440378961124e-06,
946
+ "loss": 2.3495,
947
+ "step": 6350
948
+ },
949
+ {
950
+ "epoch": 0.94076142878142,
951
+ "grad_norm": 7.00476598739624,
952
+ "learning_rate": 3.324077098987259e-06,
953
+ "loss": 1.9273,
954
+ "step": 6400
955
+ },
956
+ {
957
+ "epoch": 0.9481111274437748,
958
+ "grad_norm": 7.588624000549316,
959
+ "learning_rate": 2.9157138190133943e-06,
960
+ "loss": 1.3917,
961
+ "step": 6450
962
+ },
963
+ {
964
+ "epoch": 0.9554608261061297,
965
+ "grad_norm": 7.540070056915283,
966
+ "learning_rate": 2.5073505390395295e-06,
967
+ "loss": 1.4726,
968
+ "step": 6500
969
+ },
970
+ {
971
+ "epoch": 0.9628105247684845,
972
+ "grad_norm": 6.5697126388549805,
973
+ "learning_rate": 2.0989872590656647e-06,
974
+ "loss": 1.3922,
975
+ "step": 6550
976
+ },
977
+ {
978
+ "epoch": 0.9701602234308393,
979
+ "grad_norm": 7.318496227264404,
980
+ "learning_rate": 1.6906239790918002e-06,
981
+ "loss": 1.4664,
982
+ "step": 6600
983
+ },
984
+ {
985
+ "epoch": 0.9775099220931942,
986
+ "grad_norm": 8.529196739196777,
987
+ "learning_rate": 1.2822606991179354e-06,
988
+ "loss": 1.4329,
989
+ "step": 6650
990
+ },
991
+ {
992
+ "epoch": 0.984859620755549,
993
+ "grad_norm": 6.7521843910217285,
994
+ "learning_rate": 8.738974191440705e-07,
995
+ "loss": 1.4046,
996
+ "step": 6700
997
+ },
998
+ {
999
+ "epoch": 0.9922093194179039,
1000
+ "grad_norm": 8.020008087158203,
1001
+ "learning_rate": 4.6553413917020586e-07,
1002
+ "loss": 1.3891,
1003
+ "step": 6750
1004
+ },
1005
+ {
1006
+ "epoch": 0.9995590180802587,
1007
+ "grad_norm": 7.628392696380615,
1008
+ "learning_rate": 5.717085919634107e-08,
1009
+ "loss": 1.4731,
1010
+ "step": 6800
1011
+ }
1012
+ ],
1013
+ "logging_steps": 50,
1014
+ "max_steps": 6803,
1015
+ "num_input_tokens_seen": 0,
1016
+ "num_train_epochs": 1,
1017
+ "save_steps": 1000,
1018
+ "stateful_callbacks": {
1019
+ "TrainerControl": {
1020
+ "args": {
1021
+ "should_epoch_stop": false,
1022
+ "should_evaluate": false,
1023
+ "should_log": false,
1024
+ "should_save": true,
1025
+ "should_training_stop": true
1026
+ },
1027
+ "attributes": {}
1028
+ }
1029
+ },
1030
+ "total_flos": 0.0,
1031
+ "train_batch_size": 32,
1032
+ "trial_name": null,
1033
+ "trial_params": null
1034
+ }
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1205f968e0e4b4869d9aeaab5fe89517f53bf2ee762a0ced4c05fcf02990b3fa
3
+ size 5688