Spaces:
Running
Running
Chandima Prabhath
commited on
Commit
·
db2f80b
1
Parent(s):
cc969ee
Refactor intent handling to use Pydantic models for strict parsing; update GenerateImageIntent to inherit from BaseModel and improve route_intent function for better error handling and data extraction.
Browse files
app.py
CHANGED
@@ -330,7 +330,7 @@ class PollResultsIntent(BaseIntent):
|
|
330 |
class PollEndIntent(BaseIntent):
|
331 |
action: Literal["poll_end"]
|
332 |
|
333 |
-
class GenerateImageIntent(
|
334 |
action: Literal["generate_image"]
|
335 |
prompt: str
|
336 |
count: int = Field(default=1, ge=1)
|
@@ -341,7 +341,8 @@ class SendTextIntent(BaseIntent):
|
|
341 |
action: Literal["send_text"]
|
342 |
message: str
|
343 |
|
344 |
-
|
|
|
345 |
SummarizeIntent, TranslateIntent, JokeIntent, WeatherIntent,
|
346 |
InspireIntent, MemeIntent, PollCreateIntent, PollVoteIntent,
|
347 |
PollResultsIntent, PollEndIntent, GenerateImageIntent, SendTextIntent
|
@@ -364,7 +365,7 @@ ACTION_HANDLERS = {
|
|
364 |
|
365 |
# --- Intent Routing with Fallback ------------------------------------------
|
366 |
|
367 |
-
def route_intent(user_input: str, chat_id: str, sender: str)
|
368 |
history_text = get_history_text(chat_id, sender)
|
369 |
sys_prompt = (
|
370 |
"You are Eve. You can either chat or call one of these functions:\n"
|
@@ -384,53 +385,56 @@ def route_intent(user_input: str, chat_id: str, sender: str) -> IntentUnion:
|
|
384 |
" {\"action\":\"generate_image\",\"prompt\":\"a red fox\",\"count\":3,\"width\":512,\"height\":512}\n"
|
385 |
"Otherwise, use send_text to reply with plain chat.\n"
|
386 |
)
|
387 |
-
prompt =
|
388 |
-
f"{sys_prompt}\n"
|
389 |
-
f"Conversation so far:\n{history_text}\n\n"
|
390 |
-
f"User: {user_input}"
|
391 |
-
)
|
392 |
raw = generate_llm(prompt)
|
393 |
|
394 |
-
# 1)
|
395 |
try:
|
396 |
parsed = json.loads(raw)
|
397 |
-
intent = IntentUnion.parse_obj(parsed)
|
398 |
-
return intent
|
399 |
-
except (json.JSONDecodeError, ValidationError) as e:
|
400 |
-
logger.warning(f"Strict parse failed: {e}. Falling back to lenient.")
|
401 |
-
|
402 |
-
# 2) Lenient: basic JSON get + defaults
|
403 |
-
try:
|
404 |
-
data = json.loads(raw)
|
405 |
except json.JSONDecodeError:
|
406 |
return SendTextIntent(action="send_text", message=raw)
|
407 |
|
408 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
409 |
if action in ACTION_HANDLERS:
|
|
|
410 |
kwargs = {}
|
411 |
if action == "generate_image":
|
412 |
-
kwargs["prompt"] = data.get("prompt",
|
413 |
kwargs["count"] = int(data.get("count", BotConfig.DEFAULT_IMAGE_COUNT))
|
414 |
kwargs["width"] = data.get("width")
|
415 |
kwargs["height"] = data.get("height")
|
416 |
elif action == "send_text":
|
417 |
-
kwargs["message"] = data.get("message",
|
418 |
elif action == "translate":
|
419 |
-
kwargs["lang"] = data.get("lang",
|
420 |
-
kwargs["text"] = data.get("text",
|
421 |
elif action == "summarize":
|
422 |
-
kwargs["text"] = data.get("text",
|
423 |
elif action == "weather":
|
424 |
-
kwargs["location"] = data.get("location",
|
425 |
elif action == "meme":
|
426 |
-
kwargs["text"] = data.get("text",
|
427 |
elif action == "poll_create":
|
428 |
-
kwargs["question"] = data.get("question",
|
429 |
-
kwargs["options"] = data.get("options",
|
430 |
elif action == "poll_vote":
|
431 |
kwargs["voter"] = sender
|
432 |
-
kwargs["choice"] = int(data.get("choice",
|
433 |
-
|
|
|
|
|
|
|
|
|
434 |
|
435 |
return SendTextIntent(action="send_text", message=raw)
|
436 |
|
@@ -531,9 +535,7 @@ async def whatsapp_webhook(request: Request):
|
|
531 |
if tmd.get("contextInfo", {}).get("mentionedJidList"):
|
532 |
return {"success": True}
|
533 |
|
534 |
-
# Handle quoted replies if needed...
|
535 |
effective = body
|
536 |
-
|
537 |
intent = route_intent(effective, chat_id, sender)
|
538 |
handler = ACTION_HANDLERS.get(intent.action)
|
539 |
if handler:
|
|
|
330 |
class PollEndIntent(BaseIntent):
|
331 |
action: Literal["poll_end"]
|
332 |
|
333 |
+
class GenerateImageIntent(BaseModel):
|
334 |
action: Literal["generate_image"]
|
335 |
prompt: str
|
336 |
count: int = Field(default=1, ge=1)
|
|
|
341 |
action: Literal["send_text"]
|
342 |
message: str
|
343 |
|
344 |
+
# list of all intent models
|
345 |
+
INTENT_MODELS = [
|
346 |
SummarizeIntent, TranslateIntent, JokeIntent, WeatherIntent,
|
347 |
InspireIntent, MemeIntent, PollCreateIntent, PollVoteIntent,
|
348 |
PollResultsIntent, PollEndIntent, GenerateImageIntent, SendTextIntent
|
|
|
365 |
|
366 |
# --- Intent Routing with Fallback ------------------------------------------
|
367 |
|
368 |
+
def route_intent(user_input: str, chat_id: str, sender: str):
|
369 |
history_text = get_history_text(chat_id, sender)
|
370 |
sys_prompt = (
|
371 |
"You are Eve. You can either chat or call one of these functions:\n"
|
|
|
385 |
" {\"action\":\"generate_image\",\"prompt\":\"a red fox\",\"count\":3,\"width\":512,\"height\":512}\n"
|
386 |
"Otherwise, use send_text to reply with plain chat.\n"
|
387 |
)
|
388 |
+
prompt = f"{sys_prompt}\nConversation so far:\n{history_text}\n\nUser: {user_input}"
|
|
|
|
|
|
|
|
|
389 |
raw = generate_llm(prompt)
|
390 |
|
391 |
+
# 1) Strict: try each Pydantic model
|
392 |
try:
|
393 |
parsed = json.loads(raw)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
394 |
except json.JSONDecodeError:
|
395 |
return SendTextIntent(action="send_text", message=raw)
|
396 |
|
397 |
+
for M in INTENT_MODELS:
|
398 |
+
try:
|
399 |
+
intent = M.parse_obj(parsed)
|
400 |
+
return intent
|
401 |
+
except ValidationError:
|
402 |
+
continue
|
403 |
+
|
404 |
+
logger.warning("Strict parse failed for all models, falling back to lenient")
|
405 |
+
|
406 |
+
# 2) Lenient JSON get
|
407 |
+
action = parsed.get("action")
|
408 |
if action in ACTION_HANDLERS:
|
409 |
+
data = parsed
|
410 |
kwargs = {}
|
411 |
if action == "generate_image":
|
412 |
+
kwargs["prompt"] = data.get("prompt","")
|
413 |
kwargs["count"] = int(data.get("count", BotConfig.DEFAULT_IMAGE_COUNT))
|
414 |
kwargs["width"] = data.get("width")
|
415 |
kwargs["height"] = data.get("height")
|
416 |
elif action == "send_text":
|
417 |
+
kwargs["message"] = data.get("message","")
|
418 |
elif action == "translate":
|
419 |
+
kwargs["lang"] = data.get("lang","")
|
420 |
+
kwargs["text"] = data.get("text","")
|
421 |
elif action == "summarize":
|
422 |
+
kwargs["text"] = data.get("text","")
|
423 |
elif action == "weather":
|
424 |
+
kwargs["location"] = data.get("location","")
|
425 |
elif action == "meme":
|
426 |
+
kwargs["text"] = data.get("text","")
|
427 |
elif action == "poll_create":
|
428 |
+
kwargs["question"] = data.get("question","")
|
429 |
+
kwargs["options"] = data.get("options",[])
|
430 |
elif action == "poll_vote":
|
431 |
kwargs["voter"] = sender
|
432 |
+
kwargs["choice"] = int(data.get("choice",0))
|
433 |
+
# parse into Pydantic for uniformity
|
434 |
+
try:
|
435 |
+
return next(M for M in INTENT_MODELS if getattr(M, "__fields__", {}).get("action").default == action).parse_obj({"action":action,**kwargs})
|
436 |
+
except Exception:
|
437 |
+
return SendTextIntent(action="send_text", message=raw)
|
438 |
|
439 |
return SendTextIntent(action="send_text", message=raw)
|
440 |
|
|
|
535 |
if tmd.get("contextInfo", {}).get("mentionedJidList"):
|
536 |
return {"success": True}
|
537 |
|
|
|
538 |
effective = body
|
|
|
539 |
intent = route_intent(effective, chat_id, sender)
|
540 |
handler = ACTION_HANDLERS.get(intent.action)
|
541 |
if handler:
|