Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -34,22 +34,13 @@ def _img_list(resp) -> List[Union[np.ndarray, str]]:
|
|
34 |
Decode base64 images into numpy arrays (for Gradio) or pass URL strings directly.
|
35 |
"""
|
36 |
imgs: List[Union[np.ndarray, str]] = []
|
37 |
-
if not resp or not hasattr(resp, 'data'):
|
38 |
-
print("Warning: Response object missing or has no 'data' attribute.")
|
39 |
-
return imgs # Return empty list if response is invalid
|
40 |
-
|
41 |
for d in resp.data:
|
42 |
if hasattr(d, "b64_json") and d.b64_json:
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
imgs.append(np.array(img))
|
47 |
-
except Exception as decode_err:
|
48 |
-
print(f"Error decoding base64 image: {decode_err}")
|
49 |
elif getattr(d, "url", None):
|
50 |
imgs.append(d.url)
|
51 |
-
else:
|
52 |
-
print(f"Warning: Response item has neither b64_json nor url: {d}")
|
53 |
return imgs
|
54 |
|
55 |
|
@@ -58,38 +49,24 @@ def _common_kwargs(
|
|
58 |
n: int,
|
59 |
size: str,
|
60 |
quality: str,
|
61 |
-
out_fmt: str,
|
62 |
-
compression: int,
|
63 |
-
transparent_bg: bool,
|
64 |
) -> Dict[str, Any]:
|
65 |
"""Prepare keyword args for OpenAI Images API."""
|
66 |
kwargs: Dict[str, Any] = {
|
67 |
"model": MODEL,
|
68 |
"n": n,
|
69 |
-
# Always request PNG for maximum quality/editability before potential conversion
|
70 |
-
"response_format": "b64_json", # Request base64 to handle locally
|
71 |
}
|
72 |
if size != "auto":
|
73 |
kwargs["size"] = size
|
74 |
-
# DALL-E 3 uses 'quality': 'standard' or 'hd'. DALL-E 2 doesn't have quality.
|
75 |
-
# Adapt this based on the actual model's capabilities. Assuming 'hd' for 'high'.
|
76 |
if quality != "auto":
|
77 |
-
|
78 |
-
# Example for DALL-E 3:
|
79 |
-
# if quality == "high": kwargs["quality"] = "hd"
|
80 |
-
# elif quality == "medium": kwargs["quality"] = "standard" # or omit
|
81 |
-
# For now, pass it directly, but be aware it might not be supported by MODEL
|
82 |
-
kwargs["quality"] = quality
|
83 |
-
|
84 |
if prompt is not None:
|
85 |
kwargs["prompt"] = prompt
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
# If transparent_bg is True, you might need to handle it after receiving the image.
|
90 |
-
# if transparent_bg and out_fmt in {"png", "webp"}:
|
91 |
-
# kwargs["background"] = "transparent" # This parameter is hypothetical
|
92 |
-
|
93 |
return kwargs
|
94 |
|
95 |
|
@@ -100,98 +77,45 @@ def convert_to_format(
|
|
100 |
) -> np.ndarray:
|
101 |
"""
|
102 |
Convert a PIL numpy array to target_fmt (JPEG/WebP) and return as numpy array.
|
103 |
-
Handles PNG pass-through.
|
104 |
"""
|
105 |
-
if target_fmt.lower() == "png":
|
106 |
-
# No conversion needed if already PNG (assuming input from b64 is effectively PNG)
|
107 |
-
return img_array
|
108 |
-
|
109 |
img = Image.fromarray(img_array.astype(np.uint8))
|
110 |
buf = io.BytesIO()
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
save_kwargs["quality"] = quality
|
116 |
-
# Handle transparency for WebP
|
117 |
-
if fmt_upper == "WEBP":
|
118 |
-
# Check if image has alpha channel
|
119 |
-
if img.mode == 'RGBA' or 'A' in img.getbands():
|
120 |
-
pass # WebP supports transparency inherently
|
121 |
-
else:
|
122 |
-
# If no alpha, don't need special handling unless forcing transparency loss
|
123 |
-
pass
|
124 |
-
# Handle transparency loss for JPEG
|
125 |
-
elif fmt_upper == "JPEG":
|
126 |
-
if img.mode == 'RGBA' or img.mode == 'LA' or (img.mode == 'P' and 'transparency' in img.info):
|
127 |
-
# Convert to RGB, losing transparency. Default white background.
|
128 |
-
img = img.convert('RGB')
|
129 |
-
|
130 |
-
try:
|
131 |
-
img.save(buf, format=fmt_upper, **save_kwargs)
|
132 |
-
buf.seek(0)
|
133 |
-
img2 = Image.open(buf)
|
134 |
-
return np.array(img2)
|
135 |
-
except Exception as e:
|
136 |
-
print(f"Error during image conversion to {target_fmt}: {e}")
|
137 |
-
# Fallback to returning original array on conversion error
|
138 |
-
return img_array
|
139 |
|
140 |
|
141 |
def _format_openai_error(e: Exception) -> str:
|
142 |
-
"""Formats OpenAI API errors into user-friendly messages."""
|
143 |
error_message = f"An error occurred: {type(e).__name__}"
|
144 |
details = ""
|
145 |
if hasattr(e, 'body') and e.body:
|
146 |
try:
|
147 |
-
|
148 |
-
body = json.loads(str(e.body))
|
149 |
if isinstance(body, dict) and 'error' in body and isinstance(body['error'], dict) and 'message' in body['error']:
|
150 |
details = body['error']['message']
|
151 |
-
elif isinstance(body, dict) and 'message' in body:
|
152 |
details = body['message']
|
153 |
-
else:
|
154 |
-
details = str(e.body) # Fallback if structure is unexpected
|
155 |
-
except json.JSONDecodeError:
|
156 |
-
# If body is not JSON, use its string representation
|
157 |
-
details = str(e.body)
|
158 |
except Exception:
|
159 |
-
|
160 |
-
|
161 |
-
elif hasattr(e, 'message') and e.message: # Fallback for older error structures
|
162 |
details = e.message
|
163 |
-
|
164 |
if details:
|
165 |
error_message = f"OpenAI API Error: {details}"
|
166 |
-
else: # Keep the generic message if no details found
|
167 |
-
error_message = f"An OpenAI API error occurred: {str(e)}"
|
168 |
-
|
169 |
-
|
170 |
-
# Specific error type handling
|
171 |
if isinstance(e, openai.AuthenticationError):
|
172 |
-
error_message = "Invalid OpenAI API key. Please check your key
|
173 |
elif isinstance(e, openai.PermissionDeniedError):
|
174 |
prefix = "Permission Denied."
|
175 |
-
if
|
176 |
-
prefix += " Your organization may need verification
|
177 |
-
elif details and "quota" in details.lower():
|
178 |
-
prefix += " You might have exceeded your usage quota."
|
179 |
error_message = f"{prefix} Details: {details}" if details else prefix
|
180 |
elif isinstance(e, openai.RateLimitError):
|
181 |
-
error_message = "Rate limit exceeded. Please wait and try again later
|
182 |
elif isinstance(e, openai.BadRequestError):
|
183 |
error_message = f"OpenAI Bad Request: {details or str(e)}"
|
184 |
-
if details:
|
185 |
-
|
186 |
-
|
187 |
-
if "model does not support variations" in details.lower(): error_message += f" ({MODEL} does not support variations)."
|
188 |
-
if "unsupported file format" in details.lower() or "unsupported mimetype" in details.lower(): error_message += " (Ensure input image is PNG, JPG, or WEBP)"
|
189 |
-
if "prompt" in details.lower() and "policy" in details.lower(): error_message += " (Prompt may violate OpenAI's safety policies)"
|
190 |
-
elif isinstance(e, openai.APIConnectionError):
|
191 |
-
error_message = "Could not connect to OpenAI. Please check your network connection."
|
192 |
-
elif isinstance(e, openai.InternalServerError):
|
193 |
-
error_message = "OpenAI server error. Please try again later."
|
194 |
-
|
195 |
return error_message
|
196 |
|
197 |
|
@@ -204,383 +128,195 @@ def generate(
|
|
204 |
quality: str,
|
205 |
out_fmt: str,
|
206 |
compression: int,
|
207 |
-
transparent_bg: bool,
|
208 |
):
|
209 |
if not prompt:
|
210 |
raise gr.Error("Please enter a prompt.")
|
211 |
try:
|
212 |
client = _client(api_key)
|
213 |
-
# Request b64_json for local processing/conversion
|
214 |
common_args = _common_kwargs(prompt, n, size, quality, out_fmt, compression, transparent_bg)
|
215 |
-
common_args["response_format"] = "b64_json" # Ensure we get base64
|
216 |
-
|
217 |
-
print(f"Generating images with args: {common_args}") # Debug print
|
218 |
resp = client.images.generate(**common_args)
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
for img_np in imgs_np:
|
224 |
-
if isinstance(img_np, np.ndarray):
|
225 |
-
# Apply transparency removal or format conversion here if needed
|
226 |
-
# Note: True transparency generation isn't standard. This handles format conversion.
|
227 |
-
# If transparent_bg was intended for background removal, that needs a separate model/tool.
|
228 |
-
converted_img = convert_to_format(img_np, out_fmt, compression)
|
229 |
-
final_imgs.append(converted_img)
|
230 |
-
else:
|
231 |
-
# If we somehow got a URL (shouldn't with b64_json), append it directly
|
232 |
-
final_imgs.append(img_np)
|
233 |
-
|
234 |
-
if not final_imgs:
|
235 |
-
raise gr.Error("Failed to generate or process images. Check logs.")
|
236 |
-
|
237 |
-
return final_imgs
|
238 |
except (openai.APIError, openai.OpenAIError) as e:
|
239 |
-
print(f"OpenAI API Error during generation: {type(e).__name__}: {e}")
|
240 |
raise gr.Error(_format_openai_error(e))
|
241 |
except Exception as e:
|
242 |
print(f"Unexpected error during generation: {type(e).__name__}: {e}")
|
243 |
-
import traceback
|
244 |
-
traceback.print_exc() # Print full traceback for unexpected errors
|
245 |
raise gr.Error("An unexpected application error occurred. Please check logs.")
|
246 |
|
247 |
|
248 |
# ---------- Edit / Inpaint ---------- #
|
249 |
-
def _bytes_from_numpy(arr: np.ndarray
|
250 |
-
"""Converts numpy array to bytes in the specified format."""
|
251 |
img = Image.fromarray(arr.astype(np.uint8))
|
252 |
buf = io.BytesIO()
|
253 |
-
img.save(buf, format=
|
254 |
return buf.getvalue()
|
255 |
|
256 |
-
def _ensure_rgba_for_mask(mask_array: np.ndarray) -> np.ndarray:
|
257 |
-
"""Ensures the mask is RGBA, converting grayscale/RGB if necessary."""
|
258 |
-
if mask_array.ndim == 2: # Grayscale
|
259 |
-
# Convert grayscale to RGBA: White areas (255) become transparent (alpha=0), others opaque black
|
260 |
-
alpha = np.where(mask_array == 255, 0, 255).astype(np.uint8)
|
261 |
-
rgb = np.zeros((*mask_array.shape, 3), dtype=np.uint8) # Black RGB
|
262 |
-
rgba = np.dstack((rgb, alpha))
|
263 |
-
return rgba
|
264 |
-
elif mask_array.ndim == 3:
|
265 |
-
if mask_array.shape[2] == 3: # RGB
|
266 |
-
# Assume white RGB (255, 255, 255) means transparent for Gradio mask
|
267 |
-
is_white = np.all(mask_array == [255, 255, 255], axis=2)
|
268 |
-
alpha = np.where(is_white, 0, 255).astype(np.uint8)
|
269 |
-
rgba = np.dstack((mask_array, alpha))
|
270 |
-
return rgba
|
271 |
-
elif mask_array.shape[2] == 4: # Already RGBA
|
272 |
-
# Ensure correct interpretation: 0 alpha means transparent (area to edit)
|
273 |
-
# Gradio ImageMask often uses white paint on transparent bg.
|
274 |
-
# We need alpha=0 for transparent areas (edit target).
|
275 |
-
# If alpha channel is mostly 255 (opaque), invert it assuming white paint = transparent target.
|
276 |
-
alpha_channel = mask_array[:, :, 3]
|
277 |
-
if np.mean(alpha_channel) > 128: # Heuristic: if mostly opaque
|
278 |
-
print("Inverting mask alpha channel based on heuristic (mostly opaque).")
|
279 |
-
mask_array[:, :, 3] = 255 - alpha_channel
|
280 |
-
return mask_array # Assume it's correctly formatted otherwise
|
281 |
-
raise ValueError("Unsupported mask format/dimensions")
|
282 |
-
|
283 |
|
284 |
def _extract_mask_array(mask_value: Union[np.ndarray, Dict[str, Any], None]) -> Optional[np.ndarray]:
|
285 |
-
"""Extracts the mask numpy array from Gradio's ImageMask output."""
|
286 |
if mask_value is None:
|
287 |
-
print("Mask input is None.")
|
288 |
return None
|
289 |
-
# Gradio ImageMask output is often a dict {'image': ndarray, 'mask': ndarray}
|
290 |
-
# Or sometimes just the mask ndarray directly depending on version/setup
|
291 |
-
mask_array = None
|
292 |
if isinstance(mask_value, dict):
|
293 |
mask_array = mask_value.get("mask")
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
if isinstance(mask_array, np.ndarray):
|
300 |
-
# Add basic validation
|
301 |
-
if mask_array.ndim < 2 or mask_array.ndim > 3:
|
302 |
-
print(f"Warning: Unexpected mask dimensions: {mask_array.ndim}")
|
303 |
-
return None
|
304 |
-
if mask_array.size == 0:
|
305 |
-
print("Warning: Received empty mask array.")
|
306 |
-
return None
|
307 |
-
print(f"Successfully extracted mask array, shape: {mask_array.shape}, dtype: {mask_array.dtype}, min/max: {np.min(mask_array)}/{np.max(mask_array)}")
|
308 |
-
return mask_array
|
309 |
-
|
310 |
-
print(f"Could not extract ndarray mask from input type: {type(mask_value)}")
|
311 |
return None
|
312 |
|
313 |
|
314 |
def edit_image(
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
):
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
|
|
|
|
340 |
|
341 |
-
mask_tuple: Optional[Tuple[str, bytes, str]] = None
|
342 |
-
mask_numpy = _extract_mask_array(mask_input)
|
343 |
-
|
344 |
-
if mask_numpy is not None:
|
345 |
try:
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
|
|
|
|
|
|
|
|
|
|
362 |
except Exception as e:
|
363 |
-
print(f"
|
364 |
-
raise gr.Error("
|
365 |
-
else:
|
366 |
-
# If no mask is provided, it's an 'edit' without inpainting (DALL-E 2 supported this, DALL-E 3 might interpret differently)
|
367 |
-
# The API might require a mask for the /edit endpoint. Check API docs for the specific model.
|
368 |
-
# For DALL-E 2, omitting mask was allowed. Let's assume it might work or fail gracefully.
|
369 |
-
print("No valid mask provided or extracted. Proceeding without mask.")
|
370 |
-
# raise gr.Error("Please paint a mask to indicate the edit area.") # Uncomment if mask is strictly required
|
371 |
|
372 |
-
try:
|
373 |
-
client = _client(api_key)
|
374 |
-
# Get common args, ensure response format is b64_json
|
375 |
-
common_args = _common_kwargs(prompt, n, size, quality, out_fmt, compression, transparent_bg)
|
376 |
-
common_args["response_format"] = "b64_json" # Ensure we get base64
|
377 |
-
|
378 |
-
# Prepare final API arguments
|
379 |
-
api_kwargs = {
|
380 |
-
"image": image_tuple,
|
381 |
-
**common_args
|
382 |
-
}
|
383 |
-
if mask_tuple is not None:
|
384 |
-
api_kwargs["mask"] = mask_tuple
|
385 |
-
else:
|
386 |
-
# If mask is omitted, remove prompt from common_args if the API treats it like variations?
|
387 |
-
# DALL-E 2 /edit without mask needed no prompt. DALL-E 3 might differ.
|
388 |
-
# Let's keep the prompt for now. The API error will tell us if it's wrong.
|
389 |
-
pass
|
390 |
-
# api_kwargs.pop("prompt", None) # Consider this if API complains about prompt without mask
|
391 |
-
|
392 |
-
print(f"Editing image with args: { {k: v if k not in ['image', 'mask'] else (v[0], f'{len(v[1])} bytes', v[2]) for k, v in api_kwargs.items()} }") # Debug print
|
393 |
-
resp = client.images.edit(**api_kwargs) # Call the edit endpoint
|
394 |
-
|
395 |
-
imgs_np = _img_list(resp) # Should be list of numpy arrays
|
396 |
-
|
397 |
-
# Post-generation conversion
|
398 |
-
final_imgs = []
|
399 |
-
for img_np in imgs_np:
|
400 |
-
if isinstance(img_np, np.ndarray):
|
401 |
-
converted_img = convert_to_format(img_np, out_fmt, compression)
|
402 |
-
final_imgs.append(converted_img)
|
403 |
-
else:
|
404 |
-
final_imgs.append(img_np) # Append URL if received
|
405 |
-
|
406 |
-
if not final_imgs:
|
407 |
-
raise gr.Error("Failed to edit or process images. Check logs.")
|
408 |
-
|
409 |
-
return final_imgs
|
410 |
-
|
411 |
-
except (openai.APIError, openai.OpenAIError) as e:
|
412 |
-
print(f"OpenAI API Error during edit: {type(e).__name__}: {e}")
|
413 |
-
raise gr.Error(_format_openai_error(e))
|
414 |
-
except Exception as e:
|
415 |
-
print(f"Unexpected error during edit: {type(e).__name__}: {e}")
|
416 |
-
import traceback
|
417 |
-
traceback.print_exc()
|
418 |
-
raise gr.Error("An unexpected application error occurred. Please check logs.")
|
419 |
|
420 |
|
421 |
# ---------- Variations ---------- #
|
422 |
-
def variation_image(
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
):
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
436 |
|
437 |
-
# Convert source image to PNG bytes
|
438 |
-
try:
|
439 |
-
img_bytes = _bytes_from_numpy(image_numpy, format="PNG")
|
440 |
-
# --- FIX: Provide image data as a tuple ---
|
441 |
-
image_tuple: Tuple[str, bytes, str] = ("image.png", img_bytes, "image/png")
|
442 |
-
print(f"Prepared source image for variation: {image_tuple[0]}, size={len(image_tuple[1])} bytes, type={image_tuple[2]}")
|
443 |
-
except Exception as e:
|
444 |
-
print(f"Error converting source image to bytes for variation: {e}")
|
445 |
-
raise gr.Error("Failed to process source image.")
|
446 |
-
|
447 |
-
try:
|
448 |
-
client = _client(api_key)
|
449 |
-
# Prepare args for variations endpoint
|
450 |
-
var_args: Dict[str, Any] = {
|
451 |
-
"model": MODEL, # Use the selected model, though it might fail
|
452 |
-
"n": n,
|
453 |
-
"response_format": "b64_json" # Request base64
|
454 |
-
}
|
455 |
-
if size != "auto":
|
456 |
-
var_args["size"] = size
|
457 |
-
# Quality parameter is generally NOT supported for variations
|
458 |
-
# if quality != "auto":
|
459 |
-
# var_args["quality"] = quality # This will likely cause an error
|
460 |
-
|
461 |
-
print(f"Creating variations with args: { {k: v if k != 'image' else (v[0], f'{len(v[1])} bytes', v[2]) for k, v in {**var_args, 'image': image_tuple}.items()} }") # Debug print
|
462 |
-
|
463 |
-
# Pass the tuple to the image parameter
|
464 |
-
resp = client.images.create_variation(image=image_tuple, **var_args)
|
465 |
-
|
466 |
-
imgs_np = _img_list(resp) # Should be list of numpy arrays
|
467 |
-
|
468 |
-
# Post-generation conversion
|
469 |
-
final_imgs = []
|
470 |
-
for img_np in imgs_np:
|
471 |
-
if isinstance(img_np, np.ndarray):
|
472 |
-
converted_img = convert_to_format(img_np, out_fmt, compression)
|
473 |
-
final_imgs.append(converted_img)
|
474 |
-
else:
|
475 |
-
final_imgs.append(img_np) # Append URL if received
|
476 |
-
|
477 |
-
if not final_imgs:
|
478 |
-
raise gr.Error("Failed to create variations or process images. Check logs.")
|
479 |
-
|
480 |
-
return final_imgs
|
481 |
-
|
482 |
-
except (openai.APIError, openai.OpenAIError) as e:
|
483 |
-
print(f"OpenAI API Error during variation: {type(e).__name__}: {e}")
|
484 |
-
err_msg = _format_openai_error(e)
|
485 |
-
# Add specific check for variation incompatibility
|
486 |
-
if isinstance(e, openai.BadRequestError) and ("model does not support variations" in err_msg.lower() or "not supported" in err_msg.lower()):
|
487 |
-
raise gr.Error(f"As warned, the selected model ('{MODEL}') does not support the variations endpoint. Try using 'dall-e-2'.")
|
488 |
-
raise gr.Error(err_msg)
|
489 |
-
except Exception as e:
|
490 |
-
print(f"Unexpected error during variation: {type(e).__name__}: {e}")
|
491 |
-
import traceback
|
492 |
-
traceback.print_exc()
|
493 |
-
raise gr.Error("An unexpected application error occurred. Please check logs.")
|
494 |
|
495 |
|
496 |
# ---------- UI ---------- #
|
497 |
def build_ui():
|
498 |
-
with gr.Blocks(title="
|
499 |
-
gr.Markdown(
|
500 |
-
|
501 |
-
**Selected Model:** `{MODEL}` (Ensure your key has access)
|
502 |
-
""")
|
503 |
-
with gr.Accordion("🔐 API key & Model Info", open=False):
|
504 |
api = gr.Textbox(label="OpenAI API key", type="password", placeholder="sk-...")
|
505 |
-
gr.Markdown(f"""
|
506 |
-
* **Model:** `{MODEL}` is configured in the code. This might be a placeholder; official models are typically `dall-e-3` or `dall-e-2`.
|
507 |
-
* **Variations:** Officially only supported by `dall-e-2`. Using other models here will likely fail.
|
508 |
-
* **Edit/Inpainting:** Requires a model supporting the `/images/edits` endpoint (like `dall-e-2`).
|
509 |
-
* **Size/Quality:** Options shown may not be supported by all models. Check OpenAI documentation for `{MODEL}` if it's a real model. DALL-E 3 uses `quality` ('standard'/'hd'), DALL-E 2 does not.
|
510 |
-
""")
|
511 |
|
512 |
with gr.Row():
|
513 |
n_slider = gr.Slider(1, 4, value=1, step=1, label="Number of images (n)")
|
514 |
-
size = gr.Dropdown(SIZE_CHOICES, value="auto", label="Size
|
515 |
-
quality = gr.Dropdown(QUALITY_CHOICES, value="auto", label="Quality
|
516 |
with gr.Row():
|
517 |
-
out_fmt = gr.Radio(FORMAT_CHOICES, value="png", label="Output Format"
|
518 |
-
compression = gr.Slider(0, 100, value=75, step=1, label="Compression % (JPEG/WebP)", visible=False
|
519 |
-
|
520 |
-
# Actual transparency depends on model/post-processing.
|
521 |
-
transparent = gr.Checkbox(False, label="Transparent background (PNG/WebP only)", info="Request transparency if model supports, or save PNG/WebP with alpha if generated.", visible=False) # Hidden for now as it's not directly controllable via API param
|
522 |
|
523 |
def _toggle_compression(fmt):
|
524 |
return gr.update(visible=fmt in {"jpeg", "webp"})
|
525 |
out_fmt.change(_toggle_compression, inputs=out_fmt, outputs=compression)
|
526 |
|
527 |
-
# Combine common controls for easier passing to functions
|
528 |
common_controls = [n_slider, size, quality, out_fmt, compression, transparent]
|
529 |
|
530 |
with gr.Tabs():
|
531 |
with gr.TabItem("Generate"):
|
532 |
-
prompt_gen = gr.Textbox(label="Prompt", lines=3, placeholder="A photorealistic
|
533 |
-
btn_gen = gr.Button("Generate 🚀"
|
534 |
-
gallery_gen = gr.Gallery(
|
535 |
-
|
536 |
-
inputs_gen = [api, prompt_gen] + common_controls
|
537 |
-
prompt_gen.submit(generate, inputs=inputs_gen, outputs=gallery_gen)
|
538 |
-
btn_gen.click(generate, inputs=inputs_gen, outputs=gallery_gen)
|
539 |
-
|
540 |
|
541 |
with gr.TabItem("Edit / Inpaint"):
|
542 |
-
gr.Markdown("Upload an image,
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
prompt_edit.submit(edit_image, inputs=inputs_edit, outputs=gallery_edit)
|
557 |
-
btn_edit.click(edit_image, inputs=inputs_edit, outputs=gallery_edit)
|
558 |
-
|
559 |
-
|
560 |
-
with gr.TabItem("Variations (DALL·E 2 only)"):
|
561 |
-
gr.Markdown("Upload an image to generate variations. **Warning:** This endpoint is officially supported only by DALL·E 2.")
|
562 |
-
img_var_src = gr.Image(type="numpy", label="Source Image", height=400)
|
563 |
-
btn_var = gr.Button("Create Variations ✨", variant="primary")
|
564 |
-
gallery_var = gr.Gallery(label="Variations", columns=2, height="auto", preview=True)
|
565 |
-
|
566 |
-
# Define inputs for the variation function
|
567 |
-
inputs_var = [api, img_var_src] + common_controls
|
568 |
-
# Variations don't use prompt, quality typically ignored
|
569 |
-
btn_var.click(variation_image, inputs=inputs_var, outputs=gallery_var)
|
570 |
-
|
571 |
return demo
|
572 |
|
573 |
|
574 |
if __name__ == "__main__":
|
575 |
-
# For debugging purposes, you can preload an API key from env vars
|
576 |
-
# Make sure to handle security appropriately if deploying publicly
|
577 |
-
# api_key_env = os.getenv("OPENAI_API_KEY")
|
578 |
-
|
579 |
app = build_ui()
|
580 |
-
|
581 |
-
|
582 |
-
share=os.getenv("GRADIO_SHARE") == "true",
|
583 |
-
debug=os.getenv("GRADIO_DEBUG") == "true",
|
584 |
-
server_name="0.0.0.0" # Bind to all interfaces for Docker compatibility
|
585 |
-
# auth=("user", "password") # Add basic auth if needed for sharing
|
586 |
-
)
|
|
|
34 |
Decode base64 images into numpy arrays (for Gradio) or pass URL strings directly.
|
35 |
"""
|
36 |
imgs: List[Union[np.ndarray, str]] = []
|
|
|
|
|
|
|
|
|
37 |
for d in resp.data:
|
38 |
if hasattr(d, "b64_json") and d.b64_json:
|
39 |
+
data = base64.b64decode(d.b64_json)
|
40 |
+
img = Image.open(io.BytesIO(data))
|
41 |
+
imgs.append(np.array(img))
|
|
|
|
|
|
|
42 |
elif getattr(d, "url", None):
|
43 |
imgs.append(d.url)
|
|
|
|
|
44 |
return imgs
|
45 |
|
46 |
|
|
|
49 |
n: int,
|
50 |
size: str,
|
51 |
quality: str,
|
52 |
+
out_fmt: str,
|
53 |
+
compression: int,
|
54 |
+
transparent_bg: bool,
|
55 |
) -> Dict[str, Any]:
|
56 |
"""Prepare keyword args for OpenAI Images API."""
|
57 |
kwargs: Dict[str, Any] = {
|
58 |
"model": MODEL,
|
59 |
"n": n,
|
|
|
|
|
60 |
}
|
61 |
if size != "auto":
|
62 |
kwargs["size"] = size
|
|
|
|
|
63 |
if quality != "auto":
|
64 |
+
kwargs["quality"] = quality
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
if prompt is not None:
|
66 |
kwargs["prompt"] = prompt
|
67 |
+
if transparent_bg and out_fmt in {"png", "webp"}:
|
68 |
+
# Insert background removal flag when supported
|
69 |
+
kwargs["background"] = "transparent"
|
|
|
|
|
|
|
|
|
70 |
return kwargs
|
71 |
|
72 |
|
|
|
77 |
) -> np.ndarray:
|
78 |
"""
|
79 |
Convert a PIL numpy array to target_fmt (JPEG/WebP) and return as numpy array.
|
|
|
80 |
"""
|
|
|
|
|
|
|
|
|
81 |
img = Image.fromarray(img_array.astype(np.uint8))
|
82 |
buf = io.BytesIO()
|
83 |
+
img.save(buf, format=target_fmt.upper(), quality=quality)
|
84 |
+
buf.seek(0)
|
85 |
+
img2 = Image.open(buf)
|
86 |
+
return np.array(img2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
|
89 |
def _format_openai_error(e: Exception) -> str:
|
|
|
90 |
error_message = f"An error occurred: {type(e).__name__}"
|
91 |
details = ""
|
92 |
if hasattr(e, 'body') and e.body:
|
93 |
try:
|
94 |
+
body = e.body if isinstance(e.body, dict) else json.loads(str(e.body))
|
|
|
95 |
if isinstance(body, dict) and 'error' in body and isinstance(body['error'], dict) and 'message' in body['error']:
|
96 |
details = body['error']['message']
|
97 |
+
elif isinstance(body, dict) and 'message' in body:
|
98 |
details = body['message']
|
|
|
|
|
|
|
|
|
|
|
99 |
except Exception:
|
100 |
+
details = str(e.body)
|
101 |
+
elif hasattr(e, 'message') and e.message:
|
|
|
102 |
details = e.message
|
|
|
103 |
if details:
|
104 |
error_message = f"OpenAI API Error: {details}"
|
|
|
|
|
|
|
|
|
|
|
105 |
if isinstance(e, openai.AuthenticationError):
|
106 |
+
error_message = "Invalid OpenAI API key. Please check your key."
|
107 |
elif isinstance(e, openai.PermissionDeniedError):
|
108 |
prefix = "Permission Denied."
|
109 |
+
if "organization verification" in details.lower():
|
110 |
+
prefix += " Your organization may need verification to use this feature/model."
|
|
|
|
|
111 |
error_message = f"{prefix} Details: {details}" if details else prefix
|
112 |
elif isinstance(e, openai.RateLimitError):
|
113 |
+
error_message = "Rate limit exceeded. Please wait and try again later."
|
114 |
elif isinstance(e, openai.BadRequestError):
|
115 |
error_message = f"OpenAI Bad Request: {details or str(e)}"
|
116 |
+
if "mask" in details.lower(): error_message += " (Check mask format/dimensions)"
|
117 |
+
if "size" in details.lower(): error_message += " (Check image/mask dimensions)"
|
118 |
+
if "model does not support variations" in details.lower(): error_message += " (gpt-image-1 does not support variations)."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
return error_message
|
120 |
|
121 |
|
|
|
128 |
quality: str,
|
129 |
out_fmt: str,
|
130 |
compression: int,
|
131 |
+
transparent_bg: bool,
|
132 |
):
|
133 |
if not prompt:
|
134 |
raise gr.Error("Please enter a prompt.")
|
135 |
try:
|
136 |
client = _client(api_key)
|
|
|
137 |
common_args = _common_kwargs(prompt, n, size, quality, out_fmt, compression, transparent_bg)
|
|
|
|
|
|
|
138 |
resp = client.images.generate(**common_args)
|
139 |
+
imgs = _img_list(resp)
|
140 |
+
if out_fmt in {"jpeg", "webp"}:
|
141 |
+
imgs = [convert_to_format(img, out_fmt, compression) for img in imgs]
|
142 |
+
return imgs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
except (openai.APIError, openai.OpenAIError) as e:
|
|
|
144 |
raise gr.Error(_format_openai_error(e))
|
145 |
except Exception as e:
|
146 |
print(f"Unexpected error during generation: {type(e).__name__}: {e}")
|
|
|
|
|
147 |
raise gr.Error("An unexpected application error occurred. Please check logs.")
|
148 |
|
149 |
|
150 |
# ---------- Edit / Inpaint ---------- #
|
151 |
+
def _bytes_from_numpy(arr: np.ndarray) -> bytes:
|
|
|
152 |
img = Image.fromarray(arr.astype(np.uint8))
|
153 |
buf = io.BytesIO()
|
154 |
+
img.save(buf, format="PNG")
|
155 |
return buf.getvalue()
|
156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
|
158 |
def _extract_mask_array(mask_value: Union[np.ndarray, Dict[str, Any], None]) -> Optional[np.ndarray]:
|
|
|
159 |
if mask_value is None:
|
|
|
160 |
return None
|
|
|
|
|
|
|
161 |
if isinstance(mask_value, dict):
|
162 |
mask_array = mask_value.get("mask")
|
163 |
+
if isinstance(mask_array, np.ndarray):
|
164 |
+
return mask_array
|
165 |
+
if isinstance(mask_value, np.ndarray):
|
166 |
+
return mask_value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
return None
|
168 |
|
169 |
|
170 |
def edit_image(
|
171 |
+
api_key: str,
|
172 |
+
image_numpy: Optional[np.ndarray],
|
173 |
+
mask_dict: Optional[Dict[str, Any]],
|
174 |
+
prompt: str,
|
175 |
+
n: int,
|
176 |
+
size: str,
|
177 |
+
quality: str,
|
178 |
+
out_fmt: str,
|
179 |
+
compression: int,
|
180 |
+
transparent_bg: bool,
|
181 |
+
):
|
182 |
+
if image_numpy is None:
|
183 |
+
raise gr.Error("Please upload an image.")
|
184 |
+
if not prompt:
|
185 |
+
raise gr.Error("Please enter an edit prompt.")
|
186 |
+
|
187 |
+
img_bytes = _bytes_from_numpy(image_numpy)
|
188 |
+
mask_bytes: Optional[bytes] = None
|
189 |
+
mask_numpy = _extract_mask_array(mask_dict)
|
190 |
+
|
191 |
+
# (Mask handling code unchanged - Note: the current code doesn't actually
|
192 |
+
# convert mask_numpy to mask_bytes. If you implement this, you'll need
|
193 |
+
# to apply the tuple format to the mask as well.)
|
194 |
+
if mask_numpy is not None:
|
195 |
+
# Assuming you implement mask conversion similar to image:
|
196 |
+
# mask_bytes = _bytes_from_numpy(mask_numpy) # Example implementation needed here
|
197 |
+
pass # Placeholder - current code doesn't set mask_bytes
|
198 |
|
|
|
|
|
|
|
|
|
199 |
try:
|
200 |
+
client = _client(api_key)
|
201 |
+
common_args = _common_kwargs(prompt, n, size, quality, out_fmt, compression, transparent_bg)
|
202 |
+
|
203 |
+
# --- FIX: Provide image data as a tuple ---
|
204 |
+
image_tuple = ("image.png", img_bytes, "image/png")
|
205 |
+
api_kwargs = {"image": image_tuple, **common_args}
|
206 |
+
# ------------------------------------------
|
207 |
+
|
208 |
+
if mask_bytes is not None:
|
209 |
+
# --- FIX: Provide mask data as a tuple if used ---
|
210 |
+
mask_tuple = ("mask.png", mask_bytes, "image/png")
|
211 |
+
api_kwargs["mask"] = mask_tuple
|
212 |
+
# -------------------------------------------------
|
213 |
+
|
214 |
+
resp = client.images.edit(**api_kwargs) # This line caused the error
|
215 |
+
imgs = _img_list(resp)
|
216 |
+
if out_fmt in {"jpeg", "webp"}:
|
217 |
+
imgs = [convert_to_format(img, out_fmt, compression) for img in imgs]
|
218 |
+
return imgs
|
219 |
+
except (openai.APIError, openai.OpenAIError) as e:
|
220 |
+
raise gr.Error(_format_openai_error(e))
|
221 |
except Exception as e:
|
222 |
+
print(f"Unexpected error during edit: {type(e).__name__}: {e}")
|
223 |
+
raise gr.Error("An unexpected application error occurred. Please check logs.")
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
|
226 |
|
227 |
# ---------- Variations ---------- #
|
228 |
+
def variation_image(
|
229 |
+
api_key: str,
|
230 |
+
image_numpy: Optional[np.ndarray],
|
231 |
+
n: int,
|
232 |
+
size: str,
|
233 |
+
quality: str,
|
234 |
+
out_fmt: str,
|
235 |
+
compression: int,
|
236 |
+
transparent_bg: bool, # Note: transparent_bg is passed but not used by variations API
|
237 |
+
):
|
238 |
+
gr.Warning("Note: Image Variations are officially supported for DALL·E 2/3, not gpt-image-1. This may fail.")
|
239 |
+
if image_numpy is None:
|
240 |
+
raise gr.Error("Please upload an image.")
|
241 |
+
|
242 |
+
img_bytes = _bytes_from_numpy(image_numpy)
|
243 |
+
try:
|
244 |
+
client = _client(api_key)
|
245 |
+
var_args: Dict[str, Any] = {"model": MODEL, "n": n}
|
246 |
+
if size != "auto":
|
247 |
+
var_args["size"] = size
|
248 |
+
|
249 |
+
# --- FIX: Provide image data as a tuple ---
|
250 |
+
image_tuple = ("image.png", img_bytes, "image/png")
|
251 |
+
# ------------------------------------------
|
252 |
+
|
253 |
+
# Pass the tuple to the image parameter
|
254 |
+
resp = client.images.create_variation(image=image_tuple, **var_args) # This line would have the same error
|
255 |
+
|
256 |
+
imgs = _img_list(resp)
|
257 |
+
if out_fmt in {"jpeg", "webp"}:
|
258 |
+
imgs = [convert_to_format(img, out_fmt, compression) for img in imgs]
|
259 |
+
return imgs
|
260 |
+
except (openai.APIError, openai.OpenAIError) as e:
|
261 |
+
# Add specific check for variation incompatibility
|
262 |
+
err_msg = _format_openai_error(e)
|
263 |
+
if isinstance(e, openai.BadRequestError) and "model does not support variations" in err_msg.lower():
|
264 |
+
raise gr.Error("As warned, the selected model (gpt-image-1) does not support the variations endpoint.")
|
265 |
+
raise gr.Error(err_msg)
|
266 |
+
except Exception as e:
|
267 |
+
print(f"Unexpected error during variation: {type(e).__name__}: {e}")
|
268 |
+
raise gr.Error("An unexpected application error occurred. Please check logs.")
|
269 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
270 |
|
271 |
|
272 |
# ---------- UI ---------- #
|
273 |
def build_ui():
|
274 |
+
with gr.Blocks(title="GPT-Image-1 (BYOT)") as demo:
|
275 |
+
gr.Markdown("""# GPT-Image-1 Playground 🖼️🔑\nGenerate • Edit • Variations""")
|
276 |
+
with gr.Accordion("🔐 API key", open=False):
|
|
|
|
|
|
|
277 |
api = gr.Textbox(label="OpenAI API key", type="password", placeholder="sk-...")
|
|
|
|
|
|
|
|
|
|
|
|
|
278 |
|
279 |
with gr.Row():
|
280 |
n_slider = gr.Slider(1, 4, value=1, step=1, label="Number of images (n)")
|
281 |
+
size = gr.Dropdown(SIZE_CHOICES, value="auto", label="Size")
|
282 |
+
quality = gr.Dropdown(QUALITY_CHOICES, value="auto", label="Quality")
|
283 |
with gr.Row():
|
284 |
+
out_fmt = gr.Radio(FORMAT_CHOICES, value="png", label="Output Format")
|
285 |
+
compression = gr.Slider(0, 100, value=75, step=1, label="Compression % (JPEG/WebP)", visible=False)
|
286 |
+
transparent = gr.Checkbox(False, label="Transparent background (PNG/WebP only)")
|
|
|
|
|
287 |
|
288 |
def _toggle_compression(fmt):
|
289 |
return gr.update(visible=fmt in {"jpeg", "webp"})
|
290 |
out_fmt.change(_toggle_compression, inputs=out_fmt, outputs=compression)
|
291 |
|
|
|
292 |
common_controls = [n_slider, size, quality, out_fmt, compression, transparent]
|
293 |
|
294 |
with gr.Tabs():
|
295 |
with gr.TabItem("Generate"):
|
296 |
+
prompt_gen = gr.Textbox(label="Prompt", lines=3, placeholder="A photorealistic..." )
|
297 |
+
btn_gen = gr.Button("Generate 🚀")
|
298 |
+
gallery_gen = gr.Gallery(columns=2, height="auto")
|
299 |
+
btn_gen.click(generate, inputs=[api, prompt_gen] + common_controls, outputs=gallery_gen)
|
|
|
|
|
|
|
|
|
300 |
|
301 |
with gr.TabItem("Edit / Inpaint"):
|
302 |
+
gr.Markdown("Upload an image, then paint the area to change…")
|
303 |
+
img_edit = gr.Image(type="numpy", label="Source Image", height=400)
|
304 |
+
mask_canvas = gr.ImageMask(type="numpy", label="Mask – paint white", height=400)
|
305 |
+
prompt_edit = gr.Textbox(label="Edit prompt", lines=2, placeholder="Replace the sky…")
|
306 |
+
btn_edit = gr.Button("Edit 🖌️")
|
307 |
+
gallery_edit = gr.Gallery(columns=2, height="auto")
|
308 |
+
btn_edit.click(edit_image, inputs=[api, img_edit, mask_canvas, prompt_edit] + common_controls, outputs=gallery_edit)
|
309 |
+
|
310 |
+
with gr.TabItem("Variations"):
|
311 |
+
gr.Markdown("Upload an image to generate variations…")
|
312 |
+
img_var = gr.Image(type="numpy", label="Source Image", height=400)
|
313 |
+
btn_var = gr.Button("Create Variations ✨")
|
314 |
+
gallery_var = gr.Gallery(columns=2, height="auto")
|
315 |
+
btn_var.click(variation_image, inputs=[api, img_var] + common_controls, outputs=gallery_var)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
316 |
return demo
|
317 |
|
318 |
|
319 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
320 |
app = build_ui()
|
321 |
+
app.launch(share=os.getenv("GRADIO_SHARE") == "true", debug=os.getenv("GRADIO_DEBUG") == "true")
|
322 |
+
|
|
|
|
|
|
|
|
|
|