taesiri's picture
backup
004c531
import gradio as gr
import base64
import json
import os
import shutil
import uuid
import glob
from huggingface_hub import CommitScheduler, HfApi, snapshot_download
from pathlib import Path
import git
from datasets import Dataset, Features, Value, Sequence, Image as ImageFeature
import threading
import time
from utils import process_and_push_dataset
from datasets import load_dataset
api = HfApi(token=os.environ["HF_TOKEN"])
VALID_DATASET = load_dataset("taesiri/IERv2-Subset-Validation-150", split="train")
VALID_DATASET_POST_IDS = (
load_dataset(
"taesiri/IERv2-Subset-Validation-150", split="train", columns=["post_id"]
)
.to_pandas()["post_id"]
.tolist()
)
POST_ID_TO_ID_MAP = {post_id: idx for idx, post_id in enumerate(VALID_DATASET_POST_IDS)}
DATASET_REPO = "taesiri/AIImageEditingResults_Intemediate2"
FINAL_DATASET_REPO = "taesiri/AIImageEditingResults"
# Download existing data from hub
def sync_with_hub():
"""
Synchronize local data with the hub by cloning the dataset repo
"""
print("Starting sync with hub...")
data_dir = Path("./data")
if data_dir.exists():
# Backup existing data
backup_dir = Path("./data_backup")
if backup_dir.exists():
shutil.rmtree(backup_dir)
shutil.copytree(data_dir, backup_dir)
# Clone/pull latest data from hub
# Use token in the URL for authentication following HF's new format
token = os.environ["HF_TOKEN"]
username = "taesiri" # Extract from DATASET_REPO
repo_url = f"https://{username}:{token}@huggingface.co/datasets/{DATASET_REPO}"
hub_data_dir = Path("hub_data")
if hub_data_dir.exists():
# If repo exists, do a git pull
print("Pulling latest changes...")
repo = git.Repo(hub_data_dir)
origin = repo.remotes.origin
# Set the new URL with token
if "https://" in origin.url:
origin.set_url(repo_url)
origin.pull()
else:
# Clone the repo with token
print("Cloning repository...")
git.Repo.clone_from(repo_url, hub_data_dir)
# Merge hub data with local data
hub_data_source = hub_data_dir / "data"
if hub_data_source.exists():
# Create data dir if it doesn't exist
data_dir.mkdir(exist_ok=True)
# Copy files from hub
for item in hub_data_source.glob("*"):
if item.is_dir():
dest = data_dir / item.name
if not dest.exists(): # Only copy if doesn't exist locally
shutil.copytree(item, dest)
# Clean up cloned repo
if hub_data_dir.exists():
shutil.rmtree(hub_data_dir)
print("Finished syncing with hub!")
scheduler = CommitScheduler(
repo_id=DATASET_REPO,
repo_type="dataset",
folder_path="./data",
path_in_repo="data",
every=1,
)
def load_question_data(question_id):
"""
Load a specific question's data
Returns a tuple of all form fields
"""
if not question_id:
return [None] * 11 # Reduced number of fields
# Extract the ID part before the colon from the dropdown selection
question_id = (
question_id.split(":")[0].strip() if ":" in question_id else question_id
)
json_path = os.path.join("./data", question_id, "question.json")
if not os.path.exists(json_path):
print(f"Question file not found: {json_path}")
return [None] * 11
try:
with open(json_path, "r", encoding="utf-8") as f:
data = json.loads(f.read().strip())
# Load images
def load_image(image_path):
if not image_path:
return None
full_path = os.path.join(
"./data", question_id, os.path.basename(image_path)
)
return full_path if os.path.exists(full_path) else None
question_images = data.get("question_images", [])
rationale_images = data.get("rationale_images", [])
return [
(
",".join(data["question_categories"])
if isinstance(data["question_categories"], list)
else data["question_categories"]
),
data["question"],
data["final_answer"],
data.get("rationale_text", ""),
load_image(question_images[0] if question_images else None),
load_image(question_images[1] if len(question_images) > 1 else None),
load_image(question_images[2] if len(question_images) > 2 else None),
load_image(question_images[3] if len(question_images) > 3 else None),
load_image(rationale_images[0] if rationale_images else None),
load_image(rationale_images[1] if len(rationale_images) > 1 else None),
question_id,
]
except Exception as e:
print(f"Error loading question {question_id}: {str(e)}")
return [None] * 11
def load_post_image(post_id):
if not post_id:
return [
None
] * 33 # source image + instruction + simplified_instruction + 10 triplets
idx = POST_ID_TO_ID_MAP[post_id]
source_image = VALID_DATASET[idx]["image"]
instruction = VALID_DATASET[idx]["instruction"]
simplified_instruction = VALID_DATASET[idx]["simplified_instruction"]
# Load existing responses if any
post_folder = os.path.join("./data", str(post_id))
metadata_path = os.path.join(post_folder, "metadata.json")
if os.path.exists(metadata_path):
with open(metadata_path, "r") as f:
metadata = json.load(f)
# Initialize response data
responses = [(None, "", "")] * 10 # Initialize with empty notes
# Fill in existing responses
for response in metadata["responses"]:
idx = response["response_id"]
if idx < 10: # Ensure we don't exceed our UI limit
image_path = os.path.join(post_folder, response["image_path"])
responses[idx] = (
image_path,
response["answer_text"],
response.get("notes", ""),
)
# Flatten responses for output
flat_responses = [item for triplet in responses for item in triplet]
return [source_image, instruction, simplified_instruction] + flat_responses
# If no existing responses, return source image, instructions and empty responses
return [source_image, instruction, simplified_instruction] + [None] * 30
def generate_json_files(source_image, responses, post_id):
"""
Save the source image and multiple responses to the data directory
Args:
source_image: Path to the source image
responses: List of (image, answer, notes) tuples
post_id: The post ID from the dataset
"""
# Create parent data folder if it doesn't exist
parent_data_folder = "./data"
os.makedirs(parent_data_folder, exist_ok=True)
# Create/clear post_id folder
post_folder = os.path.join(parent_data_folder, str(post_id))
if os.path.exists(post_folder):
shutil.rmtree(post_folder)
os.makedirs(post_folder)
# Save source image
source_image_path = os.path.join(post_folder, "source_image.png")
if isinstance(source_image, str):
shutil.copy2(source_image, source_image_path)
else:
gr.processing_utils.save_image(source_image, source_image_path)
# Create responses data
responses_data = []
for idx, (response_image, answer_text, notes) in enumerate(responses):
if response_image and answer_text: # Only process if both image and text exist
response_folder = os.path.join(post_folder, f"response_{idx}")
os.makedirs(response_folder)
# Save response image
response_image_path = os.path.join(response_folder, "response_image.png")
if isinstance(response_image, str):
shutil.copy2(response_image, response_image_path)
else:
gr.processing_utils.save_image(response_image, response_image_path)
# Add to responses data
responses_data.append(
{
"response_id": idx,
"answer_text": answer_text,
"notes": notes,
"image_path": f"response_{idx}/response_image.png",
}
)
# Create metadata JSON
metadata = {
"post_id": post_id,
"source_image": "source_image.png",
"responses": responses_data,
}
# Save metadata
with open(os.path.join(post_folder, "metadata.json"), "w", encoding="utf-8") as f:
json.dump(metadata, f, ensure_ascii=False, indent=2)
return post_folder
def get_statistics():
"""
Scan the data folder and return statistics about the responses
"""
data_dir = Path("./data")
if not data_dir.exists():
return "No data directory found"
total_expected_posts = len(VALID_DATASET_POST_IDS)
processed_post_ids = set()
posts_with_responses = 0
total_responses = 0
responses_per_post = [] # List to track number of responses for each post
for metadata_file in data_dir.glob("*/metadata.json"):
post_id = metadata_file.parent.name
if post_id in VALID_DATASET_POST_IDS: # Only count valid posts
processed_post_ids.add(post_id)
try:
with open(metadata_file, "r") as f:
metadata = json.load(f)
num_responses = len(metadata.get("responses", []))
responses_per_post.append(num_responses)
if num_responses > 0:
posts_with_responses += 1
total_responses += num_responses
except:
continue
missing_posts = set(map(str, VALID_DATASET_POST_IDS)) - processed_post_ids
total_processed = len(processed_post_ids)
# Calculate additional statistics
if responses_per_post:
responses_per_post.sort()
median_responses = responses_per_post[len(responses_per_post) // 2]
max_responses = max(responses_per_post)
avg_responses = (
total_responses / posts_with_responses if posts_with_responses > 0 else 0
)
else:
median_responses = max_responses = avg_responses = 0
stats = f"""
📊 Collection Statistics:
Dataset Coverage:
- Total Expected Posts: {total_expected_posts}
- Posts Processed: {total_processed}
- Missing Posts: {len(missing_posts)} ({', '.join(list(missing_posts)[:5])}{'...' if len(missing_posts) > 5 else ''})
- Coverage Rate: {(total_processed/total_expected_posts*100):.2f}%
Response Statistics:
- Posts with Responses: {posts_with_responses}
- Posts without Responses: {total_processed - posts_with_responses}
- Total Individual Responses: {total_responses}
Response Distribution:
- Median Responses per Post: {median_responses}
- Average Responses per Post: {avg_responses:.2f}
- Maximum Responses for a Post: {max_responses}
"""
return stats
# Build the Gradio app
with gr.Blocks() as demo:
gr.Markdown("# Image Response Collector")
# Source image selection at the top
with gr.Row():
with gr.Column():
post_id_dropdown = gr.Dropdown(
label="Select Post ID to Load Image",
choices=VALID_DATASET_POST_IDS,
type="value",
allow_custom_value=False,
)
instruction_text = gr.Textbox(label="Instruction", interactive=False)
simplified_instruction_text = gr.Textbox(
label="Simplified Instruction", interactive=False
)
source_image = gr.Image(label="Source Image", type="filepath", height=300)
# Responses in tabs
with gr.Tabs() as response_tabs:
responses = []
for i in range(10):
with gr.Tab(f"Response {i+1}"):
img = gr.Image(
label=f"Response Image {i+1}", type="filepath", height=300
)
txt = gr.Textbox(label=f"Model Name {i+1}", lines=2)
notes = gr.Textbox(label=f"Miscellaneous Notes {i+1}", lines=3)
responses.append((img, txt, notes))
with gr.Row():
submit_btn = gr.Button("Submit All Responses")
clear_btn = gr.Button("Clear Form")
# Add statistics accordion
with gr.Accordion("Collection Statistics", open=False):
stats_text = gr.Markdown("Loading statistics...")
refresh_stats_btn = gr.Button("Refresh Statistics")
def update_stats():
return get_statistics()
refresh_stats_btn.click(fn=update_stats, outputs=[stats_text])
# Move the load event inside the Blocks context
demo.load(
fn=get_statistics,
outputs=[stats_text],
)
def submit_responses(
source_img, post_id, instruction, simplified_instruction, *response_data
):
if not source_img:
gr.Warning("Please select a source image first!")
return
if not post_id:
gr.Warning("Please select a post ID first!")
return
# Convert flat response_data into triplets of (image, text, notes)
response_triplets = list(
zip(response_data[::3], response_data[1::3], response_data[2::3])
)
# Check for responses with images but no model names
incomplete_responses = [
i + 1
for i, (img, txt, _) in enumerate(response_triplets)
if img is not None and not txt.strip()
]
if incomplete_responses:
gr.Warning(
f"Please provide model names for responses: {', '.join(map(str, incomplete_responses))}!"
)
return
# Filter out empty responses (where both image and model name are empty)
valid_responses = [
(img, txt, notes)
for img, txt, notes in response_triplets
if img is not None and txt.strip()
]
if not valid_responses:
gr.Warning("Please provide at least one response (image + model name)!")
return
# Generate JSON files with the valid responses
generate_json_files(source_img, valid_responses, post_id)
gr.Info("Responses saved successfully! 🎉")
def clear_form():
outputs = [None] * (
1 + 2 + 30
) # source image + 2 instruction fields + 10 triplets
return outputs
# Connect components
post_id_dropdown.change(
fn=load_post_image,
inputs=[post_id_dropdown],
outputs=[source_image, instruction_text, simplified_instruction_text]
+ [comp for triplet in responses for comp in triplet],
)
submit_inputs = [
source_image,
post_id_dropdown,
instruction_text,
simplified_instruction_text,
] + [comp for triplet in responses for comp in triplet]
submit_btn.click(fn=submit_responses, inputs=submit_inputs)
clear_outputs = [source_image, instruction_text, simplified_instruction_text] + [
comp for triplet in responses for comp in triplet
]
clear_btn.click(fn=clear_form, outputs=clear_outputs)
def process_thread():
while True:
try:
pass
# process_and_push_dataset(
# "./data",
# FINAL_DATASET_REPO,
# token=os.environ["HF_TOKEN"],
# private=True,
# )
except Exception as e:
print(f"Error in process thread: {e}")
time.sleep(120) # Sleep for 2 minutes
if __name__ == "__main__":
print("Initializing app...")
sync_with_hub() # Sync before launching the app
print("Starting Gradio interface...")
# Start the processing thread when the app starts
processing_thread = threading.Thread(target=process_thread, daemon=True)
processing_thread.start()
demo.launch()