Ahmadzei's picture
added 3 more tables for large emb model
5fa1a76
For instance, the following device map would work properly for T0pp (as long as you have the GPU memory):
python
device_map = {"shared": 0, "encoder": 0, "decoder": 1, "lm_head": 1}
Another way to minimize the memory impact of your model is to instantiate it at a lower precision dtype (like torch.float16) or use direct quantization techniques as described below.