File size: 476 Bytes
5fa1a76
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
torch.float16)
To load and run a model using Flash Attention 2, refer to the snippet below:
thon

import torch
from transformers import AutoTokenizer, AutoModel
device = "cuda" # the device to load the model onto
tokenizer = AutoTokenizer.from_pretrained('distilbert/distilbert-base-uncased')
model = AutoModel.from_pretrained("distilbert/distilbert-base-uncased", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
text = "Replace me by any text you'd like."