Hhhh / text_generation.py
Hjgugugjhuhjggg's picture
Upload 28 files
e83e49f verified
raw
history blame
5.93 kB
import torch
import torch.nn.functional as F
from tqdm import trange
import time
from tokenxxx import *
from main import *
from duckduckgo_search import DDGS
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
top_k = min(top_k, logits.size(-1))
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., [-1]]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = filter_value
return logits
def sample_sequence(prompt, model, enc, length, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
context_tokens = enc.encode(prompt)
context_tokens_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device)
generated = context_tokens
past_key_values = None
with torch.no_grad():
for _ in range(length):
outputs = model(context_tokens_tensor, past_key_values=past_key_values)
next_token_logits = outputs[0][:, -1, :] / temperature
past_key_values = outputs[1]
for token_index in set(generated):
next_token_logits[0, token_index] /= repetition_penalty
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
if temperature == 0:
next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(0)
else:
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
generated += next_token.tolist()[0]
token = next_token.tolist()[0][0]
yield enc.decode([token])
if token == enc.encoder[END_OF_TEXT_TOKEN]:
yield "<END_STREAM>"
return
def sample_sequence_codegen(prompt, model, tokenizer, length, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
context_tokens = tokenizer.encode(prompt)
context_tokens_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device).unsqueeze(0)
generated = context_tokens
past_key_values = None
with torch.no_grad():
for _ in range(length):
outputs = model(input_ids=context_tokens_tensor, past_key_values=past_key_values, labels=None)
next_token_logits = outputs[0][:, -1, :] / temperature
past_key_values = outputs[1]
for token_index in set(generated):
next_token_logits[0, token_index] /= repetition_penalty
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
if temperature == 0:
next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(0)
else:
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
generated.append(next_token.tolist()[0][0])
token = next_token.tolist()[0][0]
yield tokenizer.decode([token])
if token == 50256:
yield "<END_STREAM>"
return
def perform_reasoning_stream(text_input, temperature, top_k, top_p, repetition_penalty):
prompt_text = SYSTEM_PROMPT + "\n\n"
prompt_text += "User: " + text_input + "\nAssistant:"
reasoning_prompt = prompt_text
ddgs = DDGS()
search_results = [r for r in ddgs.text(text_input, max_results=MAX_XDD)]
if search_results:
prompt_text += "\nWeb Search Results:\n"
for result in search_results:
prompt_text += f"- {result['body']}\n"
prompt_text += "\n"
generated_text_stream = []
stream_type = "text"
if "code" in text_input.lower() or "program" in text_input.lower():
if codegen_model and codegen_tokenizer:
generated_text_stream = sample_sequence_codegen(
prompt=reasoning_prompt,
model=codegen_model,
tokenizer=codegen_tokenizer,
length=999999999,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
device=device
)
stream_type = "text"
elif "summarize" in text_input.lower() or "summary" in text_input.lower():
if summarization_model:
summary = summarize_text(text_input)
yield f"SUMMARY_TEXT:{summary}"
yield "<END_STREAM>"
stream_type = "summary"
else:
if model_gpt2 and enc:
generated_text_stream = sample_sequence(
prompt=reasoning_prompt,
model=model_gpt2,
enc=enc,
length=999999999,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
device=device
)
stream_type = "text"
accumulated_text = ""
if stream_type == "text":
for token in generated_text_stream:
if token == "<END_STREAM>":
yield accumulated_text
yield "<END_STREAM>"
return
if token == END_OF_TEXT_TOKEN:
accumulated_text += END_OF_TEXT_TOKEN
continue
if token:
accumulated_text += token