Spaces:
Running
Running
File size: 5,925 Bytes
1c817fd e83e49f 1c817fd e83e49f 1c817fd e83e49f 1c817fd e83e49f 1c817fd e83e49f 1c817fd e83e49f 1c817fd e83e49f 1c817fd e83e49f 1c817fd e83e49f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import torch
import torch.nn.functional as F
from tqdm import trange
import time
from tokenxxx import *
from main import *
from duckduckgo_search import DDGS
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
top_k = min(top_k, logits.size(-1))
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., [-1]]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = filter_value
return logits
def sample_sequence(prompt, model, enc, length, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
context_tokens = enc.encode(prompt)
context_tokens_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device)
generated = context_tokens
past_key_values = None
with torch.no_grad():
for _ in range(length):
outputs = model(context_tokens_tensor, past_key_values=past_key_values)
next_token_logits = outputs[0][:, -1, :] / temperature
past_key_values = outputs[1]
for token_index in set(generated):
next_token_logits[0, token_index] /= repetition_penalty
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
if temperature == 0:
next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(0)
else:
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
generated += next_token.tolist()[0]
token = next_token.tolist()[0][0]
yield enc.decode([token])
if token == enc.encoder[END_OF_TEXT_TOKEN]:
yield "<END_STREAM>"
return
def sample_sequence_codegen(prompt, model, tokenizer, length, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
context_tokens = tokenizer.encode(prompt)
context_tokens_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device).unsqueeze(0)
generated = context_tokens
past_key_values = None
with torch.no_grad():
for _ in range(length):
outputs = model(input_ids=context_tokens_tensor, past_key_values=past_key_values, labels=None)
next_token_logits = outputs[0][:, -1, :] / temperature
past_key_values = outputs[1]
for token_index in set(generated):
next_token_logits[0, token_index] /= repetition_penalty
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
if temperature == 0:
next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(0)
else:
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
generated.append(next_token.tolist()[0][0])
token = next_token.tolist()[0][0]
yield tokenizer.decode([token])
if token == 50256:
yield "<END_STREAM>"
return
def perform_reasoning_stream(text_input, temperature, top_k, top_p, repetition_penalty):
prompt_text = SYSTEM_PROMPT + "\n\n"
prompt_text += "User: " + text_input + "\nAssistant:"
reasoning_prompt = prompt_text
ddgs = DDGS()
search_results = [r for r in ddgs.text(text_input, max_results=MAX_XDD)]
if search_results:
prompt_text += "\nWeb Search Results:\n"
for result in search_results:
prompt_text += f"- {result['body']}\n"
prompt_text += "\n"
generated_text_stream = []
stream_type = "text"
if "code" in text_input.lower() or "program" in text_input.lower():
if codegen_model and codegen_tokenizer:
generated_text_stream = sample_sequence_codegen(
prompt=reasoning_prompt,
model=codegen_model,
tokenizer=codegen_tokenizer,
length=999999999,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
device=device
)
stream_type = "text"
elif "summarize" in text_input.lower() or "summary" in text_input.lower():
if summarization_model:
summary = summarize_text(text_input)
yield f"SUMMARY_TEXT:{summary}"
yield "<END_STREAM>"
stream_type = "summary"
else:
if model_gpt2 and enc:
generated_text_stream = sample_sequence(
prompt=reasoning_prompt,
model=model_gpt2,
enc=enc,
length=999999999,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
device=device
)
stream_type = "text"
accumulated_text = ""
if stream_type == "text":
for token in generated_text_stream:
if token == "<END_STREAM>":
yield accumulated_text
yield "<END_STREAM>"
return
if token == END_OF_TEXT_TOKEN:
accumulated_text += END_OF_TEXT_TOKEN
continue
if token:
accumulated_text += token |