Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import pandas as pd | |
import requests | |
import json | |
from typing import List, Tuple | |
import chardet | |
# -- LLM Client Class -- | |
class OllamaClient: | |
def __init__(self, model_name: str = "phi3:latest", base_url: str = "http://localhost:11434"): | |
self.model_name = model_name | |
self.base_url = base_url | |
def list_models(self): | |
"""List all available models from Ollama server""" | |
try: | |
response = requests.get(f"{self.base_url}/api/tags") | |
if response.status_code == 200: | |
data = response.json() | |
return [model['name'] for model in data.get('models', [])] | |
return [] | |
except Exception as e: | |
print(f"Error listing models: {e}") | |
return [] | |
def chat_completion(self, messages, max_tokens=4000, stream=True, temperature=0.3, top_p=0.7): | |
# Convert messages to Ollama format | |
ollama_messages = [] | |
for msg in messages: | |
if msg["role"] == "system": | |
ollama_messages.append({"role": "system", "content": msg["content"]}) | |
elif msg["role"] in ["user", "assistant"]: | |
ollama_messages.append({"role": msg["role"], "content": msg["content"]}) | |
# Prepare the request data | |
data = { | |
"model": self.model_name, | |
"messages": ollama_messages, | |
"options": { | |
"temperature": temperature, | |
"top_p": top_p, | |
"num_predict": max_tokens | |
}, | |
"stream": stream | |
} | |
# Make the request to Ollama API | |
response = requests.post( | |
f"{self.base_url}/api/chat", | |
json=data, | |
stream=stream | |
) | |
if response.status_code != 200: | |
raise Exception(f"Ollama API error: {response.text}") | |
if stream: | |
for line in response.iter_lines(): | |
if line: | |
decoded_line = line.decode('utf-8') | |
try: | |
chunk = json.loads(decoded_line) | |
if "message" in chunk and "content" in chunk["message"]: | |
yield {"content": chunk["message"]["content"]} | |
except json.JSONDecodeError: | |
continue | |
else: | |
result = response.json() | |
yield {"content": result["message"]["content"]} | |
# -- check content -- | |
def analyze_file_content(content, file_type): | |
"""Analyze file content and return structural summary""" | |
if file_type in ['parquet', 'csv']: | |
try: | |
lines = content.split('\n') | |
header = lines[0] | |
columns = header.count('|') - 1 if '|' in header else len(header.split(',')) | |
rows = len(lines) - 3 | |
return f"๐ Dataset Structure: {columns} columns, {rows} data samples" | |
except: | |
return "โ Dataset structure analysis failed" | |
lines = content.split('\n') | |
total_lines = len(lines) | |
non_empty_lines = len([line for line in lines if line.strip()]) | |
if any(keyword in content.lower() for keyword in ['def ', 'class ', 'import ', 'function']): | |
functions = len([line for line in lines if 'def ' in line]) | |
classes = len([line for line in lines if 'class ' in line]) | |
imports = len([line for line in lines if 'import ' in line or 'from ' in line]) | |
return f"๐ป Code Structure: {total_lines} lines (Functions: {functions}, Classes: {classes}, Imports: {imports})" | |
paragraphs = content.count('\n\n') + 1 | |
words = len(content.split()) | |
return f"๐ Document Structure: {total_lines} lines, {paragraphs} paragraphs, ~{words} words" | |
# -- Basic stats on content -- | |
def get_column_stats(df, col): | |
stats = { | |
'type': str(df[col].dtype), | |
'missing': df[col].isna().sum(), | |
'unique': df[col].nunique() | |
} | |
if pd.api.types.is_numeric_dtype(df[col]): | |
stats.update({ | |
'min': df[col].min(), | |
'max': df[col].max(), | |
'mean': df[col].mean() | |
}) | |
else: | |
stats['examples'] = df[col].dropna().head(3).tolist() | |
return stats | |
# -- Identify Encoding -- | |
def detect_file_encoding(file_path): | |
"""Improved encoding detection with fallback options""" | |
try: | |
with open(file_path, 'rb') as f: | |
rawdata = f.read(100000) # Read more data for better detection | |
# Try chardet first | |
result = chardet.detect(rawdata) | |
encoding = result['encoding'] | |
confidence = result['confidence'] | |
# If confidence is low, try some common encodings | |
if confidence < 0.9: | |
for test_encoding in ['utf-8', 'utf-16', 'latin1', 'cp1252']: | |
try: | |
rawdata.decode(test_encoding) | |
return test_encoding | |
except UnicodeDecodeError: | |
continue | |
return encoding if encoding else 'utf-8' | |
except Exception as e: | |
print(f"Encoding detection error: {e}") | |
return 'utf-8' # Default fallback | |
# -- Read file -- | |
def read_uploaded_file(file): | |
if file is None: | |
return "", "" | |
try: | |
file_ext = os.path.splitext(file.name)[1].lower() | |
if file_ext == '.parquet': | |
df = pd.read_parquet(file.name, engine='pyarrow') | |
content = df.head(10).to_markdown(index=False) | |
return content, "parquet" | |
if file_ext == '.csv': | |
# First try to detect encoding | |
try: | |
encoding = detect_file_encoding(file.name) | |
# Try reading with different delimiters | |
delimiters = [',', ';', '\t', '|'] | |
df = None | |
best_delimiter = ',' | |
max_columns = 1 | |
# First pass to find the best delimiter | |
for delimiter in delimiters: | |
try: | |
with open(file.name, 'r', encoding=encoding) as f: | |
first_line = f.readline() | |
current_columns = len(first_line.split(delimiter)) | |
if current_columns > max_columns: | |
max_columns = current_columns | |
best_delimiter = delimiter | |
except: | |
continue | |
# Now read with the best found delimiter | |
try: | |
df = pd.read_csv( | |
file.name, | |
encoding=encoding, | |
delimiter=best_delimiter, | |
on_bad_lines='warn', | |
engine='python', | |
quotechar='"' | |
) | |
except: | |
# Fallback to pandas auto-detection | |
df = pd.read_csv(file.name, encoding=encoding, on_bad_lines='warn') | |
if df is None or len(df.columns) < 1: | |
return "โ Could not parse CSV file - no valid columns detected", "error" | |
# Generate comprehensive data summary | |
content = "๐ CSV Metadata:\n" | |
content += f"- Rows: {len(df):,}\n" | |
content += f"- Columns: {len(df.columns):,}\n" | |
content += f"- Missing Values: {df.isna().sum().sum():,}\n\n" | |
content += "๐ Column Details:\n" | |
for col in df.columns: | |
stats = get_column_stats(df, col) | |
content += f"### {col}\n" | |
content += f"- Type: {stats['type']}\n" | |
content += f"- Unique: {stats['unique']}\n" | |
content += f"- Missing: {stats['missing']}\n" | |
if 'examples' in stats: | |
content += f"- Examples: {stats['examples']}\n" | |
else: | |
content += ( | |
f"- Range: {stats['min']} to {stats['max']}\n" | |
f"- Mean: {stats['mean']:.2f}\n" | |
) | |
content += "\n" | |
content += "๐ Sample Data (First 3 Rows):\n" | |
content += df.head(3).to_markdown(index=False) | |
return content, "csv" | |
except Exception as e: | |
return f"โ Error reading CSV file: {str(e)}", "error" | |
else: | |
encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1'] | |
for encoding in encodings: | |
try: | |
with open(file.name, 'r', encoding=encoding) as f: | |
content = f.read() | |
return content, "text" | |
except UnicodeDecodeError: | |
continue | |
raise UnicodeDecodeError(f"โ Unable to read file with supported encodings ({', '.join(encodings)})") | |
except Exception as e: | |
return f"โ Error reading file: {str(e)}", "error" | |
def format_history(history): | |
formatted_history = [] | |
for user_msg, assistant_msg in history: | |
formatted_history.append({"role": "user", "content": user_msg}) | |
if assistant_msg: | |
formatted_history.append({"role": "assistant", "content": assistant_msg}) | |
return formatted_history | |
def chat(message, | |
history, | |
uploaded_file, | |
system_message="", | |
max_tokens=4000, | |
temperature=0.3, | |
top_p=0.9, | |
selected_model="phi3:latest"): | |
system_prefix = """ | |
You are a AI Data Scientist designed to provide expert guidance in data analysis, machine learning, and big data technologies, suitable for a wide range of users seeking data-driven insights and solutions. | |
Analyze the uploaded file in depth from the following perspectives: | |
1. ๐ Overall file structure and format | |
2. โญ Data Quality and completeness evaluation | |
3. ๐ก Suggested data fixes and improvements | |
4. ๐ Data characteristics, meaning and patterns | |
5. ๐ Key component analysis and potential segmentations | |
6. ๐ฏ Insights and suggested persuasive story telling | |
Provide detailed and structured analysis from an expert perspective, but explain in an easy-to-understand way. | |
Format the analysis results in Markdown and include specific examples where possible. | |
""" | |
if uploaded_file: | |
content, file_type = read_uploaded_file(uploaded_file) | |
if file_type == "error": | |
return "", [{"role": "user", "content": message}, {"role": "assistant", "content": content}] | |
file_summary = analyze_file_content(content, file_type) | |
if file_type in ['parquet', 'csv']: | |
system_message += f"\n\nFile Content:\n```markdown\n{content}\n```" | |
else: | |
system_message += f"\n\nFile Content:\n```\n{content}\n```" | |
if message == "Starting file analysis...": | |
message = f"""[Structure Analysis] {file_summary} | |
Please provide detailed analysis from these perspectives: | |
1. ๐ Overall file structure and format | |
2. โญ Data Quality and completeness evaluation | |
3. ๐ก Suggested data fixes and improvements | |
4. ๐ Data characteristics, meaning and patterns | |
5. ๐ Key component analysis and potential segmentations | |
6. ๐ฏ Insights and suggested persuasive story telling""" | |
messages = [{"role": "system", "content": f"{system_prefix} {system_message}"}] | |
# Convert history to message format | |
if history is not None: | |
for item in history: | |
if isinstance(item, dict): | |
messages.append(item) | |
elif isinstance(item, (list, tuple)) and len(item) == 2: | |
messages.append({"role": "user", "content": item[0]}) | |
if item[1]: | |
messages.append({"role": "assistant", "content": item[1]}) | |
messages.append({"role": "user", "content": message}) | |
try: | |
client = OllamaClient(model_name=selected_model) | |
partial_message = "" | |
current_history = [] | |
for response in client.chat_completion( | |
messages, | |
max_tokens=max_tokens, | |
stream=True, | |
temperature=temperature, | |
top_p=top_p, | |
): | |
token = response.get('content', '') | |
if token: | |
partial_message += token | |
current_history = [ | |
{"role": "user", "content": message}, | |
{"role": "assistant", "content": partial_message} | |
] | |
yield "", current_history | |
except Exception as e: | |
error_msg = f"โ Inference error: {str(e)}" | |
error_history = [ | |
{"role": "user", "content": message}, | |
{"role": "assistant", "content": error_msg} | |
] | |
yield "", error_history | |
css = """ | |
footer {visibility: hidden} | |
""" | |
with gr.Blocks(theme="gstaff/xkcd", | |
css=css, | |
title="Offline Sensitive Survey Data Analysis") as demo: | |
gr.HTML( | |
""" | |
<div style="text-align: center; max-width: 1000px; margin: 0 auto;"> | |
<h1 style="font-size: 3em; font-weight: 600; margin: 0.5em;">Offline Sensitive Survey Data Analysis</h1> | |
<h3 style="font-size: 1.2em; margin: 1em;">Leveraging your Local Ollama Inference Server</h3> | |
</div> | |
""" | |
) | |
# Store the current model in a state variable | |
current_model = gr.State("phi3:latest") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
chatbot = gr.Chatbot( | |
height=500, | |
label="Chat Interface", | |
type="messages" | |
) | |
msg = gr.Textbox( | |
label="Type your message", | |
show_label=False, | |
placeholder="Ask me anything about the uploaded data file... ", | |
container=False | |
) | |
with gr.Row(): | |
send = gr.Button("Send") | |
clear = gr.ClearButton([msg, chatbot]) | |
with gr.Column(scale=1): | |
gr.Markdown("### Upload File \nSupport: CSV, Parquet files, Text") | |
file_upload = gr.File( | |
label="Upload File", | |
file_types=[".csv", ".parquet",".txt"], | |
type="filepath" | |
) | |
with gr.Accordion("Model Settings", open=False): | |
model_dropdown = gr.Dropdown( | |
label="Available Models", | |
choices=[], | |
interactive=True | |
) | |
refresh_models = gr.Button("Select Model") | |
with gr.Accordion("Advanced Settings โ๏ธ", open=False): | |
system_message = gr.Textbox(label="Override System Message ๐", value="") | |
max_tokens = gr.Slider(minimum=1, maximum=8000, value=4000, label="Max Tokens (maximum number of words for generated response)") | |
temperature = gr.Slider(minimum=0, maximum=1, value=0.3, label="Temperature (higher = more creative)") | |
top_p = gr.Slider(minimum=0, maximum=1, value=0.7, label="Top P (word choices by probability threshold)") | |
# Function to load available models | |
def load_models(): | |
client = OllamaClient() | |
models = client.list_models() | |
return gr.Dropdown(choices=models, value=models[0] if models else "phi3:latest") | |
# Refresh models button click handler | |
refresh_models.click( | |
load_models, | |
outputs=model_dropdown | |
) | |
# Model dropdown change handler | |
model_dropdown.change( | |
lambda x: x, | |
inputs=model_dropdown, | |
outputs=current_model | |
) | |
# Load models when app starts | |
demo.load( | |
load_models, | |
outputs=model_dropdown | |
) | |
# Event bindings | |
msg.submit( | |
chat, | |
inputs=[msg, chatbot, file_upload, system_message, max_tokens, temperature, top_p, current_model], | |
outputs=[msg, chatbot], | |
queue=True | |
).then( | |
lambda: gr.update(interactive=True), | |
None, | |
[msg] | |
) | |
send.click( | |
chat, | |
inputs=[msg, chatbot, file_upload, system_message, max_tokens, temperature, top_p, current_model], | |
outputs=[msg, chatbot], | |
queue=True | |
).then( | |
lambda: gr.update(interactive=True), | |
None, | |
[msg] | |
) | |
# Auto-analysis on file upload with this hidden component | |
auto_analyze_trigger = gr.Textbox(value="Analyze this file", visible=False) | |
file_upload.change( | |
lambda: gr.Chatbot(value=[]), # Clear chat history | |
outputs=[chatbot], | |
queue=True | |
).then( | |
chat, | |
inputs=[auto_analyze_trigger, chatbot, file_upload, system_message, max_tokens, temperature, top_p, current_model], | |
outputs=[msg, chatbot], | |
queue=True | |
) | |
# Example queries | |
with gr.Column(): | |
gr.Markdown("### Potential Follow-up Queries") | |
with gr.Row(): | |
example_btns = [ | |
gr.Button("Analyze open-ended responses for sentiment and recurring themes", size="lg", variant="secondary"), | |
gr.Button("Compare responses between different groups and identify potential segmentation or cluster analysis", size="lg", variant="secondary"), | |
gr.Button("Identify potential outcome variables and suggest a predicting model for it", size="lg", variant="secondary"), | |
gr.Button("Generate a Quarto notebook in Python to process this dataset", size="lg", variant="secondary"), | |
gr.Button("Generate a Rmd notebook in R to process this dataset", size="lg", variant="secondary"), | |
] | |
# Add click handlers | |
for btn in example_btns: | |
btn.click( | |
lambda x: x, | |
inputs=[gr.Textbox(value=btn.value, visible=False)], | |
outputs=msg | |
) | |
if __name__ == "__main__": | |
demo.launch() |