torch.float16) | |
To load and run a model using Flash Attention 2, refer to the snippet below: | |
thon | |
import torch | |
from transformers import AutoTokenizer, AutoModel | |
device = "cuda" # the device to load the model onto | |
tokenizer = AutoTokenizer.from_pretrained('distilbert/distilbert-base-uncased') | |
model = AutoModel.from_pretrained("distilbert/distilbert-base-uncased", torch_dtype=torch.float16, attn_implementation="flash_attention_2") | |
text = "Replace me by any text you'd like." |