Spaces:
Running
Running
File size: 6,780 Bytes
1c817fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import torch
import torch.nn.functional as F
from tqdm import trange
import time
from tokenxxx import *
from main import *
#from main import import model_gpt2, enc, codegen_model, codegen_tokenizer, summarization_model, device, system_prompt, MAX_LENGTH, summarize_text as summarize_func
from duckduckgo_search import DDGS
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
top_k = min(top_k, logits.size(-1))
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., [-1]]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = filter_value
return logits
def sample_sequence(prompt, model, enc, length, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
start_time = time.time()
context_tokens = enc.encode(prompt)
context_tokens_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device)
generated = context_tokens
past = None
text_generated_count = 0
past_key_values = past if past is not None else None
with torch.no_grad():
outputs = model(context_tokens_tensor, past_key_values=past_key_values)
next_token_logits = outputs[0][:, -1, :] / temperature
past = outputs[1]
for token_index in set(generated):
next_token_logits[0, token_index] /= repetition_penalty
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
if temperature == 0:
next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(0)
else:
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
generated += next_token.tolist()[0]
text_generated_count += 1
token = next_token.tolist()[0][0]
yield enc.decode([token])
if token == enc.encoder[END_OF_TEXT_TOKEN]:
yield "<END_STREAM>"
if text_generated_count > length:
yield "<END_STREAM>"
if (time.time() - start_time) * 1000 > 5000:
yield "<END_STREAM>"
def sample_sequence_codegen(prompt, model, tokenizer, length, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
start_time = time.time()
context_tokens = tokenizer.encode(prompt)
context_tokens_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device).unsqueeze(0)
generated = context_tokens
past = None
text_generated_count = 0
with torch.no_grad():
outputs = model(input_ids=context_tokens_tensor, past_key_values=past, labels=None)
next_token_logits = outputs[0][:, -1, :] / temperature
past = outputs[1]
for token_index in set(generated):
next_token_logits[0, token_index] /= repetition_penalty
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
if temperature == 0:
next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(0)
else:
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
generated.append(next_token.tolist()[0][0])
text_generated_count += 1
token = next_token.tolist()[0][0]
yield tokenizer.decode([token])
if token == 50256:
yield "<END_STREAM>"
if text_generated_count > length:
yield "<END_STREAM>"
if (time.time() - start_time) * 1000 > 5000:
yield "<END_STREAM>"
def perform_reasoning_stream(text_input, temperature, top_k, top_p, repetition_penalty):
try:
prompt_text = system_prompt + "\n\n"
prompt_text += "User: " + text_input + "\nCyrah: "
reasoning_prompt = prompt_text
ddgs = DDGS()
search_results = [r for r in ddgs.text(text_input, max_results=MAX_XDD)]
if search_results:
prompt_text += "\nWeb Search Results:\n"
for result in search_results:
prompt_text += f"- {result['body']}\n"
prompt_text += "\n"
generated_text_stream = []
stream_type = "text"
if "code" in text_input.lower() or "program" in text_input.lower():
if codegen_model and codegen_tokenizer:
generated_text_stream = sample_sequence_codegen(
prompt=reasoning_prompt,
model=codegen_model,
tokenizer=codegen_tokenizer,
length=MAX_LENGTH,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
device=device
)
stream_type = "text"
elif "summarize" in text_input.lower() or "summary" in text_input.lower():
if summarization_model:
summary = summarize_func(text_input)
yield f"SUMMARY_TEXT:{summary}"
yield "<END_STREAM>"
stream_type = "summary"
else:
if model_gpt2 and enc:
generated_text_stream = sample_sequence(
prompt=reasoning_prompt,
model=model_gpt2,
enc=enc,
length=MAX_LENGTH,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
device=device
)
stream_type = "text"
accumulated_text = ""
if stream_type == "text":
for token in generated_text_stream:
if token == "<END_STREAM>":
yield accumulated_text
yield "<END_STREAM>"
return
if token == END_OF_TEXT_TOKEN:
accumulated_text += END_OF_TEXT_TOKEN
continue
if token:
accumulated_text += token
except Exception as e:
print(f"Reasoning Error: {e}")
yield "Error during reasoning. Please try again."
yield "<END_STREAM>"
|