When using Flash Attention 2 via attn_implementation="flash_attention_2", don't pass torch_dtype to the from_pretrained class method and use Automatic Mixed-Precision training.