Spaces:
Running
Running
import gradio as gr | |
import random | |
from datasets import load_dataset | |
import csv | |
from datetime import datetime | |
import os | |
import pandas as pd | |
import json | |
from huggingface_hub import CommitScheduler, HfApi, snapshot_download | |
import shutil | |
import uuid | |
import git | |
from pathlib import Path | |
from io import BytesIO | |
import PIL | |
import time # Add this import at the top | |
import re | |
api = HfApi(token=os.environ["HF_TOKEN"]) | |
RESULTS_BACKUP_REPO = "taesiri/PhotoEditBattleResults-Public" | |
MAIN_DATASET_REPO = "taesiri/IERv2-BattlePairs" | |
# Load the experimental dataset | |
dataset = load_dataset(MAIN_DATASET_REPO, split="train") | |
dataset_post_ids = list( | |
set( | |
load_dataset(MAIN_DATASET_REPO, columns=["post_id"], split="train") | |
.to_pandas() | |
.post_id.tolist() | |
) | |
) | |
# Download existing data from hub | |
def sync_with_hub(): | |
""" | |
Synchronize local data with the hub by cloning the dataset repo | |
""" | |
print("Starting sync with hub...") | |
data_dir = Path("./data") | |
local_csv_path = data_dir / "evaluation_results_exp.csv" | |
# Read existing local data if it exists | |
local_data = None | |
if local_csv_path.exists(): | |
local_data = pd.read_csv(local_csv_path) | |
print(f"Found local data with {len(local_data)} entries") | |
# Clone/pull latest data from hub | |
token = os.environ["HF_TOKEN"] | |
username = "taesiri" | |
repo_url = ( | |
f"https://{username}:{token}@huggingface.co/datasets/{RESULTS_BACKUP_REPO}" | |
) | |
hub_data_dir = Path("hub_data") | |
if hub_data_dir.exists(): | |
print("Pulling latest changes...") | |
repo = git.Repo(hub_data_dir) | |
origin = repo.remotes.origin | |
if "https://" in origin.url: | |
origin.set_url(repo_url) | |
origin.pull() | |
else: | |
print("Cloning repository...") | |
git.Repo.clone_from(repo_url, hub_data_dir) | |
# Merge hub data with local data | |
hub_data_source = hub_data_dir / "data" | |
if hub_data_source.exists(): | |
data_dir.mkdir(exist_ok=True) | |
hub_csv_path = hub_data_source / "evaluation_results_exp.csv" | |
if hub_csv_path.exists(): | |
hub_data = pd.read_csv(hub_csv_path) | |
print(f"Found hub data with {len(hub_data)} entries") | |
if local_data is not None: | |
# Merge data, keeping all entries and removing exact duplicates | |
merged_data = pd.concat([local_data, hub_data]).drop_duplicates() | |
print(f"Merged data has {len(merged_data)} entries") | |
# Save merged data | |
merged_data.to_csv(local_csv_path, index=False) | |
else: | |
# If no local data exists, just copy hub data | |
shutil.copy2(hub_csv_path, local_csv_path) | |
# Copy any other files from hub | |
for item in hub_data_source.glob("*"): | |
if item.is_file() and item.name != "evaluation_results_exp.csv": | |
shutil.copy2(item, data_dir / item.name) | |
elif item.is_dir(): | |
dest = data_dir / item.name | |
if not dest.exists(): | |
shutil.copytree(item, dest) | |
# Clean up cloned repo | |
if hub_data_dir.exists(): | |
shutil.rmtree(hub_data_dir) | |
print("Finished syncing with hub!") | |
scheduler = CommitScheduler( | |
repo_id=RESULTS_BACKUP_REPO, | |
repo_type="dataset", | |
folder_path="./data", | |
path_in_repo="data", | |
every=1, | |
) | |
def save_evaluation( | |
post_id, model_a, model_b, verdict, username, start_time, end_time, dataset_idx | |
): | |
"""Save evaluation results to CSV including timing, username and dataset index information.""" | |
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
duration = end_time - start_time | |
os.makedirs("data", exist_ok=True) | |
filename = "data/evaluation_results_exp.csv" | |
# Create file with headers if it doesn't exist | |
if not os.path.exists(filename): | |
with open(filename, "w", newline="") as f: | |
writer = csv.writer(f) | |
writer.writerow( | |
[ | |
"timestamp", | |
"post_id", | |
"model_a", | |
"model_b", | |
"verdict", | |
"username", | |
"start_time", | |
"end_time", | |
"duration_seconds", | |
"dataset_idx", | |
] | |
) | |
# Append the new evaluation | |
with open(filename, "a", newline="") as f: | |
writer = csv.writer(f) | |
writer.writerow( | |
[ | |
timestamp, | |
post_id, | |
model_a, | |
model_b, | |
verdict, | |
username, | |
start_time, | |
end_time, | |
duration, | |
dataset_idx, | |
] | |
) | |
print( | |
f"Saved evaluation: {post_id} - Model A: {model_a} - Model B: {model_b} - Verdict: {verdict} - Duration: {duration:.2f}s" | |
) | |
def get_annotated_indices(username): | |
"""Get list of dataset indices already annotated by this user""" | |
filename = "data/evaluation_results_exp.csv" | |
if not os.path.exists(filename): | |
print(f"No annotations found for user {username} (file doesn't exist)") | |
return set() | |
try: | |
df = pd.read_csv(filename) | |
if "dataset_idx" not in df.columns or "username" not in df.columns: | |
print(f"No annotations found for user {username} (missing columns)") | |
return set() | |
user_annotations = df[df["username"] == username]["dataset_idx"].tolist() | |
print(f"User {username} has already processed {len(user_annotations)} posts") | |
return set(user_annotations) | |
except: | |
print(f"Error reading annotations for user {username}") | |
return set() | |
def get_annotated_post_ids(username): | |
"""Get list of post_ids already annotated by this user""" | |
filename = "data/evaluation_results_exp.csv" | |
if not os.path.exists(filename): | |
print(f"No annotations found for user {username} (file doesn't exist)") | |
return set() | |
try: | |
df = pd.read_csv(filename) | |
if "post_id" not in df.columns or "username" not in df.columns: | |
print(f"No annotations found for user {username} (missing columns)") | |
return set() | |
user_annotations = df[df["username"] == username]["post_id"].tolist() | |
print(f"User {username} has seen {len(set(user_annotations))} unique posts") | |
return set(user_annotations) | |
except: | |
print(f"Error reading annotations for user {username}") | |
return set() | |
def get_random_sample(username): | |
"""Get a random sample trying to avoid previously seen post_ids""" | |
# Get indices and post_ids already annotated by this user | |
annotated_indices = get_annotated_indices(username) | |
annotated_post_ids = get_annotated_post_ids(username) | |
# Get all valid indices that haven't been annotated | |
all_indices = set(range(len(dataset))) | |
available_indices = list(all_indices - annotated_indices) | |
if not available_indices: | |
# If user has annotated all items, allow repeats | |
available_indices = list(all_indices) | |
# Try up to 5 times to get a sample with unseen post_id | |
max_attempts = 5 | |
for _ in range(max_attempts): | |
idx = random.choice(available_indices) | |
sample = dataset[idx] | |
if sample["post_id"] not in annotated_post_ids: | |
break | |
# Remove this index from available indices for next attempt | |
available_indices.remove(idx) | |
if not available_indices: | |
# If no more indices available, use the last sampled one | |
break | |
# Randomly decide which image goes to position A and B | |
if random.choice([True, False]): | |
# AI edit is A, human edit is B | |
image_a = sample["ai_edited_image"] | |
image_b = sample["human_edited_image"] | |
model_a = sample["model"] | |
model_b = "HUMAN" | |
else: | |
# Human edit is A, AI edit is B | |
image_a = sample["human_edited_image"] | |
image_b = sample["ai_edited_image"] | |
model_a = "HUMAN" | |
model_b = sample["model"] | |
return { | |
"post_id": sample["post_id"], | |
"instruction": '<div style="font-size: 1.8em; font-weight: bold; padding: 20px; background-color: white; border-radius: 10px; margin: 10px;"><span style="color: #888888;">Request:</span> <span style="color: black;">' | |
+ sample["instruction"] | |
+ "</span></div>", | |
"simplified_instruction": '<div style="font-size: 1.8em; font-weight: bold; padding: 20px; background-color: white; border-radius: 10px; margin: 10px;"><span style="color: #888888;">Request:</span> <span style="color: black;">' | |
+ sample["simplified_instruction"] | |
+ "</span></div>", | |
"source_image": sample["source_image"], | |
"image_a": image_a, | |
"image_b": image_b, | |
"model_a": model_a, | |
"model_b": model_b, | |
"dataset_idx": idx, | |
} | |
def evaluate(verdict, state): | |
"""Handle evaluation button clicks with timing""" | |
if state is None: | |
return ( | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
False, | |
False, | |
False, | |
False, | |
None, | |
gr.update(variant="secondary"), | |
gr.update(variant="secondary"), | |
gr.update(variant="secondary"), | |
gr.update(variant="secondary"), | |
None, | |
None, | |
"", | |
) | |
# Record end time and save the evaluation | |
end_time = time.time() | |
save_evaluation( | |
state["post_id"], | |
state["model_a"], | |
state["model_b"], | |
verdict, | |
state["username"], | |
state["start_time"], | |
end_time, | |
state["dataset_idx"], | |
) | |
# Get next sample using username to avoid repeats | |
next_sample = get_random_sample(state["username"]) | |
# Preserve username in state and set new start time | |
next_state = next_sample.copy() | |
next_state["username"] = state["username"] | |
next_state["start_time"] = time.time() # Set start time for next evaluation | |
# Reset button styles | |
a_better_reset = gr.update(variant="secondary") | |
b_better_reset = gr.update(variant="secondary") | |
neither_reset = gr.update(variant="secondary") | |
tie_reset = gr.update(variant="secondary") | |
return ( | |
next_sample["source_image"], | |
next_sample["image_a"], | |
next_sample["image_b"], | |
next_sample["instruction"], | |
next_sample["simplified_instruction"], | |
f"Model A: {next_sample['model_a']} | Model B: {next_sample['model_b']}", | |
next_state, # Now includes username and start_time | |
None, # selected_verdict | |
False, # a_better_selected | |
False, # b_better_selected | |
False, # neither_selected | |
False, # tie_selected | |
a_better_reset, # reset A is better button style | |
b_better_reset, # reset B is better button style | |
neither_reset, # reset neither is good button style | |
tie_reset, # reset tie button style | |
next_sample["post_id"], | |
next_sample["simplified_instruction"], | |
state["username"], # Use username from state | |
) | |
def select_verdict(verdict, state): | |
"""Handle first step selection""" | |
if state is None: | |
return None, False, False, False, False # Ensure it returns 5 values | |
return ( | |
verdict, | |
verdict == "A is better", | |
verdict == "B is better", | |
verdict == "Neither is good", | |
verdict == "Tie", | |
) | |
def is_valid_email(email): | |
""" | |
Validate email format and content more strictly: | |
- Check basic email format | |
- Prevent common injection attempts | |
- Limit length | |
- Restrict to printable ASCII characters | |
""" | |
if not email or not isinstance(email, str): | |
return False | |
# Check length limits | |
if len(email) > 254: # Maximum length per RFC 5321 | |
return False | |
# Remove any whitespace | |
email = email.strip() | |
# Check for common injection characters | |
dangerous_chars = [";", '"', "'", ",", "\\", "\n", "\r", "\t"] | |
if any(char in email for char in dangerous_chars): | |
return False | |
# Ensure all characters are printable ASCII | |
if not all(32 <= ord(char) <= 126 for char in email): | |
return False | |
# Validate email format using comprehensive regex | |
pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$" | |
if not re.match(pattern, email): | |
return False | |
# Additional checks for common patterns | |
if ".." in email: # No consecutive dots | |
return False | |
if email.count("@") != 1: # Exactly one @ symbol | |
return False | |
# Validate lengths of local and domain parts | |
local, domain = email.split("@") | |
if len(local) > 64 or len(domain) > 255: # RFC 5321 limits | |
return False | |
return True | |
def handle_username_submit(email, current_page): | |
"""Handle email submission with enhanced validation""" | |
try: | |
if not email: | |
gr.Warning("Please enter an email address") | |
return current_page, gr.update(value=email), gr.update(value=""), None | |
# Clean the input | |
email = str(email).strip() | |
if not is_valid_email(email): | |
gr.Warning("Please enter a valid email address (e.g., [email protected])") | |
return current_page, gr.update(value=email), gr.update(value=""), None | |
# Sanitize email for CSV storage | |
safe_email = email.replace('"', "").replace("'", "") | |
return ( | |
2, # next page | |
gr.update(value=""), # clear input | |
gr.update(value=safe_email), # update debug | |
safe_email, # update state | |
) | |
except Exception as e: | |
print(f"Error in handle_username_submit: {str(e)}") | |
gr.Warning("An error occurred. Please try again.") | |
return current_page, gr.update(value=""), gr.update(value=""), None | |
def initialize(username): | |
"""Initialize the interface with first sample""" | |
sample = get_random_sample(username) | |
# Create state with username and start time included | |
state = sample.copy() | |
state["username"] = username | |
state["start_time"] = time.time() # Record start time | |
return ( | |
sample["source_image"], | |
sample["image_a"], | |
sample["image_b"], | |
sample["instruction"], | |
sample["simplified_instruction"], | |
f"Model A: {sample['model_a']} | Model B: {sample['model_b']}", | |
state, # Now includes username and start_time | |
None, # selected_verdict | |
False, # a_better_selected | |
False, # b_better_selected | |
False, # neither_selected | |
False, # tie_selected | |
sample["post_id"], | |
sample["simplified_instruction"], | |
username or "", | |
) | |
def update_button_styles(verdict): | |
"""Update button styles based on selection""" | |
# Update button labels to use emojis | |
a_better_style = gr.update( | |
value="☝️ A is better" if verdict == "A is better" else "☝️ A is better" | |
) | |
b_better_style = gr.update( | |
value="☝️ B is better" if verdict == "B is better" else "☝️ B is better" | |
) | |
neither_style = gr.update( | |
value="👎 Both are bad" if verdict == "Neither is good" else "👎 Both are bad" | |
) | |
tie_style = gr.update(value="🤝 Tie" if verdict == "Tie" else "🤝 Tie") | |
return a_better_style, b_better_style, neither_style, tie_style | |
# Add at the top after imports | |
def create_instruction_page(html_content, image_path=None): | |
"""Helper function to create consistent instruction pages""" | |
with gr.Column(): | |
gr.HTML(html_content) | |
if image_path: | |
gr.Image(image_path, container=False) | |
def advance_page(current_page): | |
"""Handle next button clicks to advance pages""" | |
return current_page + 1 | |
# Modify the main interface | |
with gr.Blocks() as demo: | |
# Add states for page management and user info | |
current_page = gr.State(1) # Start at page 1 | |
username_state = gr.State(None) # We'll actually use this now | |
# Create container for all pages | |
with gr.Column() as page_container: | |
# Page 1 - Username Collection | |
with gr.Column(visible=True) as page1: | |
create_instruction_page( | |
""" | |
<div style="text-align: center; padding: 20px;"> | |
<h1>Welcome to the Image Edit Evaluation</h1> | |
<p>Help us evaluate different image edits for a given instruction.</p> | |
</div> | |
""", | |
image_path="./instructions/home.jpg", | |
) | |
username_input = gr.Textbox( | |
label="Please enter your email address (if you don't want to share your email, please enter a fake email)", | |
placeholder="[email protected]", | |
) | |
start_btn = gr.Button("Start", variant="primary") | |
# Page 2 - First instruction page | |
with gr.Column(visible=False) as page2: | |
create_instruction_page( | |
""" | |
<div style="text-align: center; padding: 20px;"> | |
<h1>How to Evaluate Edits</h1> | |
</div> | |
""", | |
image_path="./instructions/page2.jpg", # Replace with actual image path | |
) | |
next_btn1 = gr.Button( | |
"Start Evaluation", variant="primary" | |
) # Changed button text | |
# Main Evaluation UI (existing code) | |
with gr.Column(visible=False) as main_ui: | |
# Add instruction panel at the top | |
gr.HTML( | |
""" | |
<div style="padding: 0.8rem; margin-bottom: 0.8rem; border-radius: 0.5rem; color: white; text-align: center;"> | |
<div style="font-size: 1.2rem; margin-bottom: 0.5rem;">Read the user instruction, look at the source image, then evaluate which edit (A or B) best satisfies the request better.</div> | |
<div style="font-size: 1rem;"> | |
<strong>🤝 Tie</strong> | | |
<strong> A is better</strong> | | |
<strong> B is better</strong> | |
</div> | |
<div style="color: #ff4444; font-size: 0.9rem; margin-top: 0.5rem;"> | |
Please ignore any watermark on the image. Your rating should not be affected by any watermark on the image. | |
</div> | |
</div> | |
""" | |
) | |
with gr.Row(): | |
simplified_instruction = gr.Textbox( | |
label="Simplified Instruction", show_label=True, visible=False | |
) | |
instruction = gr.HTML(label="Original Instruction", show_label=True) | |
with gr.Row(): | |
with gr.Column(): | |
source_image = gr.Image( | |
label="Source Image", show_label=True, height=500 | |
) | |
gr.HTML("<h2 style='text-align: center;'>Source Image</h2>") | |
tie_btn = gr.Button("🤝 Tie", variant="secondary") | |
with gr.Column(): | |
image_a = gr.Image(label="Image A", show_label=True, height=500) | |
gr.HTML("<h2 style='text-align: center;'>Image A</h2>") | |
a_better_btn = gr.Button("☝️ A is better", variant="secondary") | |
with gr.Column(): | |
image_b = gr.Image(label="Image B", show_label=True, height=500) | |
gr.HTML("<h2 style='text-align: center;'>Image B</h2>") | |
b_better_btn = gr.Button("☝️ B is better", variant="secondary") | |
# Add confirmation button in new row | |
with gr.Row(): | |
confirm_btn = gr.Button( | |
"Confirm Selection", variant="primary", visible=False | |
) | |
with gr.Row(): | |
neither_btn = gr.Button( | |
"👎 Both are bad", variant="secondary", visible=False | |
) | |
with gr.Accordion("DEBUG", open=False, visible=False): | |
with gr.Column(): | |
post_id_display = gr.Textbox( | |
label="Post ID", show_label=True, interactive=False | |
) | |
model_info = gr.Textbox(label="Model Information", show_label=True) | |
simplified_instruction_debug = gr.Textbox( | |
label="Simplified Instruction", | |
show_label=True, | |
interactive=False, | |
) | |
username_debug = gr.Textbox( | |
label="Username", show_label=True, interactive=False | |
) | |
state = gr.State() | |
selected_verdict = gr.State() | |
# Add states for button selection | |
a_better_selected = gr.Checkbox(visible=False) | |
b_better_selected = gr.Checkbox(visible=False) | |
neither_selected = gr.Checkbox(visible=False) | |
tie_selected = gr.Checkbox(visible=False) | |
def update_confirm_visibility(a_better, b_better, neither, tie): | |
# Update button text based on selection | |
if a_better: | |
return gr.update(visible=True, value="Confirm A is better") | |
elif b_better: | |
return gr.update(visible=True, value="Confirm B is better") | |
elif neither: | |
return gr.update(visible=True, value="Confirm Neither is good") | |
elif tie: | |
return gr.update(visible=True, value="Confirm Tie") | |
return gr.update(visible=False) | |
# Initialize the interface | |
demo.load( | |
lambda: initialize(None), # Pass None on initial load | |
outputs=[ | |
source_image, | |
image_a, | |
image_b, | |
instruction, | |
simplified_instruction, | |
model_info, | |
state, | |
selected_verdict, | |
a_better_selected, | |
b_better_selected, | |
neither_selected, | |
tie_selected, | |
post_id_display, | |
simplified_instruction_debug, | |
username_debug, | |
], | |
) | |
# Handle first step button clicks | |
a_better_btn.click( | |
lambda state: select_verdict("A is better", state), | |
inputs=[state], | |
outputs=[ | |
selected_verdict, | |
a_better_selected, | |
b_better_selected, | |
neither_selected, | |
tie_selected, | |
], | |
).then( | |
update_button_styles, | |
inputs=[selected_verdict], | |
outputs=[a_better_btn, b_better_btn, neither_btn, tie_btn], | |
) | |
b_better_btn.click( | |
lambda state: select_verdict("B is better", state), | |
inputs=[state], | |
outputs=[ | |
selected_verdict, | |
a_better_selected, | |
b_better_selected, | |
neither_selected, | |
tie_selected, | |
], | |
).then( | |
update_button_styles, | |
inputs=[selected_verdict], | |
outputs=[a_better_btn, b_better_btn, neither_btn, tie_btn], | |
) | |
neither_btn.click( | |
lambda state: select_verdict("Neither is good", state), | |
inputs=[state], | |
outputs=[ | |
selected_verdict, | |
a_better_selected, | |
b_better_selected, | |
neither_selected, | |
tie_selected, | |
], | |
).then( | |
update_button_styles, | |
inputs=[selected_verdict], | |
outputs=[a_better_btn, b_better_btn, neither_btn, tie_btn], | |
) | |
tie_btn.click( | |
lambda state: select_verdict("Tie", state), | |
inputs=[state], | |
outputs=[ | |
selected_verdict, | |
a_better_selected, | |
b_better_selected, | |
neither_selected, | |
tie_selected, | |
], | |
).then( | |
update_button_styles, | |
inputs=[selected_verdict], | |
outputs=[a_better_btn, b_better_btn, neither_btn, tie_btn], | |
) | |
# Update confirm button visibility when selection changes | |
for checkbox in [ | |
a_better_selected, | |
b_better_selected, | |
neither_selected, | |
tie_selected, | |
]: | |
checkbox.change( | |
update_confirm_visibility, | |
inputs=[ | |
a_better_selected, | |
b_better_selected, | |
neither_selected, | |
tie_selected, | |
], | |
outputs=[confirm_btn], | |
) | |
# Handle confirmation button click | |
confirm_btn.click( | |
lambda verdict, state: evaluate(verdict, state), | |
inputs=[selected_verdict, state], | |
outputs=[ | |
source_image, | |
image_a, | |
image_b, | |
instruction, | |
simplified_instruction, | |
model_info, | |
state, | |
selected_verdict, | |
a_better_selected, | |
b_better_selected, | |
neither_selected, | |
tie_selected, | |
a_better_btn, | |
b_better_btn, | |
neither_btn, | |
tie_btn, | |
post_id_display, | |
simplified_instruction_debug, | |
username_debug, | |
], | |
) | |
# Handle page visibility | |
def update_page_visibility(page_num): | |
"""Return visibility updates for each page column""" | |
return [ | |
gr.update(visible=(page_num == 1)), # page1 | |
gr.update(visible=(page_num == 2)), # page2 | |
gr.update(visible=(page_num == 3)), # main_ui - changed from 4 to 3 | |
] | |
# Connect button clicks to page navigation | |
start_btn.click( | |
handle_username_submit, | |
inputs=[username_input, current_page], | |
outputs=[ | |
current_page, | |
username_input, | |
username_debug, | |
username_state, | |
], | |
).then( | |
update_page_visibility, | |
inputs=[current_page], | |
outputs=[page1, page2, main_ui], | |
).then( | |
initialize, | |
inputs=[username_state], | |
outputs=[ | |
source_image, | |
image_a, | |
image_b, | |
instruction, | |
simplified_instruction, | |
model_info, | |
state, | |
selected_verdict, | |
a_better_selected, | |
b_better_selected, | |
neither_selected, | |
tie_selected, | |
post_id_display, | |
simplified_instruction_debug, | |
username_debug, | |
], | |
) | |
next_btn1.click( | |
lambda x: 3, # Force page 3 instead of using advance_page | |
inputs=[current_page], | |
outputs=current_page, | |
).then( | |
update_page_visibility, | |
inputs=[current_page], | |
outputs=[page1, page2, main_ui], | |
).then( | |
initialize, | |
inputs=[username_state], | |
outputs=[ | |
source_image, | |
image_a, | |
image_b, | |
instruction, | |
simplified_instruction, | |
model_info, | |
state, | |
selected_verdict, | |
a_better_selected, | |
b_better_selected, | |
neither_selected, | |
tie_selected, | |
post_id_display, | |
simplified_instruction_debug, | |
username_debug, | |
], | |
) | |
if __name__ == "__main__": | |
# Sync with hub before launching | |
sync_with_hub() | |
demo.launch() | |