File size: 2,705 Bytes
3134917 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
# BaseTrainer
## ๐ Trained With [EasyDeL](https://github.com/erfanzar/EasyDeL)
EasyDeL is an open-source framework designed to enhance and streamline the training process of machine learning
models. With a primary focus on Jax, EasyDeL aims to provide convenient and effective solutions for
training Flax/Jax models on TPU/GPU, for both serving and training purposes.
## ๐ฆ Installation & Usage
```python
from easydel import AutoEasyDeLModelForCausalLM
from jax import numpy as jnp, lax
model = AutoEasyDeLModelForCausalLM.from_pretrained(
f"REPO_ID/BaseTrainer",
dtype=...,
param_dtype=...,
precision=lax.Precision("fastest"),
auto_shard_model=True,
)
```
## ๐ง Training Configuration
### Model Details
- **Architecture**: gemma3_text
- **Platform**: TPU
- **Number of Devices**: 16
### Training Parameters
- **Learning Rate**: 4e-05 โ 4e-06
- **Optimizer**: adamw
- **Scheduler**: cosine
- **Warmup Steps**: 50
- **Weight Decay**: 0.02
- **Loss Config**: LossConfig(
ignore_index : -100
label_smoothing : 0.0
z_loss : 0.0
loss_normalizing_factor : NUM_REAL_TARGET_TOKENS
num_labels : None
problem_type : None
divide_weight_sum : False
shift_tokens : True
break_on_nan : True
reduction : None
num_classification_labels : None
classification_problem_type : None
)
### Training Setup
- **Epochs**: 3
- **Batch Size**: 8
- **Sequence Length**: 8192
- **Dtype**: <class 'jax.numpy.bfloat16'>
- **Params Dtype**: <class 'jax.numpy.bfloat16'>
### Advanced Configuration
- **Gradient Checkpointing**:
- **Gradient Accumulation Steps**: 1
- **Max Training Steps**: None
- **Max Evaluation Steps**: None
- **Training Duration**: 7H
### Sharding Configuration
```python
# Partition Rules
( ('model/embed_tokens/embedding', PartitionSpec(('fsdp', 'sp'), 'tp')),
('self_attn/q_proj/kernel', PartitionSpec('tp', ('fsdp', 'sp'))),
('self_attn/k_proj/kernel', PartitionSpec('tp', ('fsdp', 'sp'))),
('self_attn/v_proj/kernel', PartitionSpec('tp', ('fsdp', 'sp'))),
('self_attn/o_proj/kernel', PartitionSpec(('fsdp', 'sp'), 'tp')),
('mlp/gate_proj/kernel', PartitionSpec(('fsdp', 'sp'), 'tp')),
('mlp/up_proj/kernel', PartitionSpec(('fsdp', 'sp'), 'tp')),
('mlp/down_proj/kernel', PartitionSpec('tp', ('fsdp', 'sp'))),
('input_layernorm/kernel', PartitionSpec(None,)),
('post_attention_layernorm/kernel', PartitionSpec(None,)),
('pre_feedforward_layernorm/kernel', PartitionSpec(None,)),
('post_feedforward_layernorm/kernel', PartitionSpec(None,)),
('model/norm/kernel', PartitionSpec(None,)),
('lm_head/kernel', PartitionSpec(('fsdp', 'sp'), 'tp')),
('.*', PartitionSpec(None,)))
```
---
*Generated with EasyDeL v0.1.3*
|