sparsh35 commited on
Commit
3134917
·
verified ·
1 Parent(s): d34e0ce

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. README.md +88 -0
  3. config.json +112 -0
  4. easydel-model.parameters +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ easydel-model.parameters filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # BaseTrainer
3
+
4
+ ## 🚀 Trained With [EasyDeL](https://github.com/erfanzar/EasyDeL)
5
+
6
+ EasyDeL is an open-source framework designed to enhance and streamline the training process of machine learning
7
+ models. With a primary focus on Jax, EasyDeL aims to provide convenient and effective solutions for
8
+ training Flax/Jax models on TPU/GPU, for both serving and training purposes.
9
+
10
+ ## 📦 Installation & Usage
11
+
12
+ ```python
13
+ from easydel import AutoEasyDeLModelForCausalLM
14
+ from jax import numpy as jnp, lax
15
+
16
+ model = AutoEasyDeLModelForCausalLM.from_pretrained(
17
+ f"REPO_ID/BaseTrainer",
18
+ dtype=...,
19
+ param_dtype=...,
20
+ precision=lax.Precision("fastest"),
21
+ auto_shard_model=True,
22
+ )
23
+ ```
24
+
25
+ ## 🔧 Training Configuration
26
+
27
+ ### Model Details
28
+ - **Architecture**: gemma3_text
29
+ - **Platform**: TPU
30
+ - **Number of Devices**: 16
31
+
32
+ ### Training Parameters
33
+ - **Learning Rate**: 4e-05 → 4e-06
34
+ - **Optimizer**: adamw
35
+ - **Scheduler**: cosine
36
+ - **Warmup Steps**: 50
37
+ - **Weight Decay**: 0.02
38
+ - **Loss Config**: LossConfig(
39
+ ignore_index : -100
40
+ label_smoothing : 0.0
41
+ z_loss : 0.0
42
+ loss_normalizing_factor : NUM_REAL_TARGET_TOKENS
43
+ num_labels : None
44
+ problem_type : None
45
+ divide_weight_sum : False
46
+ shift_tokens : True
47
+ break_on_nan : True
48
+ reduction : None
49
+ num_classification_labels : None
50
+ classification_problem_type : None
51
+ )
52
+
53
+ ### Training Setup
54
+ - **Epochs**: 3
55
+ - **Batch Size**: 8
56
+ - **Sequence Length**: 8192
57
+ - **Dtype**: <class 'jax.numpy.bfloat16'>
58
+ - **Params Dtype**: <class 'jax.numpy.bfloat16'>
59
+
60
+ ### Advanced Configuration
61
+ - **Gradient Checkpointing**:
62
+ - **Gradient Accumulation Steps**: 1
63
+ - **Max Training Steps**: None
64
+ - **Max Evaluation Steps**: None
65
+ - **Training Duration**: 7H
66
+
67
+ ### Sharding Configuration
68
+ ```python
69
+ # Partition Rules
70
+ ( ('model/embed_tokens/embedding', PartitionSpec(('fsdp', 'sp'), 'tp')),
71
+ ('self_attn/q_proj/kernel', PartitionSpec('tp', ('fsdp', 'sp'))),
72
+ ('self_attn/k_proj/kernel', PartitionSpec('tp', ('fsdp', 'sp'))),
73
+ ('self_attn/v_proj/kernel', PartitionSpec('tp', ('fsdp', 'sp'))),
74
+ ('self_attn/o_proj/kernel', PartitionSpec(('fsdp', 'sp'), 'tp')),
75
+ ('mlp/gate_proj/kernel', PartitionSpec(('fsdp', 'sp'), 'tp')),
76
+ ('mlp/up_proj/kernel', PartitionSpec(('fsdp', 'sp'), 'tp')),
77
+ ('mlp/down_proj/kernel', PartitionSpec('tp', ('fsdp', 'sp'))),
78
+ ('input_layernorm/kernel', PartitionSpec(None,)),
79
+ ('post_attention_layernorm/kernel', PartitionSpec(None,)),
80
+ ('pre_feedforward_layernorm/kernel', PartitionSpec(None,)),
81
+ ('post_feedforward_layernorm/kernel', PartitionSpec(None,)),
82
+ ('model/norm/kernel', PartitionSpec(None,)),
83
+ ('lm_head/kernel', PartitionSpec(('fsdp', 'sp'), 'tp')),
84
+ ('.*', PartitionSpec(None,)))
85
+ ```
86
+
87
+ ---
88
+ *Generated with EasyDeL v0.1.3*
config.json ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Gemma3ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "attn_logit_softcapping": null,
8
+ "attn_mechanism": "splash",
9
+ "axis_dims": [
10
+ 1,
11
+ -1,
12
+ 2,
13
+ 1
14
+ ],
15
+ "axis_names": [
16
+ "dp",
17
+ "fsdp",
18
+ "tp",
19
+ "sp"
20
+ ],
21
+ "backend": null,
22
+ "bits": null,
23
+ "blocksize_b": 1,
24
+ "blocksize_k": 128,
25
+ "blocksize_q": 128,
26
+ "bos_token_id": 2,
27
+ "cache_implementation": "hybrid",
28
+ "dcn_axis_dims": null,
29
+ "easy_method": "train",
30
+ "eos_token_id": 1,
31
+ "fcm_max_ratio": 0.0,
32
+ "fcm_min_ratio": 0.0,
33
+ "final_logit_softcapping": null,
34
+ "flash_attention_backward_pass_impl": "triton",
35
+ "freq_max_position_embeddings": 8192,
36
+ "gradient_checkpointing": "",
37
+ "hardware_abstraction": false,
38
+ "head_dim": 256,
39
+ "hidden_activation": "gelu_pytorch_tanh",
40
+ "hidden_size": 2560,
41
+ "initializer_range": 0.02,
42
+ "intermediate_size": 10240,
43
+ "kv_cache_quantization_blocksize": 64,
44
+ "kv_cache_quantization_method": "None",
45
+ "kv_cache_sharding_sequence_axis_name": "sp",
46
+ "mask_max_position_embeddings": 8192,
47
+ "max_position_embeddings": 131072,
48
+ "model_type": "gemma3_text",
49
+ "num_attention_heads": 8,
50
+ "num_hidden_layers": 34,
51
+ "num_key_value_heads": 4,
52
+ "pad_token_id": 0,
53
+ "pallas_k_block_size": 128,
54
+ "pallas_m_block_size": 128,
55
+ "pallas_n_block_size": 128,
56
+ "partition_axis": {
57
+ "attention_dim_axis": null,
58
+ "batch_axis": [
59
+ "fsdp",
60
+ "dp"
61
+ ],
62
+ "bias_head_sequence_axis": null,
63
+ "bias_key_sequence_axis": null,
64
+ "data_parallel_axis": "dp",
65
+ "expert_axis": "ep",
66
+ "expert_gate_axis": null,
67
+ "expert_parallel_axis": "ep",
68
+ "fully_sharded_data_parallel_axis": "fsdp",
69
+ "generation_attention_dim_axis": null,
70
+ "generation_batch_axis": null,
71
+ "generation_head_axis": "tp",
72
+ "generation_key_sequence_axis": "sp",
73
+ "generation_query_sequence_axis": null,
74
+ "head_axis": "tp",
75
+ "hidden_state_axis": "tp",
76
+ "key_sequence_axis": "sp",
77
+ "mlp_intermediate_axis": "tp",
78
+ "query_sequence_axis": "sp",
79
+ "sequence_axis": "sp",
80
+ "sequence_parallel_axis": "sp",
81
+ "tensor_parallel_axis": "tp",
82
+ "vocab_axis": "tp"
83
+ },
84
+ "platform": "jax",
85
+ "precompute_masks": true,
86
+ "pretraining_tp": 1,
87
+ "quantization_blocksize": 64,
88
+ "quantization_method": "None",
89
+ "quantization_pattern": ".*",
90
+ "query_pre_attn_scalar": 256,
91
+ "rms_norm_eps": 1e-06,
92
+ "rope_local_base_freq": 10000.0,
93
+ "rope_scaling": {
94
+ "factor": 8.0,
95
+ "rope_type": "linear"
96
+ },
97
+ "rope_theta": 1000000.0,
98
+ "scan_attention_layers": false,
99
+ "scan_layers": false,
100
+ "scan_mlp_chunk_size": 1024,
101
+ "scan_ring_attention": true,
102
+ "sequence_axis_name": "sp",
103
+ "shard_attention_computation": true,
104
+ "sliding_window": 1024,
105
+ "sliding_window_pattern": 6,
106
+ "transformers_version": "4.50.3",
107
+ "use_cache": true,
108
+ "use_scan_mlp": false,
109
+ "use_sharded_kv_caching": false,
110
+ "use_sharding_constraint": false,
111
+ "vocab_size": 262208
112
+ }
easydel-model.parameters ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a5eca75e5f5eefa1ff263feac16d54061573e335f3add37d6f38996bcacd3ca
3
+ size 9103083144