mbudisic commited on
Commit
b22a9a8
·
1 Parent(s): 457c025

feat: Add interrupt handling and YesNoDecision class for user queries

Browse files

- Implemented `handle_interrupt` function to manage user approval for interrupted searches in `app.py`.
- Introduced `YesNoDecision` class with a method to parse string inputs into decision instances in `nodes.py`.
- Updated the `search_help` function to utilize the new decision structure for user permissions.

Files changed (2) hide show
  1. app.py +69 -7
  2. pstuts_rag/pstuts_rag/nodes.py +22 -4
app.py CHANGED
@@ -11,13 +11,20 @@ import chainlit as cl
11
  import httpx
12
  import nest_asyncio
13
  from dotenv import load_dotenv
 
14
  from langchain_core.documents import Document
15
  from langchain_core.runnables import Runnable
16
  from langgraph.checkpoint.memory import MemorySaver
 
17
 
18
  from pstuts_rag.configuration import Configuration
19
  from pstuts_rag.datastore import Datastore
20
- from pstuts_rag.nodes import FinalAnswer, TutorialState, initialize
 
 
 
 
 
21
  from pstuts_rag.utils import get_unique
22
 
23
  # Track the single active session
@@ -203,9 +210,6 @@ async def format_url_reference(url_ref):
203
  )
204
 
205
 
206
- from langchain.callbacks.base import BaseCallbackHandler
207
-
208
-
209
  class ChainlitCallbackHandler(BaseCallbackHandler):
210
  """
211
  Custom callback handler for Chainlit to visualize the execution of LangChain chains/graphs.
@@ -293,6 +297,42 @@ class ChainlitCallbackHandler(BaseCallbackHandler):
293
  # TODO Add buttons with pregenerated queries
294
 
295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  @cl.on_message
297
  async def message_handler(input_message: cl.Message):
298
  """
@@ -329,11 +369,25 @@ async def message_handler(input_message: cl.Message):
329
  config = configuration.to_runnable_config()
330
  config["callbacks"] = [ChainlitCallbackHandler()]
331
 
332
- response = cast(
333
- TutorialState,
334
- await ai_graph.ainvoke({"query": input_message.content}, config),
335
  )
336
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  for msg in response["messages"]:
338
  if isinstance(msg, FinalAnswer):
339
  # Stream the final answer token-by-token for a typing effect
@@ -346,10 +400,18 @@ async def message_handler(input_message: cl.Message):
346
  if final_msg:
347
  await final_msg.update()
348
 
 
 
 
 
349
  # Send all unique video references as separate messages
350
  for v in get_unique(response["video_references"]):
351
  await format_video_reference(v).send()
352
 
 
 
 
 
353
  # Send all unique URL references as separate messages (with screenshots if available)
354
  url_reference_tasks = [
355
  format_url_reference(u) for u in get_unique(response["url_references"])
 
11
  import httpx
12
  import nest_asyncio
13
  from dotenv import load_dotenv
14
+ from langchain.callbacks.base import BaseCallbackHandler
15
  from langchain_core.documents import Document
16
  from langchain_core.runnables import Runnable
17
  from langgraph.checkpoint.memory import MemorySaver
18
+ from langgraph.types import Command
19
 
20
  from pstuts_rag.configuration import Configuration
21
  from pstuts_rag.datastore import Datastore
22
+ from pstuts_rag.nodes import (
23
+ FinalAnswer,
24
+ TutorialState,
25
+ initialize,
26
+ YesNoDecision,
27
+ )
28
  from pstuts_rag.utils import get_unique
29
 
30
  # Track the single active session
 
210
  )
211
 
212
 
 
 
 
213
  class ChainlitCallbackHandler(BaseCallbackHandler):
214
  """
215
  Custom callback handler for Chainlit to visualize the execution of LangChain chains/graphs.
 
297
  # TODO Add buttons with pregenerated queries
298
 
299
 
300
+ async def handle_interrupt(query: str) -> YesNoDecision:
301
+
302
+ try:
303
+ user_input = await cl.AskActionMessage(
304
+ content="Search has been interrupted. Do you approve query: '%s' to be sent to Adobe Help?"
305
+ % query,
306
+ timeout=30,
307
+ raise_on_timeout=True,
308
+ actions=[
309
+ cl.Action(
310
+ name="approve",
311
+ payload={"value": "yes"},
312
+ label="✅ Approve",
313
+ ),
314
+ cl.Action(
315
+ name="cancel",
316
+ payload={"value": "cancel"},
317
+ label="❌ Cancel web search",
318
+ ),
319
+ ],
320
+ ).send()
321
+ if user_input and user_input.get("payload").get("value") == "yes":
322
+ return YesNoDecision(decision="yes")
323
+ else:
324
+ return YesNoDecision(decision="no")
325
+
326
+ except TimeoutError:
327
+ await cl.Message(
328
+ "Timeout: No response from user. Canceling search."
329
+ ).send()
330
+ return YesNoDecision(decision="no")
331
+
332
+
333
+ from pstuts_rag.nodes import YesNoDecision
334
+
335
+
336
  @cl.on_message
337
  async def message_handler(input_message: cl.Message):
338
  """
 
369
  config = configuration.to_runnable_config()
370
  config["callbacks"] = [ChainlitCallbackHandler()]
371
 
372
+ raw_response = await ai_graph.ainvoke(
373
+ {"query": input_message.content}, config
 
374
  )
375
 
376
+ if "__interrupt__" in raw_response:
377
+ logging.warning("*** INTERRUPT ***")
378
+
379
+ logging.info(raw_response["__interrupt__"])
380
+
381
+ answer: YesNoDecision = await handle_interrupt(
382
+ raw_response["__interrupt__"][-1].value["query"]
383
+ )
384
+
385
+ raw_response = await ai_graph.ainvoke(
386
+ Command(resume=answer.decision), config
387
+ )
388
+
389
+ response = cast(TutorialState, raw_response)
390
+
391
  for msg in response["messages"]:
392
  if isinstance(msg, FinalAnswer):
393
  # Stream the final answer token-by-token for a typing effect
 
400
  if final_msg:
401
  await final_msg.update()
402
 
403
+ await cl.Message(
404
+ content=f"Formatting {len(response['video_references'])} video references."
405
+ ).send()
406
+
407
  # Send all unique video references as separate messages
408
  for v in get_unique(response["video_references"]):
409
  await format_video_reference(v).send()
410
 
411
+ await cl.Message(
412
+ content=f"Formatting {len(response['url_references'])} website references."
413
+ ).send()
414
+
415
  # Send all unique URL references as separate messages (with screenshots if available)
416
  url_reference_tasks = [
417
  format_url_reference(u) for u in get_unique(response["url_references"])
pstuts_rag/pstuts_rag/nodes.py CHANGED
@@ -185,10 +185,10 @@ async def search_help(state: TutorialState, config: RunnableConfig):
185
  logging.info("search_help: asking permission")
186
 
187
  response = interrupt(
188
- (
189
- f"Do you allow Internet search for query '{query}'?"
190
- "Answer 'yes' will perform the search, any other answer will skip it."
191
- )
192
  )
193
 
194
  logging.info(f"Permission response '{response}'")
@@ -306,6 +306,24 @@ class YesNoDecision(BaseModel):
306
 
307
  decision: Literal["yes", "no"] = Field(description="Yes or no decision.")
308
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
 
310
  class URLReference(BaseModel):
311
  """Model for URL reference with summary.
 
185
  logging.info("search_help: asking permission")
186
 
187
  response = interrupt(
188
+ {
189
+ "message": "Do you allow Internet search for this query?",
190
+ "query": query,
191
+ }
192
  )
193
 
194
  logging.info(f"Permission response '{response}'")
 
306
 
307
  decision: Literal["yes", "no"] = Field(description="Yes or no decision.")
308
 
309
+ @classmethod
310
+ def from_string(cls, value: str) -> "YesNoDecision":
311
+ """Parse a string and return a YesNoDecision instance, mapping common affirmatives to 'yes', others to 'no'."""
312
+ affirmatives = {
313
+ "yes",
314
+ "y",
315
+ "true",
316
+ "ok",
317
+ "okay",
318
+ "sure",
319
+ "1",
320
+ "fine",
321
+ "alright",
322
+ }
323
+ if value.strip().lower() in affirmatives:
324
+ return cls(decision="yes")
325
+ return cls(decision="no")
326
+
327
 
328
  class URLReference(BaseModel):
329
  """Model for URL reference with summary.