File size: 2,705 Bytes
3134917
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

# BaseTrainer

## ๐Ÿš€ Trained With [EasyDeL](https://github.com/erfanzar/EasyDeL)

EasyDeL is an open-source framework designed to enhance and streamline the training process of machine learning
models. With a primary focus on Jax, EasyDeL aims to provide convenient and effective solutions for 
training Flax/Jax models on TPU/GPU, for both serving and training purposes.

## ๐Ÿ“ฆ Installation & Usage
 
```python
from easydel import AutoEasyDeLModelForCausalLM
from jax import numpy as jnp, lax

model = AutoEasyDeLModelForCausalLM.from_pretrained(
    f"REPO_ID/BaseTrainer",
    dtype=...,
    param_dtype=...,
    precision=lax.Precision("fastest"),
    auto_shard_model=True,
)
```

## ๐Ÿ”ง Training Configuration

### Model Details
- **Architecture**: gemma3_text
- **Platform**: TPU
- **Number of Devices**: 16

### Training Parameters
- **Learning Rate**: 4e-05 โ†’ 4e-06
- **Optimizer**: adamw
- **Scheduler**: cosine
- **Warmup Steps**: 50
- **Weight Decay**: 0.02
- **Loss Config**: LossConfig(
  ignore_index : -100
  label_smoothing : 0.0
  z_loss : 0.0
  loss_normalizing_factor : NUM_REAL_TARGET_TOKENS
  num_labels : None
  problem_type : None
  divide_weight_sum : False
  shift_tokens : True
  break_on_nan : True
  reduction : None
  num_classification_labels : None
  classification_problem_type : None
)

### Training Setup
- **Epochs**: 3
- **Batch Size**: 8
- **Sequence Length**: 8192 
- **Dtype**: <class 'jax.numpy.bfloat16'>
- **Params Dtype**: <class 'jax.numpy.bfloat16'>

### Advanced Configuration
- **Gradient Checkpointing**:   
- **Gradient Accumulation Steps**: 1
- **Max Training Steps**: None
- **Max Evaluation Steps**: None
- **Training Duration**: 7H

### Sharding Configuration
```python
# Partition Rules
( ('model/embed_tokens/embedding', PartitionSpec(('fsdp', 'sp'), 'tp')),
  ('self_attn/q_proj/kernel', PartitionSpec('tp', ('fsdp', 'sp'))),
  ('self_attn/k_proj/kernel', PartitionSpec('tp', ('fsdp', 'sp'))),
  ('self_attn/v_proj/kernel', PartitionSpec('tp', ('fsdp', 'sp'))),
  ('self_attn/o_proj/kernel', PartitionSpec(('fsdp', 'sp'), 'tp')),
  ('mlp/gate_proj/kernel', PartitionSpec(('fsdp', 'sp'), 'tp')),
  ('mlp/up_proj/kernel', PartitionSpec(('fsdp', 'sp'), 'tp')),
  ('mlp/down_proj/kernel', PartitionSpec('tp', ('fsdp', 'sp'))),
  ('input_layernorm/kernel', PartitionSpec(None,)),
  ('post_attention_layernorm/kernel', PartitionSpec(None,)),
  ('pre_feedforward_layernorm/kernel', PartitionSpec(None,)),
  ('post_feedforward_layernorm/kernel', PartitionSpec(None,)),
  ('model/norm/kernel', PartitionSpec(None,)),
  ('lm_head/kernel', PartitionSpec(('fsdp', 'sp'), 'tp')),
  ('.*', PartitionSpec(None,)))
```

---
*Generated with EasyDeL v0.1.3*