File size: 933 Bytes
5fa1a76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
torch.float16)
To load and run a model using Flash Attention 2, refer to the snippet below:
thon

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda" # the device to load the model onto
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
prompt = "My favourite condiment is"
model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
model.to(device)
generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
tokenizer.batch_decode(generated_ids)[0]
"The expected output"

Expected speedups
Below is a expected speedup diagram that compares pure inference time between the native implementation in transformers using mistralai/Mistral-7B-v0.1 checkpoint and the Flash Attention 2 version of the model.