To load and run a model using Flash Attention 2, refer to the snippet below: | |
thon | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
device = "cuda" # the device to load the model onto | |
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B", torch_dtype=torch.float16, attn_implementation="flash_attention_2") | |
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B") | |
prompt = "def hello_world():" | |
model_inputs = tokenizer([prompt], return_tensors="pt").to(device) | |
model.to(device) | |
generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True) | |
tokenizer.batch_decode(generated_ids)[0] | |
"def hello_world():\n >>> run_script("hello.py")\n >>> exit(0)\n<|endoftext|>" | |
Expected speedups | |
Below is an expected speedup diagram that compares pure inference time between the native implementation in transformers using EleutherAI/gpt-neo-2.7B checkpoint and the Flash Attention 2 version of the model. |