`torch.float16``) | |
To load and run a model using Flash Attention 2, refer to the snippet below: | |
thon | |
import torch | |
from transformers import OPTForCausalLM, GPT2Tokenizer | |
device = "cuda" # the device to load the model onto | |
model = OPTForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16, attn_implementation="flash_attention_2") | |
tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m") | |
prompt = ("A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I am the " | |
"Statue of Liberty.\nHuman: Where do you live?\nStatue: New York City.\nHuman: How long have you lived " | |
"there?") |