import torch import torch.nn.functional as F from tqdm import trange import time from tokenxxx import * from main import * from duckduckgo_search import DDGS def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): top_k = min(top_k, logits.size(-1)) if top_k > 0: indices_to_remove = logits < torch.topk(logits, top_k)[0][..., [-1]] logits[indices_to_remove] = filter_value if top_p > 0.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices[sorted_indices_to_remove] logits[indices_to_remove] = filter_value return logits def sample_sequence(prompt, model, enc, length, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"): start_time = time.time() context_tokens = enc.encode(prompt) context_tokens_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device) generated = context_tokens past = None text_generated_count = 0 past_key_values = past if past is not None else None with torch.no_grad(): outputs = model(context_tokens_tensor, past_key_values=past_key_values) next_token_logits = outputs[0][:, -1, :] / temperature past = outputs[1] for token_index in set(generated): next_token_logits[0, token_index] /= repetition_penalty filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) if temperature == 0: next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(0) else: next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) generated += next_token.tolist()[0] text_generated_count += 1 token = next_token.tolist()[0][0] yield enc.decode([token]) if token == enc.encoder[END_OF_TEXT_TOKEN]: yield "" def sample_sequence_codegen(prompt, model, tokenizer, length, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"): start_time = time.time() context_tokens = tokenizer.encode(prompt) context_tokens_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device).unsqueeze(0) generated = context_tokens past = None text_generated_count = 0 with torch.no_grad(): outputs = model(input_ids=context_tokens_tensor, past_key_values=past, labels=None) next_token_logits = outputs[0][:, -1, :] / temperature past = outputs[1] for token_index in set(generated): next_token_logits[0, token_index] /= repetition_penalty filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) if temperature == 0: next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(0) else: next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) generated.append(next_token.tolist()[0][0]) text_generated_count += 1 token = next_token.tolist()[0][0] yield tokenizer.decode([token]) if token == 50256: yield "" def perform_reasoning_stream(text_input, temperature, top_k, top_p, repetition_penalty): try: prompt_text = system_prompt + "\n\n" prompt_text += "User: " + text_input + "\nCyrah: " reasoning_prompt = prompt_text ddgs = DDGS() search_results = [r for r in ddgs.text(text_input, max_results=MAX_XDD)] if search_results: prompt_text += "\nWeb Search Results:\n" for result in search_results: prompt_text += f"- {result['body']}\n" prompt_text += "\n" generated_text_stream = [] stream_type = "text" if "code" in text_input.lower() or "program" in text_input.lower(): if codegen_model and codegen_tokenizer: generated_text_stream = sample_sequence_codegen( prompt=reasoning_prompt, model=codegen_model, tokenizer=codegen_tokenizer, length=999999999, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, device=device ) stream_type = "text" elif "summarize" in text_input.lower() or "summary" in text_input.lower(): if summarization_model: summary = summarize_func(text_input) yield f"SUMMARY_TEXT:{summary}" yield "" stream_type = "summary" else: if model_gpt2 and enc: generated_text_stream = sample_sequence( prompt=reasoning_prompt, model=model_gpt2, enc=enc, length=999999999, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, device=device ) stream_type = "text" accumulated_text = "" if stream_type == "text": for token in generated_text_stream: if token == "": yield accumulated_text yield "" return if token == END_OF_TEXT_TOKEN: accumulated_text += END_OF_TEXT_TOKEN continue if token: accumulated_text += token except Exception as e: print(f"Reasoning Error: {e}") yield "Error during reasoning. Please try again." yield ""