taesiri commited on
Commit
a167ff0
·
1 Parent(s): 8940d76
Files changed (2) hide show
  1. app.py +29 -16
  2. extract_samples.py +137 -0
app.py CHANGED
@@ -34,7 +34,6 @@ dataset_post_ids = list(
34
  photoexp = pd.read_csv("./photoexp_filtered.csv")
35
  valid_post_ids = set(photoexp.post_id.tolist())
36
 
37
- # filter RESULTS_BACKUP_REPO to include only valid_post_ids using batched processing
38
  dataset = dataset.filter(
39
  lambda xs: [x in valid_post_ids for x in xs["post_id"]],
40
  batched=True,
@@ -51,47 +50,61 @@ def sync_with_hub():
51
  """
52
  print("Starting sync with hub...")
53
  data_dir = Path("./data")
54
- if data_dir.exists():
55
- # Backup existing data
56
- backup_dir = Path("./data_backup")
57
- if backup_dir.exists():
58
- shutil.rmtree(backup_dir)
59
- shutil.copytree(data_dir, backup_dir)
 
60
 
61
  # Clone/pull latest data from hub
62
- # Use token in the URL for authentication following HF's new format
63
  token = os.environ["HF_TOKEN"]
64
- username = "taesiri" # Extract from DATASET_REPO
65
  repo_url = (
66
  f"https://{username}:{token}@huggingface.co/datasets/{RESULTS_BACKUP_REPO}"
67
  )
68
  hub_data_dir = Path("hub_data")
69
 
70
  if hub_data_dir.exists():
71
- # If repo exists, do a git pull
72
  print("Pulling latest changes...")
73
  repo = git.Repo(hub_data_dir)
74
  origin = repo.remotes.origin
75
- # Set the new URL with token
76
  if "https://" in origin.url:
77
  origin.set_url(repo_url)
78
  origin.pull()
79
  else:
80
- # Clone the repo with token
81
  print("Cloning repository...")
82
  git.Repo.clone_from(repo_url, hub_data_dir)
83
 
84
  # Merge hub data with local data
85
  hub_data_source = hub_data_dir / "data"
86
  if hub_data_source.exists():
87
- # Create data dir if it doesn't exist
88
  data_dir.mkdir(exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- # Copy files from hub
91
  for item in hub_data_source.glob("*"):
92
- if item.is_dir():
 
 
93
  dest = data_dir / item.name
94
- if not dest.exists(): # Only copy if doesn't exist locally
95
  shutil.copytree(item, dest)
96
 
97
  # Clean up cloned repo
 
34
  photoexp = pd.read_csv("./photoexp_filtered.csv")
35
  valid_post_ids = set(photoexp.post_id.tolist())
36
 
 
37
  dataset = dataset.filter(
38
  lambda xs: [x in valid_post_ids for x in xs["post_id"]],
39
  batched=True,
 
50
  """
51
  print("Starting sync with hub...")
52
  data_dir = Path("./data")
53
+ local_csv_path = data_dir / "evaluation_results_exp.csv"
54
+
55
+ # Read existing local data if it exists
56
+ local_data = None
57
+ if local_csv_path.exists():
58
+ local_data = pd.read_csv(local_csv_path)
59
+ print(f"Found local data with {len(local_data)} entries")
60
 
61
  # Clone/pull latest data from hub
 
62
  token = os.environ["HF_TOKEN"]
63
+ username = "taesiri"
64
  repo_url = (
65
  f"https://{username}:{token}@huggingface.co/datasets/{RESULTS_BACKUP_REPO}"
66
  )
67
  hub_data_dir = Path("hub_data")
68
 
69
  if hub_data_dir.exists():
 
70
  print("Pulling latest changes...")
71
  repo = git.Repo(hub_data_dir)
72
  origin = repo.remotes.origin
 
73
  if "https://" in origin.url:
74
  origin.set_url(repo_url)
75
  origin.pull()
76
  else:
 
77
  print("Cloning repository...")
78
  git.Repo.clone_from(repo_url, hub_data_dir)
79
 
80
  # Merge hub data with local data
81
  hub_data_source = hub_data_dir / "data"
82
  if hub_data_source.exists():
 
83
  data_dir.mkdir(exist_ok=True)
84
+ hub_csv_path = hub_data_source / "evaluation_results_exp.csv"
85
+
86
+ if hub_csv_path.exists():
87
+ hub_data = pd.read_csv(hub_csv_path)
88
+ print(f"Found hub data with {len(hub_data)} entries")
89
+
90
+ if local_data is not None:
91
+ # Merge data, keeping all entries and removing exact duplicates
92
+ merged_data = pd.concat([local_data, hub_data]).drop_duplicates()
93
+ print(f"Merged data has {len(merged_data)} entries")
94
+
95
+ # Save merged data
96
+ merged_data.to_csv(local_csv_path, index=False)
97
+ else:
98
+ # If no local data exists, just copy hub data
99
+ shutil.copy2(hub_csv_path, local_csv_path)
100
 
101
+ # Copy any other files from hub
102
  for item in hub_data_source.glob("*"):
103
+ if item.is_file() and item.name != "evaluation_results_exp.csv":
104
+ shutil.copy2(item, data_dir / item.name)
105
+ elif item.is_dir():
106
  dest = data_dir / item.name
107
+ if not dest.exists():
108
  shutil.copytree(item, dest)
109
 
110
  # Clean up cloned repo
extract_samples.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from datasets import load_dataset
3
+ import pandas as pd
4
+ import os
5
+ from pathlib import Path
6
+ import requests
7
+ from PIL import Image
8
+ from io import BytesIO
9
+
10
+ # Load the experimental dataset
11
+ dataset = load_dataset("taesiri/IERv2-BattleResults_exp", split="train")
12
+ dataset_post_ids = list(
13
+ set(
14
+ load_dataset(
15
+ "taesiri/IERv2-BattleResults_exp", columns=["post_id"], split="train"
16
+ )
17
+ .to_pandas()
18
+ .post_id.tolist()
19
+ )
20
+ )
21
+
22
+ # Load and filter photoexp dataset
23
+ photoexp = pd.read_csv("./photoexp_filtered.csv")
24
+ valid_post_ids = set(photoexp.post_id.tolist())
25
+
26
+ # Filter dataset to include only valid_post_ids
27
+ dataset = dataset.filter(
28
+ lambda xs: [x in valid_post_ids for x in xs["post_id"]],
29
+ batched=True,
30
+ batch_size=256,
31
+ )
32
+
33
+
34
+ def download_and_save_image(url, save_path):
35
+ """Download image from URL and save it to disk"""
36
+ try:
37
+ response = requests.get(url)
38
+ response.raise_for_status()
39
+ img = Image.open(BytesIO(response.content))
40
+ img.save(save_path)
41
+ return True
42
+ except Exception as e:
43
+ print(f"Error downloading image {url}: {e}")
44
+ return False
45
+
46
+
47
+ def get_random_sample():
48
+ """Get a random sample by first selecting a post_id then picking random edits for that post."""
49
+ # First randomly select a post_id from valid posts
50
+ random_post_id = random.choice(list(valid_post_ids))
51
+
52
+ # Filter dataset for this post_id
53
+ post_edits = dataset.filter(
54
+ lambda xs: [x == random_post_id for x in xs["post_id"]],
55
+ batched=True,
56
+ batch_size=256,
57
+ )
58
+
59
+ # Get matching photoexp entries for this post_id
60
+ matching_photoexp_entries = photoexp[photoexp.post_id == random_post_id]
61
+
62
+ # Randomly select one edit from the dataset
63
+ idx = random.randint(0, len(post_edits) - 1)
64
+ sample = post_edits[idx]
65
+
66
+ # Randomly select one entry from the matching photoexp entries
67
+ if not matching_photoexp_entries.empty:
68
+ random_photoexp_entry = matching_photoexp_entries.sample(n=1).iloc[0]
69
+ additional_edited_image = random_photoexp_entry["edited_image"]
70
+ model_b = random_photoexp_entry.get("model")
71
+ if model_b is None:
72
+ model_b = f"REDDIT_{random_photoexp_entry['comment_id']}"
73
+ else:
74
+ return None
75
+
76
+ return {
77
+ "post_id": sample["post_id"],
78
+ "instruction": sample["instruction"],
79
+ "simplified_instruction": sample["simplified_instruction"],
80
+ "source_image": sample["source_image"],
81
+ "edit1_image": sample["edited_image"],
82
+ "edit1_model": sample["model"],
83
+ "edit2_image": additional_edited_image,
84
+ "edit2_model": model_b,
85
+ }
86
+
87
+
88
+ def save_sample(sample, output_dir):
89
+ """Save a sample to disk with all its components"""
90
+ if sample is None:
91
+ return False
92
+
93
+ # Create directory structure
94
+ sample_dir = Path(output_dir) / str(sample["post_id"])
95
+ sample_dir.mkdir(parents=True, exist_ok=True)
96
+
97
+ # Save instruction and metadata
98
+ with open(sample_dir / "metadata.txt", "w") as f:
99
+ f.write(f"Post ID: {sample['post_id']}\n")
100
+ f.write(f"Original Instruction: {sample['instruction']}\n")
101
+ f.write(f"Simplified Instruction: {sample['simplified_instruction']}\n")
102
+ f.write(f"Edit 1 Model: {sample['edit1_model']}\n")
103
+ f.write(f"Edit 2 Model: {sample['edit2_model']}\n")
104
+
105
+ # Save images
106
+ success = True
107
+ success &= download_and_save_image(
108
+ sample["source_image"], sample_dir / "source.jpg"
109
+ )
110
+ success &= download_and_save_image(sample["edit1_image"], sample_dir / "edit1.jpg")
111
+ success &= download_and_save_image(sample["edit2_image"], sample_dir / "edit2.jpg")
112
+
113
+ return success
114
+
115
+
116
+ def main():
117
+ output_dir = Path("extracted_samples")
118
+ output_dir.mkdir(exist_ok=True)
119
+
120
+ num_samples = 100 # Number of samples to extract
121
+ successful_samples = 0
122
+
123
+ print(f"Extracting {num_samples} samples...")
124
+
125
+ while successful_samples < num_samples:
126
+ sample = get_random_sample()
127
+ if sample and save_sample(sample, output_dir):
128
+ successful_samples += 1
129
+ print(f"Successfully saved sample {successful_samples}/{num_samples}")
130
+ else:
131
+ print("Failed to save sample, trying next...")
132
+
133
+ print(f"Successfully extracted {successful_samples} samples to {output_dir}")
134
+
135
+
136
+ if __name__ == "__main__":
137
+ main()