Update app.py
Browse files
app.py
CHANGED
@@ -4,19 +4,11 @@ import requests
|
|
4 |
import openai
|
5 |
import gradio as gr
|
6 |
import asyncio
|
7 |
-
from langgraph import Graph, FunctionNode, RouterNode
|
8 |
from gtts import gTTS
|
|
|
|
|
9 |
|
10 |
-
|
11 |
-
"""Convert speech to text using OpenAI Whisper API"""
|
12 |
-
with open(audio_path, "rb") as afile:
|
13 |
-
transcript = openai.audio.transcriptions.create(
|
14 |
-
model="whisper-1",
|
15 |
-
file=afile
|
16 |
-
)
|
17 |
-
return transcript.text.strip()
|
18 |
-
|
19 |
-
# Load API keys from environment
|
20 |
openai.api_key = os.getenv("OPENAI_API_KEY")
|
21 |
|
22 |
# --- Business Logic Functions ---
|
@@ -24,24 +16,22 @@ def db_agent(query: str) -> str:
|
|
24 |
try:
|
25 |
conn = sqlite3.connect("shop.db")
|
26 |
cur = conn.cursor()
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
return "No transactions found for today."
|
42 |
-
return None
|
43 |
except sqlite3.OperationalError as e:
|
44 |
-
return f"Database error: {e}. Please initialize 'transactions' table in shop.db."
|
45 |
|
46 |
def web_search_agent(query: str) -> str:
|
47 |
try:
|
@@ -67,54 +57,80 @@ def llm_agent(query: str) -> str:
|
|
67 |
)
|
68 |
return response.choices[0].message.content.strip()
|
69 |
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
def tts_agent(text: str, lang: str = 'en') -> str:
|
73 |
-
"""Convert text to speech mp3 and return filepath"""
|
74 |
tts = gTTS(text=text, lang=lang)
|
75 |
out_path = "response_audio.mp3"
|
76 |
tts.save(out_path)
|
77 |
return out_path
|
78 |
|
79 |
-
# --- LangGraph
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
]
|
87 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
-
|
90 |
-
|
91 |
-
llm_node = FunctionNode(func=llm_agent, name="llm")
|
92 |
-
|
93 |
-
# Build Graph
|
94 |
-
graph = Graph("shop-assistant")
|
95 |
-
graph.add_nodes([router_node, db_node, web_node, llm_node])
|
96 |
-
graph.add_edge("router", "db", condition=lambda r: r == "db")
|
97 |
-
graph.add_edge("router", "web", condition=lambda r: r == "web")
|
98 |
-
graph.add_edge("router", "llm", condition=lambda r: r == "llm")
|
99 |
-
|
100 |
-
async def graph_handler(query: str) -> str:
|
101 |
-
# If audio file path passed, convert to text first
|
102 |
-
if query.startswith("audio://"):
|
103 |
-
audio_path = query.replace("audio://", "")
|
104 |
-
query = stt_agent(audio_path)
|
105 |
-
text_resp = await graph.run(input=query, start_node="router")
|
106 |
-
return text_resp
|
107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
def handle_query(audio_or_text: str):
|
109 |
-
# Determine output type
|
110 |
is_audio = audio_or_text.endswith('.wav') or audio_or_text.endswith('.mp3')
|
111 |
-
text_input = f"audio://{audio_or_text}" if is_audio else audio_or_text
|
112 |
-
text_resp = asyncio.run(graph_handler(text_input))
|
113 |
if is_audio:
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
# --- Gradio UI ---
|
120 |
with gr.Blocks() as demo:
|
@@ -123,7 +139,6 @@ with gr.Blocks() as demo:
|
|
123 |
out_text = gr.Textbox(label="Answer (text)")
|
124 |
out_audio = gr.Audio(label="Answer (speech)")
|
125 |
submit = gr.Button("Submit")
|
126 |
-
# Examples
|
127 |
gr.Examples(
|
128 |
examples=[
|
129 |
["What is the max revenue product today?"],
|
|
|
4 |
import openai
|
5 |
import gradio as gr
|
6 |
import asyncio
|
|
|
7 |
from gtts import gTTS
|
8 |
+
from typing_extensions import TypedDict
|
9 |
+
from langgraph.graph import StateGraph, START, END
|
10 |
|
11 |
+
# Load API keys
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
openai.api_key = os.getenv("OPENAI_API_KEY")
|
13 |
|
14 |
# --- Business Logic Functions ---
|
|
|
16 |
try:
|
17 |
conn = sqlite3.connect("shop.db")
|
18 |
cur = conn.cursor()
|
19 |
+
cur.execute(
|
20 |
+
"""
|
21 |
+
SELECT product, SUM(amount) AS revenue
|
22 |
+
FROM transactions
|
23 |
+
WHERE date = date('now')
|
24 |
+
GROUP BY product
|
25 |
+
ORDER BY revenue DESC
|
26 |
+
LIMIT 1
|
27 |
+
"""
|
28 |
+
)
|
29 |
+
row = cur.fetchone()
|
30 |
+
if row:
|
31 |
+
return f"Top product today: {row[0]} with ₹{row[1]:,.2f}"
|
32 |
+
return "No transactions found for today."
|
|
|
|
|
33 |
except sqlite3.OperationalError as e:
|
34 |
+
return f"Database error: {e}. Please initialize 'transactions' table in shop.db."
|
35 |
|
36 |
def web_search_agent(query: str) -> str:
|
37 |
try:
|
|
|
57 |
)
|
58 |
return response.choices[0].message.content.strip()
|
59 |
|
60 |
+
def stt_agent(audio_path: str) -> str:
|
61 |
+
with open(audio_path, "rb") as afile:
|
62 |
+
transcript = openai.audio.transcriptions.create(
|
63 |
+
model="whisper-1",
|
64 |
+
file=afile
|
65 |
+
)
|
66 |
+
return transcript.text.strip()
|
67 |
|
68 |
def tts_agent(text: str, lang: str = 'en') -> str:
|
|
|
69 |
tts = gTTS(text=text, lang=lang)
|
70 |
out_path = "response_audio.mp3"
|
71 |
tts.save(out_path)
|
72 |
return out_path
|
73 |
|
74 |
+
# --- LangGraph State and Nodes ---
|
75 |
+
class State(TypedDict):
|
76 |
+
query: str
|
77 |
+
result: str
|
78 |
+
|
79 |
+
# Routing logic based on query
|
80 |
+
def route_fn(state: State) -> str:
|
81 |
+
q = state["query"].lower()
|
82 |
+
if any(k in q for k in ["max revenue", "revenue"]):
|
83 |
+
return "db"
|
84 |
+
if any(k in q for k in ["who", "what", "when", "where"]):
|
85 |
+
return "web"
|
86 |
+
return "llm"
|
87 |
+
|
88 |
+
# Node implementations
|
89 |
+
|
90 |
+
def router_node(state: State) -> dict:
|
91 |
+
return {"query": state["query"]}
|
92 |
+
|
93 |
+
def db_node(state: State) -> dict:
|
94 |
+
return {"result": db_agent(state["query"]) }
|
95 |
|
96 |
+
def web_node(state: State) -> dict:
|
97 |
+
return {"result": web_search_agent(state["query"]) }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
+
def llm_node(state: State) -> dict:
|
100 |
+
return {"result": llm_agent(state["query"]) }
|
101 |
+
|
102 |
+
# Build the LangGraph
|
103 |
+
builder = StateGraph(State)
|
104 |
+
builder.add_node("router", router_node)
|
105 |
+
builder.set_entry_point("router")
|
106 |
+
builder.set_conditional_entry_point(
|
107 |
+
route_fn,
|
108 |
+
path_map={"db": "db", "web": "web", "llm": "llm"}
|
109 |
+
)
|
110 |
+
builder.add_node("db", db_node)
|
111 |
+
builder.add_node("web", web_node)
|
112 |
+
builder.add_node("llm", llm_node)
|
113 |
+
builder.add_edge(START, "router")
|
114 |
+
builder.add_edge("db", END)
|
115 |
+
builder.add_edge("web", END)
|
116 |
+
builder.add_edge("llm", END)
|
117 |
+
graph = builder.compile()
|
118 |
+
|
119 |
+
# Handler integrates STT/TTS and graph execution
|
120 |
def handle_query(audio_or_text: str):
|
|
|
121 |
is_audio = audio_or_text.endswith('.wav') or audio_or_text.endswith('.mp3')
|
|
|
|
|
122 |
if is_audio:
|
123 |
+
query = stt_agent(audio_or_text)
|
124 |
+
else:
|
125 |
+
query = audio_or_text
|
126 |
+
|
127 |
+
state = graph.invoke({"query": query})
|
128 |
+
response = state["result"]
|
129 |
+
|
130 |
+
if is_audio:
|
131 |
+
audio_path = tts_agent(response)
|
132 |
+
return response, audio_path
|
133 |
+
return response
|
134 |
|
135 |
# --- Gradio UI ---
|
136 |
with gr.Blocks() as demo:
|
|
|
139 |
out_text = gr.Textbox(label="Answer (text)")
|
140 |
out_audio = gr.Audio(label="Answer (speech)")
|
141 |
submit = gr.Button("Submit")
|
|
|
142 |
gr.Examples(
|
143 |
examples=[
|
144 |
["What is the max revenue product today?"],
|