File size: 672 Bytes
5fa1a76 |
1 2 3 4 5 6 7 8 9 10 11 12 |
`torch.float16``) To load and run a model using Flash Attention 2, refer to the snippet below: thon import torch from transformers import OPTForCausalLM, GPT2Tokenizer device = "cuda" # the device to load the model onto model = OPTForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16, attn_implementation="flash_attention_2") tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m") prompt = ("A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I am the " "Statue of Liberty.\nHuman: Where do you live?\nStatue: New York City.\nHuman: How long have you lived " "there?") |