|
import transformers |
|
from transformers import TextStreamer |
|
import torch |
|
from transformers.generation.streamers import BaseStreamer |
|
|
|
|
|
class TokenStreamer(BaseStreamer): |
|
""" |
|
Simple token streamer that prints each token with its corresponding layers used. |
|
|
|
Parameters: |
|
tokenizer (`AutoTokenizer`): |
|
The tokenizer used to decode the tokens. |
|
skip_prompt (`bool`, *optional*, defaults to `False`): |
|
Whether to skip the prompt tokens in the output. Useful for chatbots. |
|
""" |
|
|
|
def __init__(self, tokenizer, skip_prompt=True): |
|
self.tokenizer = tokenizer |
|
self.skip_prompt = skip_prompt |
|
self.next_tokens_are_prompt = True |
|
|
|
def put(self, value): |
|
""" |
|
Receives tokens and prints each one surrounded by brackets. |
|
""" |
|
if len(value.shape) > 1 and value.shape[0] > 1: |
|
raise ValueError("TokenStreamer only supports batch size 1") |
|
elif len(value.shape) > 1: |
|
value = value[0] |
|
|
|
if self.skip_prompt and self.next_tokens_are_prompt: |
|
self.next_tokens_are_prompt = False |
|
return |
|
|
|
|
|
for token_id in value.tolist(): |
|
token_text = self.tokenizer.decode([token_id]) |
|
print(f"={repr(token_text)}", end="\n", flush=True) |
|
|
|
def end(self): |
|
"""Prints a newline at the end of generation.""" |
|
self.next_tokens_are_prompt = True |
|
print() |
|
|
|
|
|
|
|
|
|
model_id = "./" |
|
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", trust_remote_code=True) |
|
model = transformers.AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto", |
|
trust_remote_code=True |
|
) |
|
|
|
pipeline = transformers.pipeline( |
|
"text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
model_kwargs={"torch_dtype": torch.bfloat16}, |
|
device_map="auto", |
|
trust_remote_code=True |
|
) |
|
|
|
messages = [ |
|
{"role": "user", "content": \ |
|
""" |
|
Jan has three times the number of pets as Marcia. Marcia has two more pets than Cindy. If Cindy has four pets, how many total pets do the three have? |
|
"""}, |
|
] |
|
|
|
terminators = [ |
|
pipeline.tokenizer.eos_token_id, |
|
pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>") |
|
] |
|
|
|
|
|
streamer = TokenStreamer(tokenizer) |
|
outputs = pipeline( |
|
messages, |
|
max_new_tokens=512, |
|
eos_token_id=terminators, |
|
do_sample=True, |
|
temperature=0.6, |
|
top_p=1.0, |
|
streamer=streamer, |
|
) |