hertogateis commited on
Commit
0a87f1d
·
verified ·
1 Parent(s): 981c3bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -76
app.py CHANGED
@@ -3,7 +3,6 @@ import streamlit as st
3
  from st_aggrid import AgGrid
4
  import pandas as pd
5
  from transformers import pipeline, T5ForConditionalGeneration, T5Tokenizer
6
- import plotly.express as px
7
 
8
  # Set the page layout for Streamlit
9
  st.set_page_config(layout="wide")
@@ -29,9 +28,6 @@ st.markdown("""
29
  - **max()**: For "max", it computes the maximum value in the DataFrame.
30
  - **min()**: For "min", it computes the minimum value in the DataFrame.
31
  - **count()**: For "count", it counts the non-null values in the entire DataFrame.
32
- - **Graph Generation**: You can ask questions like "make a graph of column sales?" or "make a graph between sales and expenses?". The app will generate interactive graphs for you.
33
-
34
- Upload your data and ask questions to get both answers and visualizations.
35
  """)
36
 
37
  # File uploader in the sidebar
@@ -77,82 +73,26 @@ else:
77
  # User input for the question
78
  question = st.text_input('Type your question')
79
 
80
- # Check if the question is about generating a graph
81
- is_graph_query = False
82
- is_count_query = False
83
-
84
- # Check if the question contains "graph"
85
-
86
- if 'graph' in question.lower():
87
- is_graph_query = True
88
-
89
  # Process the answer using TAPAS and T5
90
  with st.spinner():
91
  if st.button('Answer'):
92
  try:
93
- if not is_graph_query:
94
- # Process TAPAS-related questions if it's not a graph query
95
- raw_answer = tqa(table=df, query=question, truncation=True)
96
-
97
- # Display raw answer from TAPAS on the screen
98
- st.markdown("<p style='font-family:sans-serif;font-size: 1rem;'>Raw TAPAS Answer: </p>", unsafe_allow_html=True)
99
- st.write(raw_answer) # Display the raw TAPAS output
100
-
101
- # Extract relevant values for Plotly
102
- answer = raw_answer.get('answer', '')
103
- coordinates = raw_answer.get('coordinates', [])
104
- cells = raw_answer.get('cells', [])
105
-
106
- st.markdown("<p style='font-family:sans-serif;font-size: 1rem;'>Relevant Data for Plotly: </p>", unsafe_allow_html=True)
107
- st.write(f"Answer: {answer}")
108
- st.write(f"Coordinates: {coordinates}")
109
- st.write(f"Cells: {cells}")
110
-
111
- # If TAPAS is returning a list of numbers for "graph" like you mentioned
112
- if "graph" in question.lower() and cells:
113
- # Assuming cells are numeric values that can be plotted in a graph
114
- plot_data = [float(cell) for cell in cells] # Convert cells to numeric data
115
-
116
- # Create a DataFrame for Plotly
117
- plot_df = pd.DataFrame({ 'Index': list(range(1, len(plot_data) + 1)), 'Value': plot_data })
118
-
119
- # Generate a graph using Plotly
120
- fig = px.line(plot_df, x='Index', y='Value', title=f"Graph for '{question}'")
121
- st.plotly_chart(fig, use_container_width=True)
122
-
123
- else:
124
- st.write(f"No data to plot for the question: '{question}'")
125
-
126
- else:
127
- # Handle graph-related questions
128
- if is_count_query:
129
- # Extract the column name to count
130
- column_name = question.split('count')[-1].strip()
131
-
132
-
133
-
134
- if column_name in df.columns:
135
- # Ask TAPAS to count the rows for this specific column
136
- count_result = tqa(table=df, query=f"count of {column_name}")
137
- st.write(f"Count for column '{column_name}': {count_result['answer']}")
138
- else:
139
- st.warning(f"Column '{column_name}' not found in the dataset.")
140
- elif 'between' in question.lower() and 'and' in question.lower():
141
- columns = question.split('between')[-1].split('and')
142
- columns = [col.strip() for col in columns]
143
- if len(columns) == 2 and all(col in df.columns for col in columns):
144
- fig = px.scatter(df, x=columns[0], y=columns[1], title=f"Graph between {columns[0]} and {columns[1]}")
145
- st.plotly_chart(fig, use_container_width=True)
146
- st.success(f"Here is the graph between '{columns[0]}' and '{columns[1]}'.")
147
- else:
148
- st.warning("Columns not found in the dataset.")
149
- elif 'column' in question.lower():
150
- column = question.split('of')[-1].strip()
151
- if column in df.columns:
152
- fig = px.line(df, x=df.index, y=column, title=f"Graph of column '{column}'")
153
- st.plotly_chart(fig, use_container_width=True)
154
-
155
- st.stop() # This halts further execution
156
 
157
  except Exception as e:
158
  st.warning(f"Error processing question or generating answer: {str(e)}")
 
3
  from st_aggrid import AgGrid
4
  import pandas as pd
5
  from transformers import pipeline, T5ForConditionalGeneration, T5Tokenizer
 
6
 
7
  # Set the page layout for Streamlit
8
  st.set_page_config(layout="wide")
 
28
  - **max()**: For "max", it computes the maximum value in the DataFrame.
29
  - **min()**: For "min", it computes the minimum value in the DataFrame.
30
  - **count()**: For "count", it counts the non-null values in the entire DataFrame.
 
 
 
31
  """)
32
 
33
  # File uploader in the sidebar
 
73
  # User input for the question
74
  question = st.text_input('Type your question')
75
 
 
 
 
 
 
 
 
 
 
76
  # Process the answer using TAPAS and T5
77
  with st.spinner():
78
  if st.button('Answer'):
79
  try:
80
+ # Process TAPAS-related questions
81
+ raw_answer = tqa(table=df, query=question, truncation=True)
82
+
83
+ # Display raw answer from TAPAS on the screen
84
+ st.markdown("<p style='font-family:sans-serif;font-size: 1rem;'>Raw TAPAS Answer: </p>", unsafe_allow_html=True)
85
+ st.write(raw_answer) # Display the raw TAPAS output
86
+
87
+ # Extract relevant values for Plotly
88
+ answer = raw_answer.get('answer', '')
89
+ coordinates = raw_answer.get('coordinates', [])
90
+ cells = raw_answer.get('cells', [])
91
+
92
+ st.markdown("<p style='font-family:sans-serif;font-size: 1rem;'>Relevant Data for Plotly: </p>", unsafe_allow_html=True)
93
+ st.write(f"Answer: {answer}")
94
+ st.write(f"Coordinates: {coordinates}")
95
+ st.write(f"Cells: {cells}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  except Exception as e:
98
  st.warning(f"Error processing question or generating answer: {str(e)}")