torch.float16), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference: | |
thon | |
from transformers import GPTNeoXForCausalLM, GPTNeoXTokenizerFast | |
model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to(device) | |
Expected speedups | |
Below is an expected speedup diagram that compares pure inference time between the native implementation in transformers using stockmark/gpt-neox-japanese-1.4b checkpoint and the Flash Attention 2 version of the model using a sequence length of 2048. |