Ahmadzei's picture
added 3 more tables for large emb model
5fa1a76
Import the
torch.neuron framework extension to access the components of the Neuron SDK through a
Python API:
python
from transformers import BertModel, BertTokenizer, BertConfig
import torch
import torch.neuron
You only need to modify the following line:
diff
- torch.jit.trace(model, [tokens_tensor, segments_tensors])
+ torch.neuron.trace(model, [token_tensor, segments_tensors])
This enables the Neuron SDK to trace the model and optimize it for Inf1 instances.