Spaces:
Sleeping
Sleeping
import os | |
import time | |
import pandas as pd | |
import numpy as np | |
import gradio as gr | |
from typing import Dict, List, Optional | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
# Import functions from our modules | |
from evaluate_creativity import evaluate_creativity | |
from evaluate_stability import ( | |
evaluate_stability, | |
evaluate_combined_score, | |
create_radar_chart, | |
create_bar_chart, | |
get_leaderboard_data | |
) | |
def list_available_models(csv_file): | |
try: | |
df = pd.read_csv(csv_file) | |
model_columns = [col for col in df.columns if col.endswith('_answers')] | |
models = [col.replace('_answers', '') for col in model_columns] | |
return models | |
except Exception as e: | |
print(f"Error listing models: {str(e)}") | |
return [] | |
def evaluate_models(file_path, api_key, prompt_col, selected_models=None, progress=gr.Progress()): | |
os.makedirs('results', exist_ok=True) | |
progress(0, desc="Loading data...") | |
df = pd.read_csv(file_path) | |
# Determine which models to evaluate | |
if selected_models: | |
answer_cols = [f"{model}_answers" for model in selected_models] | |
models = selected_models | |
else: | |
answer_cols = [col for col in df.columns if col.endswith('_answers')] | |
models = [col.replace('_answers', '') for col in answer_cols] | |
model_mapping = dict(zip(models, answer_cols)) | |
progress(0.1, desc=f"Found {len(model_mapping)} models to evaluate") | |
all_results = {} | |
all_creativity_dfs = {} | |
benchmark_file = 'results/benchmark_results.csv' | |
if os.path.exists(benchmark_file): | |
try: | |
benchmark_df = pd.read_csv(benchmark_file) | |
except: | |
benchmark_df = pd.DataFrame(columns=[ | |
'model', 'creativity_score', 'stability_score', | |
'combined_score', 'evaluation_timestamp' | |
]) | |
else: | |
benchmark_df = pd.DataFrame(columns=[ | |
'model', 'creativity_score', 'stability_score', | |
'combined_score', 'evaluation_timestamp' | |
]) | |
progress_increment = 0.9 / len(model_mapping) | |
progress_current = 0.1 | |
for model, column in model_mapping.items(): | |
try: | |
progress(progress_current, desc=f"Evaluating {model}...") | |
# Evaluate creativity | |
creativity_df = evaluate_creativity(api_key, df, prompt_col, column, batch_size=1, progress=progress) | |
progress_current += progress_increment * 0.6 | |
progress(progress_current, desc=f"Evaluating stability for {model}...") | |
# Evaluate stability | |
stability_results = evaluate_stability(df, prompt_col, column, progress=progress) | |
progress_current += progress_increment * 0.3 | |
progress(progress_current, desc=f"Calculating combined score for {model}...") | |
# Calculate combined score | |
combined_results = evaluate_combined_score(creativity_df, stability_results, model) | |
# Save detailed results | |
timestamp = pd.Timestamp.now().strftime('%Y-%m-%d_%H-%M-%S') | |
output_file = f'results/evaluated_responses_{model}_{timestamp}.csv' | |
creativity_df.to_csv(output_file, index=False) | |
# Add to benchmark DataFrame | |
result_row = { | |
'model': model, | |
'creativity_score': combined_results['creativity_score'], | |
'stability_score': combined_results['stability_score'], | |
'combined_score': combined_results['combined_score'], | |
'evaluation_timestamp': combined_results['evaluation_timestamp'] | |
} | |
benchmark_df = pd.concat([benchmark_df, pd.DataFrame([result_row])], ignore_index=True) | |
all_results[model] = combined_results | |
all_creativity_dfs[model] = creativity_df | |
progress_current += progress_increment * 0.1 | |
progress(progress_current, desc=f"Finished evaluating {model}") | |
except Exception as e: | |
print(f"Error evaluating {model}: {str(e)}") | |
# Save benchmark results | |
benchmark_df.to_csv(benchmark_file, index=False) | |
timestamp = pd.Timestamp.now().strftime('%Y%m%d_%H%M%S') | |
combined_benchmark_path = f'results/benchmark_results_{timestamp}.csv' | |
benchmark_df.to_csv(combined_benchmark_path, index=False) | |
progress(0.95, desc="Creating visualizations...") | |
radar_chart_path = create_radar_chart(all_results) | |
bar_chart_path = create_bar_chart(all_results) | |
progress(1.0, desc="Evaluation complete!") | |
sorted_results = benchmark_df.sort_values(by='combined_score', ascending=False) | |
return sorted_results, radar_chart_path, bar_chart_path, combined_benchmark_path | |
def get_leaderboard_data(): | |
return [ | |
["Vikhr", "7.75", "0.9363600260019302", "0.860"], | |
["Llama3", "7.30", "0.9410231244564057", "0.827"], | |
["Mistral", "6.95", "0.9459488660097122", "0.807"], | |
["Owen", "6.93", "0.945682458281517", "0.800"], | |
["TinyLlama", "1.12", "0.945682458281517", "0.573"] | |
] | |
def create_gradio_interface(): | |
with gr.Blocks(title="LLM Evaluation Tool") as app: | |
gr.Markdown("# LLM Evaluation Tool") | |
gr.Markdown("Оцените модели на креативность, разнообразие, релевантность и стабильность") | |
with gr.Tab("Evaluate Models"): | |
with gr.Row(): | |
with gr.Column(): | |
file_input = gr.File(label="Upload CSV with prompts and responses") | |
api_key_input = gr.Textbox(label="Gemini API Key", type="password") | |
prompt_col_input = gr.Textbox(label="Prompt Column Name", value="rus_prompt") | |
model_selection = gr.CheckboxGroup( | |
label="Select Models to Evaluate (leave empty to evaluate all)", | |
choices=[], | |
interactive=True | |
) | |
refresh_button = gr.Button("Refresh Model List") | |
def update_model_list(file): | |
if file: | |
models = list_available_models(file.name) | |
return gr.CheckboxGroup(choices=models) | |
return gr.CheckboxGroup(choices=[]) | |
evaluate_button = gr.Button("Evaluate Models", variant="primary") | |
with gr.Row(): | |
result_table = gr.Dataframe(label="Evaluation Results") | |
with gr.Row(): | |
with gr.Column(): | |
radar_chart = gr.Image(label="Radar Chart") | |
with gr.Column(): | |
bar_chart = gr.Image(label="Bar Chart") | |
result_file = gr.File(label="Download Complete Results") | |
evaluate_button.click( | |
fn=evaluate_models, | |
inputs=[file_input, api_key_input, prompt_col_input, model_selection], | |
outputs=[result_table, radar_chart, bar_chart, result_file] | |
) | |
with gr.Tab("Leaderboard"): | |
with gr.Row(): | |
leaderboard_table = gr.Dataframe( | |
label="Model Leaderboard", | |
headers=["Model", "Креативность", "Стабильность", "Общий балл"] | |
) | |
refresh_leaderboard = gr.Button("Refresh Leaderboard") | |
def update_leaderboard(): | |
return get_leaderboard_data() | |
with gr.Row(): | |
gr.Markdown("### Leaderboard Details") | |
gr.Markdown(""" | |
- **Креативность**: Оригинальность и инновационность ответов (шкала до 10) | |
- **Стабильность**: Коэффициент стабильности модели (0-1) | |
- **Общий балл**: Средний комбинированный показатель производительности (0-1) | |
""") | |
return app | |
if __name__ == "__main__": | |
app = create_gradio_interface() | |
app.launch() |