|
|
|
# BaseTrainer |
|
|
|
## π Trained With [EasyDeL](https://github.com/erfanzar/EasyDeL) |
|
|
|
EasyDeL is an open-source framework designed to enhance and streamline the training process of machine learning |
|
models. With a primary focus on Jax, EasyDeL aims to provide convenient and effective solutions for |
|
training Flax/Jax models on TPU/GPU, for both serving and training purposes. |
|
|
|
## π¦ Installation & Usage |
|
|
|
```python |
|
from easydel import AutoEasyDeLModelForCausalLM |
|
from jax import numpy as jnp, lax |
|
|
|
model = AutoEasyDeLModelForCausalLM.from_pretrained( |
|
f"REPO_ID/BaseTrainer", |
|
dtype=..., |
|
param_dtype=..., |
|
precision=lax.Precision("fastest"), |
|
auto_shard_model=True, |
|
) |
|
``` |
|
|
|
## π§ Training Configuration |
|
|
|
### Model Details |
|
- **Architecture**: gemma3_text |
|
- **Platform**: TPU |
|
- **Number of Devices**: 16 |
|
|
|
### Training Parameters |
|
- **Learning Rate**: 4e-05 β 4e-06 |
|
- **Optimizer**: adamw |
|
- **Scheduler**: cosine |
|
- **Warmup Steps**: 50 |
|
- **Weight Decay**: 0.02 |
|
- **Loss Config**: LossConfig( |
|
ignore_index : -100 |
|
label_smoothing : 0.0 |
|
z_loss : 0.0 |
|
loss_normalizing_factor : NUM_REAL_TARGET_TOKENS |
|
num_labels : None |
|
problem_type : None |
|
divide_weight_sum : False |
|
shift_tokens : True |
|
break_on_nan : True |
|
reduction : None |
|
num_classification_labels : None |
|
classification_problem_type : None |
|
) |
|
|
|
### Training Setup |
|
- **Epochs**: 3 |
|
- **Batch Size**: 8 |
|
- **Sequence Length**: 8192 |
|
- **Dtype**: <class 'jax.numpy.bfloat16'> |
|
- **Params Dtype**: <class 'jax.numpy.bfloat16'> |
|
|
|
### Advanced Configuration |
|
- **Gradient Checkpointing**: |
|
- **Gradient Accumulation Steps**: 1 |
|
- **Max Training Steps**: None |
|
- **Max Evaluation Steps**: None |
|
- **Training Duration**: 7H |
|
|
|
### Sharding Configuration |
|
```python |
|
# Partition Rules |
|
( ('model/embed_tokens/embedding', PartitionSpec(('fsdp', 'sp'), 'tp')), |
|
('self_attn/q_proj/kernel', PartitionSpec('tp', ('fsdp', 'sp'))), |
|
('self_attn/k_proj/kernel', PartitionSpec('tp', ('fsdp', 'sp'))), |
|
('self_attn/v_proj/kernel', PartitionSpec('tp', ('fsdp', 'sp'))), |
|
('self_attn/o_proj/kernel', PartitionSpec(('fsdp', 'sp'), 'tp')), |
|
('mlp/gate_proj/kernel', PartitionSpec(('fsdp', 'sp'), 'tp')), |
|
('mlp/up_proj/kernel', PartitionSpec(('fsdp', 'sp'), 'tp')), |
|
('mlp/down_proj/kernel', PartitionSpec('tp', ('fsdp', 'sp'))), |
|
('input_layernorm/kernel', PartitionSpec(None,)), |
|
('post_attention_layernorm/kernel', PartitionSpec(None,)), |
|
('pre_feedforward_layernorm/kernel', PartitionSpec(None,)), |
|
('post_feedforward_layernorm/kernel', PartitionSpec(None,)), |
|
('model/norm/kernel', PartitionSpec(None,)), |
|
('lm_head/kernel', PartitionSpec(('fsdp', 'sp'), 'tp')), |
|
('.*', PartitionSpec(None,))) |
|
``` |
|
|
|
--- |
|
*Generated with EasyDeL v0.1.3* |
|
|