Ahmadzei's picture
added 3 more tables for large emb model
5fa1a76
model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
model.to(device)
generated_ids = model.generate(**model_inputs, max_new_tokens=30, do_sample=False)
tokenizer.batch_decode(generated_ids)[0]
'A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I am the Statue of Liberty.\nHuman: Where do you live?\nStatue: New York City.\nHuman: How long have you lived there?\nStatue: I have lived here for about a year.\nHuman: What is your favorite place to eat?\nStatue: I love'
Expected speedups
Below is an expected speedup diagram that compares pure inference time between the native implementation in transformers using facebook/opt-2.7b checkpoint and the Flash Attention 2 version of the model using two different sequence lengths.