Spaces:
Running
Running
import asyncio | |
from typing import List, Dict, Any | |
from ..config.config import Config | |
from ..utils.llm import create_chat_completion | |
from ..utils.logger import get_formatted_logger | |
from ..prompts import ( | |
generate_report_introduction, | |
generate_draft_titles_prompt, | |
generate_report_conclusion, | |
get_prompt_by_report_type, | |
) | |
from ..utils.enum import Tone | |
logger = get_formatted_logger() | |
async def write_report_introduction( | |
query: str, | |
context: str, | |
agent_role_prompt: str, | |
config: Config, | |
websocket=None, | |
cost_callback: callable = None | |
) -> str: | |
""" | |
Generate an introduction for the report. | |
Args: | |
query (str): The research query. | |
context (str): Context for the report. | |
role (str): The role of the agent. | |
config (Config): Configuration object. | |
websocket: WebSocket connection for streaming output. | |
cost_callback (callable, optional): Callback for calculating LLM costs. | |
Returns: | |
str: The generated introduction. | |
""" | |
try: | |
introduction = await create_chat_completion( | |
model=config.smart_llm_model, | |
messages=[ | |
{"role": "system", "content": f"{agent_role_prompt}"}, | |
{"role": "user", "content": generate_report_introduction( | |
query, context)}, | |
], | |
temperature=0.25, | |
llm_provider=config.smart_llm_provider, | |
stream=True, | |
websocket=websocket, | |
max_tokens=config.smart_token_limit, | |
llm_kwargs=config.llm_kwargs, | |
cost_callback=cost_callback, | |
) | |
return introduction | |
except Exception as e: | |
logger.error(f"Error in generating report introduction: {e}") | |
return "" | |
async def write_conclusion( | |
query: str, | |
context: str, | |
agent_role_prompt: str, | |
config: Config, | |
websocket=None, | |
cost_callback: callable = None | |
) -> str: | |
""" | |
Write a conclusion for the report. | |
Args: | |
query (str): The research query. | |
context (str): Context for the report. | |
role (str): The role of the agent. | |
config (Config): Configuration object. | |
websocket: WebSocket connection for streaming output. | |
cost_callback (callable, optional): Callback for calculating LLM costs. | |
Returns: | |
str: The generated conclusion. | |
""" | |
try: | |
conclusion = await create_chat_completion( | |
model=config.smart_llm_model, | |
messages=[ | |
{"role": "system", "content": f"{agent_role_prompt}"}, | |
{"role": "user", "content": generate_report_conclusion(query, context)}, | |
], | |
temperature=0.25, | |
llm_provider=config.smart_llm_provider, | |
stream=True, | |
websocket=websocket, | |
max_tokens=config.smart_token_limit, | |
llm_kwargs=config.llm_kwargs, | |
cost_callback=cost_callback, | |
) | |
return conclusion | |
except Exception as e: | |
logger.error(f"Error in writing conclusion: {e}") | |
return "" | |
async def summarize_url( | |
url: str, | |
content: str, | |
role: str, | |
config: Config, | |
websocket=None, | |
cost_callback: callable = None | |
) -> str: | |
""" | |
Summarize the content of a URL. | |
Args: | |
url (str): The URL to summarize. | |
content (str): The content of the URL. | |
role (str): The role of the agent. | |
config (Config): Configuration object. | |
websocket: WebSocket connection for streaming output. | |
cost_callback (callable, optional): Callback for calculating LLM costs. | |
Returns: | |
str: The summarized content. | |
""" | |
try: | |
summary = await create_chat_completion( | |
model=config.smart_llm_model, | |
messages=[ | |
{"role": "system", "content": f"{role}"}, | |
{"role": "user", "content": f"Summarize the following content from {url}:\n\n{content}"}, | |
], | |
temperature=0.25, | |
llm_provider=config.smart_llm_provider, | |
stream=True, | |
websocket=websocket, | |
max_tokens=config.smart_token_limit, | |
llm_kwargs=config.llm_kwargs, | |
cost_callback=cost_callback, | |
) | |
return summary | |
except Exception as e: | |
logger.error(f"Error in summarizing URL: {e}") | |
return "" | |
async def generate_draft_section_titles( | |
query: str, | |
current_subtopic: str, | |
context: str, | |
role: str, | |
config: Config, | |
websocket=None, | |
cost_callback: callable = None | |
) -> List[str]: | |
""" | |
Generate draft section titles for the report. | |
Args: | |
query (str): The research query. | |
context (str): Context for the report. | |
role (str): The role of the agent. | |
config (Config): Configuration object. | |
websocket: WebSocket connection for streaming output. | |
cost_callback (callable, optional): Callback for calculating LLM costs. | |
Returns: | |
List[str]: A list of generated section titles. | |
""" | |
try: | |
section_titles = await create_chat_completion( | |
model=config.smart_llm_model, | |
messages=[ | |
{"role": "system", "content": f"{role}"}, | |
{"role": "user", "content": generate_draft_titles_prompt( | |
current_subtopic, query, context)}, | |
], | |
temperature=0.25, | |
llm_provider=config.smart_llm_provider, | |
stream=True, | |
websocket=None, | |
max_tokens=config.smart_token_limit, | |
llm_kwargs=config.llm_kwargs, | |
cost_callback=cost_callback, | |
) | |
return section_titles.split("\n") | |
except Exception as e: | |
logger.error(f"Error in generating draft section titles: {e}") | |
return [] | |
async def generate_report( | |
query: str, | |
context, | |
agent_role_prompt: str, | |
report_type: str, | |
tone: Tone, | |
report_source: str, | |
websocket, | |
cfg, | |
main_topic: str = "", | |
existing_headers: list = [], | |
relevant_written_contents: list = [], | |
cost_callback: callable = None, | |
headers=None, | |
): | |
""" | |
generates the final report | |
Args: | |
query: | |
context: | |
agent_role_prompt: | |
report_type: | |
websocket: | |
tone: | |
cfg: | |
main_topic: | |
existing_headers: | |
relevant_written_contents: | |
cost_callback: | |
Returns: | |
report: | |
""" | |
generate_prompt = get_prompt_by_report_type(report_type) | |
report = "" | |
if report_type == "subtopic_report": | |
content = f"{generate_prompt(query, existing_headers, relevant_written_contents, main_topic, context, report_format=cfg.report_format, tone=tone, total_words=cfg.total_words, language=cfg.language)}" | |
else: | |
content = f"{generate_prompt(query, context, report_source, report_format=cfg.report_format, tone=tone, total_words=cfg.total_words, language=cfg.language)}" | |
try: | |
report = await create_chat_completion( | |
model=cfg.smart_llm_model, | |
messages=[ | |
{"role": "system", "content": f"{agent_role_prompt}"}, | |
{"role": "user", "content": content}, | |
], | |
temperature=0.35, | |
llm_provider=cfg.smart_llm_provider, | |
stream=True, | |
websocket=websocket, | |
max_tokens=cfg.smart_token_limit, | |
llm_kwargs=cfg.llm_kwargs, | |
cost_callback=cost_callback, | |
) | |
except: | |
try: | |
report = await create_chat_completion( | |
model=cfg.smart_llm_model, | |
messages=[ | |
{"role": "user", "content": f"{agent_role_prompt}\n\n{content}"}, | |
], | |
temperature=0.35, | |
llm_provider=cfg.smart_llm_provider, | |
stream=True, | |
websocket=websocket, | |
max_tokens=cfg.smart_token_limit, | |
llm_kwargs=cfg.llm_kwargs, | |
cost_callback=cost_callback, | |
) | |
except Exception as e: | |
print(f"Error in generate_report: {e}") | |
return report | |