theodotus commited on
Commit
1ca335c
·
1 Parent(s): 651c45f

left only answers

Browse files
Files changed (1) hide show
  1. pipeline.py +7 -4
pipeline.py CHANGED
@@ -14,13 +14,16 @@ class PreTrainedPipeline():
14
  self.tokenizer = transformers.AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
15
 
16
  def __call__(self, inputs: str) -> List[Dict]:
17
-
18
  text = inputs + self.tokenizer.eos_token
19
  start_tokens = self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(text))
20
-
21
  results = self.generator.generate_batch([start_tokens])
22
  output = results[0].sequences[0]
23
-
24
- generated_text = self.tokenizer.decode(self.tokenizer.convert_tokens_to_ids(output))
 
 
 
25
 
26
  return [{"generated_text": generated_text}]
 
14
  self.tokenizer = transformers.AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
15
 
16
  def __call__(self, inputs: str) -> List[Dict]:
17
+ # Get input tokens
18
  text = inputs + self.tokenizer.eos_token
19
  start_tokens = self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(text))
20
+ # generate
21
  results = self.generator.generate_batch([start_tokens])
22
  output = results[0].sequences[0]
23
+ # left only answers
24
+ tokens = self.tokenizer.convert_tokens_to_ids(output)
25
+ eos_index = tokens.index(self.tokenizer.eos_token_id)
26
+ answer_tokens = tokens[eos_index+1:]
27
+ generated_text = self.tokenizer.decode(answer_tokens)
28
 
29
  return [{"generated_text": generated_text}]