MrSimple01's picture
Update app.py
7149d5b verified
raw
history blame contribute delete
8.66 kB
import os
import time
import pandas as pd
import numpy as np
import gradio as gr
from typing import Dict, List, Optional
import matplotlib.pyplot as plt
import seaborn as sns
# Import functions from our modules
from evaluate_creativity import evaluate_creativity
from evaluate_stability import (
evaluate_stability,
evaluate_combined_score,
create_radar_chart,
create_bar_chart,
get_leaderboard_data
)
def list_available_models(csv_file):
try:
df = pd.read_csv(csv_file)
model_columns = [col for col in df.columns if col.endswith('_answers')]
models = [col.replace('_answers', '') for col in model_columns]
return models
except Exception as e:
print(f"Error listing models: {str(e)}")
return []
def evaluate_models(file_path, api_key, prompt_col, selected_models=None, progress=gr.Progress()):
os.makedirs('results', exist_ok=True)
progress(0, desc="Loading data...")
df = pd.read_csv(file_path)
# Determine which models to evaluate
if selected_models:
answer_cols = [f"{model}_answers" for model in selected_models]
models = selected_models
else:
answer_cols = [col for col in df.columns if col.endswith('_answers')]
models = [col.replace('_answers', '') for col in answer_cols]
model_mapping = dict(zip(models, answer_cols))
progress(0.1, desc=f"Found {len(model_mapping)} models to evaluate")
all_results = {}
all_creativity_dfs = {}
benchmark_file = 'results/benchmark_results.csv'
if os.path.exists(benchmark_file):
try:
benchmark_df = pd.read_csv(benchmark_file)
except:
benchmark_df = pd.DataFrame(columns=[
'model', 'creativity_score', 'stability_score',
'combined_score', 'evaluation_timestamp'
])
else:
benchmark_df = pd.DataFrame(columns=[
'model', 'creativity_score', 'stability_score',
'combined_score', 'evaluation_timestamp'
])
progress_increment = 0.9 / len(model_mapping)
progress_current = 0.1
for model, column in model_mapping.items():
try:
progress(progress_current, desc=f"Evaluating {model}...")
# Evaluate creativity
creativity_df = evaluate_creativity(api_key, df, prompt_col, column, batch_size=1, progress=progress)
progress_current += progress_increment * 0.6
progress(progress_current, desc=f"Evaluating stability for {model}...")
# Evaluate stability
stability_results = evaluate_stability(df, prompt_col, column, progress=progress)
progress_current += progress_increment * 0.3
progress(progress_current, desc=f"Calculating combined score for {model}...")
# Calculate combined score
combined_results = evaluate_combined_score(creativity_df, stability_results, model)
# Save detailed results
timestamp = pd.Timestamp.now().strftime('%Y-%m-%d_%H-%M-%S')
output_file = f'results/evaluated_responses_{model}_{timestamp}.csv'
creativity_df.to_csv(output_file, index=False)
# Add to benchmark DataFrame
result_row = {
'model': model,
'creativity_score': combined_results['creativity_score'],
'stability_score': combined_results['stability_score'],
'combined_score': combined_results['combined_score'],
'evaluation_timestamp': combined_results['evaluation_timestamp']
}
benchmark_df = pd.concat([benchmark_df, pd.DataFrame([result_row])], ignore_index=True)
all_results[model] = combined_results
all_creativity_dfs[model] = creativity_df
progress_current += progress_increment * 0.1
progress(progress_current, desc=f"Finished evaluating {model}")
except Exception as e:
print(f"Error evaluating {model}: {str(e)}")
# Save benchmark results
benchmark_df.to_csv(benchmark_file, index=False)
timestamp = pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')
combined_benchmark_path = f'results/benchmark_results_{timestamp}.csv'
benchmark_df.to_csv(combined_benchmark_path, index=False)
progress(0.95, desc="Creating visualizations...")
radar_chart_path = create_radar_chart(all_results)
bar_chart_path = create_bar_chart(all_results)
progress(1.0, desc="Evaluation complete!")
sorted_results = benchmark_df.sort_values(by='combined_score', ascending=False)
return sorted_results, radar_chart_path, bar_chart_path, combined_benchmark_path
def get_leaderboard_data():
return [
["Vikhr", "7.75", "0.9363600260019302", "0.860"],
["Llama3", "7.30", "0.9410231244564057", "0.827"],
["Mistral", "6.95", "0.9459488660097122", "0.807"],
["Owen", "6.93", "0.945682458281517", "0.800"],
["TinyLlama", "1.12", "0.945682458281517", "0.573"]
]
def create_gradio_interface():
with gr.Blocks(title="LLM Evaluation Tool") as app:
gr.Markdown("# LLM Evaluation Tool")
gr.Markdown("Оцените модели на креативность, разнообразие, релевантность и стабильность")
with gr.Tab("Evaluate Models"):
with gr.Row():
with gr.Column():
file_input = gr.File(label="Upload CSV with prompts and responses")
api_key_input = gr.Textbox(label="Gemini API Key", type="password")
prompt_col_input = gr.Textbox(label="Prompt Column Name", value="rus_prompt")
model_selection = gr.CheckboxGroup(
label="Select Models to Evaluate (leave empty to evaluate all)",
choices=[],
interactive=True
)
refresh_button = gr.Button("Refresh Model List")
@refresh_button.click(inputs=[file_input], outputs=[model_selection])
def update_model_list(file):
if file:
models = list_available_models(file.name)
return gr.CheckboxGroup(choices=models)
return gr.CheckboxGroup(choices=[])
evaluate_button = gr.Button("Evaluate Models", variant="primary")
with gr.Row():
result_table = gr.Dataframe(label="Evaluation Results")
with gr.Row():
with gr.Column():
radar_chart = gr.Image(label="Radar Chart")
with gr.Column():
bar_chart = gr.Image(label="Bar Chart")
result_file = gr.File(label="Download Complete Results")
evaluate_button.click(
fn=evaluate_models,
inputs=[file_input, api_key_input, prompt_col_input, model_selection],
outputs=[result_table, radar_chart, bar_chart, result_file]
)
with gr.Tab("Leaderboard"):
with gr.Row():
leaderboard_table = gr.Dataframe(
label="Model Leaderboard",
headers=["Model", "Креативность", "Стабильность", "Общий балл"]
)
refresh_leaderboard = gr.Button("Refresh Leaderboard")
@refresh_leaderboard.click(outputs=[leaderboard_table])
def update_leaderboard():
return get_leaderboard_data()
with gr.Row():
gr.Markdown("### Leaderboard Details")
gr.Markdown("""
- **Креативность**: Оригинальность и инновационность ответов (шкала до 10)
- **Стабильность**: Коэффициент стабильности модели (0-1)
- **Общий балл**: Средний комбинированный показатель производительности (0-1)
""")
return app
if __name__ == "__main__":
app = create_gradio_interface()
app.launch()