File size: 5,993 Bytes
873ae70
 
 
ec0498e
873ae70
 
 
 
d41df51
ec0498e
 
 
 
873ae70
ec0498e
873ae70
ec0498e
 
873ae70
 
ec0498e
 
873ae70
ec0498e
873ae70
 
d41df51
 
 
 
 
 
873ae70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a670d11
873ae70
 
 
ec0498e
873ae70
 
 
 
 
 
ec0498e
d41df51
873ae70
2279e40
 
 
 
 
 
 
 
 
e72b6dc
2279e40
 
 
 
 
 
 
 
 
 
 
 
873ae70
2279e40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31de781
2279e40
 
 
 
 
 
 
 
 
 
 
 
0dea963
2279e40
 
 
873ae70
2279e40
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import streamlit as st
import pandas as pd
import plotly.express as px
from transformers import pipeline

# Set the page layout for Streamlit
st.set_page_config(layout="wide")

# Initialize TAPAS pipeline for table-based question answering (multilingual)
tqa = pipeline(task="table-question-answering", 
              model="google/tapas-large-finetuned-wtq",
              device=0)  # Assuming GPU is available, otherwise set device="cpu"

# Title and Introduction
st.title("Data Visualization App with TAPAS NLP Integration")
st.markdown(""" 
    This app allows you to upload a table (CSV or Excel) and ask questions to generate graphs visualizing the data.
    Using **TAPAS**, the app can interpret your questions and generate the corresponding graphs.

    ### Available Features:
    - **Scatter Plot**: Visualize relationships between two columns.
    - **Line Graph**: Visualize a single column over time.
    
    Upload your data and ask questions about the data to generate visualizations.
""")

# Language Selection
language = st.selectbox(
    "Select the language of your question",
    ("English", "German", "French", "Spanish", "Italian", "Others")
)

# File uploader in the sidebar
file_name = st.sidebar.file_uploader("Upload file:", type=['csv', 'xlsx'])

# File processing and question answering
if file_name is None:
    st.markdown('<p class="font">Please upload an excel or csv file </p>', unsafe_allow_html=True)
else:
    try:
        # Check file type and handle reading accordingly
        if file_name.name.endswith('.csv'):
            df = pd.read_csv(file_name, sep=';', encoding='ISO-8859-1')  # Adjust encoding if needed
        elif file_name.name.endswith('.xlsx'):
            df = pd.read_excel(file_name, engine='openpyxl')  # Use openpyxl to read .xlsx files
        else:
            st.error("Unsupported file type")
            df = None

        if df is not None:
            # Show the original data with text columns intact
            st.write("Original Data:")
            st.write(df)

            # Display a sample of data for graph generation
            st.write("Sample data for graph generation:")
            st.write(df.head())

    except Exception as e:
        st.error(f"Error reading file: {str(e)}")

    # User input for the question
    question = st.text_input(f'Ask your graph-related question in {language}')

   with st.spinner():
    if st.button('Generate Graph'):
        try:
            # Ensure the question is a valid string
            if not question or not isinstance(question, str):
                st.error("Please enter a valid question in the form of text.")
            else:
                # Use TAPAS model to process the question
                result = tqa(table=df, query=question)

                # Display the raw output from TAPAS
                st.write("TAPAS Raw Output (Response):")
                st.write(result)  # This will display the raw output from TAPAS

                # Optionally, you can output the raw output as plain text:
                st.text("Raw TAPAS Output (Plain Text):")
                st.text(str(result))  # This will display raw output as plain text

                # Check if TAPAS is returning the expected answer
                answer = result.get('answer', None)
                if answer:
                    st.write(f"TAPAS Answer: {answer}")
                else:
                    st.warning("TAPAS did not return a valid answer.")

                # Determine if the question relates to graph generation
                if 'between' in question.lower() and 'and' in question.lower():
                    # This is a request for a scatter plot (two columns)
                    columns = question.split('between')[-1].split('and')
                    columns = [col.strip() for col in columns]
                    if len(columns) == 2 and all(col in df.columns for col in columns):
                        # Prepare the data for Plotly (scatter plot)
                        x_data = df[columns[0]].dropna()  # Extract x column, drop NaN values
                        y_data = df[columns[1]].dropna()  # Extract y column, drop NaN values

                        # Ensure x_data and y_data have the same length
                        min_length = min(len(x_data), len(y_data))
                        x_data = x_data[:min_length]
                        y_data = y_data[:min_length]

                        # Create the scatter plot
                        fig = px.scatter(x=x_data, y=y_data, title=f"Scatter Plot between {columns[0]} and {columns[1]}")
                        st.plotly_chart(fig, use_container_width=True)
                        st.success(f"Here is the scatter plot between '{columns[0]}' and '{columns[1]}'.")
                    else:
                        st.warning("Columns not found in the dataset or the question format is incorrect.")
                elif 'column' in question.lower():
                    # This is a request for a line graph (single column)
                    column = question.split('of')[-1].strip()  # Handle 'of' keyword
                    if column in df.columns:
                        # Prepare the data for Plotly (line graph)
                        column_data = df[column].dropna()  # Drop NaN values

                        # Create the line plot
                        fig = px.line(x=column_data.index, y=column_data, title=f"Graph of column '{column}'")
                        st.plotly_chart(fig, use_container_width=True)
                        st.success(f"Here is the graph of column '{column}'.")
                    else:
                        st.warning(f"Column '{column}' not found in the data.")
                else:
                    st.warning("Please ask a valid graph-related question (e.g., 'make a graph between column1 and column2').")

        except Exception as e:
            st.warning(f"Error processing question or generating graph: {str(e)}")