darrenphodgson76's picture
Update app.py
1d91d99 verified
raw
history blame
2.96 kB
from smolagents import CodeAgent, HfApiModel, tool
import yaml
from tools.final_answer import FinalAnswerTool
import wikipedia
from Gradio_UI import GradioUI
# Wikipedia search tool
@tool
def wikipedia_search(query: str, sentences: int = 2) -> str:
"""Search Wikipedia and return a short summary.
Args:
query: The search term for Wikipedia.
sentences: The number of sentences to return from the summary.
"""
try:
summary = wikipedia.summary(query, sentences=sentences)
return summary
except wikipedia.exceptions.DisambiguationError as e:
return f"Multiple results found: {', '.join(e.options[:5])}..."
except wikipedia.exceptions.PageError:
return "No Wikipedia page found for that query."
except Exception as e:
return f"An error occurred: {str(e)}"
final_answer = FinalAnswerTool()
model = HfApiModel(
max_tokens=2096,
temperature=0.5,
model_id='Qwen/Qwen2.5-Coder-32B-Instruct',
custom_role_conversions=None,
)
with open("prompts.yaml", 'r') as stream:
prompt_templates = yaml.safe_load(stream)
agent = CodeAgent(
model=model,
tools=[final_answer, wikipedia_search],
max_steps=6,
verbosity_level=1,
grammar=None,
planning_interval=None,
name=None,
description=None,
prompt_templates=prompt_templates
)
# Custom GradioUI that resets agent context after 4 user messages/responses.
class CustomGradioUI(GradioUI):
def __init__(self, agent, max_messages=4):
super().__init__(agent)
self.max_messages = max_messages
self.message_count = 0
def process_user_input(self, user_input):
"""
Process a user message, call the agent, and then reset context
if the number of interactions reaches max_messages.
"""
# Get response from the agent
response = self.agent.run(user_input)
self.message_count += 1
# Check if we've reached the limit of messages before reset.
if self.message_count >= self.max_messages:
# Reset the agent's context.
if hasattr(self.agent, 'reset'):
self.agent.reset() # Use agent's built-in reset method if available.
elif hasattr(self.agent, 'conversation_history'):
self.agent.conversation_history.clear() # Clear conversation history if accessible.
self.message_count = 0 # Reset our counter.
return response
def launch(self):
"""
Override launch if needed to ensure our process_user_input method is used.
This assumes that the base GradioUI calls a method we can override.
"""
# If GradioUI accepts a custom function for processing input, you might pass self.process_user_input.
# Otherwise, ensure that the UI calls this method when handling a user message.
super().launch()
# Launch the custom UI.
CustomGradioUI(agent).launch()