Spaces:
Runtime error
Runtime error
import gradio as gr | |
from gradio_client import Client, handle_file | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
import os | |
import pandas as pd | |
from io import StringIO, BytesIO | |
import base64 | |
import json | |
import plotly.graph_objects as go | |
# import plotly.io as pio | |
# from linePlot import plot_stacked_time_series, plot_emotion_topic_grid | |
# Define your Hugging Face token (make sure to set it as an environment variable) | |
HF_TOKEN = os.getenv("HF_TOKEN") # Replace with your actual token if not using an environment variable | |
# Initialize the Gradio Client for the specified API | |
client = Client("mangoesai/Elections_Comparison_Agent_V4.1", hf_token=HF_TOKEN) | |
# query_input = "" | |
def stream_chat_with_rag( | |
message: str, | |
history: list, | |
year: str | |
): | |
# print(f"Message: {message}") | |
#answer = client.predict(question=question, api_name="/run_graph") | |
answer, sources = client.predict( | |
query= message, | |
election_year=year, | |
api_name="/process_query" | |
) | |
# Debugging: Print the raw response | |
response = f"Retrieving the submissions in {year}..." | |
print("Raw answer from API:") | |
print(answer) | |
history.append((message, response +"\n"+ answer)) | |
# Render the figure | |
return answer | |
def topic_plot_gener(message: str, year: str): | |
fig = client.predict( | |
query= message, | |
election_year=year, | |
api_name="/topics_plot_genera" | |
) | |
# print("top works from API:") | |
print(fig) | |
# plot_base64 = fig | |
# plot_bytes = base64.b64decode(plot_base64['plot'].split(',')[1]) | |
# img = plt.imread(BytesIO(plot_bytes), format='PNG') | |
# plt.figure(figsize = (12, 6), dpi = 150) | |
# plt.imshow(img) | |
# plt.axis('off') | |
# plt.show() | |
plot_json = json.loads(fig['plot']) | |
# Create a figure using the decoded data | |
fig = go.Figure(data=plot_json["data"]) | |
# Show the plot | |
return fig | |
# return plt.gcf() | |
# def predict(message, history): | |
# history_langchain_format = [] | |
# for msg in history: | |
# if msg['role'] == "user": | |
# history_langchain_format.append(HumanMessage(content=msg['content'])) | |
# elif msg['role'] == "assistant": | |
# history_langchain_format.append(AIMessage(content=msg['content'])) | |
# history_langchain_format.append(HumanMessage(content=message)) | |
# gpt_response = llm(history_langchain_format) | |
# return gpt_response.content | |
def heatmap(top_n): | |
# df = pd.read_csv('submission_emotiontopics2024GPTresult.csv') | |
# topics_df = gr.Dataframe(value=df, label="Data Input") | |
pivot_table = client.predict( | |
top_n= top_n, | |
api_name="/get_heatmap_pivot_table" | |
) | |
print(pivot_table) | |
print(type(pivot_table)) | |
""" | |
pivot_table is a dict like: | |
{'headers': ['Index', 'economy', 'human rights', 'immigrant', 'politics'], | |
'data': [['anger', 55880.0, 557679.0, 147766.0, 180094.0], | |
['disgust', 26911.0, 123112.0, 64567.0, 46460.0], | |
['fear', 51466.0, 188898.0, 113174.0, 150578.0], | |
['neutral', 77005.0, 192945.0, 20549.0, 190793.0]], | |
'metadata': None} | |
""" | |
# transfere dictionary to df | |
df = pd.DataFrame(pivot_table['data'], columns=pivot_table['headers']) | |
df.set_index('Index', inplace=True) | |
plt.figure(figsize=(10, 8)) | |
sns.heatmap(df, | |
cmap='YlOrRd', | |
cbar_kws={'label': 'Weighted Frequency'}, | |
square=True) | |
plt.title(f'Top {top_n} Emotions vs Topics Weighted Frequency') | |
plt.xlabel('Topics') | |
plt.ylabel('Emotions') | |
plt.xticks(rotation=45, ha='right') | |
plt.tight_layout() | |
return plt.gcf() | |
# def decode_plot(plot_base64, top_n): | |
# plot_bytes = base64.b64decode(plot_base64['plot'].split(',')[1]) | |
# img = plt.imread(BytesIO(plot_bytes), format='PNG') | |
# plt.figure(figsize = (12, 2*top_n), dpi = 150) | |
# plt.imshow(img) | |
# plt.axis('off') | |
# plt.show() | |
# return plt.gcf() | |
def linePlot(viz_type, weight, top_n): | |
# client = Client("mangoesai/Elections_Comparison_Agent_V4.1") | |
result = client.predict( | |
viz_type=viz_type, | |
weight=weight, | |
top_n=top_n, | |
api_name="/linePlot_3C1" | |
) | |
# print(result) | |
# result is a tuble of dictionary of (plot_base64, str), string message of description of the plot | |
plot_base64 = result[0] | |
plot_bytes = base64.b64decode(plot_base64['plot'].split(',')[1]) | |
img = plt.imread(BytesIO(plot_bytes), format='PNG') | |
plt.figure(figsize = (12, 2*top_n), dpi = 150) | |
plt.imshow(img) | |
plt.axis('off') | |
plt.show() | |
return plt.gcf(), result[1] | |
# Create Gradio interface | |
with gr.Blocks(title="Reddit Election Analysis") as demo: | |
gr.Markdown("# Reddit Public sentiment & Social topic distribution ") | |
with gr.Row(): | |
with gr.Column(): | |
top_n = gr.Dropdown(choices=[1,2,3,4,5,6,7,8,9,10]) | |
fresh_btn = gr.Button("Refresh Heatmap") | |
with gr.Column(): | |
# with gr.Row(): | |
output_heatmap = gr.Plot( | |
label="Top Public sentiment & Social topic Heatmap", | |
container=True, # Ensures the plot is contained within its area | |
elem_classes="heatmap-plot" # Add a custom class for styling | |
) | |
gr.Markdown("# Get the time series of the Public sentiment & Social topic") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# Control panel | |
lineGraph_type = gr.Dropdown(choices = ['emotions', 'topics', '2Dmatrix']) | |
weight_slider = gr.Slider( | |
minimum=0, | |
maximum=1, | |
value=0.5, | |
step=0.1, | |
label="Weight (Score vs. Frequency)" | |
) | |
top_n_slider = gr.Slider( | |
minimum=2, | |
maximum=10, | |
value=5, | |
step=1, | |
label="Top N Items" | |
) | |
# with gr.Column(): | |
viz_dropdown = gr.Dropdown( | |
choices=["emotions", "topics", "grid"], | |
value="emotions", | |
label="Visualization Type", | |
info="Select the type of visualization to display" | |
) | |
linePlot_btn = gr.Button("Update Visualizations") | |
linePlot_status_text = gr.Textbox(label="Status", interactive=False) | |
with gr.Column(scale=3): | |
time_series_fig = gr.Plot() | |
gr.Markdown("# Reddit Election Posts/Comments Analysis") | |
gr.Markdown("Ask questions about election-related comments and posts") | |
with gr.Row(): | |
with gr.Column(scale = 1): | |
year_selector = gr.Radio( | |
choices=["2016 Election", "2024 Election", "Comparison two years"], | |
label="Select Election Year", | |
value="2024 Election" | |
) | |
slider = gr.Slider(50, 500, render=False, label= "Tokens") | |
# query_input = gr.Textbox( | |
# label="Your Question", | |
# placeholder="Ask about election comments or posts..." | |
# ) | |
# submit_btn = gr.Button("Submit") | |
gr.Markdown(""" | |
## Example Questions: | |
- Is there any comments don't like the election results | |
- Summarize the main discussions about voting process | |
- What're the common opinions about candidates? | |
- What're common opinions about immigrant topic? | |
""") | |
# with gr.Column(): | |
# output_text = gr.Textbox( | |
# label="Response", | |
# lines=20 | |
# ) | |
with gr.Column(scale = 2): | |
gr.ChatInterface(stream_chat_with_rag, | |
type="messages", | |
# chatbot=stream_chat_with_rag, | |
additional_inputs = [year_selector] | |
) | |
gr.Markdown("## Top words of the relevant Q&A") | |
with gr.Row(): | |
with gr.Column(scale = 1): | |
query_input = gr.Textbox( | |
label="Your Question For Topicalize", | |
placeholder="Copy and past your question there to vilaulize the top words of relevant topic" | |
) | |
topic_btn = gr.Button("Topicalize the RAG sources") | |
with gr.Column(scale = 2): | |
topic_plot = gr.Plot( | |
label="Top Words Distribution", | |
container=True, # Ensures the plot is contained within its area | |
elem_classes="topic-plot" # Add a custom class for styling | |
) | |
# Add custom CSS to ensure proper plot sizing | |
gr.HTML(""" | |
<style> | |
.heatmap-plot { | |
min-height: 400px; | |
width: 100%; | |
margin: auto; | |
} | |
.topic-plot { | |
min-width: 600px; | |
height: 100%; | |
margin: auto; | |
} | |
</style> | |
""") | |
# topics_df = gr.Dataframe(value=df, label="Data Input") | |
fresh_btn.click( | |
fn=heatmap, | |
inputs=top_n, | |
outputs=output_heatmap | |
) | |
linePlot_btn.click( | |
fn = linePlot, | |
inputs = [viz_dropdown,weight_slider,top_n_slider], | |
outputs = [time_series_fig, linePlot_status_text] | |
) | |
# Update both outputs when submit is clicked | |
topic_btn.click( | |
fn= topic_plot_gener, | |
inputs=[query_input, year_selector], | |
outputs= topic_plot | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) |