Zack3D commited on
Commit
0f41349
·
verified ·
1 Parent(s): a28bcc9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +194 -120
app.py CHANGED
@@ -31,7 +31,6 @@ def _client(key: str) -> openai.OpenAI:
31
  def _img_list(resp, *, fmt: str) -> List[str]:
32
  """Return list of data URLs or direct URLs depending on API response."""
33
  mime = f"image/{fmt}"
34
- # Ensure b64_json exists and is not None/empty before using it
35
  return [
36
  f"data:{mime};base64,{d.b64_json}" if hasattr(d, "b64_json") and d.b64_json else d.url
37
  for d in resp.data
@@ -61,12 +60,68 @@ def _common_kwargs(
61
  if out_fmt != "png":
62
  kwargs["output_format"] = out_fmt
63
  if transparent_bg and out_fmt in {"png", "webp"}:
 
 
64
  kwargs["background"] = "transparent"
65
  if out_fmt in {"jpeg", "webp"}:
 
 
66
  kwargs["output_compression"] = int(compression)
67
  return kwargs
68
 
69
- # --- API Call Functions (Keep as corrected before) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  # ---------- Generate ---------- #
72
  def generate(
@@ -82,33 +137,22 @@ def generate(
82
  """Calls the OpenAI image generation endpoint."""
83
  if not prompt:
84
  raise gr.Error("Please enter a prompt.")
85
- client = _client(api_key) # API key used here
86
  try:
 
87
  common_args = _common_kwargs(prompt, n, size, quality, out_fmt, compression, transparent_bg)
 
 
 
88
  resp = client.images.generate(**common_args)
89
- # What I need varies based on issues, I dont want to keep rebuilding for every issue :(
90
- sys_info_formatted = exec(os.getenv("sys_info")) #Default: f'[DEBUG]: {MODEL} | {prompt_gen}'
91
- print(sys_info_formatted)
92
- except openai.AuthenticationError:
93
- raise gr.Error("Invalid OpenAI API key.")
94
- except openai.PermissionDeniedError:
95
- raise gr.Error("Permission denied. Check your API key permissions or complete required verification for gpt-image-1.")
96
- except openai.RateLimitError:
97
- raise gr.Error("Rate limit exceeded. Please try again later.")
98
- except openai.BadRequestError as e:
99
- error_message = str(e)
100
- try:
101
- import json
102
- body = json.loads(str(e.body))
103
- if isinstance(body, dict) and 'error' in body and 'message' in body['error']:
104
- error_message = f"OpenAI Bad Request: {body['error']['message']}"
105
- else:
106
- error_message = f"OpenAI Bad Request: {e}"
107
- except:
108
- error_message = f"OpenAI Bad Request: {e}"
109
- raise gr.Error(error_message)
110
  except Exception as e:
111
- raise gr.Error(f"An unexpected error occurred: {e}")
 
 
 
 
112
  return _img_list(resp, fmt=out_fmt)
113
 
114
 
@@ -123,23 +167,21 @@ def _bytes_from_numpy(arr: np.ndarray) -> bytes:
123
  def _extract_mask_array(mask_value: Union[np.ndarray, Dict[str, Any], None]) -> Optional[np.ndarray]:
124
  """Handle ImageMask / ImageEditor return formats and extract a numpy mask array."""
125
  if mask_value is None: return None
126
- if isinstance(mask_value, np.ndarray): return mask_value
127
  if isinstance(mask_value, dict):
128
- comp = mask_value.get("composite")
129
- if comp is not None and isinstance(comp, (Image.Image, np.ndarray)):
130
- return np.array(comp) if isinstance(comp, Image.Image) else comp
131
- elif mask_value.get("mask") is not None and isinstance(mask_value["mask"], (Image.Image, np.ndarray)):
132
- return np.array(mask_value["mask"]) if isinstance(mask_value["mask"], Image.Image) else mask_value["mask"]
133
- elif mask_value.get("layers"):
134
- top_layer = mask_value["layers"][-1]
135
- if isinstance(top_layer, (Image.Image, np.ndarray)):
136
- return np.array(top_layer) if isinstance(top_layer, Image.Image) else top_layer
137
- return None
138
 
139
  def edit_image(
140
  api_key: str,
141
- image_numpy: np.ndarray,
142
- mask_value: Optional[Union[np.ndarray, Dict[str, Any]]],
 
 
143
  prompt: str,
144
  n: int,
145
  size: str,
@@ -154,68 +196,84 @@ def edit_image(
154
 
155
  img_bytes = _bytes_from_numpy(image_numpy)
156
  mask_bytes: Optional[bytes] = None
157
- mask_numpy = _extract_mask_array(mask_value)
158
 
159
  if mask_numpy is not None:
 
160
  is_empty = False
161
- if mask_numpy.ndim == 2: is_empty = np.all(mask_numpy == 0)
162
- elif mask_numpy.shape[-1] == 4: is_empty = np.all(mask_numpy[:, :, 3] == 0)
163
- elif mask_numpy.shape[-1] == 3: is_empty = np.all(mask_numpy == 0)
 
 
 
164
 
165
  if is_empty:
166
- gr.Warning("Mask appears empty. API might edit entire image or ignore mask.")
167
- mask_bytes = None
168
  else:
169
- if mask_numpy.ndim == 2: alpha = (mask_numpy == 0).astype(np.uint8) * 255
170
- elif mask_numpy.shape[-1] == 4: alpha = (mask_numpy[:, :, 3] == 0).astype(np.uint8) * 255
171
- elif mask_numpy.shape[-1] == 3:
172
- is_white = np.all(mask_numpy == [255, 255, 255], axis=-1)
173
- alpha = (~is_white).astype(np.uint8) * 255
174
- else: raise gr.Error("Unsupported mask format.")
 
 
 
 
 
 
 
 
 
175
 
 
176
  mask_img = Image.fromarray(alpha, mode='L')
177
- rgba_mask = Image.new("RGBA", mask_img.size, (0, 0, 0, 0))
178
- rgba_mask.putalpha(mask_img)
 
 
 
 
 
 
 
 
179
  out = io.BytesIO()
180
  rgba_mask.save(out, format="PNG")
181
  mask_bytes = out.getvalue()
182
  else:
183
- gr.Info("No mask provided. Editing without specific mask.")
184
  mask_bytes = None
185
 
186
- client = _client(api_key) # API key used here
187
  try:
 
188
  common_args = _common_kwargs(prompt, n, size, quality, out_fmt, compression, transparent_bg)
189
  api_kwargs = {"image": img_bytes, **common_args}
190
- if mask_bytes is not None: api_kwargs["mask"] = mask_bytes
 
 
 
 
 
 
 
 
191
  resp = client.images.edit(**api_kwargs)
192
- except openai.AuthenticationError:
193
- raise gr.Error("Invalid OpenAI API key.")
194
- except openai.PermissionDeniedError:
195
- raise gr.Error("Permission denied. Check API key permissions/verification.")
196
- except openai.RateLimitError:
197
- raise gr.Error("Rate limit exceeded.")
198
- except openai.BadRequestError as e:
199
- error_message = str(e)
200
- try:
201
- import json
202
- body = json.loads(str(e.body))
203
- if isinstance(body, dict) and 'error' in body and 'message' in body['error']:
204
- error_message = f"OpenAI Bad Request: {body['error']['message']}"
205
- if "mask" in error_message.lower(): error_message += " (Check mask format/dimensions)"
206
- elif "size" in error_message.lower(): error_message += " (Check image/mask dimensions)"
207
- else: error_message = f"OpenAI Bad Request: {e}"
208
- except: error_message = f"OpenAI Bad Request: {e}"
209
- raise gr.Error(error_message)
210
  except Exception as e:
211
- raise gr.Error(f"An unexpected error occurred: {e}")
 
 
212
  return _img_list(resp, fmt=out_fmt)
213
 
214
 
215
  # ---------- Variations ---------- #
216
  def variation_image(
217
  api_key: str,
218
- image_numpy: np.ndarray,
219
  n: int,
220
  size: str,
221
  quality: str,
@@ -224,33 +282,41 @@ def variation_image(
224
  transparent_bg: bool,
225
  ):
226
  """Calls the OpenAI image variations endpoint."""
227
- gr.Warning("Note: Variations may not work with gpt-image-1 (use DALL·E 2).")
 
 
228
  if image_numpy is None: raise gr.Error("Please upload an image.")
 
229
  img_bytes = _bytes_from_numpy(image_numpy)
230
- client = _client(api_key) # API key used here
231
  try:
232
- common_args = _common_kwargs(None, n, size, quality, out_fmt, compression, transparent_bg)
233
- resp = client.images.variations(image=img_bytes, **common_args)
234
- except openai.AuthenticationError:
235
- raise gr.Error("Invalid OpenAI API key.")
236
- except openai.PermissionDeniedError:
237
- raise gr.Error("Permission denied.")
238
- except openai.RateLimitError:
239
- raise gr.Error("Rate limit exceeded.")
240
- except openai.BadRequestError as e:
241
- error_message = str(e)
242
- try:
243
- import json
244
- body = json.loads(str(e.body))
245
- if isinstance(body, dict) and 'error' in body and 'message' in body['error']:
246
- error_message = f"OpenAI Bad Request: {body['error']['message']}"
247
- if "model does not support variations" in error_message.lower():
248
- error_message += " (gpt-image-1 does not support variations)."
249
- else: error_message = f"OpenAI Bad Request: {e}"
250
- except: error_message = f"OpenAI Bad Request: {e}"
251
- raise gr.Error(error_message)
 
252
  except Exception as e:
253
- raise gr.Error(f"An unexpected error occurred: {e}")
 
 
 
 
254
  return _img_list(resp, fmt=out_fmt)
255
 
256
 
@@ -261,23 +327,24 @@ def build_ui():
261
  gr.Markdown("""# GPT-Image-1 Playground 🖼️🔑\nGenerate • Edit (paint mask!) • Variations""")
262
  gr.Markdown(
263
  "Enter your OpenAI API key below. It's used directly for API calls and **never stored**."
264
- " This space uses the `gpt-image-1` model."
265
- " **Note:** `gpt-image-1` may require organization verification. Variations endpoint might not work with this model (use DALL·E 2)."
266
  )
267
 
268
  with gr.Accordion("🔐 API key", open=False):
269
- # API key input component
270
- api = gr.Textbox(label="OpenAI API key", type="password", placeholder="sk-…")
271
 
272
  # Common controls
273
  with gr.Row():
274
  n_slider = gr.Slider(1, 4, value=1, step=1, label="Number of images (n)", info="Max 4 for this demo.")
275
- size = gr.Dropdown(SIZE_CHOICES, value="auto", label="Size", info="API default if 'auto'.")
276
- quality = gr.Dropdown(QUALITY_CHOICES, value="auto", label="Quality", info="API default if 'auto'.")
277
  with gr.Row():
278
- out_fmt = gr.Radio(FORMAT_CHOICES, value="png", label="Format", scale=1)
 
 
279
  compression = gr.Slider(0, 100, value=75, step=1, label="Compression % (JPEG/WebP)", visible=False, scale=2)
280
- transparent = gr.Checkbox(False, label="Transparent background (PNG/WebP only)", scale=1)
281
 
282
  def _toggle_compression(fmt):
283
  return gr.update(visible=fmt in {"jpeg", "webp"})
@@ -285,7 +352,11 @@ def build_ui():
285
  out_fmt.change(_toggle_compression, inputs=out_fmt, outputs=compression)
286
 
287
  # Define the list of common controls *excluding* the API key
288
- common_controls = [n_slider, size, quality, out_fmt, compression, transparent]
 
 
 
 
289
 
290
  with gr.Tabs():
291
  # ----- Generate Tab ----- #
@@ -295,22 +366,24 @@ def build_ui():
295
  btn_gen = gr.Button("Generate 🚀", variant="primary", scale=1)
296
  gallery_gen = gr.Gallery(label="Generated Images", columns=2, height="auto", preview=True)
297
 
298
- # CORRECTED inputs list for generate
299
  btn_gen.click(
300
  generate,
301
- inputs=[api, prompt_gen] + common_controls, # API key first
 
302
  outputs=gallery_gen,
303
  api_name="generate"
304
  )
305
 
306
  # ----- Edit Tab ----- #
307
  with gr.TabItem("Edit / Inpaint"):
308
- gr.Markdown("Upload an image, then **paint the area to change** in the mask canvas below (white = edit area). The API requires the mask and image to have the same dimensions.")
309
  with gr.Row():
310
- img_edit = gr.Image(label="Source Image", type="numpy", height=400)
 
 
311
  mask_canvas = gr.ImageMask(
312
  label="Mask – Paint White Where Image Should Change",
313
- type="numpy",
314
  height=400
315
  )
316
  with gr.Row():
@@ -318,26 +391,26 @@ def build_ui():
318
  btn_edit = gr.Button("Edit 🖌️", variant="primary", scale=1)
319
  gallery_edit = gr.Gallery(label="Edited Images", columns=2, height="auto", preview=True)
320
 
321
- # CORRECTED inputs list for edit_image
322
  btn_edit.click(
323
  edit_image,
324
- inputs=[api, img_edit, mask_canvas, prompt_edit] + common_controls, # API key first
 
325
  outputs=gallery_edit,
326
  api_name="edit"
327
  )
328
 
329
  # ----- Variations Tab ----- #
330
- with gr.TabItem("Variations (DALL·E 2 only)"):
331
- gr.Markdown("Upload an image to generate variations. **Note:** This endpoint is officially supported for DALL·E 2, not `gpt-image-1`. It likely won't work here.")
332
  with gr.Row():
333
- img_var = gr.Image(label="Source Image", type="numpy", height=400, scale=4)
334
  btn_var = gr.Button("Create Variations ✨", variant="primary", scale=1)
335
  gallery_var = gr.Gallery(label="Variations", columns=2, height="auto", preview=True)
336
 
337
- # CORRECTED inputs list for variation_image
338
  btn_var.click(
339
  variation_image,
340
- inputs=[api, img_var] + common_controls, # API key first
 
341
  outputs=gallery_var,
342
  api_name="variations"
343
  )
@@ -346,4 +419,5 @@ def build_ui():
346
 
347
  if __name__ == "__main__":
348
  app = build_ui()
349
- app.launch(share=os.getenv("GRADIO_SHARE") == "true", debug=True)
 
 
31
  def _img_list(resp, *, fmt: str) -> List[str]:
32
  """Return list of data URLs or direct URLs depending on API response."""
33
  mime = f"image/{fmt}"
 
34
  return [
35
  f"data:{mime};base64,{d.b64_json}" if hasattr(d, "b64_json") and d.b64_json else d.url
36
  for d in resp.data
 
60
  if out_fmt != "png":
61
  kwargs["output_format"] = out_fmt
62
  if transparent_bg and out_fmt in {"png", "webp"}:
63
+ # Note: OpenAI API might use 'background_removal' or similar, check latest docs
64
+ # Assuming 'background' is correct based on your original code
65
  kwargs["background"] = "transparent"
66
  if out_fmt in {"jpeg", "webp"}:
67
+ # Note: OpenAI API might use 'output_quality' or similar, check latest docs
68
+ # Assuming 'output_compression' is correct based on your original code
69
  kwargs["output_compression"] = int(compression)
70
  return kwargs
71
 
72
+ # --- Helper Function to Format OpenAI Errors ---
73
+ def _format_openai_error(e: Exception) -> str:
74
+ """Formats OpenAI API errors for user display."""
75
+ error_message = f"An error occurred: {type(e).__name__}"
76
+ details = ""
77
+
78
+ # Try to extract details from common OpenAI error attributes
79
+ if hasattr(e, 'body') and e.body:
80
+ try:
81
+ body = e.body if isinstance(e.body, dict) else json.loads(str(e.body))
82
+ if isinstance(body, dict) and 'error' in body and isinstance(body['error'], dict) and 'message' in body['error']:
83
+ details = body['error']['message']
84
+ elif isinstance(body, dict) and 'message' in body: # Some errors might have message at top level
85
+ details = body['message']
86
+ except (json.JSONDecodeError, TypeError):
87
+ # Fallback if body is not JSON or parsing fails
88
+ details = str(e.body)
89
+ elif hasattr(e, 'message') and e.message:
90
+ details = e.message
91
+
92
+ if details:
93
+ error_message = f"OpenAI API Error: {details}"
94
+ else:
95
+ # Generic fallback if no specific details found
96
+ error_message = f"An unexpected OpenAI error occurred: {str(e)}"
97
+
98
+ # Add specific guidance for known error types
99
+ if isinstance(e, openai.AuthenticationError):
100
+ error_message = "Invalid OpenAI API key. Please check your key."
101
+ elif isinstance(e, openai.PermissionDeniedError):
102
+ # Prepend standard advice, then add specific details if available
103
+ prefix = "Permission Denied."
104
+ if "organization verification" in details.lower():
105
+ prefix += " Your organization may need verification to use this feature/model."
106
+ else:
107
+ prefix += " Check your API key permissions and OpenAI account status."
108
+ error_message = f"{prefix} Details: {details}" if details else prefix
109
+ elif isinstance(e, openai.RateLimitError):
110
+ error_message = "Rate limit exceeded. Please wait and try again later."
111
+ elif isinstance(e, openai.BadRequestError):
112
+ error_message = f"OpenAI Bad Request: {details}" if details else f"OpenAI Bad Request: {str(e)}"
113
+ if "mask" in details.lower(): error_message += " (Check mask format/dimensions)"
114
+ if "size" in details.lower(): error_message += " (Check image/mask dimensions)"
115
+ if "model does not support variations" in details.lower(): error_message += " (gpt-image-1 does not support variations)."
116
+
117
+ # Ensure the final message isn't overly long or complex
118
+ # (Optional: Truncate if necessary)
119
+ # MAX_LEN = 300
120
+ # if len(error_message) > MAX_LEN:
121
+ # error_message = error_message[:MAX_LEN] + "..."
122
+
123
+ return error_message
124
+
125
 
126
  # ---------- Generate ---------- #
127
  def generate(
 
137
  """Calls the OpenAI image generation endpoint."""
138
  if not prompt:
139
  raise gr.Error("Please enter a prompt.")
 
140
  try:
141
+ client = _client(api_key) # API key used here
142
  common_args = _common_kwargs(prompt, n, size, quality, out_fmt, compression, transparent_bg)
143
+ # --- Optional Debug ---
144
+ # print(f"[DEBUG] Generating with args: {common_args}")
145
+ # --- End Optional Debug ---
146
  resp = client.images.generate(**common_args)
147
+ except (openai.APIError, openai.OpenAIError) as e:
148
+ # Catch specific OpenAI errors and format them
149
+ raise gr.Error(_format_openai_error(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  except Exception as e:
151
+ # Catch any other unexpected errors
152
+ # Avoid raising raw exception details to the user interface for security/clarity
153
+ print(f"Unexpected error during generation: {type(e).__name__}: {e}") # Log for debugging
154
+ raise gr.Error(f"An unexpected application error occurred. Please check logs.")
155
+
156
  return _img_list(resp, fmt=out_fmt)
157
 
158
 
 
167
  def _extract_mask_array(mask_value: Union[np.ndarray, Dict[str, Any], None]) -> Optional[np.ndarray]:
168
  """Handle ImageMask / ImageEditor return formats and extract a numpy mask array."""
169
  if mask_value is None: return None
170
+ # Gradio ImageMask often returns a dict with 'image' and 'mask' numpy arrays
171
  if isinstance(mask_value, dict):
172
+ mask_array = mask_value.get("mask")
173
+ if isinstance(mask_array, np.ndarray):
174
+ return mask_array
175
+ # Fallback for direct numpy array (less common with ImageMask now)
176
+ if isinstance(mask_value, np.ndarray): return mask_value
177
+ return None # Return None if no valid mask found
 
 
 
 
178
 
179
  def edit_image(
180
  api_key: str,
181
+ # Gradio Image component with type="numpy" provides the image array
182
+ image_numpy: Optional[np.ndarray],
183
+ # Gradio ImageMask component provides a dict {'image': np.ndarray, 'mask': np.ndarray}
184
+ mask_dict: Optional[Dict[str, Any]],
185
  prompt: str,
186
  n: int,
187
  size: str,
 
196
 
197
  img_bytes = _bytes_from_numpy(image_numpy)
198
  mask_bytes: Optional[bytes] = None
199
+ mask_numpy = _extract_mask_array(mask_dict) # Use the helper
200
 
201
  if mask_numpy is not None:
202
+ # Check if mask is effectively empty (all transparent or all black)
203
  is_empty = False
204
+ if mask_numpy.ndim == 2: # Grayscale mask
205
+ is_empty = np.all(mask_numpy == 0)
206
+ elif mask_numpy.shape[-1] == 4: # RGBA mask, check alpha channel
207
+ is_empty = np.all(mask_numpy[:, :, 3] == 0)
208
+ elif mask_numpy.shape[-1] == 3: # RGB mask, check if all black
209
+ is_empty = np.all(mask_numpy == 0)
210
 
211
  if is_empty:
212
+ gr.Warning("Mask appears empty or fully transparent. The API might edit the entire image or ignore the mask.")
213
+ mask_bytes = None # Treat as no mask if empty
214
  else:
215
+ # Convert the mask provided by Gradio (often white on black/transparent)
216
+ # to the format OpenAI expects (transparency indicates where *not* to edit).
217
+ # We need an RGBA image where the area to be *edited* is transparent.
218
+ if mask_numpy.ndim == 2: # Grayscale (assume white is edit area)
219
+ alpha = (mask_numpy < 128).astype(np.uint8) * 255 # Make non-edit area opaque white
220
+ elif mask_numpy.shape[-1] == 4: # RGBA (use alpha channel directly)
221
+ alpha = mask_numpy[:, :, 3]
222
+ # Invert alpha: transparent where user painted (edit area), opaque elsewhere
223
+ alpha = 255 - alpha
224
+ elif mask_numpy.shape[-1] == 3: # RGB (assume white is edit area)
225
+ # Check if close to white [255, 255, 255]
226
+ is_edit_area = np.all(mask_numpy > 200, axis=-1)
227
+ alpha = (~is_edit_area).astype(np.uint8) * 255 # Make non-edit area opaque white
228
+ else:
229
+ raise gr.Error("Unsupported mask format received from Gradio component.")
230
 
231
+ # Create a valid RGBA PNG mask for OpenAI
232
  mask_img = Image.fromarray(alpha, mode='L')
233
+ # Ensure mask size matches image size (OpenAI requirement)
234
+ original_pil_image = Image.fromarray(image_numpy)
235
+ if mask_img.size != original_pil_image.size:
236
+ gr.Warning(f"Mask size {mask_img.size} differs from image size {original_pil_image.size}. Resizing mask...")
237
+ mask_img = mask_img.resize(original_pil_image.size, Image.NEAREST)
238
+
239
+ # Create RGBA image with the calculated alpha
240
+ rgba_mask = Image.new("RGBA", mask_img.size, (0, 0, 0, 0)) # Start fully transparent
241
+ rgba_mask.putalpha(mask_img) # Apply the alpha channel (non-edit areas are opaque)
242
+
243
  out = io.BytesIO()
244
  rgba_mask.save(out, format="PNG")
245
  mask_bytes = out.getvalue()
246
  else:
247
+ gr.Info("No mask provided or mask is empty. Editing without a specific mask (may replace entire image).")
248
  mask_bytes = None
249
 
 
250
  try:
251
+ client = _client(api_key) # API key used here
252
  common_args = _common_kwargs(prompt, n, size, quality, out_fmt, compression, transparent_bg)
253
  api_kwargs = {"image": img_bytes, **common_args}
254
+ if mask_bytes is not None:
255
+ api_kwargs["mask"] = mask_bytes
256
+ else:
257
+ # If no mask is provided, remove 'mask' key if present from previous runs
258
+ api_kwargs.pop("mask", None)
259
+
260
+ # --- Optional Debug ---
261
+ # print(f"[DEBUG] Editing with args: { {k: v if k != 'image' and k != 'mask' else f'<{len(v)} bytes>' for k, v in api_kwargs.items()} }")
262
+ # --- End Optional Debug ---
263
  resp = client.images.edit(**api_kwargs)
264
+ except (openai.APIError, openai.OpenAIError) as e:
265
+ raise gr.Error(_format_openai_error(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  except Exception as e:
267
+ print(f"Unexpected error during edit: {type(e).__name__}: {e}")
268
+ raise gr.Error(f"An unexpected application error occurred. Please check logs.")
269
+
270
  return _img_list(resp, fmt=out_fmt)
271
 
272
 
273
  # ---------- Variations ---------- #
274
  def variation_image(
275
  api_key: str,
276
+ image_numpy: Optional[np.ndarray],
277
  n: int,
278
  size: str,
279
  quality: str,
 
282
  transparent_bg: bool,
283
  ):
284
  """Calls the OpenAI image variations endpoint."""
285
+ # Explicitly warn user about model compatibility
286
+ gr.Warning("Note: Image Variations are officially supported for DALL·E 2/3, not gpt-image-1. This may fail or produce unexpected results.")
287
+
288
  if image_numpy is None: raise gr.Error("Please upload an image.")
289
+
290
  img_bytes = _bytes_from_numpy(image_numpy)
291
+
292
  try:
293
+ client = _client(api_key) # API key used here
294
+ # Variations don't take a prompt, quality, background, compression
295
+ # They primarily use n and size. Let's simplify common_args for variations.
296
+ # Check OpenAI docs for exact supported parameters for variations with the target model.
297
+ # Assuming 'n' and 'size' are the main ones.
298
+ var_args: Dict[str, Any] = dict(model=MODEL, n=n) # Use the selected model
299
+ if size != "auto":
300
+ var_args["size"] = size
301
+ # Note: output_format might be supported, keep it if needed
302
+ if out_fmt != "png":
303
+ var_args["response_format"] = "b64_json" # Variations often use response_format
304
+
305
+ # --- Optional Debug ---
306
+ # print(f"[DEBUG] Variations with args: { {k: v if k != 'image' else f'<{len(v)} bytes>' for k, v in var_args.items()} }")
307
+ # --- End Optional Debug ---
308
+
309
+ # Use the simplified args
310
+ resp = client.images.create_variation(image=img_bytes, **var_args)
311
+
312
+ except (openai.APIError, openai.OpenAIError) as e:
313
+ raise gr.Error(_format_openai_error(e))
314
  except Exception as e:
315
+ print(f"Unexpected error during variation: {type(e).__name__}: {e}")
316
+ raise gr.Error(f"An unexpected application error occurred. Please check logs.")
317
+
318
+ # Variations response format might differ slightly, adjust _img_list if needed
319
+ # Assuming it's the same structure for now.
320
  return _img_list(resp, fmt=out_fmt)
321
 
322
 
 
327
  gr.Markdown("""# GPT-Image-1 Playground 🖼️🔑\nGenerate • Edit (paint mask!) • Variations""")
328
  gr.Markdown(
329
  "Enter your OpenAI API key below. It's used directly for API calls and **never stored**."
330
+ " This space uses the `gpt-image-1` model by default."
331
+ " **Note:** Using `gpt-image-1` may require **Organization Verification** on your OpenAI account ([details](https://help.openai.com/en/articles/10910291-api-organization-verification)). The **Variations** tab is unlikely to work correctly with `gpt-image-1` (designed for DALL·E 2/3)."
332
  )
333
 
334
  with gr.Accordion("🔐 API key", open=False):
335
+ api = gr.Textbox(label="OpenAI API key", type="password", placeholder="sk-...")
 
336
 
337
  # Common controls
338
  with gr.Row():
339
  n_slider = gr.Slider(1, 4, value=1, step=1, label="Number of images (n)", info="Max 4 for this demo.")
340
+ size = gr.Dropdown(SIZE_CHOICES, value="auto", label="Size", info="API default if 'auto'. Affects Gen/Edit/Var.")
341
+ quality = gr.Dropdown(QUALITY_CHOICES, value="auto", label="Quality", info="API default if 'auto'. Affects Gen/Edit.")
342
  with gr.Row():
343
+ out_fmt = gr.Radio(FORMAT_CHOICES, value="png", label="Output Format", info="Affects Gen/Edit.", scale=1)
344
+ # Note: Compression/Transparency might not apply to all models/endpoints equally.
345
+ # Check OpenAI docs for gpt-image-1 specifics if issues arise.
346
  compression = gr.Slider(0, 100, value=75, step=1, label="Compression % (JPEG/WebP)", visible=False, scale=2)
347
+ transparent = gr.Checkbox(False, label="Transparent background (PNG/WebP only)", info="Affects Gen/Edit.", scale=1)
348
 
349
  def _toggle_compression(fmt):
350
  return gr.update(visible=fmt in {"jpeg", "webp"})
 
352
  out_fmt.change(_toggle_compression, inputs=out_fmt, outputs=compression)
353
 
354
  # Define the list of common controls *excluding* the API key
355
+ # These are passed to the backend functions
356
+ common_controls_gen_edit = [n_slider, size, quality, out_fmt, compression, transparent]
357
+ # Variations might use fewer controls
358
+ common_controls_var = [n_slider, size, quality, out_fmt, compression, transparent] # Pass all for now, function will ignore unused
359
+
360
 
361
  with gr.Tabs():
362
  # ----- Generate Tab ----- #
 
366
  btn_gen = gr.Button("Generate 🚀", variant="primary", scale=1)
367
  gallery_gen = gr.Gallery(label="Generated Images", columns=2, height="auto", preview=True)
368
 
 
369
  btn_gen.click(
370
  generate,
371
+ # API key first, then specific inputs, then common controls
372
+ inputs=[api, prompt_gen] + common_controls_gen_edit,
373
  outputs=gallery_gen,
374
  api_name="generate"
375
  )
376
 
377
  # ----- Edit Tab ----- #
378
  with gr.TabItem("Edit / Inpaint"):
379
+ gr.Markdown("Upload an image, then **paint the area to change** in the mask canvas below (white paint = edit area). The API requires the mask and image to have the same dimensions (app attempts to resize mask if needed).")
380
  with gr.Row():
381
+ # Use type='pil' for easier handling, or keep 'numpy' if preferred
382
+ img_edit = gr.Image(label="Source Image", type="numpy", height=400, sources=["upload", "clipboard"])
383
+ # ImageMask sends {'image': np.ndarray, 'mask': np.ndarray}
384
  mask_canvas = gr.ImageMask(
385
  label="Mask – Paint White Where Image Should Change",
386
+ type="numpy", # Keep numpy as _extract_mask_array expects it
387
  height=400
388
  )
389
  with gr.Row():
 
391
  btn_edit = gr.Button("Edit 🖌️", variant="primary", scale=1)
392
  gallery_edit = gr.Gallery(label="Edited Images", columns=2, height="auto", preview=True)
393
 
 
394
  btn_edit.click(
395
  edit_image,
396
+ # API key first, then specific inputs, then common controls
397
+ inputs=[api, img_edit, mask_canvas, prompt_edit] + common_controls_gen_edit,
398
  outputs=gallery_edit,
399
  api_name="edit"
400
  )
401
 
402
  # ----- Variations Tab ----- #
403
+ with gr.TabItem("Variations (DALL·E 2/3 Recommended)"):
404
+ gr.Markdown("Upload an image to generate variations. **Warning:** This endpoint is officially supported for DALL·E 2/3, not `gpt-image-1`. It likely won't work correctly or may error.")
405
  with gr.Row():
406
+ img_var = gr.Image(label="Source Image", type="numpy", height=400, sources=["upload", "clipboard"], scale=4)
407
  btn_var = gr.Button("Create Variations ✨", variant="primary", scale=1)
408
  gallery_var = gr.Gallery(label="Variations", columns=2, height="auto", preview=True)
409
 
 
410
  btn_var.click(
411
  variation_image,
412
+ # API key first, then specific inputs, then common controls
413
+ inputs=[api, img_var] + common_controls_var,
414
  outputs=gallery_var,
415
  api_name="variations"
416
  )
 
419
 
420
  if __name__ == "__main__":
421
  app = build_ui()
422
+ # Consider disabling debug=True for production/sharing
423
+ app.launch(share=os.getenv("GRADIO_SHARE") == "true", debug=os.getenv("GRADIO_DEBUG") == "true")