hertogateis commited on
Commit
0bb3639
·
verified ·
1 Parent(s): 3ebbb9a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -44
app.py CHANGED
@@ -1,26 +1,37 @@
 
1
  import streamlit as st
 
2
  import pandas as pd
 
3
  import plotly.express as px
4
- from transformers import pipeline
5
 
6
  # Set the page layout for Streamlit
7
  st.set_page_config(layout="wide")
8
 
9
- # Initialize TAPAS pipeline for table-based question answering (multilingual)
10
  tqa = pipeline(task="table-question-answering",
11
  model="google/tapas-large-finetuned-wtq",
12
- device=0) # Assuming GPU is available, otherwise set device="cpu"
 
 
 
 
13
 
14
  # Title and Introduction
15
- st.title("Data Table with TAPAS NLP Integration")
16
  st.markdown("""
17
- This app allows you to upload a table (CSV or Excel) and ask questions to extract information from the data.
18
- Using **TAPAS**, the app can interpret your questions and provide the corresponding answers.
19
 
20
  ### Available Features:
21
- - **Table Question Answering**: Ask questions related to the uploaded table.
 
 
 
 
 
22
 
23
- Upload your data and ask questions to extract answers.
24
  """)
25
 
26
  # File uploader in the sidebar
@@ -41,61 +52,107 @@ else:
41
  df = None
42
 
43
  if df is not None:
44
- # Convert object columns to numeric where possible
45
- df = df.apply(pd.to_numeric, errors='ignore')
 
46
 
47
  st.write("Original Data:")
48
  st.write(df)
49
 
50
- # Display a sample of data for user reference
51
- st.write("Sample data:")
52
- st.write(df.head())
53
-
 
 
 
 
 
 
 
 
54
  except Exception as e:
55
  st.error(f"Error reading file: {str(e)}")
56
 
57
  # User input for the question
58
- question = st.text_input(f'Ask your question related to the table')
 
 
 
 
59
 
 
 
 
 
 
 
 
60
  with st.spinner():
61
- if st.button('Get Answer'):
62
  try:
63
- # Ensure the question is a valid string
64
- if not question or not isinstance(question, str):
65
- st.error("Please enter a valid question.")
66
- else:
67
- # Use TAPAS model to process the question
68
- result = tqa(table=df, query=question)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- # Display the raw output from TAPAS
71
- st.write("TAPAS Raw Output (Response):")
72
- st.write(result) # This will display the raw output from TAPAS
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
- # If the user asked for a count of a column or specific data:
75
- if "count" in question.lower():
76
- # Ask TAPAS to count rows of a specific column
77
- column_name = question.split("count")[-1].strip() # Extract column name
78
  if column_name in df.columns:
 
79
  count_result = tqa(table=df, query=f"count of {column_name}")
80
  st.write(f"Count for column '{column_name}': {count_result['answer']}")
81
  else:
82
  st.warning(f"Column '{column_name}' not found in the dataset.")
83
-
84
- elif isinstance(result.get("answer"), list):
85
- # Handle structured data for graphing (e.g., scatter plot or other visualizations)
86
- answer_data = result["answer"]
87
- if answer_data and isinstance(answer_data, list) and isinstance(answer_data[0], dict):
88
- # Extract column data for x and y axes for Plotly
89
- x_data = [item.get("column1") for item in answer_data] # Replace column1 with actual column name
90
- y_data = [item.get("column2") for item in answer_data] # Replace column2 with actual column name
91
-
92
- # Create a scatter plot using Plotly
93
- fig = px.scatter(x=x_data, y=y_data, title="Scatter Plot based on TAPAS Answer")
94
  st.plotly_chart(fig, use_container_width=True)
95
-
96
- elif isinstance(result.get("answer"), str):
97
- # Handle simple answers (e.g., sums, counts, etc.)
98
- st.write(f"TAPAS Answer: {result['answer']}")
 
 
 
 
 
 
99
 
100
  except Exception as e:
101
  st.warning(f"Error processing question or generating answer: {str(e)}")
 
 
1
+ import os
2
  import streamlit as st
3
+ from st_aggrid import AgGrid
4
  import pandas as pd
5
+ from transformers import pipeline, T5ForConditionalGeneration, T5Tokenizer
6
  import plotly.express as px
 
7
 
8
  # Set the page layout for Streamlit
9
  st.set_page_config(layout="wide")
10
 
11
+ # Initialize TAPAS pipeline
12
  tqa = pipeline(task="table-question-answering",
13
  model="google/tapas-large-finetuned-wtq",
14
+ device="cpu")
15
+
16
+ # Initialize T5 tokenizer and model for text generation
17
+ t5_tokenizer = T5Tokenizer.from_pretrained("t5-small")
18
+ t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")
19
 
20
  # Title and Introduction
21
+ st.title("Table Question Answering and Data Analysis App")
22
  st.markdown("""
23
+ This app allows you to upload a table (CSV or Excel) and ask questions about the data.
24
+ Based on your question, it will provide the corresponding answer using the **TAPAS** model and additional data processing.
25
 
26
  ### Available Features:
27
+ - **mean()**: For "average", it computes the mean of the entire numeric DataFrame.
28
+ - **sum()**: For "sum", it calculates the sum of all numeric values in the DataFrame.
29
+ - **max()**: For "max", it computes the maximum value in the DataFrame.
30
+ - **min()**: For "min", it computes the minimum value in the DataFrame.
31
+ - **count()**: For "count", it counts the non-null values in the entire DataFrame.
32
+ - **Graph Generation**: You can ask questions like "make a graph of column sales?" or "make a graph between sales and expenses?". The app will generate interactive graphs for you.
33
 
34
+ Upload your data and ask questions to get both answers and visualizations.
35
  """)
36
 
37
  # File uploader in the sidebar
 
52
  df = None
53
 
54
  if df is not None:
55
+ numeric_columns = df.select_dtypes(include=['object']).columns
56
+ for col in numeric_columns:
57
+ df[col] = pd.to_numeric(df[col], errors='ignore')
58
 
59
  st.write("Original Data:")
60
  st.write(df)
61
 
62
+ df_numeric = df.copy()
63
+ df = df.astype(str)
64
+
65
+ # Display the first 5 rows of the dataframe in an editable grid
66
+ grid_response = AgGrid(
67
+ df.head(5),
68
+ fit_columns_on_grid_load=True, # Correct parameter to fit columns on grid load
69
+ editable=True,
70
+ height=300,
71
+ width='100%',
72
+ )
73
+
74
  except Exception as e:
75
  st.error(f"Error reading file: {str(e)}")
76
 
77
  # User input for the question
78
+ question = st.text_input('Type your question')
79
+
80
+ # Check if the question is about generating a graph
81
+ is_graph_query = False
82
+ is_count_query = False
83
 
84
+ # Check if the question contains "count"
85
+ if 'count' in question.lower():
86
+ is_count_query = True
87
+ elif 'graph' in question.lower():
88
+ is_graph_query = True
89
+
90
+ # Process the answer using TAPAS and T5
91
  with st.spinner():
92
+ if st.button('Answer'):
93
  try:
94
+ if not is_graph_query:
95
+ # Process TAPAS-related questions if it's not a graph query
96
+ raw_answer = tqa(table=df, query=question, truncation=True)
97
+
98
+ # Display raw answer from TAPAS on the screen
99
+ st.markdown("<p style='font-family:sans-serif;font-size: 1rem;'>Raw TAPAS Answer: </p>", unsafe_allow_html=True)
100
+ st.write(raw_answer) # Display the raw TAPAS output
101
+
102
+ # Extract relevant values for Plotly
103
+ answer = raw_answer.get('answer', '')
104
+ coordinates = raw_answer.get('coordinates', [])
105
+ cells = raw_answer.get('cells', [])
106
+
107
+ st.markdown("<p style='font-family:sans-serif;font-size: 1rem;'>Relevant Data for Plotly: </p>", unsafe_allow_html=True)
108
+ st.write(f"Answer: {answer}")
109
+ st.write(f"Coordinates: {coordinates}")
110
+ st.write(f"Cells: {cells}")
111
+
112
+ # If TAPAS is returning a list of numbers for "average" like you mentioned
113
+ if "average" in question.lower() and cells:
114
+ # Assuming cells are numeric values that can be plotted in a graph
115
+ plot_data = [float(cell) for cell in cells] # Convert cells to numeric data
116
 
117
+ # Create a DataFrame for Plotly
118
+ plot_df = pd.DataFrame({ 'Index': list(range(1, len(plot_data) + 1)), 'Value': plot_data })
119
+
120
+ # Generate a graph using Plotly
121
+ fig = px.line(plot_df, x='Index', y='Value', title=f"Graph for '{question}'")
122
+ st.plotly_chart(fig, use_container_width=True)
123
+
124
+ else:
125
+ st.write(f"No data to plot for the question: '{question}'")
126
+
127
+ else:
128
+ # Handle graph-related questions
129
+ if is_count_query:
130
+ # Extract the column name to count
131
+ column_name = question.split('count')[-1].strip()
132
 
 
 
 
 
133
  if column_name in df.columns:
134
+ # Ask TAPAS to count the rows for this specific column
135
  count_result = tqa(table=df, query=f"count of {column_name}")
136
  st.write(f"Count for column '{column_name}': {count_result['answer']}")
137
  else:
138
  st.warning(f"Column '{column_name}' not found in the dataset.")
139
+ elif 'between' in question.lower() and 'and' in question.lower():
140
+ columns = question.split('between')[-1].split('and')
141
+ columns = [col.strip() for col in columns]
142
+ if len(columns) == 2 and all(col in df.columns for col in columns):
143
+ fig = px.scatter(df, x=columns[0], y=columns[1], title=f"Graph between {columns[0]} and {columns[1]}")
 
 
 
 
 
 
144
  st.plotly_chart(fig, use_container_width=True)
145
+ st.success(f"Here is the graph between '{columns[0]}' and '{columns[1]}'.")
146
+ else:
147
+ st.warning("Columns not found in the dataset.")
148
+ elif 'column' in question.lower():
149
+ column = question.split('of')[-1].strip()
150
+ if column in df.columns:
151
+ fig = px.line(df, x=df.index, y=column, title=f"Graph of column '{column}'")
152
+ st.plotly_chart(fig, use_container_width=True)
153
+
154
+ st.stop() # This halts further execution
155
 
156
  except Exception as e:
157
  st.warning(f"Error processing question or generating answer: {str(e)}")
158
+ st.warning("Please retype your question and make sure to use the column name and cell value correctly.")