File size: 476 Bytes
5fa1a76 |
1 2 3 4 5 6 7 8 9 10 |
torch.float16) To load and run a model using Flash Attention 2, refer to the snippet below: thon import torch from transformers import AutoTokenizer, AutoModel device = "cuda" # the device to load the model onto tokenizer = AutoTokenizer.from_pretrained('distilbert/distilbert-base-uncased') model = AutoModel.from_pretrained("distilbert/distilbert-base-uncased", torch_dtype=torch.float16, attn_implementation="flash_attention_2") text = "Replace me by any text you'd like." |