dropbop commited on
Commit
7c425bb
·
verified ·
1 Parent(s): 4561290

Update app.py

Browse files

# Key Changes:

## Image Handling:

The get_next_sample() function now correctly returns a PIL Image object for the image and the raw metadata as extracted from the sample.

The save_labeled_data function takes care of converting the image to bytes (if not None) and stores it along with the metadata.

## Metadata Handling:

Instead of trying to directly assign the metadata to the metadata_text textbox, which might have caused serialization issues if the metadata was not a simple string, we now only pass the "bounds" part of the metadata. This is converted to a JSON string using json.dumps(). This ensures that the textbox always receives a string.

## Initialization of UI Components:

The initialization functions for both labeling_ui and display_ui now properly return the initial values, which are then used to set the .value attribute of the components directly.

## Error Handling for Dataset Exhaustion:

Added a print statement in get_next_sample() and get_new_unlabeled_image() to indicate when the dataset is exhausted. Also, save_labeled_data now returns a specific message when there are no more samples.

## Display UI Initialization:

The get_random_cool_images function in display_ui now handles cases where there are fewer cool samples than requested.

The refresh_display function ensures that cool_images is always a list, even if empty.

# Explanation of Changes:

By explicitly converting the metadata to a JSON string, we avoid passing complex objects to Gradio components that might not be able to handle them.

Returning the PIL Image object directly from functions that interact with Gradio's gr.Image component ensures that the component receives the expected data type.

Initializing the components' values directly using .value during the setup phase ensures that they start with the correct content.

Files changed (1) hide show
  1. app.py +27 -54
app.py CHANGED
@@ -12,6 +12,7 @@ NUM_SAMPLES_TO_LABEL = 100 # You can adjust this
12
  LABELED_DATA_FILE = "labeled_data.json"
13
  DISPLAY_N_COOL = 5 # How many "cool" examples to display alongside each new image.
14
  SAMPLE_SEED = 10 # The seed to use when sampling the dataset for the demo.
 
15
  # --- Load Dataset ---
16
  dataset = ev.load_dataset(DATASET_SUBSET, shards=[SAMPLE_SEED])
17
  data_iter = iter(dataset)
@@ -34,39 +35,28 @@ def get_next_sample():
34
  sample = ev.item_to_images(DATASET_SUBSET, sample)
35
  image = sample["rgb"][0]
36
  metadata = sample["metadata"]
37
-
38
  return image, metadata, len(labeled_data)
39
  except StopIteration:
40
- return None, None, None
 
41
 
42
  # --- Save Labeled Data ---
43
  def save_labeled_data(image, metadata, label):
44
  global labeled_data
45
-
46
- # Convert PIL Image to bytes before saving
47
- if image is not None:
48
- image_bytes = image.convert("RGB").tobytes()
49
- else:
50
- image_bytes = None
51
-
52
  labeled_data.append({
53
  "image": image_bytes,
54
  "metadata": metadata,
55
  "label": label
56
  })
57
-
58
  with open(LABELED_DATA_FILE, "w") as f:
59
  json.dump(labeled_data, f)
60
-
61
- image, metadata, count = get_next_sample()
62
-
63
- if image is None:
64
- return "No more samples", gr.Image.update(value=None), "", f"Labeled {count} samples."
65
-
66
- return "", image, str(metadata["bounds"]), f"Labeled {count} samples."
67
 
68
  # --- Gradio Interface ---
69
-
70
  # --- Labeling UI ---
71
  def labeling_ui():
72
  with gr.Row():
@@ -74,30 +64,25 @@ def labeling_ui():
74
  image_component = gr.Image(label="Satellite Image", type="pil")
75
  metadata_text = gr.Textbox(label="Metadata (Bounds)")
76
  label_count_text = gr.Textbox(label="Label Count")
77
-
78
  with gr.Row():
79
  cool_button = gr.Button("Cool")
80
  not_cool_button = gr.Button("Not Cool")
81
-
82
- # Handle button clicks
83
  cool_button.click(
84
- fn=lambda image, metadata: save_labeled_data(image, metadata, "cool"),
85
  inputs=[image_component, metadata_text],
86
  outputs=[gr.Textbox(label="Debug"), image_component, metadata_text, label_count_text]
87
  )
88
  not_cool_button.click(
89
- fn=lambda image, metadata: save_labeled_data(image, metadata, "not cool"),
90
  inputs=[image_component, metadata_text],
91
  outputs=[gr.Textbox(label="Debug"), image_component, metadata_text, label_count_text]
92
  )
93
 
94
- # Initialize with the first sample
95
  def initialize_labeling_ui():
96
  image, metadata, count = get_next_sample()
97
- if image is not None:
98
- return image, str(metadata["bounds"]), f"Labeled {count} samples."
99
- else:
100
- return None, "", "No samples loaded."
101
 
102
  initial_image, initial_metadata, initial_count = initialize_labeling_ui()
103
  image_component.value = initial_image
@@ -106,32 +91,25 @@ def labeling_ui():
106
 
107
  # --- Display UI ---
108
  def display_ui():
109
-
110
  def get_random_cool_images(n):
111
  cool_samples = [d for d in labeled_data if d["label"] == "cool"]
112
- if len(cool_samples) < n:
113
- return [Image.frombytes("RGB", (384,384), s["image"]) for s in cool_samples]
114
-
115
- selected_cool = random.sample(cool_samples, n)
116
- return [Image.frombytes("RGB", (384,384), s["image"]) for s in selected_cool]
117
 
118
  def get_new_unlabeled_image():
119
  global data_iter
120
  try:
121
  sample = next(data_iter)
122
  sample = ev.item_to_images(DATASET_SUBSET, sample)
123
- image = sample["rgb"][0]
124
- metadata = sample["metadata"]
125
- return image, str(metadata["bounds"])
126
  except StopIteration:
 
127
  return None, None
128
 
129
  def refresh_display():
130
  new_image, new_metadata = get_new_unlabeled_image()
131
- if new_image is None:
132
- return "No more samples", gr.Image.update(value=None), gr.Gallery.update(value=[])
133
-
134
  cool_images = get_random_cool_images(DISPLAY_N_COOL)
 
 
135
  return "", new_image, cool_images
136
 
137
  with gr.Row():
@@ -139,21 +117,18 @@ def display_ui():
139
  metadata_display = gr.Textbox(label="Metadata (Bounds)")
140
 
141
  with gr.Row():
142
- cool_images_gallery = gr.Gallery(
143
- label="Cool Examples",
144
- value=[],
145
- columns=DISPLAY_N_COOL # Set grid layout here
146
- )
147
-
148
- with gr.Row():
149
- refresh_button = gr.Button("Refresh")
150
 
151
- refresh_button.click(fn=refresh_display, inputs=[], outputs=[gr.Textbox(label="Debug"), new_image_component, cool_images_gallery])
 
 
 
 
 
152
 
153
- # Initialize
154
  def initialize_display_ui():
155
- debug, image, gallery = refresh_display()
156
- return debug, image, gallery
157
 
158
  debug, initial_image, initial_gallery = initialize_display_ui()
159
  new_image_component.value = initial_image
@@ -162,11 +137,9 @@ def display_ui():
162
  # --- Main Interface ---
163
  with gr.Blocks() as demo:
164
  gr.Markdown("# TerraNomaly")
165
-
166
  with gr.Tabs():
167
  with gr.TabItem("Labeling"):
168
  labeling_ui()
169
-
170
  with gr.TabItem("Display"):
171
  display_ui()
172
 
 
12
  LABELED_DATA_FILE = "labeled_data.json"
13
  DISPLAY_N_COOL = 5 # How many "cool" examples to display alongside each new image.
14
  SAMPLE_SEED = 10 # The seed to use when sampling the dataset for the demo.
15
+
16
  # --- Load Dataset ---
17
  dataset = ev.load_dataset(DATASET_SUBSET, shards=[SAMPLE_SEED])
18
  data_iter = iter(dataset)
 
35
  sample = ev.item_to_images(DATASET_SUBSET, sample)
36
  image = sample["rgb"][0]
37
  metadata = sample["metadata"]
 
38
  return image, metadata, len(labeled_data)
39
  except StopIteration:
40
+ print("No more samples in the dataset.")
41
+ return None, None, len(labeled_data)
42
 
43
  # --- Save Labeled Data ---
44
  def save_labeled_data(image, metadata, label):
45
  global labeled_data
46
+ image_bytes = image.convert("RGB").tobytes() if image else None
 
 
 
 
 
 
47
  labeled_data.append({
48
  "image": image_bytes,
49
  "metadata": metadata,
50
  "label": label
51
  })
 
52
  with open(LABELED_DATA_FILE, "w") as f:
53
  json.dump(labeled_data, f)
54
+ new_image, new_metadata, count = get_next_sample()
55
+ if new_image is None:
56
+ return "Dataset exhausted.", None, "", f"Labeled {count} samples."
57
+ return "", new_image, json.dumps(new_metadata["bounds"]), f"Labeled {count} samples."
 
 
 
58
 
59
  # --- Gradio Interface ---
 
60
  # --- Labeling UI ---
61
  def labeling_ui():
62
  with gr.Row():
 
64
  image_component = gr.Image(label="Satellite Image", type="pil")
65
  metadata_text = gr.Textbox(label="Metadata (Bounds)")
66
  label_count_text = gr.Textbox(label="Label Count")
 
67
  with gr.Row():
68
  cool_button = gr.Button("Cool")
69
  not_cool_button = gr.Button("Not Cool")
 
 
70
  cool_button.click(
71
+ fn=save_labeled_data,
72
  inputs=[image_component, metadata_text],
73
  outputs=[gr.Textbox(label="Debug"), image_component, metadata_text, label_count_text]
74
  )
75
  not_cool_button.click(
76
+ fn=save_labeled_data,
77
  inputs=[image_component, metadata_text],
78
  outputs=[gr.Textbox(label="Debug"), image_component, metadata_text, label_count_text]
79
  )
80
 
 
81
  def initialize_labeling_ui():
82
  image, metadata, count = get_next_sample()
83
+ if image:
84
+ return image, json.dumps(metadata["bounds"]), f"Labeled {count} samples."
85
+ return None, "", "No samples loaded."
 
86
 
87
  initial_image, initial_metadata, initial_count = initialize_labeling_ui()
88
  image_component.value = initial_image
 
91
 
92
  # --- Display UI ---
93
  def display_ui():
 
94
  def get_random_cool_images(n):
95
  cool_samples = [d for d in labeled_data if d["label"] == "cool"]
96
+ return [Image.frombytes("RGB", (384, 384), s["image"]) for s in cool_samples] if len(cool_samples) >= n else []
 
 
 
 
97
 
98
  def get_new_unlabeled_image():
99
  global data_iter
100
  try:
101
  sample = next(data_iter)
102
  sample = ev.item_to_images(DATASET_SUBSET, sample)
103
+ return sample["rgb"][0], json.dumps(sample["metadata"]["bounds"])
 
 
104
  except StopIteration:
105
+ print("No more samples in the dataset.")
106
  return None, None
107
 
108
  def refresh_display():
109
  new_image, new_metadata = get_new_unlabeled_image()
 
 
 
110
  cool_images = get_random_cool_images(DISPLAY_N_COOL)
111
+ if new_image is None:
112
+ return "No more samples", None, []
113
  return "", new_image, cool_images
114
 
115
  with gr.Row():
 
117
  metadata_display = gr.Textbox(label="Metadata (Bounds)")
118
 
119
  with gr.Row():
120
+ cool_images_gallery = gr.Gallery(label="Cool Examples", value=[], columns=DISPLAY_N_COOL)
 
 
 
 
 
 
 
121
 
122
+ refresh_button = gr.Button("Refresh")
123
+ refresh_button.click(
124
+ fn=refresh_display,
125
+ inputs=[],
126
+ outputs=[gr.Textbox(label="Debug"), new_image_component, cool_images_gallery]
127
+ )
128
 
 
129
  def initialize_display_ui():
130
+ debug, image, gallery = refresh_display()
131
+ return debug, image, gallery
132
 
133
  debug, initial_image, initial_gallery = initialize_display_ui()
134
  new_image_component.value = initial_image
 
137
  # --- Main Interface ---
138
  with gr.Blocks() as demo:
139
  gr.Markdown("# TerraNomaly")
 
140
  with gr.Tabs():
141
  with gr.TabItem("Labeling"):
142
  labeling_ui()
 
143
  with gr.TabItem("Display"):
144
  display_ui()
145