Import the | |
torch.neuron framework extension to access the components of the Neuron SDK through a | |
Python API: | |
python | |
from transformers import BertModel, BertTokenizer, BertConfig | |
import torch | |
import torch.neuron | |
You only need to modify the following line: | |
diff | |
- torch.jit.trace(model, [tokens_tensor, segments_tensors]) | |
+ torch.neuron.trace(model, [token_tensor, segments_tensors]) | |
This enables the Neuron SDK to trace the model and optimize it for Inf1 instances. |