Chandima Prabhath commited on
Commit
cc969ee
·
1 Parent(s): 869bca1

Refactor app.py to improve function signatures and enhance intent routing; update requirements.txt to ensure pydantic is included.

Browse files
Files changed (2) hide show
  1. app.py +162 -185
  2. requirements.txt +2 -1
app.py CHANGED
@@ -4,11 +4,14 @@ import requests
4
  import logging
5
  import queue
6
  import json
7
- import random
8
  from collections import defaultdict, deque
9
  from concurrent.futures import ThreadPoolExecutor
 
10
  from fastapi import FastAPI, Request, HTTPException
11
  from fastapi.responses import PlainTextResponse
 
 
12
  from FLUX import generate_image
13
  from VoiceReply import generate_voice_reply
14
  from polLLM import generate_llm
@@ -39,7 +42,7 @@ _thread_ctx = threading.local()
39
  def set_thread_context(chat_id, sender, message_id):
40
  _thread_ctx.chat_id = chat_id
41
  _thread_ctx.sender = sender
42
- _thread_ctx.message_id = message_id
43
 
44
  def get_thread_context():
45
  return (
@@ -50,7 +53,6 @@ def get_thread_context():
50
 
51
  # --- Conversation History -------------------------------------------------
52
 
53
- # keep last 20 messages per (chat_id, sender)
54
  history = defaultdict(lambda: deque(maxlen=20))
55
 
56
  def record_user_message(chat_id, sender, message):
@@ -135,7 +137,6 @@ class BotClient:
135
  files = [("file",(os.path.basename(file_path),f,mime))]
136
  return self.send(endpoint, payload, files=files)
137
 
138
- # Validate env & init client
139
  BotConfig.validate()
140
  client = BotClient(BotConfig)
141
 
@@ -150,21 +151,9 @@ def worker():
150
  task = task_queue.get()
151
  try:
152
  if task["type"] == "image":
153
- _fn_generate_images(
154
- task["message_id"],
155
- task["chat_id"],
156
- task["prompt"],
157
- task.get("num_images", 1),
158
- task.get("width"),
159
- task.get("height")
160
- )
161
-
162
  elif task["type"] == "audio":
163
- _fn_voice_reply(
164
- task["message_id"],
165
- task["chat_id"],
166
- task["prompt"]
167
- )
168
  except Exception as e:
169
  logger.error(f"Worker error {task}: {e}")
170
  finally:
@@ -176,7 +165,6 @@ for _ in range(4):
176
  # --- Basic Tool Functions -------------------------------------------------
177
 
178
  def _fn_send_text(mid, cid, message):
179
- """Send text + record + queue voice."""
180
  client.send_message(mid, cid, message)
181
  chat_id, sender, _ = get_thread_context()
182
  if chat_id and sender:
@@ -189,7 +177,6 @@ def _fn_send_text(mid, cid, message):
189
  })
190
 
191
  def _fn_send_accept(mid, cid, message):
192
- """Send text + record, but no voice."""
193
  client.send_message(mid, cid, message)
194
  chat_id, sender, _ = get_thread_context()
195
  if chat_id and sender:
@@ -269,23 +256,24 @@ def _fn_poll_end(mid, cid):
269
  )
270
  _fn_send_text(mid, cid, txt)
271
 
272
- def _fn_generate_images(mid, cid, prompt, count=1, width=None, height=None):
 
 
273
  for i in range(1, count+1):
274
  try:
275
  img, path, ret_p, url = generate_image(
276
- prompt, mid, mid, BotConfig.IMAGE_DIR,
277
  width=width, height=height
278
  )
279
  formatted = "\n\n".join(f"_{p.strip()}_" for p in ret_p.split("\n\n") if p.strip())
280
  cap = f"✨ Image {i}/{count}: {url}\n>{chr(8203)} {formatted}"
281
- client.send_media(mid, cid, path, cap, media_type="image")
282
  os.remove(path)
283
  except Exception as e:
284
  logger.warning(f"Img {i}/{count} failed: {e}")
285
- _fn_send_text(mid, cid, f"😢 Failed to generate image {i}/{count}.")
286
-
287
 
288
- def _fn_voice_reply(mid, cid, prompt):
289
  proc = (
290
  f"Just say this exactly as written in a flirty, friendly, playful, "
291
  f"happy and helpful but a little bit clumsy-cute way: {prompt}"
@@ -293,144 +281,158 @@ def _fn_voice_reply(mid, cid, prompt):
293
  res = generate_voice_reply(proc, model="openai-audio", voice="coral", audio_dir=BotConfig.AUDIO_DIR)
294
  if res and res[0]:
295
  path, _ = res
296
- client.send_media(mid, cid, path, "", media_type="audio")
297
  os.remove(path)
298
  else:
299
- _fn_send_text(mid, cid, prompt)
300
-
301
- # --- Intent Dispatcher ----------------------------------------------------
302
-
303
- FUNCTION_SCHEMA = {
304
- "summarize": {"description":"Summarize text", "params":["text"]},
305
- "translate": {"description":"Translate text", "params":["lang","text"]},
306
- "joke": {"description":"Tell a joke", "params":[]},
307
- "weather": {"description":"Weather report", "params":["location"]},
308
- "inspire": {"description":"Inspirational quote","params":[]},
309
- "meme": {"description":"Generate meme", "params":["text"]},
310
- "poll_create": {"description":"Create poll", "params":["question","options"]},
311
- "poll_vote": {"description":"Vote poll", "params":["choice"]},
312
- "poll_results": {"description":"Show poll results", "params":[]},
313
- "poll_end": {"description":"End poll", "params":[]},
314
- "generate_image": {
315
- "description":"Generate images",
316
- "params":["prompt","count","width","height"]
317
- },
318
- "send_text": {"description":"Send plain text", "params":["message"]}}
319
-
320
- class IntentDispatcher:
321
- def __init__(self):
322
- self.handlers = {}
323
-
324
- def register(self, action):
325
- def decorator(fn):
326
- self.handlers[action] = fn
327
- return fn
328
- return decorator
329
-
330
- def dispatch(self, action, mid, cid, intent):
331
- fn = self.handlers.get(action)
332
- if not fn:
333
- return False
334
- fn(mid, cid, intent)
335
- return True
336
-
337
- dispatcher = IntentDispatcher()
338
-
339
- def validate_intent(action, intent):
340
- schema = FUNCTION_SCHEMA.get(action)
341
- if not schema:
342
- return False
343
- for p in schema["params"]:
344
- if p not in intent:
345
- logger.warning(f"Missing param '{p}' for action '{action}'")
346
- return False
347
- return True
348
-
349
- @dispatcher.register("summarize")
350
- def _h_summarize(mid, cid, intent):
351
- _fn_summarize(mid, cid, intent["text"])
352
-
353
- @dispatcher.register("translate")
354
- def _h_translate(mid, cid, intent):
355
- _fn_translate(mid, cid, intent["lang"], intent["text"])
356
-
357
- @dispatcher.register("joke")
358
- def _h_joke(mid, cid, intent):
359
- _fn_joke(mid, cid)
360
-
361
- @dispatcher.register("weather")
362
- def _h_weather(mid, cid, intent):
363
- _fn_weather(mid, cid, intent["location"])
364
-
365
- @dispatcher.register("inspire")
366
- def _h_inspire(mid, cid, intent):
367
- _fn_inspire(mid, cid)
368
-
369
- @dispatcher.register("meme")
370
- def _h_meme(mid, cid, intent):
371
- _fn_meme(mid, cid, intent["text"])
372
-
373
- @dispatcher.register("poll_create")
374
- def _h_poll_create(mid, cid, intent):
375
- _fn_poll_create(mid, cid, intent["question"], intent["options"])
376
-
377
- @dispatcher.register("poll_vote")
378
- def _h_poll_vote(mid, cid, intent):
379
- _fn_poll_vote(mid, cid, intent["voter"], intent["choice"])
380
-
381
- @dispatcher.register("poll_results")
382
- def _h_poll_results(mid, cid, intent):
383
- _fn_poll_results(mid, cid)
384
-
385
- @dispatcher.register("poll_end")
386
- def _h_poll_end(mid, cid, intent):
387
- _fn_poll_end(mid, cid)
388
-
389
- @dispatcher.register("generate_image")
390
- def _h_generate_image(mid, cid, intent):
391
- prompt = intent["prompt"]
392
- count = intent.get("count", 1)
393
- width = intent.get("width")
394
- height = intent.get("height")
395
- _fn_send_accept(mid, cid, f"✨ Generating {count} image(s)…")
396
- task_queue.put({
397
- "type": "image",
398
- "message_id": mid,
399
- "chat_id": cid,
400
- "prompt": prompt,
401
- "num_images": count,
402
- "width": width,
403
- "height": height
404
- })
405
-
406
-
407
- @dispatcher.register("send_text")
408
- def _h_send_text(mid, cid, intent):
409
- _fn_send_text(mid, cid, intent["message"])
410
-
411
- # --- Intent Routing --------------------------------------------------------
412
-
413
- def route_intent(user_input, chat_id, sender):
414
  history_text = get_history_text(chat_id, sender)
415
  sys_prompt = (
416
  "You are Eve. You can either chat or call one of these functions:\n"
417
- + "\n".join(f"- {n}: {f['description']}" for n,f in FUNCTION_SCHEMA.items())
418
- + "\n\nTo call a function, return JSON with \"action\":\"<name>\", plus its parameters.\n"
419
- + "Here’s an example for generating images:\n"
420
- + " {\"action\":\"generate_image\",\"prompt\":\"a red fox\",\"count\":3,\"width\":512,\"height\":512}\n"
421
- + "Otherwise return JSON with \"action\":\"send_text\",\"message\":\"...\".\n"
422
- "Return only raw JSON."
 
 
 
 
 
 
 
 
 
423
  )
424
  prompt = (
425
- f"{sys_prompt}\n\n"
426
  f"Conversation so far:\n{history_text}\n\n"
427
  f"User: {user_input}"
428
  )
429
  raw = generate_llm(prompt)
 
 
430
  try:
431
- return json.loads(raw)
432
- except:
433
- return {"action":"send_text","message":raw}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
 
435
  # --- FastAPI & Webhook ----------------------------------------------------
436
 
@@ -470,12 +472,8 @@ async def whatsapp_webhook(request: Request):
470
  return {"success": True}
471
 
472
  body = (tmd.get("textMessage") or tmd.get("text","")).strip()
473
- ctx = tmd.get("contextInfo", {})
474
-
475
- # record user message
476
  record_user_message(chat_id, sender, body)
477
 
478
- # Slash commands
479
  low = body.lower()
480
  if low == "/help":
481
  _fn_send_text(mid, chat_id, help_text)
@@ -524,45 +522,24 @@ async def whatsapp_webhook(request: Request):
524
  "message_id": mid,
525
  "chat_id": chat_id,
526
  "prompt": pr,
527
- "num_images": ct,
528
  "width": width,
529
  "height": height
530
  })
531
  return {"success": True}
532
 
533
-
534
- # Skip mentions
535
- if ctx.get("mentionedJidList"):
536
  return {"success": True}
537
 
538
- # Build effective text (handle quoted replies to the bot)
539
- if md.get("typeMessage") == "quotedMessage":
540
- ext = md["extendedTextMessageData"]
541
- quoted = md["quotedMessage"]
542
- if ext.get("participant") == BotConfig.BOT_JID:
543
- effective = (
544
- f"Quoted: {quoted.get('textMessage','')}\n"
545
- f"User: {ext.get('text','')}"
546
- )
547
- else:
548
- effective = body
549
- else:
550
- effective = body
551
 
552
- # Route intent & dispatch
553
  intent = route_intent(effective, chat_id, sender)
554
- action = intent.get("action")
555
-
556
- if action in FUNCTION_SCHEMA:
557
- if not validate_intent(action, intent):
558
- _fn_send_text(mid, chat_id, f"❗ Missing parameter(s) for `{action}`.")
559
- else:
560
- dispatched = dispatcher.dispatch(action, mid, chat_id, intent)
561
- if not dispatched:
562
- _fn_send_text(mid, chat_id, intent.get("message","Sorry, I couldn't handle that."))
563
  else:
564
- # fallback chat
565
- _fn_send_text(mid, chat_id, intent.get("message","Sorry, I didn't get that."))
566
 
567
  return {"success": True}
568
 
 
4
  import logging
5
  import queue
6
  import json
7
+ from typing import List, Optional, Union, Literal
8
  from collections import defaultdict, deque
9
  from concurrent.futures import ThreadPoolExecutor
10
+
11
  from fastapi import FastAPI, Request, HTTPException
12
  from fastapi.responses import PlainTextResponse
13
+ from pydantic import BaseModel, Field, ValidationError
14
+
15
  from FLUX import generate_image
16
  from VoiceReply import generate_voice_reply
17
  from polLLM import generate_llm
 
42
  def set_thread_context(chat_id, sender, message_id):
43
  _thread_ctx.chat_id = chat_id
44
  _thread_ctx.sender = sender
45
+ _thread_ctx.message_id = message_id
46
 
47
  def get_thread_context():
48
  return (
 
53
 
54
  # --- Conversation History -------------------------------------------------
55
 
 
56
  history = defaultdict(lambda: deque(maxlen=20))
57
 
58
  def record_user_message(chat_id, sender, message):
 
137
  files = [("file",(os.path.basename(file_path),f,mime))]
138
  return self.send(endpoint, payload, files=files)
139
 
 
140
  BotConfig.validate()
141
  client = BotClient(BotConfig)
142
 
 
151
  task = task_queue.get()
152
  try:
153
  if task["type"] == "image":
154
+ _fn_generate_images(**task)
 
 
 
 
 
 
 
 
155
  elif task["type"] == "audio":
156
+ _fn_voice_reply(**task)
 
 
 
 
157
  except Exception as e:
158
  logger.error(f"Worker error {task}: {e}")
159
  finally:
 
165
  # --- Basic Tool Functions -------------------------------------------------
166
 
167
  def _fn_send_text(mid, cid, message):
 
168
  client.send_message(mid, cid, message)
169
  chat_id, sender, _ = get_thread_context()
170
  if chat_id and sender:
 
177
  })
178
 
179
  def _fn_send_accept(mid, cid, message):
 
180
  client.send_message(mid, cid, message)
181
  chat_id, sender, _ = get_thread_context()
182
  if chat_id and sender:
 
256
  )
257
  _fn_send_text(mid, cid, txt)
258
 
259
+ def _fn_generate_images(message_id: str, chat_id: str, prompt: str,
260
+ count: int = 1, width: Optional[int] = None,
261
+ height: Optional[int] = None, **_):
262
  for i in range(1, count+1):
263
  try:
264
  img, path, ret_p, url = generate_image(
265
+ prompt, message_id, message_id, BotConfig.IMAGE_DIR,
266
  width=width, height=height
267
  )
268
  formatted = "\n\n".join(f"_{p.strip()}_" for p in ret_p.split("\n\n") if p.strip())
269
  cap = f"✨ Image {i}/{count}: {url}\n>{chr(8203)} {formatted}"
270
+ client.send_media(message_id, chat_id, path, cap, media_type="image")
271
  os.remove(path)
272
  except Exception as e:
273
  logger.warning(f"Img {i}/{count} failed: {e}")
274
+ _fn_send_text(message_id, chat_id, f"😢 Failed to generate image {i}/{count}.")
 
275
 
276
+ def _fn_voice_reply(message_id: str, chat_id: str, prompt: str, **_):
277
  proc = (
278
  f"Just say this exactly as written in a flirty, friendly, playful, "
279
  f"happy and helpful but a little bit clumsy-cute way: {prompt}"
 
281
  res = generate_voice_reply(proc, model="openai-audio", voice="coral", audio_dir=BotConfig.AUDIO_DIR)
282
  if res and res[0]:
283
  path, _ = res
284
+ client.send_media(message_id, chat_id, path, "", media_type="audio")
285
  os.remove(path)
286
  else:
287
+ _fn_send_text(message_id, chat_id, prompt)
288
+
289
+ # --- Pydantic Models for Function Calling --------------------------------
290
+
291
+ class BaseIntent(BaseModel):
292
+ action: str
293
+
294
+ class SummarizeIntent(BaseIntent):
295
+ action: Literal["summarize"]
296
+ text: str
297
+
298
+ class TranslateIntent(BaseIntent):
299
+ action: Literal["translate"]
300
+ lang: str
301
+ text: str
302
+
303
+ class JokeIntent(BaseIntent):
304
+ action: Literal["joke"]
305
+
306
+ class WeatherIntent(BaseIntent):
307
+ action: Literal["weather"]
308
+ location: str
309
+
310
+ class InspireIntent(BaseIntent):
311
+ action: Literal["inspire"]
312
+
313
+ class MemeIntent(BaseIntent):
314
+ action: Literal["meme"]
315
+ text: str
316
+
317
+ class PollCreateIntent(BaseIntent):
318
+ action: Literal["poll_create"]
319
+ question: str
320
+ options: List[str]
321
+
322
+ class PollVoteIntent(BaseIntent):
323
+ action: Literal["poll_vote"]
324
+ voter: str
325
+ choice: int
326
+
327
+ class PollResultsIntent(BaseIntent):
328
+ action: Literal["poll_results"]
329
+
330
+ class PollEndIntent(BaseIntent):
331
+ action: Literal["poll_end"]
332
+
333
+ class GenerateImageIntent(BaseIntent):
334
+ action: Literal["generate_image"]
335
+ prompt: str
336
+ count: int = Field(default=1, ge=1)
337
+ width: Optional[int]
338
+ height: Optional[int]
339
+
340
+ class SendTextIntent(BaseIntent):
341
+ action: Literal["send_text"]
342
+ message: str
343
+
344
+ IntentUnion = Union[
345
+ SummarizeIntent, TranslateIntent, JokeIntent, WeatherIntent,
346
+ InspireIntent, MemeIntent, PollCreateIntent, PollVoteIntent,
347
+ PollResultsIntent, PollEndIntent, GenerateImageIntent, SendTextIntent
348
+ ]
349
+
350
+ ACTION_HANDLERS = {
351
+ "summarize": lambda mid,cid,**i: _fn_summarize(mid,cid,i["text"]),
352
+ "translate": lambda mid,cid,**i: _fn_translate(mid,cid,i["lang"],i["text"]),
353
+ "joke": lambda mid,cid,**i: _fn_joke(mid,cid),
354
+ "weather": lambda mid,cid,**i: _fn_weather(mid,cid,i["location"]),
355
+ "inspire": lambda mid,cid,**i: _fn_inspire(mid,cid),
356
+ "meme": lambda mid,cid,**i: _fn_meme(mid,cid,i["text"]),
357
+ "poll_create": lambda mid,cid,**i: _fn_poll_create(mid,cid,i["question"],i["options"]),
358
+ "poll_vote": lambda mid,cid,**i: _fn_poll_vote(mid,cid,i["voter"],i["choice"]),
359
+ "poll_results": lambda mid,cid,**i: _fn_poll_results(mid,cid),
360
+ "poll_end": lambda mid,cid,**i: _fn_poll_end(mid,cid),
361
+ "generate_image": _fn_generate_images,
362
+ "send_text": lambda mid,cid,**i: _fn_send_text(mid,cid,i["message"]),
363
+ }
364
+
365
+ # --- Intent Routing with Fallback ------------------------------------------
366
+
367
+ def route_intent(user_input: str, chat_id: str, sender: str) -> IntentUnion:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
  history_text = get_history_text(chat_id, sender)
369
  sys_prompt = (
370
  "You are Eve. You can either chat or call one of these functions:\n"
371
+ "- summarize(text)\n"
372
+ "- translate(lang, text)\n"
373
+ "- joke()\n"
374
+ "- weather(location)\n"
375
+ "- inspire()\n"
376
+ "- meme(text)\n"
377
+ "- poll_create(question, options)\n"
378
+ "- poll_vote(voter, choice)\n"
379
+ "- poll_results()\n"
380
+ "- poll_end()\n"
381
+ "- generate_image(prompt, count, width, height)\n"
382
+ "- send_text(message)\n\n"
383
+ "Return only raw JSON matching one of these shapes. For example:\n"
384
+ " {\"action\":\"generate_image\",\"prompt\":\"a red fox\",\"count\":3,\"width\":512,\"height\":512}\n"
385
+ "Otherwise, use send_text to reply with plain chat.\n"
386
  )
387
  prompt = (
388
+ f"{sys_prompt}\n"
389
  f"Conversation so far:\n{history_text}\n\n"
390
  f"User: {user_input}"
391
  )
392
  raw = generate_llm(prompt)
393
+
394
+ # 1) Try strict Pydantic parse
395
  try:
396
+ parsed = json.loads(raw)
397
+ intent = IntentUnion.parse_obj(parsed)
398
+ return intent
399
+ except (json.JSONDecodeError, ValidationError) as e:
400
+ logger.warning(f"Strict parse failed: {e}. Falling back to lenient.")
401
+
402
+ # 2) Lenient: basic JSON get + defaults
403
+ try:
404
+ data = json.loads(raw)
405
+ except json.JSONDecodeError:
406
+ return SendTextIntent(action="send_text", message=raw)
407
+
408
+ action = data.get("action")
409
+ if action in ACTION_HANDLERS:
410
+ kwargs = {}
411
+ if action == "generate_image":
412
+ kwargs["prompt"] = data.get("prompt", "")
413
+ kwargs["count"] = int(data.get("count", BotConfig.DEFAULT_IMAGE_COUNT))
414
+ kwargs["width"] = data.get("width")
415
+ kwargs["height"] = data.get("height")
416
+ elif action == "send_text":
417
+ kwargs["message"] = data.get("message", "")
418
+ elif action == "translate":
419
+ kwargs["lang"] = data.get("lang", "")
420
+ kwargs["text"] = data.get("text", "")
421
+ elif action == "summarize":
422
+ kwargs["text"] = data.get("text", "")
423
+ elif action == "weather":
424
+ kwargs["location"] = data.get("location", "")
425
+ elif action == "meme":
426
+ kwargs["text"] = data.get("text", "")
427
+ elif action == "poll_create":
428
+ kwargs["question"] = data.get("question", "")
429
+ kwargs["options"] = data.get("options", [])
430
+ elif action == "poll_vote":
431
+ kwargs["voter"] = sender
432
+ kwargs["choice"] = int(data.get("choice", 0))
433
+ return IntentUnion.parse_obj({"action": action, **kwargs})
434
+
435
+ return SendTextIntent(action="send_text", message=raw)
436
 
437
  # --- FastAPI & Webhook ----------------------------------------------------
438
 
 
472
  return {"success": True}
473
 
474
  body = (tmd.get("textMessage") or tmd.get("text","")).strip()
 
 
 
475
  record_user_message(chat_id, sender, body)
476
 
 
477
  low = body.lower()
478
  if low == "/help":
479
  _fn_send_text(mid, chat_id, help_text)
 
522
  "message_id": mid,
523
  "chat_id": chat_id,
524
  "prompt": pr,
525
+ "count": ct,
526
  "width": width,
527
  "height": height
528
  })
529
  return {"success": True}
530
 
531
+ if tmd.get("contextInfo", {}).get("mentionedJidList"):
 
 
532
  return {"success": True}
533
 
534
+ # Handle quoted replies if needed...
535
+ effective = body
 
 
 
 
 
 
 
 
 
 
 
536
 
 
537
  intent = route_intent(effective, chat_id, sender)
538
+ handler = ACTION_HANDLERS.get(intent.action)
539
+ if handler:
540
+ handler(mid, chat_id, **intent.dict(exclude={"action"}))
 
 
 
 
 
 
541
  else:
542
+ _fn_send_text(mid, chat_id, "Sorry, I didn't understand that.")
 
543
 
544
  return {"success": True}
545
 
requirements.txt CHANGED
@@ -3,4 +3,5 @@ uvicorn[standard]
3
  openai
4
  pillow
5
  requests
6
- supabase
 
 
3
  openai
4
  pillow
5
  requests
6
+ supabase
7
+ pydantic