File size: 672 Bytes
5fa1a76
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
`torch.float16``)
To load and run a model using Flash Attention 2, refer to the snippet below:
thon

import torch
from transformers import OPTForCausalLM, GPT2Tokenizer
device = "cuda" # the device to load the model onto
model = OPTForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")
prompt = ("A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I am the "
              "Statue of Liberty.\nHuman: Where do you live?\nStatue: New York City.\nHuman: How long have you lived "
              "there?")