Ahmadzei's picture
added 3 more tables for large emb model
5fa1a76
`torch.float16``)
To load and run a model using Flash Attention 2, refer to the snippet below:
thon
import torch
from transformers import OPTForCausalLM, GPT2Tokenizer
device = "cuda" # the device to load the model onto
model = OPTForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")
prompt = ("A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I am the "
"Statue of Liberty.\nHuman: Where do you live?\nStatue: New York City.\nHuman: How long have you lived "
"there?")