For instance, the following device map would work properly for T0pp (as long as you have the GPU memory): | |
python | |
device_map = {"shared": 0, "encoder": 0, "decoder": 1, "lm_head": 1} | |
Another way to minimize the memory impact of your model is to instantiate it at a lower precision dtype (like torch.float16) or use direct quantization techniques as described below. |