Ahmadzei's picture
added 3 more tables for large emb model
5fa1a76
Quantizing a model in 8-bit halves the memory-usage, and for large models, set device_map="auto" to efficiently use the GPUs available:
from transformers import AutoModelForCausalLM
model_8bit = AutoModelForCausalLM.from_pretrained("bigscience/bloom-1b7", device_map="auto", load_in_8bit=True)
By default, all the other modules such as torch.nn.LayerNorm are converted to torch.float16.