hertogateis commited on
Commit
3ebbb9a
·
verified ·
1 Parent(s): b1c60f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -10
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import streamlit as st
2
  import pandas as pd
 
3
  from transformers import pipeline
4
 
5
  # Set the page layout for Streamlit
@@ -70,16 +71,31 @@ else:
70
  st.write("TAPAS Raw Output (Response):")
71
  st.write(result) # This will display the raw output from TAPAS
72
 
73
- # Optionally, you can output the raw output as plain text:
74
- st.text("Raw TAPAS Output (Plain Text):")
75
- st.text(str(result)) # This will display raw output as plain text
76
-
77
- # Check if TAPAS is returning the expected answer
78
- answer = result.get('answer', None)
79
- if answer:
80
- st.write(f"TAPAS Answer: {answer}")
81
- else:
82
- st.warning("TAPAS did not return a valid answer.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  except Exception as e:
85
  st.warning(f"Error processing question or generating answer: {str(e)}")
 
1
  import streamlit as st
2
  import pandas as pd
3
+ import plotly.express as px
4
  from transformers import pipeline
5
 
6
  # Set the page layout for Streamlit
 
71
  st.write("TAPAS Raw Output (Response):")
72
  st.write(result) # This will display the raw output from TAPAS
73
 
74
+ # If the user asked for a count of a column or specific data:
75
+ if "count" in question.lower():
76
+ # Ask TAPAS to count rows of a specific column
77
+ column_name = question.split("count")[-1].strip() # Extract column name
78
+ if column_name in df.columns:
79
+ count_result = tqa(table=df, query=f"count of {column_name}")
80
+ st.write(f"Count for column '{column_name}': {count_result['answer']}")
81
+ else:
82
+ st.warning(f"Column '{column_name}' not found in the dataset.")
83
+
84
+ elif isinstance(result.get("answer"), list):
85
+ # Handle structured data for graphing (e.g., scatter plot or other visualizations)
86
+ answer_data = result["answer"]
87
+ if answer_data and isinstance(answer_data, list) and isinstance(answer_data[0], dict):
88
+ # Extract column data for x and y axes for Plotly
89
+ x_data = [item.get("column1") for item in answer_data] # Replace column1 with actual column name
90
+ y_data = [item.get("column2") for item in answer_data] # Replace column2 with actual column name
91
+
92
+ # Create a scatter plot using Plotly
93
+ fig = px.scatter(x=x_data, y=y_data, title="Scatter Plot based on TAPAS Answer")
94
+ st.plotly_chart(fig, use_container_width=True)
95
+
96
+ elif isinstance(result.get("answer"), str):
97
+ # Handle simple answers (e.g., sums, counts, etc.)
98
+ st.write(f"TAPAS Answer: {result['answer']}")
99
 
100
  except Exception as e:
101
  st.warning(f"Error processing question or generating answer: {str(e)}")