DrishtiSharma commited on
Commit
5ab3e50
Β·
verified Β·
1 Parent(s): 419fe53

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -2
app.py CHANGED
@@ -89,7 +89,7 @@ if st.session_state.df is not None and st.session_state.show_preview:
89
 
90
 
91
 
92
- def ask_gpt4o_for_visualization(query, df, llm):
93
  columns = ', '.join(df.columns)
94
  prompt = f"""
95
  Analyze the query and suggest one or more relevant visualizations.
@@ -110,7 +110,85 @@ def ask_gpt4o_for_visualization(query, df, llm):
110
  return json.loads(response)
111
  except json.JSONDecodeError:
112
  st.error("⚠️ GPT-4o failed to generate a valid suggestion.")
113
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  def add_stats_to_figure(fig, df, y_axis, chart_type):
116
  """
 
89
 
90
 
91
 
92
+ """def ask_gpt4o_for_visualization(query, df, llm):
93
  columns = ', '.join(df.columns)
94
  prompt = f"""
95
  Analyze the query and suggest one or more relevant visualizations.
 
110
  return json.loads(response)
111
  except json.JSONDecodeError:
112
  st.error("⚠️ GPT-4o failed to generate a valid suggestion.")
113
+ return None"""
114
+
115
+ def ask_gpt4o_for_visualization(query, df, llm, retries=2):
116
+ import json
117
+
118
+ # Identify numeric and categorical columns
119
+ numeric_columns = df.select_dtypes(include='number').columns.tolist()
120
+ categorical_columns = df.select_dtypes(exclude='number').columns.tolist()
121
+
122
+ # Enhanced Prompt with Clear Instructions
123
+ prompt = f"""
124
+ Analyze the following query and suggest the most suitable visualization(s) using the dataset.
125
+
126
+ **Query:** "{query}"
127
+
128
+ **Numeric Columns (for Y-axis):** {', '.join(numeric_columns) if numeric_columns else 'None'}
129
+ **Categorical Columns (for X-axis or grouping):** {', '.join(categorical_columns) if categorical_columns else 'None'}
130
+
131
+ Suggest visualizations in this exact JSON format:
132
+ [
133
+ {{
134
+ "chart_type": "bar/box/line/scatter/pie/heatmap",
135
+ "x_axis": "categorical_or_time_column",
136
+ "y_axis": "numeric_column",
137
+ "group_by": "optional_column_for_grouping",
138
+ "title": "Title of the chart",
139
+ "description": "Why this chart is suitable"
140
+ }}
141
+ ]
142
+
143
+ **Examples:**
144
+ - For salary distribution:
145
+ {{
146
+ "chart_type": "box",
147
+ "x_axis": "job_title",
148
+ "y_axis": "salary_in_usd",
149
+ "group_by": "experience_level",
150
+ "title": "Salary Distribution by Job Title and Experience",
151
+ "description": "A box plot showing salary ranges across job titles and experience levels."
152
+ }}
153
+
154
+ - For trend analysis:
155
+ {{
156
+ "chart_type": "line",
157
+ "x_axis": "year",
158
+ "y_axis": "revenue",
159
+ "group_by": null,
160
+ "title": "Revenue Growth Over Years",
161
+ "description": "A line chart showing the trend of revenue over the years."
162
+ }}
163
+
164
+ Only suggest visualizations that make sense for the data and the query.
165
+ """
166
+
167
+ for attempt in range(retries + 1):
168
+ try:
169
+ # Generate response from the model
170
+ response = llm.generate(prompt)
171
+
172
+ # Load JSON response
173
+ suggestions = json.loads(response)
174
+
175
+ # Validate response structure
176
+ if isinstance(suggestions, list):
177
+ valid_suggestions = [
178
+ s for s in suggestions if all(k in s for k in ["chart_type", "x_axis", "y_axis"])
179
+ ]
180
+ if valid_suggestions:
181
+ return valid_suggestions
182
+ else:
183
+ st.warning("⚠️ GPT-4o did not suggest valid visualizations.")
184
+ return None
185
+
186
+ elif isinstance(suggestions, dict):
187
+ if all(k in suggestions for k in ["chart_type", "x_axis", "y_axis"]):
188
+ return [suggestions]
189
+ else:
190
+ st.warning("⚠️ GPT-4o's suggestion is
191
+
192
 
193
  def add_stats_to_figure(fig, df, y_axis, chart_type):
194
  """