Ashhar commited on
Commit
9f9844d
·
1 Parent(s): 1270915

first commit

Browse files
Files changed (6) hide show
  1. .gitignore +10 -0
  2. .streamlit/config.toml +5 -0
  3. app.py +328 -0
  4. clients/openRouter.py +172 -0
  5. requirements.txt +11 -0
  6. utils.py +41 -0
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ .env
2
+ .venv
3
+ __pycache__/
4
+ .gitattributes
5
+ gradio_cached_examples/
6
+ app_*.py
7
+ soup_dump*.html
8
+ soup_dump.html
9
+ system_prompt.txt
10
+ scratch.py
.streamlit/config.toml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [client]
2
+ showSidebarNavigation = false
3
+
4
+ [theme]
5
+ base="dark"
app.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import pandas as pd
4
+ from typing import Literal, TypedDict
5
+ from sqlalchemy import create_engine, inspect
6
+ import json
7
+ from transformers import AutoTokenizer
8
+ from utils import pprint
9
+ import time
10
+ import re
11
+
12
+ from openai import OpenAI
13
+ import anthropic
14
+ from clients.openRouter import OpenRouter
15
+
16
+ # Load environment variables
17
+ from dotenv import load_dotenv
18
+ load_dotenv()
19
+
20
+ ModelType = Literal["GPT_4o", "GPT_o1", "CLAUDE", "LLAMA", "DEEPSEEK", "DEEPSEEK_R1", "DEEPSEEK_R1_DISTILL"]
21
+ ModelConfig = TypedDict("ModelConfig", {
22
+ "client": OpenAI | anthropic.Anthropic,
23
+ "model": str,
24
+ "max_context": int,
25
+ "tokenizer": AutoTokenizer
26
+ })
27
+
28
+ MODEL_CONFIG: dict[ModelType, ModelConfig] = {
29
+ "CLAUDE": {
30
+ "client": anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")),
31
+ "model": "claude-3-5-haiku-20241022",
32
+ # "model": "claude-3-5-sonnet-20241022",
33
+ # "model": "claude-3-5-sonnet-20240620",
34
+ "max_context": 40000,
35
+ "tokenizer": AutoTokenizer.from_pretrained("Xenova/claude-tokenizer")
36
+ },
37
+ "GPT_4o": {
38
+ "client": OpenAI(api_key=os.environ.get("OPENAI_API_KEY")),
39
+ "model": "gpt-4o",
40
+ "max_context": 15000,
41
+ "tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o")
42
+ },
43
+ # "GPT_o1": {
44
+ # "client": OpenAI(api_key=os.environ.get("OPENAI_API_KEY")),
45
+ # "model": "o1-preview",
46
+ # "max_context": 15000,
47
+ # "tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o")
48
+ # },
49
+ "DEEPSEEK": {
50
+ "client": OpenRouter(
51
+ api_key=os.environ.get("OPENROUTER_API_KEY"),
52
+ ),
53
+ "model": "deepseek/deepseek-chat",
54
+ "max_context": 30000,
55
+ "tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o")
56
+ },
57
+ "DEEPSEEK_R1": {
58
+ "client": OpenRouter(
59
+ api_key=os.environ.get("OPENROUTER_API_KEY"),
60
+ ),
61
+ "model": "deepseek/deepseek-r1",
62
+ "max_context": 30000,
63
+ "tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o")
64
+ },
65
+ }
66
+
67
+
68
+ def get_model_type():
69
+ """
70
+ Get the model type from Streamlit sidebar with model names
71
+ """
72
+ # Get the available model types from the MODEL_CONFIG keys
73
+ available_models = list(MODEL_CONFIG.keys())
74
+
75
+ # Create a list of display labels with just the model names
76
+ model_display_labels = [
77
+ MODEL_CONFIG[model_type]['model']
78
+ for model_type in available_models
79
+ ]
80
+
81
+ # Add a sidebar selection for model name
82
+ selected_model_name = st.sidebar.selectbox(
83
+ "Select AI Model",
84
+ model_display_labels,
85
+ index=0
86
+ )
87
+
88
+ # Find the corresponding model type for the selected model name
89
+ selected_model_type = next(
90
+ model_type for model_type in available_models
91
+ if MODEL_CONFIG[model_type]['model'] == selected_model_name
92
+ )
93
+
94
+ return selected_model_type
95
+
96
+
97
+ # In the main application flow, replace the previous modelType assignment
98
+ modelType = get_model_type()
99
+
100
+ client = MODEL_CONFIG[modelType]["client"]
101
+ MODEL = MODEL_CONFIG[modelType]["model"]
102
+ TOOLS_MODEL = MODEL_CONFIG[modelType].get("tools_model") or MODEL
103
+ MAX_CONTEXT = MODEL_CONFIG[modelType]["max_context"]
104
+ tokenizer = MODEL_CONFIG[modelType]["tokenizer"]
105
+
106
+ isClaudeModel = modelType == "CLAUDE"
107
+ isDeepSeekModel = modelType.startswith("DEEPSEEK")
108
+
109
+
110
+ def __countTokens(text):
111
+ text = str(text)
112
+ tokens = tokenizer.encode(text, add_special_tokens=False)
113
+ return len(tokens)
114
+
115
+
116
+ # Initialize session state variables
117
+ if "ipAddress" not in st.session_state:
118
+ st.session_state.ipAddress = st.context.headers.get("x-forwarded-for")
119
+ if "connection_string" not in st.session_state:
120
+ st.session_state.connection_string = None
121
+ if "selected_table" not in st.session_state:
122
+ st.session_state.selected_table = None
123
+ if "table_schema" not in st.session_state:
124
+ st.session_state.table_schema = None
125
+ if "sample_data" not in st.session_state:
126
+ st.session_state.sample_data = None
127
+ if "engine" not in st.session_state:
128
+ st.session_state.engine = None
129
+
130
+
131
+ def connect_to_db(connection_string):
132
+ try:
133
+ engine = create_engine(connection_string)
134
+ # Test the connection
135
+ with engine.connect():
136
+ pass
137
+ st.session_state.engine = engine
138
+ return True
139
+ except Exception as e:
140
+ st.error(f"Failed to connect to database: {str(e)}")
141
+ return False
142
+
143
+
144
+ def get_table_schema(table_name):
145
+ if not st.session_state.engine:
146
+ return None
147
+
148
+ inspector = inspect(st.session_state.engine)
149
+ columns = inspector.get_columns(table_name)
150
+ return {col['name']: str(col['type']) for col in columns}
151
+
152
+
153
+ def get_sample_data(table_name):
154
+ if not st.session_state.engine:
155
+ return None
156
+
157
+ query = f"SELECT * FROM {table_name} ORDER BY 1 DESC LIMIT 3"
158
+ try:
159
+ with st.session_state.engine.connect() as conn:
160
+ df = pd.read_sql(query, conn)
161
+ return df
162
+ except Exception as e:
163
+ st.error(f"Error fetching sample data: {str(e)}")
164
+ return None
165
+
166
+
167
+ def clean_sql_response(response: str) -> str:
168
+ """Extract clean SQL query from a potentially formatted response."""
169
+ # If response contains SQL code block, extract it
170
+ sql_block_match = re.search(r'```sql\n(.*?)\n```', response, re.DOTALL)
171
+ if sql_block_match:
172
+ return sql_block_match.group(1).strip()
173
+ return response.strip()
174
+
175
+
176
+ def execute_query(query):
177
+ if not st.session_state.engine:
178
+ return None
179
+
180
+ try:
181
+ start_time = time.time()
182
+ with st.spinner("Executing SQL query..."):
183
+ with st.session_state.engine.connect() as conn:
184
+ df = pd.read_sql(query, conn)
185
+ execution_time = time.time() - start_time
186
+ pprint(f"[Query Execution] Latency: {execution_time:.2f}s")
187
+ return df
188
+ except Exception as e:
189
+ st.error(f"Error executing query: {str(e)}")
190
+ return None
191
+
192
+
193
+ def generate_sql_query(user_query):
194
+ prompt = f"""You are a SQL expert. Generate a valid PostgreSQL query based on the following context and user query.
195
+
196
+ Table Name: {st.session_state.selected_table}
197
+
198
+ Table Schema:
199
+ {json.dumps(st.session_state.table_schema, indent=2)}
200
+
201
+ Sample Data:
202
+ {st.session_state.sample_data.to_markdown(index=False)}
203
+
204
+ Important:
205
+ 1. Only return the SQL query, nothing else
206
+ 2. The query should be valid PostgreSQL syntax
207
+ 3. Do not include any explanations or comments
208
+ 4. Make sure to handle NULL values appropriately
209
+ 5. Use the table name '{st.session_state.selected_table}' in your query
210
+
211
+ User Query: {user_query}
212
+ """
213
+
214
+ prompt_tokens = __countTokens(prompt)
215
+ pprint(f"[{MODEL}] Prompt tokens for SQL generation: {prompt_tokens}")
216
+
217
+ # Debug prompt in a Streamlit expander for better organization
218
+ with st.expander("Debug: Prompt Generation"):
219
+ st.write(f"\nUser Query: {user_query}")
220
+ st.write("\nFull Prompt:")
221
+ st.code(prompt, language="text")
222
+
223
+ start_time = time.time()
224
+ with st.spinner(f"Generating SQL query using {MODEL}..."):
225
+ if isClaudeModel:
226
+ response = client.messages.create(
227
+ model=MODEL,
228
+ max_tokens=1000,
229
+ messages=[
230
+ {"role": "user", "content": prompt},
231
+ ]
232
+ )
233
+ raw_response = response.content[0].text
234
+ else:
235
+ response = client.chat.completions.create(
236
+ model=MODEL,
237
+ messages=[
238
+ {"role": "user", "content": prompt},
239
+ ]
240
+ )
241
+ raw_response = response.choices[0].message.content
242
+
243
+ generation_time = time.time() - start_time
244
+ pprint(f"[{MODEL}] Query Generation Latency: {generation_time:.2f}s")
245
+
246
+ return clean_sql_response(raw_response)
247
+
248
+
249
+ # UI Components
250
+ st.title("SQL Query Assistant")
251
+
252
+ # Database Connection Section
253
+ st.header("1. Database Connection")
254
+ connection_string = st.text_input(
255
+ "Enter PostgreSQL Connection String",
256
+ value=st.session_state.connection_string if st.session_state.connection_string else "",
257
+ type="password"
258
+ )
259
+
260
+ if connection_string and connection_string != st.session_state.connection_string:
261
+ if connect_to_db(connection_string):
262
+ st.session_state.connection_string = connection_string
263
+ st.success("Successfully connected to database!")
264
+
265
+ # Table Selection Section
266
+ if st.session_state.connection_string:
267
+ st.header("2. Table Selection")
268
+ inspector = inspect(st.session_state.engine)
269
+ tables = inspector.get_table_names()
270
+
271
+ # Set default index to 'lsq_leads' if present, otherwise 0
272
+ default_index = tables.index('lsq_leads') if 'lsq_leads' in tables else 0
273
+ selected_table = st.selectbox("Select a table", tables, index=default_index)
274
+
275
+ # Create containers for schema and data
276
+ schema_container = st.container()
277
+ data_container = st.container()
278
+
279
+ # Always load table data if we have a selected table
280
+ if selected_table:
281
+ # Update session state
282
+ if selected_table != st.session_state.selected_table:
283
+ st.session_state.selected_table = selected_table
284
+
285
+ # Always fetch schema and sample data
286
+ st.session_state.table_schema = get_table_schema(selected_table)
287
+ st.session_state.sample_data = get_sample_data(selected_table)
288
+
289
+ # Always display schema and sample data if available
290
+ with schema_container:
291
+ if st.session_state.table_schema:
292
+ st.subheader("Table Schema")
293
+ # Force immediate rendering with an empty element
294
+ st.empty()
295
+ st.json(st.session_state.table_schema)
296
+
297
+ with data_container:
298
+ if st.session_state.sample_data is not None:
299
+ st.subheader("Sample Data (Last 3 rows)")
300
+ # Force immediate rendering with an empty element
301
+ st.empty()
302
+ st.dataframe(
303
+ st.session_state.sample_data,
304
+ use_container_width=True,
305
+ hide_index=True
306
+ )
307
+
308
+ # Query Input Section
309
+ if st.session_state.selected_table:
310
+ st.header("3. Query Input")
311
+ user_query = st.text_area("Enter your query in plain English")
312
+
313
+ if st.button("Generate and Execute Query"):
314
+ if user_query:
315
+ # Generate SQL query
316
+ sql_query = generate_sql_query(user_query)
317
+
318
+ # Display the generated query
319
+ st.subheader("Generated SQL Query")
320
+ st.code(sql_query, language="sql")
321
+
322
+ # Execute the query
323
+ results = execute_query(sql_query)
324
+ if results is not None:
325
+ st.subheader("Query Results")
326
+ st.dataframe(results)
327
+
328
+
clients/openRouter.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+ from typing import List, Dict, Optional
4
+
5
+
6
+ class ResponseWrapper:
7
+ def __init__(self, response_data):
8
+ """
9
+ Wrap the response data to support both dict-like and attribute-like access
10
+
11
+ :param response_data: The raw response dictionary from OpenRouter
12
+ """
13
+ self._data = response_data
14
+
15
+ def __getattr__(self, name):
16
+ """
17
+ Allow attribute-style access to the response data
18
+
19
+ :param name: Attribute name to access
20
+ :return: Corresponding value from the response data
21
+ """
22
+ if name in self._data:
23
+ value = self._data[name]
24
+ return self._wrap(value)
25
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
26
+
27
+ def __getitem__(self, key):
28
+ """
29
+ Allow dictionary-style access to the response data
30
+
31
+ :param key: Key to access
32
+ :return: Corresponding value from the response data
33
+ """
34
+ value = self._data[key]
35
+ return self._wrap(value)
36
+
37
+ def _wrap(self, value):
38
+ """
39
+ Recursively wrap dictionaries and lists to support attribute access
40
+
41
+ :param value: Value to wrap
42
+ :return: Wrapped value
43
+ """
44
+ if isinstance(value, dict):
45
+ return ResponseWrapper(value)
46
+ elif isinstance(value, list):
47
+ return [self._wrap(item) for item in value]
48
+ return value
49
+
50
+ def __iter__(self):
51
+ """
52
+ Allow iteration over the wrapped dictionary
53
+ """
54
+ return iter(self._data)
55
+
56
+ def get(self, key, default=None):
57
+ """
58
+ Provide a get method similar to dictionary
59
+ """
60
+ return self._wrap(self._data.get(key, default))
61
+
62
+ def keys(self):
63
+ """
64
+ Return dictionary keys
65
+ """
66
+ return self._data.keys()
67
+
68
+ def items(self):
69
+ """
70
+ Return dictionary items
71
+ """
72
+ return [(k, self._wrap(v)) for k, v in self._data.items()]
73
+
74
+ def __str__(self):
75
+ """
76
+ Return a JSON string representation of the response data
77
+
78
+ :return: JSON-formatted string of the response
79
+ """
80
+ return json.dumps(self._data, indent=2)
81
+
82
+ def __repr__(self):
83
+ """
84
+ Return a string representation for debugging
85
+
86
+ :return: Representation of the ResponseWrapper
87
+ """
88
+ return f"ResponseWrapper({json.dumps(self._data, indent=2)})"
89
+
90
+
91
+ class OpenRouter:
92
+ def __init__(self, api_key: str, base_url: str = "https://openrouter.ai/api/v1"):
93
+ """
94
+ Initialize OpenRouter client
95
+
96
+ :param api_key: API key for OpenRouter
97
+ :param base_url: Base URL for OpenRouter API (default is standard endpoint)
98
+ """
99
+ self.api_key = api_key
100
+ self.base_url = base_url
101
+ self.chat = self.ChatNamespace(self)
102
+
103
+ class ChatNamespace:
104
+ def __init__(self, client):
105
+ self._client = client
106
+ self.completions = self.CompletionsNamespace(client)
107
+
108
+ class CompletionsNamespace:
109
+ def __init__(self, client):
110
+ self._client = client
111
+
112
+ def create(
113
+ self,
114
+ model: str,
115
+ messages: List[Dict[str, str]],
116
+ temperature: float = 0.7,
117
+ max_tokens: Optional[int] = None,
118
+ **kwargs
119
+ ):
120
+ """
121
+ Create a chat completion request
122
+
123
+ :param model: Model to use
124
+ :param messages: List of message dictionaries
125
+ :param temperature: Sampling temperature
126
+ :param max_tokens: Maximum number of tokens to generate
127
+ :return: Wrapped response object
128
+ """
129
+ headers = {
130
+ "Authorization": f"Bearer {self._client.api_key}",
131
+ "Content-Type": "application/json",
132
+ "HTTP-Referer": kwargs.get("http_referer", "https://your-app-domain.com"),
133
+ "X-Title": kwargs.get("x_title", "AI Ad Generator")
134
+ }
135
+
136
+ payload = {
137
+ "model": model,
138
+ "messages": messages,
139
+ "temperature": temperature,
140
+ }
141
+
142
+ if model.startswith("deepseek"):
143
+ payload["provider"] = {
144
+ "order": [
145
+ "DeepSeek",
146
+ "DeepInfra",
147
+ "Fireworks",
148
+ ],
149
+ "allow_fallbacks": False
150
+ }
151
+
152
+ if max_tokens is not None:
153
+ payload["max_tokens"] = max_tokens
154
+
155
+ # Add any additional parameters
156
+ payload.update({k: v for k, v in kwargs.items()
157
+ if k not in ["http_referer", "x_title"]})
158
+
159
+ try:
160
+ response = requests.post(
161
+ f"{self._client.base_url}/chat/completions",
162
+ headers=headers,
163
+ data=json.dumps(payload)
164
+ )
165
+
166
+ response.raise_for_status()
167
+
168
+ # Wrap the response data
169
+ return ResponseWrapper(response.json())
170
+
171
+ except requests.RequestException as e:
172
+ raise Exception(f"OpenRouter API request failed: {e}")
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # streamlit
2
+ # pandas
3
+
4
+ python-dotenv
5
+ # groq
6
+ openai
7
+ transformers
8
+ # gradio_client
9
+ anthropic
10
+ sqlalchemy
11
+ psycopg2-binary
utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime as DT
2
+ import pytz
3
+ import streamlit as st
4
+
5
+
6
+ FONTS = [
7
+ # "Poppins:ital,wght@0,100;0,200;0,300;0,400;0,500;0,600;0,700;0,800;0,900;1,100;1,200;1,300;1,400;1,500;1,600;1,700;1,800;1,900",
8
+ # "Roboto:ital,wght@0,100;0,300;0,400;0,500;0,700;0,900;1,100;1,300;1,400;1,500;1,700;1,900",
9
+ # "Raleway:ital,wght@0,100..900;1,100..900",
10
+ # "Lato:ital,wght@0,100;0,300;0,400;0,700;0,900;1,100;1,300;1,400;1,700;1,900",
11
+ # "Nunito:ital,wght@0,200..1000;1,200..1000",
12
+ # "Quicksand:[email protected]",
13
+ "Montserrat:ital,wght@0,100..900;1,100..900",
14
+ # "Edu+AU+VIC+WA+NT+Dots:[email protected]",
15
+ "Whisper",
16
+ # "Merienda:[email protected]",
17
+ "Playwrite+DE+Grund:[email protected]",
18
+ # "Roboto+Slab:[email protected]",
19
+ # "Open+Sans:ital,wght@0,300..800;1,300..800",
20
+ # "Nunito+Sans:ital,opsz,wght@0,6..12,200..1000;1,6..12,200..1000",
21
+ # "Ubuntu:ital,wght@0,300;0,400;0,500;0,700;1,300;1,400;1,500;1,700",
22
+ ]
23
+
24
+
25
+ def __nowInIST() -> DT.datetime:
26
+ return DT.datetime.now(pytz.timezone("Asia/Kolkata"))
27
+
28
+
29
+ def pprint(log: str):
30
+ now = __nowInIST()
31
+ now = now.strftime("%Y-%m-%d %H:%M:%S")
32
+ print(f"[{now}] [{st.session_state.ipAddress}] {log}")
33
+
34
+
35
+ def getFontsUrl():
36
+ baseLink = "https://fonts.googleapis.com/css2"
37
+ params = "&".join([f"family={font}" for font in FONTS])
38
+ params = f"{params}&display=swap"
39
+ fontsUrl = f"{baseLink}?{params}"
40
+ # pprint(f"{fontsUrl=}")
41
+ return fontsUrl