Spaces:
Running
Running
import torch | |
import torch.nn.functional as F | |
from tqdm import trange | |
import time | |
from tokenxxx import * | |
from main import * | |
from duckduckgo_search import DDGS | |
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): | |
top_k = min(top_k, logits.size(-1)) | |
if top_k > 0: | |
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., [-1]] | |
logits[indices_to_remove] = filter_value | |
if top_p > 0.0: | |
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) | |
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |
sorted_indices_to_remove = cumulative_probs > top_p | |
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
sorted_indices_to_remove[..., 0] = 0 | |
indices_to_remove = sorted_indices[sorted_indices_to_remove] | |
logits[indices_to_remove] = filter_value | |
return logits | |
def sample_sequence(prompt, model, enc, length, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"): | |
context_tokens = enc.encode(prompt) | |
context_tokens_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device) | |
generated = context_tokens | |
past_key_values = None | |
with torch.no_grad(): | |
for _ in range(length): | |
outputs = model(context_tokens_tensor, past_key_values=past_key_values) | |
next_token_logits = outputs[0][:, -1, :] / temperature | |
past_key_values = outputs[1] | |
for token_index in set(generated): | |
next_token_logits[0, token_index] /= repetition_penalty | |
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) | |
if temperature == 0: | |
next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(0) | |
else: | |
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) | |
generated += next_token.tolist()[0] | |
token = next_token.tolist()[0][0] | |
yield enc.decode([token]) | |
if token == enc.encoder[END_OF_TEXT_TOKEN]: | |
yield "<END_STREAM>" | |
return | |
def sample_sequence_codegen(prompt, model, tokenizer, length, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"): | |
context_tokens = tokenizer.encode(prompt) | |
context_tokens_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device).unsqueeze(0) | |
generated = context_tokens | |
past_key_values = None | |
with torch.no_grad(): | |
for _ in range(length): | |
outputs = model(input_ids=context_tokens_tensor, past_key_values=past_key_values, labels=None) | |
next_token_logits = outputs[0][:, -1, :] / temperature | |
past_key_values = outputs[1] | |
for token_index in set(generated): | |
next_token_logits[0, token_index] /= repetition_penalty | |
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) | |
if temperature == 0: | |
next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(0) | |
else: | |
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) | |
generated.append(next_token.tolist()[0][0]) | |
token = next_token.tolist()[0][0] | |
yield tokenizer.decode([token]) | |
if token == 50256: | |
yield "<END_STREAM>" | |
return | |
def perform_reasoning_stream(text_input, temperature, top_k, top_p, repetition_penalty): | |
prompt_text = SYSTEM_PROMPT + "\n\n" | |
prompt_text += "User: " + text_input + "\nAssistant:" | |
reasoning_prompt = prompt_text | |
ddgs = DDGS() | |
search_results = [r for r in ddgs.text(text_input, max_results=MAX_XDD)] | |
if search_results: | |
prompt_text += "\nWeb Search Results:\n" | |
for result in search_results: | |
prompt_text += f"- {result['body']}\n" | |
prompt_text += "\n" | |
generated_text_stream = [] | |
stream_type = "text" | |
if "code" in text_input.lower() or "program" in text_input.lower(): | |
if codegen_model and codegen_tokenizer: | |
generated_text_stream = sample_sequence_codegen( | |
prompt=reasoning_prompt, | |
model=codegen_model, | |
tokenizer=codegen_tokenizer, | |
length=999999999, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
device=device | |
) | |
stream_type = "text" | |
elif "summarize" in text_input.lower() or "summary" in text_input.lower(): | |
if summarization_model: | |
summary = summarize_text(text_input) | |
yield f"SUMMARY_TEXT:{summary}" | |
yield "<END_STREAM>" | |
stream_type = "summary" | |
else: | |
if model_gpt2 and enc: | |
generated_text_stream = sample_sequence( | |
prompt=reasoning_prompt, | |
model=model_gpt2, | |
enc=enc, | |
length=999999999, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
device=device | |
) | |
stream_type = "text" | |
accumulated_text = "" | |
if stream_type == "text": | |
for token in generated_text_stream: | |
if token == "<END_STREAM>": | |
yield accumulated_text | |
yield "<END_STREAM>" | |
return | |
if token == END_OF_TEXT_TOKEN: | |
accumulated_text += END_OF_TEXT_TOKEN | |
continue | |
if token: | |
accumulated_text += token |