left only answers
Browse files- pipeline.py +7 -4
pipeline.py
CHANGED
@@ -14,13 +14,16 @@ class PreTrainedPipeline():
|
|
14 |
self.tokenizer = transformers.AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
|
15 |
|
16 |
def __call__(self, inputs: str) -> List[Dict]:
|
17 |
-
|
18 |
text = inputs + self.tokenizer.eos_token
|
19 |
start_tokens = self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(text))
|
20 |
-
|
21 |
results = self.generator.generate_batch([start_tokens])
|
22 |
output = results[0].sequences[0]
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
25 |
|
26 |
return [{"generated_text": generated_text}]
|
|
|
14 |
self.tokenizer = transformers.AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
|
15 |
|
16 |
def __call__(self, inputs: str) -> List[Dict]:
|
17 |
+
# Get input tokens
|
18 |
text = inputs + self.tokenizer.eos_token
|
19 |
start_tokens = self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(text))
|
20 |
+
# generate
|
21 |
results = self.generator.generate_batch([start_tokens])
|
22 |
output = results[0].sequences[0]
|
23 |
+
# left only answers
|
24 |
+
tokens = self.tokenizer.convert_tokens_to_ids(output)
|
25 |
+
eos_index = tokens.index(self.tokenizer.eos_token_id)
|
26 |
+
answer_tokens = tokens[eos_index+1:]
|
27 |
+
generated_text = self.tokenizer.decode(answer_tokens)
|
28 |
|
29 |
return [{"generated_text": generated_text}]
|