File size: 5,262 Bytes
0bb3639
873ae70
0bb3639
873ae70
0bb3639
9aaa1ae
873ae70
 
 
 
0bb3639
ec0498e
 
0bb3639
 
 
 
 
ec0498e
873ae70
0bb3639
873ae70
0bb3639
 
873ae70
 
0bb3639
 
 
 
 
873ae70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bb3639
 
 
b1c60f6
873ae70
 
 
0bb3639
 
 
 
 
 
 
 
 
 
 
 
873ae70
 
 
ec0498e
0bb3639
 
 
b1c60f6
0bb3639
b1c60f6
0a87f1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
873ae70
9aaa1ae
 
23651f1
 
 
 
 
 
 
 
 
9aaa1ae
 
23651f1
9aaa1ae
 
23651f1
 
9aaa1ae
 
b1c60f6
 
0bb3639
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import streamlit as st
from st_aggrid import AgGrid
import pandas as pd
from transformers import pipeline, T5ForConditionalGeneration, T5Tokenizer
import matplotlib.pyplot as plt

# Set the page layout for Streamlit
st.set_page_config(layout="wide")

# Initialize TAPAS pipeline
tqa = pipeline(task="table-question-answering", 
              model="google/tapas-large-finetuned-wtq",
              device="cpu")

# Initialize T5 tokenizer and model for text generation
t5_tokenizer = T5Tokenizer.from_pretrained("t5-small")
t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")

# Title and Introduction
st.title("Table Question Answering and Data Analysis App")
st.markdown(""" 
    This app allows you to upload a table (CSV or Excel) and ask questions about the data.
    Based on your question, it will provide the corresponding answer using the **TAPAS** model and additional data processing.

    ### Available Features:
    - **mean()**: For "average", it computes the mean of the entire numeric DataFrame.
    - **sum()**: For "sum", it calculates the sum of all numeric values in the DataFrame.
    - **max()**: For "max", it computes the maximum value in the DataFrame.
    - **min()**: For "min", it computes the minimum value in the DataFrame.
    - **count()**: For "count", it counts the non-null values in the entire DataFrame.
""")

# File uploader in the sidebar
file_name = st.sidebar.file_uploader("Upload file:", type=['csv', 'xlsx'])

# File processing and question answering
if file_name is None:
    st.markdown('<p class="font">Please upload an excel or csv file </p>', unsafe_allow_html=True)
else:
    try:
        # Check file type and handle reading accordingly
        if file_name.name.endswith('.csv'):
            df = pd.read_csv(file_name, sep=';', encoding='ISO-8859-1')  # Adjust encoding if needed
        elif file_name.name.endswith('.xlsx'):
            df = pd.read_excel(file_name, engine='openpyxl')  # Use openpyxl to read .xlsx files
        else:
            st.error("Unsupported file type")
            df = None

        if df is not None:
            numeric_columns = df.select_dtypes(include=['object']).columns
            for col in numeric_columns:
                df[col] = pd.to_numeric(df[col], errors='ignore')

            st.write("Original Data:")
            st.write(df)

            df_numeric = df.copy()
            df = df.astype(str)

            # Display the first 5 rows of the dataframe in an editable grid
            grid_response = AgGrid(
                df.head(5),
                fit_columns_on_grid_load=True,  # Correct parameter to fit columns on grid load
                editable=True, 
                height=300, 
                width='100%',
            )
            
    except Exception as e:
        st.error(f"Error reading file: {str(e)}")

    # User input for the question
    question = st.text_input('Type your question')

    # Process the answer using TAPAS and T5
    with st.spinner():
        if st.button('Answer'):
            try:
                # Process TAPAS-related questions
                raw_answer = tqa(table=df, query=question, truncation=True)

                # Display raw answer from TAPAS on the screen
                st.markdown("<p style='font-family:sans-serif;font-size: 1rem;'>Raw TAPAS Answer: </p>", unsafe_allow_html=True)
                st.write(raw_answer)  # Display the raw TAPAS output

                # Extract relevant values for Plotly
                answer = raw_answer.get('answer', '')
                coordinates = raw_answer.get('coordinates', [])
                cells = raw_answer.get('cells', [])

                st.markdown("<p style='font-family:sans-serif;font-size: 1rem;'>Relevant Data for Plotly: </p>", unsafe_allow_html=True)
                st.write(f"Answer: {answer}")
                st.write(f"Coordinates: {coordinates}")
                st.write(f"Cells: {cells}")

                # If cells are returned, we will extract the corresponding values for plotting
                if cells:
                    # Convert cell values from strings to floats for plotting
                    cell_values = [float(cell) for cell in cells if cell.replace('.', '', 1).isdigit()]

                    # Plot the data if we have valid numeric values
                    if len(cell_values) > 0:
                        # Assuming that the coordinates or answer provides context on column names
                        # You can adjust the labels or data based on the actual output
                        column_names = [f"Row {i+1}" for i in range(len(cell_values))]

                        fig, ax = plt.subplots()
                        ax.bar(column_names, cell_values)
                        ax.set_xlabel('Rows')
                        ax.set_ylabel('Values')
                        ax.set_title('Bar Plot of TAPAS Answer')

                        # Display the plot in the Streamlit app
                        st.pyplot(fig)

            except Exception as e:
                st.warning(f"Error processing question or generating answer: {str(e)}")
                st.warning("Please retype your question and make sure to use the column name and cell value correctly.")