Ahmadzei's picture
added 3 more tables for large emb model
5fa1a76
`torch.float16``)
To load and run a model using Flash Attention 2, refer to the snippet below:
thon
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda" # the device to load the model onto
model = AutoModelForCausalLM.from_pretrained("bigcode/gpt_bigcode-santacoder", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
tokenizer = AutoTokenizer.from_pretrained("bigcode/gpt_bigcode-santacoder")
prompt = "def hello_world():"
model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
model.to(device)
generated_ids = model.generate(**model_inputs, max_new_tokens=30, do_sample=False)
tokenizer.batch_decode(generated_ids)[0]
'def hello_world():\n print("hello world")\n\nif name == "main":\n print("hello world")\n<|endoftext|>'
Expected speedups
Below is a expected speedup diagram that compares pure inference time between the native implementation in transformers using bigcode/starcoder checkpoint and the Flash Attention 2 version of the model using two different sequence lengths.