Zack3D commited on
Commit
2841bef
·
verified ·
1 Parent(s): bc30d26

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -152
app.py CHANGED
@@ -9,6 +9,7 @@ import numpy as np
9
  from PIL import Image
10
  import openai
11
 
 
12
  MODEL = "gpt-image-1"
13
  SIZE_CHOICES = ["auto", "1024x1024", "1536x1024", "1024x1536"]
14
  QUALITY_CHOICES = ["auto", "low", "medium", "high"]
@@ -46,37 +47,24 @@ def _common_kwargs(
46
  kwargs: Dict[str, Any] = dict(
47
  model=MODEL,
48
  n=n,
49
- # REMOVED: response_format="b64_json", # This parameter caused the BadRequestError
50
  )
51
-
52
- # Use API defaults if 'auto' is selected
53
  if size != "auto":
54
  kwargs["size"] = size
55
  if quality != "auto":
56
  kwargs["quality"] = quality
57
-
58
- # Prompt is optional for variations
59
  if prompt is not None:
60
  kwargs["prompt"] = prompt
61
-
62
- # Output format specific settings (API default is png)
63
  if out_fmt != "png":
64
  kwargs["output_format"] = out_fmt
65
-
66
- # Transparency via background parameter (png & webp only)
67
  if transparent_bg and out_fmt in {"png", "webp"}:
68
  kwargs["background"] = "transparent"
69
-
70
- # Compression for lossy formats (API expects integer 0-100)
71
  if out_fmt in {"jpeg", "webp"}:
72
- # Ensure compression is an integer as expected by the API
73
  kwargs["output_compression"] = int(compression)
74
-
75
  return kwargs
76
 
 
77
 
78
  # ---------- Generate ---------- #
79
-
80
  def generate(
81
  api_key: str,
82
  prompt: str,
@@ -90,7 +78,7 @@ def generate(
90
  """Calls the OpenAI image generation endpoint."""
91
  if not prompt:
92
  raise gr.Error("Please enter a prompt.")
93
- client = _client(api_key)
94
  try:
95
  common_args = _common_kwargs(prompt, n, size, quality, out_fmt, compression, transparent_bg)
96
  resp = client.images.generate(**common_args)
@@ -101,18 +89,16 @@ def generate(
101
  except openai.RateLimitError:
102
  raise gr.Error("Rate limit exceeded. Please try again later.")
103
  except openai.BadRequestError as e:
104
- # Extract the specific error message if possible
105
  error_message = str(e)
106
  try:
107
- # Attempt to parse the error body if it's JSON-like
108
  import json
109
- body = json.loads(str(e.body)) # e.body might be bytes
110
  if isinstance(body, dict) and 'error' in body and 'message' in body['error']:
111
  error_message = f"OpenAI Bad Request: {body['error']['message']}"
112
  else:
113
  error_message = f"OpenAI Bad Request: {e}"
114
  except:
115
- error_message = f"OpenAI Bad Request: {e}" # Fallback
116
  raise gr.Error(error_message)
117
  except Exception as e:
118
  raise gr.Error(f"An unexpected error occurred: {e}")
@@ -120,7 +106,6 @@ def generate(
120
 
121
 
122
  # ---------- Edit / Inpaint ---------- #
123
-
124
  def _bytes_from_numpy(arr: np.ndarray) -> bytes:
125
  """Convert RGBA/RGB uint8 numpy array to PNG bytes."""
126
  img = Image.fromarray(arr.astype(np.uint8))
@@ -128,43 +113,21 @@ def _bytes_from_numpy(arr: np.ndarray) -> bytes:
128
  img.save(out, format="PNG")
129
  return out.getvalue()
130
 
131
-
132
  def _extract_mask_array(mask_value: Union[np.ndarray, Dict[str, Any], None]) -> Optional[np.ndarray]:
133
  """Handle ImageMask / ImageEditor return formats and extract a numpy mask array."""
134
- if mask_value is None:
135
- return None
136
-
137
- # If we already have a numpy array (ImageMask with type="numpy")
138
- if isinstance(mask_value, np.ndarray):
139
- mask_arr = mask_value
140
- # If it's an EditorValue dict coming from ImageEditor/ImageMask with type="file" or "pil"
141
- elif isinstance(mask_value, dict):
142
- # Prefer the composite (all layers merged) if present
143
  comp = mask_value.get("composite")
144
  if comp is not None and isinstance(comp, (Image.Image, np.ndarray)):
145
- mask_arr = np.array(comp) if isinstance(comp, Image.Image) else comp
146
- # Fallback to the mask if present (often from ImageMask)
147
  elif mask_value.get("mask") is not None and isinstance(mask_value["mask"], (Image.Image, np.ndarray)):
148
- mask_arr = np.array(mask_value["mask"]) if isinstance(mask_value["mask"], Image.Image) else mask_value["mask"]
149
- # Fallback to the topmost layer
150
  elif mask_value.get("layers"):
151
  top_layer = mask_value["layers"][-1]
152
  if isinstance(top_layer, (Image.Image, np.ndarray)):
153
- mask_arr = np.array(top_layer) if isinstance(top_layer, Image.Image) else top_layer
154
- else:
155
- return None # Cannot process layer format
156
- else:
157
- return None # No usable image data found in dict
158
- else:
159
- # Unknown format – ignore
160
- return None
161
-
162
- # Ensure mask_arr is a numpy array now
163
- if not isinstance(mask_arr, np.ndarray):
164
- return None # Should not happen after above checks, but safeguard
165
-
166
- return mask_arr
167
-
168
 
169
  def edit_image(
170
  api_key: str,
@@ -179,99 +142,52 @@ def edit_image(
179
  transparent_bg: bool,
180
  ):
181
  """Calls the OpenAI image edit endpoint."""
182
- if image_numpy is None:
183
- raise gr.Error("Please upload an image.")
184
- if not prompt:
185
- raise gr.Error("Please enter an edit prompt.")
186
 
187
  img_bytes = _bytes_from_numpy(image_numpy)
188
-
189
  mask_bytes: Optional[bytes] = None
190
  mask_numpy = _extract_mask_array(mask_value)
191
 
192
  if mask_numpy is not None:
193
- # Check if the mask seems empty (all black or fully transparent)
194
  is_empty = False
195
- if mask_numpy.ndim == 2: # Grayscale
196
- is_empty = np.all(mask_numpy == 0)
197
- elif mask_numpy.shape[-1] == 4: # RGBA
198
- is_empty = np.all(mask_numpy[:, :, 3] == 0)
199
- elif mask_numpy.shape[-1] == 3: # RGB
200
- is_empty = np.all(mask_numpy == 0)
201
 
202
  if is_empty:
203
- gr.Warning("The provided mask appears empty (all black/transparent). The API might edit the entire image or ignore the mask.")
204
- # Pass None if the mask is effectively empty, as per API docs (transparent areas are edited)
205
  mask_bytes = None
206
  else:
207
- # Convert the mask to the format required by the API:
208
- # A PNG image where TRANSPARENT areas indicate where the image should be edited.
209
- # Our Gradio mask uses WHITE to indicate the edit area.
210
- # So, we need to create an alpha channel where white pixels in the input mask become transparent (0),
211
- # and black/other pixels become opaque (255).
212
-
213
- if mask_numpy.ndim == 2: # Grayscale input mask
214
- # Assume white (255) means edit -> make transparent (0 alpha)
215
- # Assume black (0) means keep -> make opaque (255 alpha)
216
- alpha = (mask_numpy == 0).astype(np.uint8) * 255
217
- elif mask_numpy.shape[-1] == 4: # RGBA input mask (from gr.ImageMask)
218
- # Use the alpha channel directly if it exists and seems meaningful,
219
- # otherwise, treat non-black RGB as edit area.
220
- # gr.ImageMask often returns RGBA where painted area is white [255,255,255,255] and background is [0,0,0,0]
221
- # We want the painted (white) area to be transparent in the final mask.
222
- # We want the unpainted (transparent black) area to be opaque in the final mask.
223
- alpha = (mask_numpy[:, :, 3] == 0).astype(np.uint8) * 255
224
- elif mask_numpy.shape[-1] == 3: # RGB input mask
225
- # Assume white [255, 255, 255] means edit -> make transparent (0 alpha)
226
- # Assume black [0, 0, 0] or other colors mean keep -> make opaque (255 alpha)
227
  is_white = np.all(mask_numpy == [255, 255, 255], axis=-1)
228
  alpha = (~is_white).astype(np.uint8) * 255
229
- else:
230
- raise gr.Error("Unsupported mask format.")
231
 
232
- # Create a single-channel L mode image (grayscale/alpha) for the mask
233
  mask_img = Image.fromarray(alpha, mode='L')
234
-
235
- # The API expects an RGBA PNG where the alpha channel defines the mask.
236
- # Create a black image with the calculated alpha channel.
237
  rgba_mask = Image.new("RGBA", mask_img.size, (0, 0, 0, 0))
238
- black_opaque = Image.new("L", mask_img.size, 0) # Black base
239
- rgba_mask.putalpha(mask_img) # Use the calculated alpha
240
-
241
  out = io.BytesIO()
242
  rgba_mask.save(out, format="PNG")
243
  mask_bytes = out.getvalue()
244
-
245
- # Debug: Save mask locally to check
246
- # rgba_mask.save("debug_mask_sent_to_api.png")
247
-
248
  else:
249
- gr.Info("No mask provided. The API will attempt to edit the image based on the prompt without a specific mask.")
250
- mask_bytes = None # Explicitly pass None if no mask is usable
251
 
252
- client = _client(api_key)
253
  try:
254
  common_args = _common_kwargs(prompt, n, size, quality, out_fmt, compression, transparent_bg)
255
- # The edit endpoint requires the prompt
256
- if "prompt" not in common_args:
257
- common_args["prompt"] = prompt # Should always be there via _common_kwargs, but safeguard
258
-
259
- # Ensure image and mask are passed correctly
260
- api_kwargs = {
261
- "image": img_bytes,
262
- **common_args
263
- }
264
- if mask_bytes is not None:
265
- api_kwargs["mask"] = mask_bytes
266
-
267
  resp = client.images.edit(**api_kwargs)
268
-
269
  except openai.AuthenticationError:
270
  raise gr.Error("Invalid OpenAI API key.")
271
  except openai.PermissionDeniedError:
272
- raise gr.Error("Permission denied. Check your API key permissions or complete required verification for gpt-image-1.")
273
  except openai.RateLimitError:
274
- raise gr.Error("Rate limit exceeded. Please try again later.")
275
  except openai.BadRequestError as e:
276
  error_message = str(e)
277
  try:
@@ -279,15 +195,10 @@ def edit_image(
279
  body = json.loads(str(e.body))
280
  if isinstance(body, dict) and 'error' in body and 'message' in body['error']:
281
  error_message = f"OpenAI Bad Request: {body['error']['message']}"
282
- # Add specific advice based on common mask errors
283
- if "mask" in error_message.lower():
284
- error_message += " (Ensure mask is a valid PNG with an alpha channel and matches the image dimensions.)"
285
- elif "size" in error_message.lower():
286
- error_message += " (Ensure image and mask dimensions match and are supported.)"
287
- else:
288
- error_message = f"OpenAI Bad Request: {e}"
289
- except:
290
- error_message = f"OpenAI Bad Request: {e}" # Fallback
291
  raise gr.Error(error_message)
292
  except Exception as e:
293
  raise gr.Error(f"An unexpected error occurred: {e}")
@@ -295,7 +206,6 @@ def edit_image(
295
 
296
 
297
  # ---------- Variations ---------- #
298
-
299
  def variation_image(
300
  api_key: str,
301
  image_numpy: np.ndarray,
@@ -307,27 +217,19 @@ def variation_image(
307
  transparent_bg: bool,
308
  ):
309
  """Calls the OpenAI image variations endpoint."""
310
- # NOTE: Variations are only supported for DALL-E 2 according to docs.
311
- # This might fail with gpt-image-1. Consider adding a check or using DALL-E 2.
312
- gr.Warning("Note: Image variations are officially supported for DALL·E 2, not gpt-image-1. This may not work as expected.")
313
-
314
- if image_numpy is None:
315
- raise gr.Error("Please upload an image.")
316
  img_bytes = _bytes_from_numpy(image_numpy)
317
- client = _client(api_key)
318
  try:
319
- # Prompt is None for variations
320
  common_args = _common_kwargs(None, n, size, quality, out_fmt, compression, transparent_bg)
321
- resp = client.images.variations(
322
- image=img_bytes,
323
- **common_args,
324
- )
325
  except openai.AuthenticationError:
326
  raise gr.Error("Invalid OpenAI API key.")
327
  except openai.PermissionDeniedError:
328
- raise gr.Error("Permission denied. Check your API key permissions.")
329
  except openai.RateLimitError:
330
- raise gr.Error("Rate limit exceeded. Please try again later.")
331
  except openai.BadRequestError as e:
332
  error_message = str(e)
333
  try:
@@ -336,11 +238,9 @@ def variation_image(
336
  if isinstance(body, dict) and 'error' in body and 'message' in body['error']:
337
  error_message = f"OpenAI Bad Request: {body['error']['message']}"
338
  if "model does not support variations" in error_message.lower():
339
- error_message += " (gpt-image-1 does not support variations, use DALL·E 2 instead)."
340
- else:
341
- error_message = f"OpenAI Bad Request: {e}"
342
- except:
343
- error_message = f"OpenAI Bad Request: {e}" # Fallback
344
  raise gr.Error(error_message)
345
  except Exception as e:
346
  raise gr.Error(f"An unexpected error occurred: {e}")
@@ -359,11 +259,12 @@ def build_ui():
359
  )
360
 
361
  with gr.Accordion("🔐 API key", open=False):
 
362
  api = gr.Textbox(label="OpenAI API key", type="password", placeholder="sk-…")
363
 
364
  # Common controls
365
  with gr.Row():
366
- n_slider = gr.Slider(1, 4, value=1, step=1, label="Number of images (n)", info="Max 4 for this demo.") # Limit n for stability/cost
367
  size = gr.Dropdown(SIZE_CHOICES, value="auto", label="Size", info="API default if 'auto'.")
368
  quality = gr.Dropdown(QUALITY_CHOICES, value="auto", label="Quality", info="API default if 'auto'.")
369
  with gr.Row():
@@ -376,7 +277,8 @@ def build_ui():
376
 
377
  out_fmt.change(_toggle_compression, inputs=out_fmt, outputs=compression)
378
 
379
- common_inputs = [api, n_slider, size, quality, out_fmt, compression, transparent]
 
380
 
381
  with gr.Tabs():
382
  # ----- Generate Tab ----- #
@@ -385,9 +287,11 @@ def build_ui():
385
  prompt_gen = gr.Textbox(label="Prompt", lines=3, placeholder="A photorealistic ginger cat astronaut on Mars", scale=4)
386
  btn_gen = gr.Button("Generate 🚀", variant="primary", scale=1)
387
  gallery_gen = gr.Gallery(label="Generated Images", columns=2, height="auto", preview=True)
 
 
388
  btn_gen.click(
389
  generate,
390
- inputs=[prompt_gen] + common_inputs, # Prepend specific inputs
391
  outputs=gallery_gen,
392
  api_name="generate"
393
  )
@@ -397,19 +301,20 @@ def build_ui():
397
  gr.Markdown("Upload an image, then **paint the area to change** in the mask canvas below (white = edit area). The API requires the mask and image to have the same dimensions.")
398
  with gr.Row():
399
  img_edit = gr.Image(label="Source Image", type="numpy", height=400)
400
- # Use ImageMask component for interactive painting
401
  mask_canvas = gr.ImageMask(
402
  label="Mask – Paint White Where Image Should Change",
403
- type="numpy", # Get mask as numpy array
404
  height=400
405
  )
406
  with gr.Row():
407
  prompt_edit = gr.Textbox(label="Edit prompt", lines=2, placeholder="Replace the sky with a starry night", scale=4)
408
  btn_edit = gr.Button("Edit 🖌️", variant="primary", scale=1)
409
  gallery_edit = gr.Gallery(label="Edited Images", columns=2, height="auto", preview=True)
 
 
410
  btn_edit.click(
411
  edit_image,
412
- inputs=[img_edit, mask_canvas, prompt_edit] + common_inputs, # Prepend specific inputs
413
  outputs=gallery_edit,
414
  api_name="edit"
415
  )
@@ -421,9 +326,11 @@ def build_ui():
421
  img_var = gr.Image(label="Source Image", type="numpy", height=400, scale=4)
422
  btn_var = gr.Button("Create Variations ✨", variant="primary", scale=1)
423
  gallery_var = gr.Gallery(label="Variations", columns=2, height="auto", preview=True)
 
 
424
  btn_var.click(
425
  variation_image,
426
- inputs=[img_var] + common_inputs, # Prepend specific inputs
427
  outputs=gallery_var,
428
  api_name="variations"
429
  )
@@ -432,6 +339,4 @@ def build_ui():
432
 
433
  if __name__ == "__main__":
434
  app = build_ui()
435
- # Set share=True to create a public link (useful for Spaces)
436
- # Set debug=True for more detailed logs in the console
437
  app.launch(share=os.getenv("GRADIO_SHARE") == "true", debug=True)
 
9
  from PIL import Image
10
  import openai
11
 
12
+ # --- Constants and Helper Functions (Keep as before) ---
13
  MODEL = "gpt-image-1"
14
  SIZE_CHOICES = ["auto", "1024x1024", "1536x1024", "1024x1536"]
15
  QUALITY_CHOICES = ["auto", "low", "medium", "high"]
 
47
  kwargs: Dict[str, Any] = dict(
48
  model=MODEL,
49
  n=n,
 
50
  )
 
 
51
  if size != "auto":
52
  kwargs["size"] = size
53
  if quality != "auto":
54
  kwargs["quality"] = quality
 
 
55
  if prompt is not None:
56
  kwargs["prompt"] = prompt
 
 
57
  if out_fmt != "png":
58
  kwargs["output_format"] = out_fmt
 
 
59
  if transparent_bg and out_fmt in {"png", "webp"}:
60
  kwargs["background"] = "transparent"
 
 
61
  if out_fmt in {"jpeg", "webp"}:
 
62
  kwargs["output_compression"] = int(compression)
 
63
  return kwargs
64
 
65
+ # --- API Call Functions (Keep as corrected before) ---
66
 
67
  # ---------- Generate ---------- #
 
68
  def generate(
69
  api_key: str,
70
  prompt: str,
 
78
  """Calls the OpenAI image generation endpoint."""
79
  if not prompt:
80
  raise gr.Error("Please enter a prompt.")
81
+ client = _client(api_key) # API key used here
82
  try:
83
  common_args = _common_kwargs(prompt, n, size, quality, out_fmt, compression, transparent_bg)
84
  resp = client.images.generate(**common_args)
 
89
  except openai.RateLimitError:
90
  raise gr.Error("Rate limit exceeded. Please try again later.")
91
  except openai.BadRequestError as e:
 
92
  error_message = str(e)
93
  try:
 
94
  import json
95
+ body = json.loads(str(e.body))
96
  if isinstance(body, dict) and 'error' in body and 'message' in body['error']:
97
  error_message = f"OpenAI Bad Request: {body['error']['message']}"
98
  else:
99
  error_message = f"OpenAI Bad Request: {e}"
100
  except:
101
+ error_message = f"OpenAI Bad Request: {e}"
102
  raise gr.Error(error_message)
103
  except Exception as e:
104
  raise gr.Error(f"An unexpected error occurred: {e}")
 
106
 
107
 
108
  # ---------- Edit / Inpaint ---------- #
 
109
  def _bytes_from_numpy(arr: np.ndarray) -> bytes:
110
  """Convert RGBA/RGB uint8 numpy array to PNG bytes."""
111
  img = Image.fromarray(arr.astype(np.uint8))
 
113
  img.save(out, format="PNG")
114
  return out.getvalue()
115
 
 
116
  def _extract_mask_array(mask_value: Union[np.ndarray, Dict[str, Any], None]) -> Optional[np.ndarray]:
117
  """Handle ImageMask / ImageEditor return formats and extract a numpy mask array."""
118
+ if mask_value is None: return None
119
+ if isinstance(mask_value, np.ndarray): return mask_value
120
+ if isinstance(mask_value, dict):
 
 
 
 
 
 
121
  comp = mask_value.get("composite")
122
  if comp is not None and isinstance(comp, (Image.Image, np.ndarray)):
123
+ return np.array(comp) if isinstance(comp, Image.Image) else comp
 
124
  elif mask_value.get("mask") is not None and isinstance(mask_value["mask"], (Image.Image, np.ndarray)):
125
+ return np.array(mask_value["mask"]) if isinstance(mask_value["mask"], Image.Image) else mask_value["mask"]
 
126
  elif mask_value.get("layers"):
127
  top_layer = mask_value["layers"][-1]
128
  if isinstance(top_layer, (Image.Image, np.ndarray)):
129
+ return np.array(top_layer) if isinstance(top_layer, Image.Image) else top_layer
130
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  def edit_image(
133
  api_key: str,
 
142
  transparent_bg: bool,
143
  ):
144
  """Calls the OpenAI image edit endpoint."""
145
+ if image_numpy is None: raise gr.Error("Please upload an image.")
146
+ if not prompt: raise gr.Error("Please enter an edit prompt.")
 
 
147
 
148
  img_bytes = _bytes_from_numpy(image_numpy)
 
149
  mask_bytes: Optional[bytes] = None
150
  mask_numpy = _extract_mask_array(mask_value)
151
 
152
  if mask_numpy is not None:
 
153
  is_empty = False
154
+ if mask_numpy.ndim == 2: is_empty = np.all(mask_numpy == 0)
155
+ elif mask_numpy.shape[-1] == 4: is_empty = np.all(mask_numpy[:, :, 3] == 0)
156
+ elif mask_numpy.shape[-1] == 3: is_empty = np.all(mask_numpy == 0)
 
 
 
157
 
158
  if is_empty:
159
+ gr.Warning("Mask appears empty. API might edit entire image or ignore mask.")
 
160
  mask_bytes = None
161
  else:
162
+ if mask_numpy.ndim == 2: alpha = (mask_numpy == 0).astype(np.uint8) * 255
163
+ elif mask_numpy.shape[-1] == 4: alpha = (mask_numpy[:, :, 3] == 0).astype(np.uint8) * 255
164
+ elif mask_numpy.shape[-1] == 3:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  is_white = np.all(mask_numpy == [255, 255, 255], axis=-1)
166
  alpha = (~is_white).astype(np.uint8) * 255
167
+ else: raise gr.Error("Unsupported mask format.")
 
168
 
 
169
  mask_img = Image.fromarray(alpha, mode='L')
 
 
 
170
  rgba_mask = Image.new("RGBA", mask_img.size, (0, 0, 0, 0))
171
+ rgba_mask.putalpha(mask_img)
 
 
172
  out = io.BytesIO()
173
  rgba_mask.save(out, format="PNG")
174
  mask_bytes = out.getvalue()
 
 
 
 
175
  else:
176
+ gr.Info("No mask provided. Editing without specific mask.")
177
+ mask_bytes = None
178
 
179
+ client = _client(api_key) # API key used here
180
  try:
181
  common_args = _common_kwargs(prompt, n, size, quality, out_fmt, compression, transparent_bg)
182
+ api_kwargs = {"image": img_bytes, **common_args}
183
+ if mask_bytes is not None: api_kwargs["mask"] = mask_bytes
 
 
 
 
 
 
 
 
 
 
184
  resp = client.images.edit(**api_kwargs)
 
185
  except openai.AuthenticationError:
186
  raise gr.Error("Invalid OpenAI API key.")
187
  except openai.PermissionDeniedError:
188
+ raise gr.Error("Permission denied. Check API key permissions/verification.")
189
  except openai.RateLimitError:
190
+ raise gr.Error("Rate limit exceeded.")
191
  except openai.BadRequestError as e:
192
  error_message = str(e)
193
  try:
 
195
  body = json.loads(str(e.body))
196
  if isinstance(body, dict) and 'error' in body and 'message' in body['error']:
197
  error_message = f"OpenAI Bad Request: {body['error']['message']}"
198
+ if "mask" in error_message.lower(): error_message += " (Check mask format/dimensions)"
199
+ elif "size" in error_message.lower(): error_message += " (Check image/mask dimensions)"
200
+ else: error_message = f"OpenAI Bad Request: {e}"
201
+ except: error_message = f"OpenAI Bad Request: {e}"
 
 
 
 
 
202
  raise gr.Error(error_message)
203
  except Exception as e:
204
  raise gr.Error(f"An unexpected error occurred: {e}")
 
206
 
207
 
208
  # ---------- Variations ---------- #
 
209
  def variation_image(
210
  api_key: str,
211
  image_numpy: np.ndarray,
 
217
  transparent_bg: bool,
218
  ):
219
  """Calls the OpenAI image variations endpoint."""
220
+ gr.Warning("Note: Variations may not work with gpt-image-1 (use DALL·E 2).")
221
+ if image_numpy is None: raise gr.Error("Please upload an image.")
 
 
 
 
222
  img_bytes = _bytes_from_numpy(image_numpy)
223
+ client = _client(api_key) # API key used here
224
  try:
 
225
  common_args = _common_kwargs(None, n, size, quality, out_fmt, compression, transparent_bg)
226
+ resp = client.images.variations(image=img_bytes, **common_args)
 
 
 
227
  except openai.AuthenticationError:
228
  raise gr.Error("Invalid OpenAI API key.")
229
  except openai.PermissionDeniedError:
230
+ raise gr.Error("Permission denied.")
231
  except openai.RateLimitError:
232
+ raise gr.Error("Rate limit exceeded.")
233
  except openai.BadRequestError as e:
234
  error_message = str(e)
235
  try:
 
238
  if isinstance(body, dict) and 'error' in body and 'message' in body['error']:
239
  error_message = f"OpenAI Bad Request: {body['error']['message']}"
240
  if "model does not support variations" in error_message.lower():
241
+ error_message += " (gpt-image-1 does not support variations)."
242
+ else: error_message = f"OpenAI Bad Request: {e}"
243
+ except: error_message = f"OpenAI Bad Request: {e}"
 
 
244
  raise gr.Error(error_message)
245
  except Exception as e:
246
  raise gr.Error(f"An unexpected error occurred: {e}")
 
259
  )
260
 
261
  with gr.Accordion("🔐 API key", open=False):
262
+ # API key input component
263
  api = gr.Textbox(label="OpenAI API key", type="password", placeholder="sk-…")
264
 
265
  # Common controls
266
  with gr.Row():
267
+ n_slider = gr.Slider(1, 4, value=1, step=1, label="Number of images (n)", info="Max 4 for this demo.")
268
  size = gr.Dropdown(SIZE_CHOICES, value="auto", label="Size", info="API default if 'auto'.")
269
  quality = gr.Dropdown(QUALITY_CHOICES, value="auto", label="Quality", info="API default if 'auto'.")
270
  with gr.Row():
 
277
 
278
  out_fmt.change(_toggle_compression, inputs=out_fmt, outputs=compression)
279
 
280
+ # Define the list of common controls *excluding* the API key
281
+ common_controls = [n_slider, size, quality, out_fmt, compression, transparent]
282
 
283
  with gr.Tabs():
284
  # ----- Generate Tab ----- #
 
287
  prompt_gen = gr.Textbox(label="Prompt", lines=3, placeholder="A photorealistic ginger cat astronaut on Mars", scale=4)
288
  btn_gen = gr.Button("Generate 🚀", variant="primary", scale=1)
289
  gallery_gen = gr.Gallery(label="Generated Images", columns=2, height="auto", preview=True)
290
+
291
+ # CORRECTED inputs list for generate
292
  btn_gen.click(
293
  generate,
294
+ inputs=[api, prompt_gen] + common_controls, # API key first
295
  outputs=gallery_gen,
296
  api_name="generate"
297
  )
 
301
  gr.Markdown("Upload an image, then **paint the area to change** in the mask canvas below (white = edit area). The API requires the mask and image to have the same dimensions.")
302
  with gr.Row():
303
  img_edit = gr.Image(label="Source Image", type="numpy", height=400)
 
304
  mask_canvas = gr.ImageMask(
305
  label="Mask – Paint White Where Image Should Change",
306
+ type="numpy",
307
  height=400
308
  )
309
  with gr.Row():
310
  prompt_edit = gr.Textbox(label="Edit prompt", lines=2, placeholder="Replace the sky with a starry night", scale=4)
311
  btn_edit = gr.Button("Edit 🖌️", variant="primary", scale=1)
312
  gallery_edit = gr.Gallery(label="Edited Images", columns=2, height="auto", preview=True)
313
+
314
+ # CORRECTED inputs list for edit_image
315
  btn_edit.click(
316
  edit_image,
317
+ inputs=[api, img_edit, mask_canvas, prompt_edit] + common_controls, # API key first
318
  outputs=gallery_edit,
319
  api_name="edit"
320
  )
 
326
  img_var = gr.Image(label="Source Image", type="numpy", height=400, scale=4)
327
  btn_var = gr.Button("Create Variations ✨", variant="primary", scale=1)
328
  gallery_var = gr.Gallery(label="Variations", columns=2, height="auto", preview=True)
329
+
330
+ # CORRECTED inputs list for variation_image
331
  btn_var.click(
332
  variation_image,
333
+ inputs=[api, img_var] + common_controls, # API key first
334
  outputs=gallery_var,
335
  api_name="variations"
336
  )
 
339
 
340
  if __name__ == "__main__":
341
  app = build_ui()
 
 
342
  app.launch(share=os.getenv("GRADIO_SHARE") == "true", debug=True)