File size: 4,192 Bytes
372531f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import json_repair
from ..utils.llm import create_chat_completion
from ..prompts import generate_search_queries_prompt
from typing import Any, List, Dict
from ..config import Config
import logging

logger = logging.getLogger(__name__)

async def get_search_results(query: str, retriever: Any) -> List[Dict[str, Any]]:
    """

    Get web search results for a given query.

    

    Args:

        query: The search query

        retriever: The retriever instance

    

    Returns:

        A list of search results

    """
    search_retriever = retriever(query)
    return search_retriever.search()

async def generate_sub_queries(

    query: str,

    parent_query: str,

    report_type: str,

    context: List[Dict[str, Any]],

    cfg: Config,

    cost_callback: callable = None

) -> List[str]:
    """

    Generate sub-queries using the specified LLM model.

    

    Args:

        query: The original query

        parent_query: The parent query

        report_type: The type of report

        max_iterations: Maximum number of research iterations

        context: Search results context

        cfg: Configuration object

        cost_callback: Callback for cost calculation

    

    Returns:

        A list of sub-queries

    """
    gen_queries_prompt = generate_search_queries_prompt(
        query,
        parent_query,
        report_type,
        max_iterations=cfg.max_iterations or 1,
        context=context
    )

    try:
        response = await create_chat_completion(
            model=cfg.strategic_llm_model,
            messages=[{"role": "user", "content": gen_queries_prompt}],
            temperature=1,
            llm_provider=cfg.strategic_llm_provider,
            max_tokens=None,
            llm_kwargs=cfg.llm_kwargs,
            cost_callback=cost_callback,
        )
    except Exception as e:
        logger.warning(f"Error with strategic LLM: {e}. Retrying with max_tokens={cfg.strategic_token_limit}.")
        logger.warning(f"See https://github.com/assafelovic/gpt-researcher/issues/1022")
        try:
            response = await create_chat_completion(
                model=cfg.strategic_llm_model,
                messages=[{"role": "user", "content": gen_queries_prompt}],
                temperature=1,
                llm_provider=cfg.strategic_llm_provider,
                max_tokens=cfg.strategic_token_limit,
                llm_kwargs=cfg.llm_kwargs,
                cost_callback=cost_callback,
            )
            logger.warning(f"Retrying with max_tokens={cfg.strategic_token_limit} successful.")
        except Exception as e:
            logger.warning(f"Retrying with max_tokens={cfg.strategic_token_limit} failed.")
            logger.warning(f"Error with strategic LLM: {e}. Falling back to smart LLM.")
            response = await create_chat_completion(
                model=cfg.smart_llm_model,
                messages=[{"role": "user", "content": gen_queries_prompt}],
                temperature=cfg.temperature,
                max_tokens=cfg.smart_token_limit,
                llm_provider=cfg.smart_llm_provider,
                llm_kwargs=cfg.llm_kwargs,
                cost_callback=cost_callback,
            )

    return json_repair.loads(response)

async def plan_research_outline(

    query: str,

    search_results: List[Dict[str, Any]],

    agent_role_prompt: str,

    cfg: Config,

    parent_query: str,

    report_type: str,

    cost_callback: callable = None,

) -> List[str]:
    """

    Plan the research outline by generating sub-queries.

    

    Args:

        query: Original query

        retriever: Retriever instance

        agent_role_prompt: Agent role prompt

        cfg: Configuration object

        parent_query: Parent query

        report_type: Report type

        cost_callback: Callback for cost calculation

    

    Returns:

        A list of sub-queries

    """
    
    sub_queries = await generate_sub_queries(
        query,
        parent_query,
        report_type,
        search_results,
        cfg,
        cost_callback
    )

    return sub_queries