arjunanand13 commited on
Commit
7598edf
·
verified ·
1 Parent(s): b5ac2d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -65
app.py CHANGED
@@ -4,19 +4,11 @@ import requests
4
  import openai
5
  import gradio as gr
6
  import asyncio
7
- from langgraph import Graph, FunctionNode, RouterNode
8
  from gtts import gTTS
 
 
9
 
10
- def stt_agent(audio_path: str) -> str:
11
- """Convert speech to text using OpenAI Whisper API"""
12
- with open(audio_path, "rb") as afile:
13
- transcript = openai.audio.transcriptions.create(
14
- model="whisper-1",
15
- file=afile
16
- )
17
- return transcript.text.strip()
18
-
19
- # Load API keys from environment
20
  openai.api_key = os.getenv("OPENAI_API_KEY")
21
 
22
  # --- Business Logic Functions ---
@@ -24,24 +16,22 @@ def db_agent(query: str) -> str:
24
  try:
25
  conn = sqlite3.connect("shop.db")
26
  cur = conn.cursor()
27
- if "max revenue" in query.lower():
28
- cur.execute(
29
- """
30
- SELECT product, SUM(amount) AS revenue
31
- FROM transactions
32
- WHERE date = date('now')
33
- GROUP BY product
34
- ORDER BY revenue DESC
35
- LIMIT 1
36
- """
37
- )
38
- row = cur.fetchone()
39
- if row:
40
- return f"Top product today: {row[0]} with ₹{row[1]:,.2f}"
41
- return "No transactions found for today."
42
- return None
43
  except sqlite3.OperationalError as e:
44
- return f"Database error: {e}. Please initialize 'transactions' table in shop.db."
45
 
46
  def web_search_agent(query: str) -> str:
47
  try:
@@ -67,54 +57,80 @@ def llm_agent(query: str) -> str:
67
  )
68
  return response.choices[0].message.content.strip()
69
 
70
- # Text-to-Speech
 
 
 
 
 
 
71
 
72
  def tts_agent(text: str, lang: str = 'en') -> str:
73
- """Convert text to speech mp3 and return filepath"""
74
  tts = gTTS(text=text, lang=lang)
75
  out_path = "response_audio.mp3"
76
  tts.save(out_path)
77
  return out_path
78
 
79
- # --- LangGraph Multi-Agent Setup ---
80
- router_node = RouterNode(
81
- name="router",
82
- routes=[
83
- (lambda q: any(k in q.lower() for k in ["max revenue", "revenue"]), "db"),
84
- (lambda q: any(k in q.lower() for k in ["who", "what", "when", "where"]), "web"),
85
- (lambda q: True, "llm"),
86
- ]
87
- )
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- db_node = FunctionNode(func=db_agent, name="db")
90
- web_node = FunctionNode(func=web_search_agent, name="web")
91
- llm_node = FunctionNode(func=llm_agent, name="llm")
92
-
93
- # Build Graph
94
- graph = Graph("shop-assistant")
95
- graph.add_nodes([router_node, db_node, web_node, llm_node])
96
- graph.add_edge("router", "db", condition=lambda r: r == "db")
97
- graph.add_edge("router", "web", condition=lambda r: r == "web")
98
- graph.add_edge("router", "llm", condition=lambda r: r == "llm")
99
-
100
- async def graph_handler(query: str) -> str:
101
- # If audio file path passed, convert to text first
102
- if query.startswith("audio://"):
103
- audio_path = query.replace("audio://", "")
104
- query = stt_agent(audio_path)
105
- text_resp = await graph.run(input=query, start_node="router")
106
- return text_resp
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  def handle_query(audio_or_text: str):
109
- # Determine output type
110
  is_audio = audio_or_text.endswith('.wav') or audio_or_text.endswith('.mp3')
111
- text_input = f"audio://{audio_or_text}" if is_audio else audio_or_text
112
- text_resp = asyncio.run(graph_handler(text_input))
113
  if is_audio:
114
- # Return both text and audio
115
- audio_path = tts_agent(text_resp)
116
- return text_resp, audio_path
117
- return text_resp
 
 
 
 
 
 
 
118
 
119
  # --- Gradio UI ---
120
  with gr.Blocks() as demo:
@@ -123,7 +139,6 @@ with gr.Blocks() as demo:
123
  out_text = gr.Textbox(label="Answer (text)")
124
  out_audio = gr.Audio(label="Answer (speech)")
125
  submit = gr.Button("Submit")
126
- # Examples
127
  gr.Examples(
128
  examples=[
129
  ["What is the max revenue product today?"],
 
4
  import openai
5
  import gradio as gr
6
  import asyncio
 
7
  from gtts import gTTS
8
+ from typing_extensions import TypedDict
9
+ from langgraph.graph import StateGraph, START, END
10
 
11
+ # Load API keys
 
 
 
 
 
 
 
 
 
12
  openai.api_key = os.getenv("OPENAI_API_KEY")
13
 
14
  # --- Business Logic Functions ---
 
16
  try:
17
  conn = sqlite3.connect("shop.db")
18
  cur = conn.cursor()
19
+ cur.execute(
20
+ """
21
+ SELECT product, SUM(amount) AS revenue
22
+ FROM transactions
23
+ WHERE date = date('now')
24
+ GROUP BY product
25
+ ORDER BY revenue DESC
26
+ LIMIT 1
27
+ """
28
+ )
29
+ row = cur.fetchone()
30
+ if row:
31
+ return f"Top product today: {row[0]} with ₹{row[1]:,.2f}"
32
+ return "No transactions found for today."
 
 
33
  except sqlite3.OperationalError as e:
34
+ return f"Database error: {e}. Please initialize 'transactions' table in shop.db."
35
 
36
  def web_search_agent(query: str) -> str:
37
  try:
 
57
  )
58
  return response.choices[0].message.content.strip()
59
 
60
+ def stt_agent(audio_path: str) -> str:
61
+ with open(audio_path, "rb") as afile:
62
+ transcript = openai.audio.transcriptions.create(
63
+ model="whisper-1",
64
+ file=afile
65
+ )
66
+ return transcript.text.strip()
67
 
68
  def tts_agent(text: str, lang: str = 'en') -> str:
 
69
  tts = gTTS(text=text, lang=lang)
70
  out_path = "response_audio.mp3"
71
  tts.save(out_path)
72
  return out_path
73
 
74
+ # --- LangGraph State and Nodes ---
75
+ class State(TypedDict):
76
+ query: str
77
+ result: str
78
+
79
+ # Routing logic based on query
80
+ def route_fn(state: State) -> str:
81
+ q = state["query"].lower()
82
+ if any(k in q for k in ["max revenue", "revenue"]):
83
+ return "db"
84
+ if any(k in q for k in ["who", "what", "when", "where"]):
85
+ return "web"
86
+ return "llm"
87
+
88
+ # Node implementations
89
+
90
+ def router_node(state: State) -> dict:
91
+ return {"query": state["query"]}
92
+
93
+ def db_node(state: State) -> dict:
94
+ return {"result": db_agent(state["query"]) }
95
 
96
+ def web_node(state: State) -> dict:
97
+ return {"result": web_search_agent(state["query"]) }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
+ def llm_node(state: State) -> dict:
100
+ return {"result": llm_agent(state["query"]) }
101
+
102
+ # Build the LangGraph
103
+ builder = StateGraph(State)
104
+ builder.add_node("router", router_node)
105
+ builder.set_entry_point("router")
106
+ builder.set_conditional_entry_point(
107
+ route_fn,
108
+ path_map={"db": "db", "web": "web", "llm": "llm"}
109
+ )
110
+ builder.add_node("db", db_node)
111
+ builder.add_node("web", web_node)
112
+ builder.add_node("llm", llm_node)
113
+ builder.add_edge(START, "router")
114
+ builder.add_edge("db", END)
115
+ builder.add_edge("web", END)
116
+ builder.add_edge("llm", END)
117
+ graph = builder.compile()
118
+
119
+ # Handler integrates STT/TTS and graph execution
120
  def handle_query(audio_or_text: str):
 
121
  is_audio = audio_or_text.endswith('.wav') or audio_or_text.endswith('.mp3')
 
 
122
  if is_audio:
123
+ query = stt_agent(audio_or_text)
124
+ else:
125
+ query = audio_or_text
126
+
127
+ state = graph.invoke({"query": query})
128
+ response = state["result"]
129
+
130
+ if is_audio:
131
+ audio_path = tts_agent(response)
132
+ return response, audio_path
133
+ return response
134
 
135
  # --- Gradio UI ---
136
  with gr.Blocks() as demo:
 
139
  out_text = gr.Textbox(label="Answer (text)")
140
  out_audio = gr.Audio(label="Answer (speech)")
141
  submit = gr.Button("Submit")
 
142
  gr.Examples(
143
  examples=[
144
  ["What is the max revenue product today?"],