There is also a fp16 branch which stores the fp16 weights, | |
which could be used to further minimize the RAM usage: | |
thon | |
from transformers import GPTJForCausalLM | |
import torch | |
device = "cuda" | |
model = GPTJForCausalLM.from_pretrained( | |
"EleutherAI/gpt-j-6B", | |
revision="float16", | |
torch_dtype=torch.float16, | |
).to(device) | |
The model should fit on 16GB GPU for inference. |