GPT-Researcher / gpt_researcher /actions /report_generation.py
Shreyas094's picture
Upload 528 files
372531f verified
raw
history blame
8.44 kB
import asyncio
from typing import List, Dict, Any
from ..config.config import Config
from ..utils.llm import create_chat_completion
from ..utils.logger import get_formatted_logger
from ..prompts import (
generate_report_introduction,
generate_draft_titles_prompt,
generate_report_conclusion,
get_prompt_by_report_type,
)
from ..utils.enum import Tone
logger = get_formatted_logger()
async def write_report_introduction(
query: str,
context: str,
agent_role_prompt: str,
config: Config,
websocket=None,
cost_callback: callable = None
) -> str:
"""
Generate an introduction for the report.
Args:
query (str): The research query.
context (str): Context for the report.
role (str): The role of the agent.
config (Config): Configuration object.
websocket: WebSocket connection for streaming output.
cost_callback (callable, optional): Callback for calculating LLM costs.
Returns:
str: The generated introduction.
"""
try:
introduction = await create_chat_completion(
model=config.smart_llm_model,
messages=[
{"role": "system", "content": f"{agent_role_prompt}"},
{"role": "user", "content": generate_report_introduction(
query, context)},
],
temperature=0.25,
llm_provider=config.smart_llm_provider,
stream=True,
websocket=websocket,
max_tokens=config.smart_token_limit,
llm_kwargs=config.llm_kwargs,
cost_callback=cost_callback,
)
return introduction
except Exception as e:
logger.error(f"Error in generating report introduction: {e}")
return ""
async def write_conclusion(
query: str,
context: str,
agent_role_prompt: str,
config: Config,
websocket=None,
cost_callback: callable = None
) -> str:
"""
Write a conclusion for the report.
Args:
query (str): The research query.
context (str): Context for the report.
role (str): The role of the agent.
config (Config): Configuration object.
websocket: WebSocket connection for streaming output.
cost_callback (callable, optional): Callback for calculating LLM costs.
Returns:
str: The generated conclusion.
"""
try:
conclusion = await create_chat_completion(
model=config.smart_llm_model,
messages=[
{"role": "system", "content": f"{agent_role_prompt}"},
{"role": "user", "content": generate_report_conclusion(query, context)},
],
temperature=0.25,
llm_provider=config.smart_llm_provider,
stream=True,
websocket=websocket,
max_tokens=config.smart_token_limit,
llm_kwargs=config.llm_kwargs,
cost_callback=cost_callback,
)
return conclusion
except Exception as e:
logger.error(f"Error in writing conclusion: {e}")
return ""
async def summarize_url(
url: str,
content: str,
role: str,
config: Config,
websocket=None,
cost_callback: callable = None
) -> str:
"""
Summarize the content of a URL.
Args:
url (str): The URL to summarize.
content (str): The content of the URL.
role (str): The role of the agent.
config (Config): Configuration object.
websocket: WebSocket connection for streaming output.
cost_callback (callable, optional): Callback for calculating LLM costs.
Returns:
str: The summarized content.
"""
try:
summary = await create_chat_completion(
model=config.smart_llm_model,
messages=[
{"role": "system", "content": f"{role}"},
{"role": "user", "content": f"Summarize the following content from {url}:\n\n{content}"},
],
temperature=0.25,
llm_provider=config.smart_llm_provider,
stream=True,
websocket=websocket,
max_tokens=config.smart_token_limit,
llm_kwargs=config.llm_kwargs,
cost_callback=cost_callback,
)
return summary
except Exception as e:
logger.error(f"Error in summarizing URL: {e}")
return ""
async def generate_draft_section_titles(
query: str,
current_subtopic: str,
context: str,
role: str,
config: Config,
websocket=None,
cost_callback: callable = None
) -> List[str]:
"""
Generate draft section titles for the report.
Args:
query (str): The research query.
context (str): Context for the report.
role (str): The role of the agent.
config (Config): Configuration object.
websocket: WebSocket connection for streaming output.
cost_callback (callable, optional): Callback for calculating LLM costs.
Returns:
List[str]: A list of generated section titles.
"""
try:
section_titles = await create_chat_completion(
model=config.smart_llm_model,
messages=[
{"role": "system", "content": f"{role}"},
{"role": "user", "content": generate_draft_titles_prompt(
current_subtopic, query, context)},
],
temperature=0.25,
llm_provider=config.smart_llm_provider,
stream=True,
websocket=None,
max_tokens=config.smart_token_limit,
llm_kwargs=config.llm_kwargs,
cost_callback=cost_callback,
)
return section_titles.split("\n")
except Exception as e:
logger.error(f"Error in generating draft section titles: {e}")
return []
async def generate_report(
query: str,
context,
agent_role_prompt: str,
report_type: str,
tone: Tone,
report_source: str,
websocket,
cfg,
main_topic: str = "",
existing_headers: list = [],
relevant_written_contents: list = [],
cost_callback: callable = None,
headers=None,
):
"""
generates the final report
Args:
query:
context:
agent_role_prompt:
report_type:
websocket:
tone:
cfg:
main_topic:
existing_headers:
relevant_written_contents:
cost_callback:
Returns:
report:
"""
generate_prompt = get_prompt_by_report_type(report_type)
report = ""
if report_type == "subtopic_report":
content = f"{generate_prompt(query, existing_headers, relevant_written_contents, main_topic, context, report_format=cfg.report_format, tone=tone, total_words=cfg.total_words, language=cfg.language)}"
else:
content = f"{generate_prompt(query, context, report_source, report_format=cfg.report_format, tone=tone, total_words=cfg.total_words, language=cfg.language)}"
try:
report = await create_chat_completion(
model=cfg.smart_llm_model,
messages=[
{"role": "system", "content": f"{agent_role_prompt}"},
{"role": "user", "content": content},
],
temperature=0.35,
llm_provider=cfg.smart_llm_provider,
stream=True,
websocket=websocket,
max_tokens=cfg.smart_token_limit,
llm_kwargs=cfg.llm_kwargs,
cost_callback=cost_callback,
)
except:
try:
report = await create_chat_completion(
model=cfg.smart_llm_model,
messages=[
{"role": "user", "content": f"{agent_role_prompt}\n\n{content}"},
],
temperature=0.35,
llm_provider=cfg.smart_llm_provider,
stream=True,
websocket=websocket,
max_tokens=cfg.smart_token_limit,
llm_kwargs=cfg.llm_kwargs,
cost_callback=cost_callback,
)
except Exception as e:
print(f"Error in generate_report: {e}")
return report