DLF / trains /subNets /BertTextEncoder.py
peter-wang321
Initial DLF commit
9157432
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer, RobertaModel, RobertaTokenizer
__all__ = ['BertTextEncoder']
TRANSFORMERS_MAP = {
'bert': (BertModel, BertTokenizer),
'roberta': (RobertaModel, RobertaTokenizer),
}
class BertTextEncoder(nn.Module):
def __init__(self, use_finetune=False, transformers='bert', pretrained='bert-base-uncased'):
super().__init__()
tokenizer_class = TRANSFORMERS_MAP[transformers][1]
model_class = TRANSFORMERS_MAP[transformers][0]
self.tokenizer = tokenizer_class.from_pretrained(pretrained)
self.model = model_class.from_pretrained(pretrained)
self.use_finetune = use_finetune
def get_tokenizer(self):
return self.tokenizer
# def from_text(self, text):
# """
# text: raw data
# """
# input_ids = self.get_id(text)
# with torch.no_grad():
# last_hidden_states = self.model(input_ids)[0] # Models outputs are now tuples
# return last_hidden_states.squeeze()
def forward(self, text):
"""
text: (batch_size, 3, seq_len)
3: input_ids, input_mask, segment_ids
input_ids: input_ids,
input_mask: attention_mask,
segment_ids: token_type_ids
"""
input_ids, input_mask, segment_ids = text[:,0,:].long(), text[:,1,:].float(), text[:,2,:].long()
if self.use_finetune:
last_hidden_states = self.model(input_ids=input_ids,
attention_mask=input_mask,
token_type_ids=segment_ids)[0] # Models outputs are now tuples
else:
with torch.no_grad():
last_hidden_states = self.model(input_ids=input_ids,
attention_mask=input_mask,
token_type_ids=segment_ids)[0] # Models outputs are now tuples
return last_hidden_states