Zack3D commited on
Commit
918e017
·
verified ·
1 Parent(s): 6aea94d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +380 -90
app.py CHANGED
@@ -34,13 +34,22 @@ def _img_list(resp) -> List[Union[np.ndarray, str]]:
34
  Decode base64 images into numpy arrays (for Gradio) or pass URL strings directly.
35
  """
36
  imgs: List[Union[np.ndarray, str]] = []
 
 
 
 
37
  for d in resp.data:
38
  if hasattr(d, "b64_json") and d.b64_json:
39
- data = base64.b64decode(d.b64_json)
40
- img = Image.open(io.BytesIO(data))
41
- imgs.append(np.array(img))
 
 
 
42
  elif getattr(d, "url", None):
43
  imgs.append(d.url)
 
 
44
  return imgs
45
 
46
 
@@ -49,24 +58,38 @@ def _common_kwargs(
49
  n: int,
50
  size: str,
51
  quality: str,
52
- out_fmt: str,
53
- compression: int,
54
- transparent_bg: bool,
55
  ) -> Dict[str, Any]:
56
  """Prepare keyword args for OpenAI Images API."""
57
  kwargs: Dict[str, Any] = {
58
  "model": MODEL,
59
  "n": n,
 
 
60
  }
61
  if size != "auto":
62
  kwargs["size"] = size
 
 
63
  if quality != "auto":
64
- kwargs["quality"] = quality
 
 
 
 
 
 
65
  if prompt is not None:
66
  kwargs["prompt"] = prompt
67
- if transparent_bg and out_fmt in {"png", "webp"}:
68
- # Insert background removal flag when supported
69
- kwargs["background"] = "transparent"
 
 
 
 
70
  return kwargs
71
 
72
 
@@ -77,45 +100,98 @@ def convert_to_format(
77
  ) -> np.ndarray:
78
  """
79
  Convert a PIL numpy array to target_fmt (JPEG/WebP) and return as numpy array.
 
80
  """
 
 
 
 
81
  img = Image.fromarray(img_array.astype(np.uint8))
82
  buf = io.BytesIO()
83
- img.save(buf, format=target_fmt.upper(), quality=quality)
84
- buf.seek(0)
85
- img2 = Image.open(buf)
86
- return np.array(img2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
 
89
  def _format_openai_error(e: Exception) -> str:
 
90
  error_message = f"An error occurred: {type(e).__name__}"
91
  details = ""
92
  if hasattr(e, 'body') and e.body:
93
  try:
94
- body = e.body if isinstance(e.body, dict) else json.loads(str(e.body))
 
95
  if isinstance(body, dict) and 'error' in body and isinstance(body['error'], dict) and 'message' in body['error']:
96
  details = body['error']['message']
97
- elif isinstance(body, dict) and 'message' in body:
98
  details = body['message']
99
- except Exception:
 
 
 
100
  details = str(e.body)
101
- elif hasattr(e, 'message') and e.message:
 
 
 
102
  details = e.message
 
103
  if details:
104
  error_message = f"OpenAI API Error: {details}"
 
 
 
 
 
105
  if isinstance(e, openai.AuthenticationError):
106
- error_message = "Invalid OpenAI API key. Please check your key."
107
  elif isinstance(e, openai.PermissionDeniedError):
108
  prefix = "Permission Denied."
109
- if "organization verification" in details.lower():
110
- prefix += " Your organization may need verification to use this feature/model."
 
 
111
  error_message = f"{prefix} Details: {details}" if details else prefix
112
  elif isinstance(e, openai.RateLimitError):
113
- error_message = "Rate limit exceeded. Please wait and try again later."
114
  elif isinstance(e, openai.BadRequestError):
115
  error_message = f"OpenAI Bad Request: {details or str(e)}"
116
- if "mask" in details.lower(): error_message += " (Check mask format/dimensions)"
117
- if "size" in details.lower(): error_message += " (Check image/mask dimensions)"
118
- if "model does not support variations" in details.lower(): error_message += " (gpt-image-1 does not support variations)."
 
 
 
 
 
 
 
 
119
  return error_message
120
 
121
 
@@ -128,86 +204,217 @@ def generate(
128
  quality: str,
129
  out_fmt: str,
130
  compression: int,
131
- transparent_bg: bool,
132
  ):
133
  if not prompt:
134
  raise gr.Error("Please enter a prompt.")
135
  try:
136
  client = _client(api_key)
 
137
  common_args = _common_kwargs(prompt, n, size, quality, out_fmt, compression, transparent_bg)
 
 
 
138
  resp = client.images.generate(**common_args)
139
- imgs = _img_list(resp)
140
- if out_fmt in {"jpeg", "webp"}:
141
- imgs = [convert_to_format(img, out_fmt, compression) for img in imgs]
142
- return imgs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  except (openai.APIError, openai.OpenAIError) as e:
 
144
  raise gr.Error(_format_openai_error(e))
145
  except Exception as e:
146
  print(f"Unexpected error during generation: {type(e).__name__}: {e}")
 
 
147
  raise gr.Error("An unexpected application error occurred. Please check logs.")
148
 
149
 
150
  # ---------- Edit / Inpaint ---------- #
151
- def _bytes_from_numpy(arr: np.ndarray) -> bytes:
 
152
  img = Image.fromarray(arr.astype(np.uint8))
153
  buf = io.BytesIO()
154
- img.save(buf, format="PNG")
155
  return buf.getvalue()
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  def _extract_mask_array(mask_value: Union[np.ndarray, Dict[str, Any], None]) -> Optional[np.ndarray]:
 
159
  if mask_value is None:
 
160
  return None
 
 
 
161
  if isinstance(mask_value, dict):
162
  mask_array = mask_value.get("mask")
163
- if isinstance(mask_array, np.ndarray):
164
- return mask_array
165
- if isinstance(mask_value, np.ndarray):
166
- return mask_value
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  return None
168
 
169
 
170
  def edit_image(
171
  api_key: str,
172
  image_numpy: Optional[np.ndarray],
173
- mask_dict: Optional[Dict[str, Any]],
174
  prompt: str,
175
  n: int,
176
  size: str,
177
  quality: str,
178
  out_fmt: str,
179
  compression: int,
180
- transparent_bg: bool,
181
  ):
182
  if image_numpy is None:
183
  raise gr.Error("Please upload an image.")
184
  if not prompt:
185
  raise gr.Error("Please enter an edit prompt.")
186
 
187
- img_bytes = _bytes_from_numpy(image_numpy)
188
- mask_bytes: Optional[bytes] = None
189
- mask_numpy = _extract_mask_array(mask_dict)
 
 
 
 
 
 
 
 
 
190
 
191
- # (Mask handling code unchanged)
192
  if mask_numpy is not None:
193
- # existing mask-to-bytes logic...
194
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
  try:
197
  client = _client(api_key)
 
198
  common_args = _common_kwargs(prompt, n, size, quality, out_fmt, compression, transparent_bg)
199
- api_kwargs = {"image": img_bytes, **common_args}
200
- if mask_bytes is not None:
201
- api_kwargs["mask"] = mask_bytes
202
- resp = client.images.edit(**api_kwargs)
203
- imgs = _img_list(resp)
204
- if out_fmt in {"jpeg", "webp"}:
205
- imgs = [convert_to_format(img, out_fmt, compression) for img in imgs]
206
- return imgs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  except (openai.APIError, openai.OpenAIError) as e:
 
208
  raise gr.Error(_format_openai_error(e))
209
  except Exception as e:
210
  print(f"Unexpected error during edit: {type(e).__name__}: {e}")
 
 
211
  raise gr.Error("An unexpected application error occurred. Please check logs.")
212
 
213
 
@@ -217,80 +424,163 @@ def variation_image(
217
  image_numpy: Optional[np.ndarray],
218
  n: int,
219
  size: str,
220
- quality: str,
221
  out_fmt: str,
222
  compression: int,
223
- transparent_bg: bool,
224
  ):
225
- gr.Warning("Note: Image Variations are officially supported for DALL·E 2/3, not gpt-image-1. This may fail.")
 
226
  if image_numpy is None:
227
  raise gr.Error("Please upload an image.")
228
 
229
- img_bytes = _bytes_from_numpy(image_numpy)
 
 
 
 
 
 
 
 
 
230
  try:
231
  client = _client(api_key)
232
- var_args: Dict[str, Any] = {"model": MODEL, "n": n}
 
 
 
 
 
233
  if size != "auto":
234
  var_args["size"] = size
235
- resp = client.images.create_variation(image=img_bytes, **var_args)
236
- imgs = _img_list(resp)
237
- if out_fmt in {"jpeg", "webp"}:
238
- imgs = [convert_to_format(img, out_fmt, compression) for img in imgs]
239
- return imgs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  except (openai.APIError, openai.OpenAIError) as e:
241
- raise gr.Error(_format_openai_error(e))
 
 
 
 
 
242
  except Exception as e:
243
  print(f"Unexpected error during variation: {type(e).__name__}: {e}")
 
 
244
  raise gr.Error("An unexpected application error occurred. Please check logs.")
245
 
246
 
247
  # ---------- UI ---------- #
248
  def build_ui():
249
- with gr.Blocks(title="GPT-Image-1 (BYOT)") as demo:
250
- gr.Markdown("""# GPT-Image-1 Playground 🖼️🔑\nGenerate • Edit • Variations""")
251
- with gr.Accordion("🔐 API key", open=False):
 
 
 
252
  api = gr.Textbox(label="OpenAI API key", type="password", placeholder="sk-...")
 
 
 
 
 
 
253
 
254
  with gr.Row():
255
  n_slider = gr.Slider(1, 4, value=1, step=1, label="Number of images (n)")
256
- size = gr.Dropdown(SIZE_CHOICES, value="auto", label="Size")
257
- quality = gr.Dropdown(QUALITY_CHOICES, value="auto", label="Quality")
258
  with gr.Row():
259
- out_fmt = gr.Radio(FORMAT_CHOICES, value="png", label="Output Format")
260
- compression = gr.Slider(0, 100, value=75, step=1, label="Compression % (JPEG/WebP)", visible=False)
261
- transparent = gr.Checkbox(False, label="Transparent background (PNG/WebP only)")
 
 
262
 
263
  def _toggle_compression(fmt):
264
  return gr.update(visible=fmt in {"jpeg", "webp"})
265
  out_fmt.change(_toggle_compression, inputs=out_fmt, outputs=compression)
266
 
 
267
  common_controls = [n_slider, size, quality, out_fmt, compression, transparent]
268
 
269
  with gr.Tabs():
270
  with gr.TabItem("Generate"):
271
- prompt_gen = gr.Textbox(label="Prompt", lines=3, placeholder="A photorealistic..." )
272
- btn_gen = gr.Button("Generate 🚀")
273
- gallery_gen = gr.Gallery(columns=2, height="auto")
274
- btn_gen.click(generate, inputs=[api, prompt_gen] + common_controls, outputs=gallery_gen)
 
 
 
 
275
 
276
  with gr.TabItem("Edit / Inpaint"):
277
- gr.Markdown("Upload an image, then paint the area to change")
278
- img_edit = gr.Image(type="numpy", label="Source Image", height=400)
279
- mask_canvas = gr.ImageMask(type="numpy", label="Mask – paint white", height=400)
280
- prompt_edit = gr.Textbox(label="Edit prompt", lines=2, placeholder="Replace the sky…")
281
- btn_edit = gr.Button("Edit 🖌️")
282
- gallery_edit = gr.Gallery(columns=2, height="auto")
283
- btn_edit.click(edit_image, inputs=[api, img_edit, mask_canvas, prompt_edit] + common_controls, outputs=gallery_edit)
284
-
285
- with gr.TabItem("Variations"):
286
- gr.Markdown("Upload an image to generate variations…")
287
- img_var = gr.Image(type="numpy", label="Source Image", height=400)
288
- btn_var = gr.Button("Create Variations ✨")
289
- gallery_var = gr.Gallery(columns=2, height="auto")
290
- btn_var.click(variation_image, inputs=[api, img_var] + common_controls, outputs=gallery_var)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  return demo
292
 
293
 
294
  if __name__ == "__main__":
 
 
 
 
295
  app = build_ui()
296
- app.launch(share=os.getenv("GRADIO_SHARE") == "true", debug=os.getenv("GRADIO_DEBUG") == "true")
 
 
 
 
 
 
 
34
  Decode base64 images into numpy arrays (for Gradio) or pass URL strings directly.
35
  """
36
  imgs: List[Union[np.ndarray, str]] = []
37
+ if not resp or not hasattr(resp, 'data'):
38
+ print("Warning: Response object missing or has no 'data' attribute.")
39
+ return imgs # Return empty list if response is invalid
40
+
41
  for d in resp.data:
42
  if hasattr(d, "b64_json") and d.b64_json:
43
+ try:
44
+ data = base64.b64decode(d.b64_json)
45
+ img = Image.open(io.BytesIO(data))
46
+ imgs.append(np.array(img))
47
+ except Exception as decode_err:
48
+ print(f"Error decoding base64 image: {decode_err}")
49
  elif getattr(d, "url", None):
50
  imgs.append(d.url)
51
+ else:
52
+ print(f"Warning: Response item has neither b64_json nor url: {d}")
53
  return imgs
54
 
55
 
 
58
  n: int,
59
  size: str,
60
  quality: str,
61
+ out_fmt: str, # Note: out_fmt is used *after* generation for conversion, not directly in API call
62
+ compression: int, # Note: compression is used *after* generation for conversion
63
+ transparent_bg: bool, # Note: transparent_bg is used *after* generation for conversion if not directly supported
64
  ) -> Dict[str, Any]:
65
  """Prepare keyword args for OpenAI Images API."""
66
  kwargs: Dict[str, Any] = {
67
  "model": MODEL,
68
  "n": n,
69
+ # Always request PNG for maximum quality/editability before potential conversion
70
+ "response_format": "b64_json", # Request base64 to handle locally
71
  }
72
  if size != "auto":
73
  kwargs["size"] = size
74
+ # DALL-E 3 uses 'quality': 'standard' or 'hd'. DALL-E 2 doesn't have quality.
75
+ # Adapt this based on the actual model's capabilities. Assuming 'hd' for 'high'.
76
  if quality != "auto":
77
+ # Map your choices to OpenAI's expected values if needed
78
+ # Example for DALL-E 3:
79
+ # if quality == "high": kwargs["quality"] = "hd"
80
+ # elif quality == "medium": kwargs["quality"] = "standard" # or omit
81
+ # For now, pass it directly, but be aware it might not be supported by MODEL
82
+ kwargs["quality"] = quality
83
+
84
  if prompt is not None:
85
  kwargs["prompt"] = prompt
86
+
87
+ # Note: Background removal is not a standard parameter for generate/edit/variation.
88
+ # This would typically be a post-processing step or require a specific model/API.
89
+ # If transparent_bg is True, you might need to handle it after receiving the image.
90
+ # if transparent_bg and out_fmt in {"png", "webp"}:
91
+ # kwargs["background"] = "transparent" # This parameter is hypothetical
92
+
93
  return kwargs
94
 
95
 
 
100
  ) -> np.ndarray:
101
  """
102
  Convert a PIL numpy array to target_fmt (JPEG/WebP) and return as numpy array.
103
+ Handles PNG pass-through.
104
  """
105
+ if target_fmt.lower() == "png":
106
+ # No conversion needed if already PNG (assuming input from b64 is effectively PNG)
107
+ return img_array
108
+
109
  img = Image.fromarray(img_array.astype(np.uint8))
110
  buf = io.BytesIO()
111
+ save_kwargs = {}
112
+ fmt_upper = target_fmt.upper()
113
+
114
+ if fmt_upper in ["JPEG", "WEBP"]:
115
+ save_kwargs["quality"] = quality
116
+ # Handle transparency for WebP
117
+ if fmt_upper == "WEBP":
118
+ # Check if image has alpha channel
119
+ if img.mode == 'RGBA' or 'A' in img.getbands():
120
+ pass # WebP supports transparency inherently
121
+ else:
122
+ # If no alpha, don't need special handling unless forcing transparency loss
123
+ pass
124
+ # Handle transparency loss for JPEG
125
+ elif fmt_upper == "JPEG":
126
+ if img.mode == 'RGBA' or img.mode == 'LA' or (img.mode == 'P' and 'transparency' in img.info):
127
+ # Convert to RGB, losing transparency. Default white background.
128
+ img = img.convert('RGB')
129
+
130
+ try:
131
+ img.save(buf, format=fmt_upper, **save_kwargs)
132
+ buf.seek(0)
133
+ img2 = Image.open(buf)
134
+ return np.array(img2)
135
+ except Exception as e:
136
+ print(f"Error during image conversion to {target_fmt}: {e}")
137
+ # Fallback to returning original array on conversion error
138
+ return img_array
139
 
140
 
141
  def _format_openai_error(e: Exception) -> str:
142
+ """Formats OpenAI API errors into user-friendly messages."""
143
  error_message = f"An error occurred: {type(e).__name__}"
144
  details = ""
145
  if hasattr(e, 'body') and e.body:
146
  try:
147
+ # Try parsing as JSON first
148
+ body = json.loads(str(e.body))
149
  if isinstance(body, dict) and 'error' in body and isinstance(body['error'], dict) and 'message' in body['error']:
150
  details = body['error']['message']
151
+ elif isinstance(body, dict) and 'message' in body: # Sometimes the message is top-level
152
  details = body['message']
153
+ else:
154
+ details = str(e.body) # Fallback if structure is unexpected
155
+ except json.JSONDecodeError:
156
+ # If body is not JSON, use its string representation
157
  details = str(e.body)
158
+ except Exception:
159
+ # Catch any other parsing errors
160
+ details = str(e.body)
161
+ elif hasattr(e, 'message') and e.message: # Fallback for older error structures
162
  details = e.message
163
+
164
  if details:
165
  error_message = f"OpenAI API Error: {details}"
166
+ else: # Keep the generic message if no details found
167
+ error_message = f"An OpenAI API error occurred: {str(e)}"
168
+
169
+
170
+ # Specific error type handling
171
  if isinstance(e, openai.AuthenticationError):
172
+ error_message = "Invalid OpenAI API key. Please check your key and ensure it's active."
173
  elif isinstance(e, openai.PermissionDeniedError):
174
  prefix = "Permission Denied."
175
+ if details and "organization verification" in details.lower():
176
+ prefix += " Your organization may need verification or payment method update to use this feature/model."
177
+ elif details and "quota" in details.lower():
178
+ prefix += " You might have exceeded your usage quota."
179
  error_message = f"{prefix} Details: {details}" if details else prefix
180
  elif isinstance(e, openai.RateLimitError):
181
+ error_message = "Rate limit exceeded. Please wait and try again later, or check your usage limits."
182
  elif isinstance(e, openai.BadRequestError):
183
  error_message = f"OpenAI Bad Request: {details or str(e)}"
184
+ if details:
185
+ if "mask" in details.lower(): error_message += " (Check mask format/dimensions/transparency)"
186
+ if "size" in details.lower(): error_message += " (Check image/mask dimensions or requested size compatibility)"
187
+ if "model does not support variations" in details.lower(): error_message += f" ({MODEL} does not support variations)."
188
+ if "unsupported file format" in details.lower() or "unsupported mimetype" in details.lower(): error_message += " (Ensure input image is PNG, JPG, or WEBP)"
189
+ if "prompt" in details.lower() and "policy" in details.lower(): error_message += " (Prompt may violate OpenAI's safety policies)"
190
+ elif isinstance(e, openai.APIConnectionError):
191
+ error_message = "Could not connect to OpenAI. Please check your network connection."
192
+ elif isinstance(e, openai.InternalServerError):
193
+ error_message = "OpenAI server error. Please try again later."
194
+
195
  return error_message
196
 
197
 
 
204
  quality: str,
205
  out_fmt: str,
206
  compression: int,
207
+ transparent_bg: bool, # Note: Transparency handled post-generation if needed
208
  ):
209
  if not prompt:
210
  raise gr.Error("Please enter a prompt.")
211
  try:
212
  client = _client(api_key)
213
+ # Request b64_json for local processing/conversion
214
  common_args = _common_kwargs(prompt, n, size, quality, out_fmt, compression, transparent_bg)
215
+ common_args["response_format"] = "b64_json" # Ensure we get base64
216
+
217
+ print(f"Generating images with args: {common_args}") # Debug print
218
  resp = client.images.generate(**common_args)
219
+ imgs_np = _img_list(resp) # Should be list of numpy arrays
220
+
221
+ # Post-generation conversion
222
+ final_imgs = []
223
+ for img_np in imgs_np:
224
+ if isinstance(img_np, np.ndarray):
225
+ # Apply transparency removal or format conversion here if needed
226
+ # Note: True transparency generation isn't standard. This handles format conversion.
227
+ # If transparent_bg was intended for background removal, that needs a separate model/tool.
228
+ converted_img = convert_to_format(img_np, out_fmt, compression)
229
+ final_imgs.append(converted_img)
230
+ else:
231
+ # If we somehow got a URL (shouldn't with b64_json), append it directly
232
+ final_imgs.append(img_np)
233
+
234
+ if not final_imgs:
235
+ raise gr.Error("Failed to generate or process images. Check logs.")
236
+
237
+ return final_imgs
238
  except (openai.APIError, openai.OpenAIError) as e:
239
+ print(f"OpenAI API Error during generation: {type(e).__name__}: {e}")
240
  raise gr.Error(_format_openai_error(e))
241
  except Exception as e:
242
  print(f"Unexpected error during generation: {type(e).__name__}: {e}")
243
+ import traceback
244
+ traceback.print_exc() # Print full traceback for unexpected errors
245
  raise gr.Error("An unexpected application error occurred. Please check logs.")
246
 
247
 
248
  # ---------- Edit / Inpaint ---------- #
249
+ def _bytes_from_numpy(arr: np.ndarray, format: str = "PNG") -> bytes:
250
+ """Converts numpy array to bytes in the specified format."""
251
  img = Image.fromarray(arr.astype(np.uint8))
252
  buf = io.BytesIO()
253
+ img.save(buf, format=format)
254
  return buf.getvalue()
255
 
256
+ def _ensure_rgba_for_mask(mask_array: np.ndarray) -> np.ndarray:
257
+ """Ensures the mask is RGBA, converting grayscale/RGB if necessary."""
258
+ if mask_array.ndim == 2: # Grayscale
259
+ # Convert grayscale to RGBA: White areas (255) become transparent (alpha=0), others opaque black
260
+ alpha = np.where(mask_array == 255, 0, 255).astype(np.uint8)
261
+ rgb = np.zeros((*mask_array.shape, 3), dtype=np.uint8) # Black RGB
262
+ rgba = np.dstack((rgb, alpha))
263
+ return rgba
264
+ elif mask_array.ndim == 3:
265
+ if mask_array.shape[2] == 3: # RGB
266
+ # Assume white RGB (255, 255, 255) means transparent for Gradio mask
267
+ is_white = np.all(mask_array == [255, 255, 255], axis=2)
268
+ alpha = np.where(is_white, 0, 255).astype(np.uint8)
269
+ rgba = np.dstack((mask_array, alpha))
270
+ return rgba
271
+ elif mask_array.shape[2] == 4: # Already RGBA
272
+ # Ensure correct interpretation: 0 alpha means transparent (area to edit)
273
+ # Gradio ImageMask often uses white paint on transparent bg.
274
+ # We need alpha=0 for transparent areas (edit target).
275
+ # If alpha channel is mostly 255 (opaque), invert it assuming white paint = transparent target.
276
+ alpha_channel = mask_array[:, :, 3]
277
+ if np.mean(alpha_channel) > 128: # Heuristic: if mostly opaque
278
+ print("Inverting mask alpha channel based on heuristic (mostly opaque).")
279
+ mask_array[:, :, 3] = 255 - alpha_channel
280
+ return mask_array # Assume it's correctly formatted otherwise
281
+ raise ValueError("Unsupported mask format/dimensions")
282
+
283
 
284
  def _extract_mask_array(mask_value: Union[np.ndarray, Dict[str, Any], None]) -> Optional[np.ndarray]:
285
+ """Extracts the mask numpy array from Gradio's ImageMask output."""
286
  if mask_value is None:
287
+ print("Mask input is None.")
288
  return None
289
+ # Gradio ImageMask output is often a dict {'image': ndarray, 'mask': ndarray}
290
+ # Or sometimes just the mask ndarray directly depending on version/setup
291
+ mask_array = None
292
  if isinstance(mask_value, dict):
293
  mask_array = mask_value.get("mask")
294
+ print(f"Extracted mask from dict: type={type(mask_array)}, shape={getattr(mask_array, 'shape', 'N/A')}")
295
+ elif isinstance(mask_value, np.ndarray):
296
+ mask_array = mask_value
297
+ print(f"Received mask as ndarray directly: shape={mask_array.shape}")
298
+
299
+ if isinstance(mask_array, np.ndarray):
300
+ # Add basic validation
301
+ if mask_array.ndim < 2 or mask_array.ndim > 3:
302
+ print(f"Warning: Unexpected mask dimensions: {mask_array.ndim}")
303
+ return None
304
+ if mask_array.size == 0:
305
+ print("Warning: Received empty mask array.")
306
+ return None
307
+ print(f"Successfully extracted mask array, shape: {mask_array.shape}, dtype: {mask_array.dtype}, min/max: {np.min(mask_array)}/{np.max(mask_array)}")
308
+ return mask_array
309
+
310
+ print(f"Could not extract ndarray mask from input type: {type(mask_value)}")
311
  return None
312
 
313
 
314
  def edit_image(
315
  api_key: str,
316
  image_numpy: Optional[np.ndarray],
317
+ mask_input: Optional[Union[np.ndarray, Dict[str, Any]]], # Renamed for clarity
318
  prompt: str,
319
  n: int,
320
  size: str,
321
  quality: str,
322
  out_fmt: str,
323
  compression: int,
324
+ transparent_bg: bool, # Note: Transparency handled post-generation if needed
325
  ):
326
  if image_numpy is None:
327
  raise gr.Error("Please upload an image.")
328
  if not prompt:
329
  raise gr.Error("Please enter an edit prompt.")
330
 
331
+ # Convert source image to PNG bytes
332
+ try:
333
+ img_bytes = _bytes_from_numpy(image_numpy, format="PNG")
334
+ # --- FIX: Provide image data as a tuple (filename, bytes, mimetype) ---
335
+ image_tuple: Tuple[str, bytes, str] = ("image.png", img_bytes, "image/png")
336
+ print(f"Prepared source image: {image_tuple[0]}, size={len(image_tuple[1])} bytes, type={image_tuple[2]}")
337
+ except Exception as e:
338
+ print(f"Error converting source image to bytes: {e}")
339
+ raise gr.Error("Failed to process source image.")
340
+
341
+ mask_tuple: Optional[Tuple[str, bytes, str]] = None
342
+ mask_numpy = _extract_mask_array(mask_input)
343
 
 
344
  if mask_numpy is not None:
345
+ try:
346
+ # Ensure mask matches image dimensions (OpenAI requires this)
347
+ if image_numpy.shape[:2] != mask_numpy.shape[:2]:
348
+ raise gr.Error(f"Mask dimensions ({mask_numpy.shape[:2]}) must match image dimensions ({image_numpy.shape[:2]}). Please repaint the mask.")
349
+
350
+ # Convert mask to RGBA PNG bytes as required by OpenAI API
351
+ # The API expects a PNG where transparent pixels (alpha=0) indicate the area to edit.
352
+ mask_rgba = _ensure_rgba_for_mask(mask_numpy)
353
+ mask_bytes = _bytes_from_numpy(mask_rgba, format="PNG")
354
+
355
+ # --- FIX: Provide mask data as a tuple ---
356
+ mask_tuple = ("mask.png", mask_bytes, "image/png")
357
+ print(f"Prepared mask: {mask_tuple[0]}, size={len(mask_tuple[1])} bytes, type={mask_tuple[2]}")
358
+
359
+ except ValueError as e:
360
+ print(f"Error processing mask: {e}")
361
+ raise gr.Error(f"Failed to process mask: {e}")
362
+ except Exception as e:
363
+ print(f"Error converting mask to bytes: {e}")
364
+ raise gr.Error("Failed to process mask.")
365
+ else:
366
+ # If no mask is provided, it's an 'edit' without inpainting (DALL-E 2 supported this, DALL-E 3 might interpret differently)
367
+ # The API might require a mask for the /edit endpoint. Check API docs for the specific model.
368
+ # For DALL-E 2, omitting mask was allowed. Let's assume it might work or fail gracefully.
369
+ print("No valid mask provided or extracted. Proceeding without mask.")
370
+ # raise gr.Error("Please paint a mask to indicate the edit area.") # Uncomment if mask is strictly required
371
 
372
  try:
373
  client = _client(api_key)
374
+ # Get common args, ensure response format is b64_json
375
  common_args = _common_kwargs(prompt, n, size, quality, out_fmt, compression, transparent_bg)
376
+ common_args["response_format"] = "b64_json" # Ensure we get base64
377
+
378
+ # Prepare final API arguments
379
+ api_kwargs = {
380
+ "image": image_tuple,
381
+ **common_args
382
+ }
383
+ if mask_tuple is not None:
384
+ api_kwargs["mask"] = mask_tuple
385
+ else:
386
+ # If mask is omitted, remove prompt from common_args if the API treats it like variations?
387
+ # DALL-E 2 /edit without mask needed no prompt. DALL-E 3 might differ.
388
+ # Let's keep the prompt for now. The API error will tell us if it's wrong.
389
+ pass
390
+ # api_kwargs.pop("prompt", None) # Consider this if API complains about prompt without mask
391
+
392
+ print(f"Editing image with args: { {k: v if k not in ['image', 'mask'] else (v[0], f'{len(v[1])} bytes', v[2]) for k, v in api_kwargs.items()} }") # Debug print
393
+ resp = client.images.edit(**api_kwargs) # Call the edit endpoint
394
+
395
+ imgs_np = _img_list(resp) # Should be list of numpy arrays
396
+
397
+ # Post-generation conversion
398
+ final_imgs = []
399
+ for img_np in imgs_np:
400
+ if isinstance(img_np, np.ndarray):
401
+ converted_img = convert_to_format(img_np, out_fmt, compression)
402
+ final_imgs.append(converted_img)
403
+ else:
404
+ final_imgs.append(img_np) # Append URL if received
405
+
406
+ if not final_imgs:
407
+ raise gr.Error("Failed to edit or process images. Check logs.")
408
+
409
+ return final_imgs
410
+
411
  except (openai.APIError, openai.OpenAIError) as e:
412
+ print(f"OpenAI API Error during edit: {type(e).__name__}: {e}")
413
  raise gr.Error(_format_openai_error(e))
414
  except Exception as e:
415
  print(f"Unexpected error during edit: {type(e).__name__}: {e}")
416
+ import traceback
417
+ traceback.print_exc()
418
  raise gr.Error("An unexpected application error occurred. Please check logs.")
419
 
420
 
 
424
  image_numpy: Optional[np.ndarray],
425
  n: int,
426
  size: str,
427
+ quality: str, # Note: Quality may not be supported by variations endpoint
428
  out_fmt: str,
429
  compression: int,
430
+ transparent_bg: bool, # Note: Transparency handled post-generation if needed
431
  ):
432
+ # Explicit warning as gpt-image-1 is likely not the correct model for variations
433
+ gr.Warning(f"Note: Image Variations are officially supported for DALL·E 2. Using model '{MODEL}' may fail or produce unexpected results.")
434
  if image_numpy is None:
435
  raise gr.Error("Please upload an image.")
436
 
437
+ # Convert source image to PNG bytes
438
+ try:
439
+ img_bytes = _bytes_from_numpy(image_numpy, format="PNG")
440
+ # --- FIX: Provide image data as a tuple ---
441
+ image_tuple: Tuple[str, bytes, str] = ("image.png", img_bytes, "image/png")
442
+ print(f"Prepared source image for variation: {image_tuple[0]}, size={len(image_tuple[1])} bytes, type={image_tuple[2]}")
443
+ except Exception as e:
444
+ print(f"Error converting source image to bytes for variation: {e}")
445
+ raise gr.Error("Failed to process source image.")
446
+
447
  try:
448
  client = _client(api_key)
449
+ # Prepare args for variations endpoint
450
+ var_args: Dict[str, Any] = {
451
+ "model": MODEL, # Use the selected model, though it might fail
452
+ "n": n,
453
+ "response_format": "b64_json" # Request base64
454
+ }
455
  if size != "auto":
456
  var_args["size"] = size
457
+ # Quality parameter is generally NOT supported for variations
458
+ # if quality != "auto":
459
+ # var_args["quality"] = quality # This will likely cause an error
460
+
461
+ print(f"Creating variations with args: { {k: v if k != 'image' else (v[0], f'{len(v[1])} bytes', v[2]) for k, v in {**var_args, 'image': image_tuple}.items()} }") # Debug print
462
+
463
+ # Pass the tuple to the image parameter
464
+ resp = client.images.create_variation(image=image_tuple, **var_args)
465
+
466
+ imgs_np = _img_list(resp) # Should be list of numpy arrays
467
+
468
+ # Post-generation conversion
469
+ final_imgs = []
470
+ for img_np in imgs_np:
471
+ if isinstance(img_np, np.ndarray):
472
+ converted_img = convert_to_format(img_np, out_fmt, compression)
473
+ final_imgs.append(converted_img)
474
+ else:
475
+ final_imgs.append(img_np) # Append URL if received
476
+
477
+ if not final_imgs:
478
+ raise gr.Error("Failed to create variations or process images. Check logs.")
479
+
480
+ return final_imgs
481
+
482
  except (openai.APIError, openai.OpenAIError) as e:
483
+ print(f"OpenAI API Error during variation: {type(e).__name__}: {e}")
484
+ err_msg = _format_openai_error(e)
485
+ # Add specific check for variation incompatibility
486
+ if isinstance(e, openai.BadRequestError) and ("model does not support variations" in err_msg.lower() or "not supported" in err_msg.lower()):
487
+ raise gr.Error(f"As warned, the selected model ('{MODEL}') does not support the variations endpoint. Try using 'dall-e-2'.")
488
+ raise gr.Error(err_msg)
489
  except Exception as e:
490
  print(f"Unexpected error during variation: {type(e).__name__}: {e}")
491
+ import traceback
492
+ traceback.print_exc()
493
  raise gr.Error("An unexpected application error occurred. Please check logs.")
494
 
495
 
496
  # ---------- UI ---------- #
497
  def build_ui():
498
+ with gr.Blocks(title="OpenAI Image Playground (BYOK)") as demo:
499
+ gr.Markdown(f"""# OpenAI Image Playground 🖼️🔑
500
+ Generate • Edit • Variations (using your own API key)
501
+ **Selected Model:** `{MODEL}` (Ensure your key has access)
502
+ """)
503
+ with gr.Accordion("🔐 API key & Model Info", open=False):
504
  api = gr.Textbox(label="OpenAI API key", type="password", placeholder="sk-...")
505
+ gr.Markdown(f"""
506
+ * **Model:** `{MODEL}` is configured in the code. This might be a placeholder; official models are typically `dall-e-3` or `dall-e-2`.
507
+ * **Variations:** Officially only supported by `dall-e-2`. Using other models here will likely fail.
508
+ * **Edit/Inpainting:** Requires a model supporting the `/images/edits` endpoint (like `dall-e-2`).
509
+ * **Size/Quality:** Options shown may not be supported by all models. Check OpenAI documentation for `{MODEL}` if it's a real model. DALL-E 3 uses `quality` ('standard'/'hd'), DALL-E 2 does not.
510
+ """)
511
 
512
  with gr.Row():
513
  n_slider = gr.Slider(1, 4, value=1, step=1, label="Number of images (n)")
514
+ size = gr.Dropdown(SIZE_CHOICES, value="auto", label="Size (if supported)", info="Set target size. 'auto' uses model default.")
515
+ quality = gr.Dropdown(QUALITY_CHOICES, value="auto", label="Quality (if supported)", info="'auto' uses model default. DALL-E 3: 'standard'/'hd'. DALL-E 2 ignores this.")
516
  with gr.Row():
517
+ out_fmt = gr.Radio(FORMAT_CHOICES, value="png", label="Output Format", info="Format for viewing/downloading generated images.")
518
+ compression = gr.Slider(0, 100, value=75, step=1, label="Compression % (JPEG/WebP)", visible=False, info="Lower value = smaller file, lower quality.")
519
+ # Transparency generation is complex; this checkbox is mainly for format support.
520
+ # Actual transparency depends on model/post-processing.
521
+ transparent = gr.Checkbox(False, label="Transparent background (PNG/WebP only)", info="Request transparency if model supports, or save PNG/WebP with alpha if generated.", visible=False) # Hidden for now as it's not directly controllable via API param
522
 
523
  def _toggle_compression(fmt):
524
  return gr.update(visible=fmt in {"jpeg", "webp"})
525
  out_fmt.change(_toggle_compression, inputs=out_fmt, outputs=compression)
526
 
527
+ # Combine common controls for easier passing to functions
528
  common_controls = [n_slider, size, quality, out_fmt, compression, transparent]
529
 
530
  with gr.Tabs():
531
  with gr.TabItem("Generate"):
532
+ prompt_gen = gr.Textbox(label="Prompt", lines=3, placeholder="A photorealistic image of..." )
533
+ btn_gen = gr.Button("Generate 🚀", variant="primary")
534
+ gallery_gen = gr.Gallery(label="Generated Images", columns=2, height="auto", preview=True)
535
+ # Clear outputs on new click
536
+ inputs_gen = [api, prompt_gen] + common_controls
537
+ prompt_gen.submit(generate, inputs=inputs_gen, outputs=gallery_gen)
538
+ btn_gen.click(generate, inputs=inputs_gen, outputs=gallery_gen)
539
+
540
 
541
  with gr.TabItem("Edit / Inpaint"):
542
+ gr.Markdown("Upload an image, **paint white** over the area you want the AI to change, then provide an edit prompt.")
543
+ with gr.Row():
544
+ img_edit_src = gr.Image(type="numpy", label="Source Image", height=400, tool="select")
545
+ # Use ImageMask tool for painting
546
+ mask_canvas = gr.ImageMask(type="numpy", label="Mask – Paint Area to Edit (White)", height=400, brush_radius=20)
547
+ # Link source image to mask canvas background
548
+ # img_edit_src.change(lambda x: x, inputs=img_edit_src, outputs=mask_canvas) # This might auto-clear mask, check Gradio docs if needed
549
+
550
+ prompt_edit = gr.Textbox(label="Edit prompt", lines=2, placeholder="Example: Make the cat wear a wizard hat")
551
+ btn_edit = gr.Button("Edit Image 🖌️", variant="primary")
552
+ gallery_edit = gr.Gallery(label="Edited Images", columns=2, height="auto", preview=True)
553
+
554
+ # Define inputs for the edit function
555
+ inputs_edit = [api, img_edit_src, mask_canvas, prompt_edit] + common_controls
556
+ prompt_edit.submit(edit_image, inputs=inputs_edit, outputs=gallery_edit)
557
+ btn_edit.click(edit_image, inputs=inputs_edit, outputs=gallery_edit)
558
+
559
+
560
+ with gr.TabItem("Variations (DALL·E 2 only)"):
561
+ gr.Markdown("Upload an image to generate variations. **Warning:** This endpoint is officially supported only by DALL·E 2.")
562
+ img_var_src = gr.Image(type="numpy", label="Source Image", height=400)
563
+ btn_var = gr.Button("Create Variations ✨", variant="primary")
564
+ gallery_var = gr.Gallery(label="Variations", columns=2, height="auto", preview=True)
565
+
566
+ # Define inputs for the variation function
567
+ inputs_var = [api, img_var_src] + common_controls
568
+ # Variations don't use prompt, quality typically ignored
569
+ btn_var.click(variation_image, inputs=inputs_var, outputs=gallery_var)
570
+
571
  return demo
572
 
573
 
574
  if __name__ == "__main__":
575
+ # For debugging purposes, you can preload an API key from env vars
576
+ # Make sure to handle security appropriately if deploying publicly
577
+ # api_key_env = os.getenv("OPENAI_API_KEY")
578
+
579
  app = build_ui()
580
+ # Launch the Gradio app
581
+ app.launch(
582
+ share=os.getenv("GRADIO_SHARE") == "true",
583
+ debug=os.getenv("GRADIO_DEBUG") == "true",
584
+ server_name="0.0.0.0" # Bind to all interfaces for Docker compatibility
585
+ # auth=("user", "password") # Add basic auth if needed for sharing
586
+ )