Spaces:
Running
Running
from __future__ import annotations | |
import io | |
import os | |
from typing import List, Optional, Union, Dict, Any | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
import openai | |
MODEL = "gpt-image-1" | |
SIZE_CHOICES = ["auto", "1024x1024", "1536x1024", "1024x1536"] | |
QUALITY_CHOICES = ["auto", "low", "medium", "high"] | |
FORMAT_CHOICES = ["png", "jpeg", "webp"] | |
def _client(key: str) -> openai.OpenAI: | |
"""Initializes the OpenAI client with the provided API key.""" | |
api_key = key.strip() or os.getenv("OPENAI_API_KEY", "") | |
if not api_key: | |
raise gr.Error("Please enter your OpenAI API key (never stored)") | |
return openai.OpenAI(api_key=api_key) | |
def _img_list(resp, *, fmt: str) -> List[str]: | |
"""Return list of data URLs or direct URLs depending on API response.""" | |
mime = f"image/{fmt}" | |
# Ensure b64_json exists and is not None/empty before using it | |
return [ | |
f"data:{mime};base64,{d.b64_json}" if hasattr(d, "b64_json") and d.b64_json else d.url | |
for d in resp.data | |
] | |
def _common_kwargs( | |
prompt: Optional[str], | |
n: int, | |
size: str, | |
quality: str, | |
out_fmt: str, | |
compression: int, | |
transparent_bg: bool, | |
) -> Dict[str, Any]: | |
"""Prepare keyword arguments for Images API based on latest OpenAI spec.""" | |
kwargs: Dict[str, Any] = dict( | |
model=MODEL, | |
n=n, | |
# REMOVED: response_format="b64_json", # This parameter caused the BadRequestError | |
) | |
# Use API defaults if 'auto' is selected | |
if size != "auto": | |
kwargs["size"] = size | |
if quality != "auto": | |
kwargs["quality"] = quality | |
# Prompt is optional for variations | |
if prompt is not None: | |
kwargs["prompt"] = prompt | |
# Output format specific settings (API default is png) | |
if out_fmt != "png": | |
kwargs["output_format"] = out_fmt | |
# Transparency via background parameter (png & webp only) | |
if transparent_bg and out_fmt in {"png", "webp"}: | |
kwargs["background"] = "transparent" | |
# Compression for lossy formats (API expects integer 0-100) | |
if out_fmt in {"jpeg", "webp"}: | |
# Ensure compression is an integer as expected by the API | |
kwargs["output_compression"] = int(compression) | |
return kwargs | |
# ---------- Generate ---------- # | |
def generate( | |
api_key: str, | |
prompt: str, | |
n: int, | |
size: str, | |
quality: str, | |
out_fmt: str, | |
compression: int, | |
transparent_bg: bool, | |
): | |
"""Calls the OpenAI image generation endpoint.""" | |
if not prompt: | |
raise gr.Error("Please enter a prompt.") | |
client = _client(api_key) | |
try: | |
common_args = _common_kwargs(prompt, n, size, quality, out_fmt, compression, transparent_bg) | |
resp = client.images.generate(**common_args) | |
except openai.AuthenticationError: | |
raise gr.Error("Invalid OpenAI API key.") | |
except openai.PermissionDeniedError: | |
raise gr.Error("Permission denied. Check your API key permissions or complete required verification for gpt-image-1.") | |
except openai.RateLimitError: | |
raise gr.Error("Rate limit exceeded. Please try again later.") | |
except openai.BadRequestError as e: | |
# Extract the specific error message if possible | |
error_message = str(e) | |
try: | |
# Attempt to parse the error body if it's JSON-like | |
import json | |
body = json.loads(str(e.body)) # e.body might be bytes | |
if isinstance(body, dict) and 'error' in body and 'message' in body['error']: | |
error_message = f"OpenAI Bad Request: {body['error']['message']}" | |
else: | |
error_message = f"OpenAI Bad Request: {e}" | |
except: | |
error_message = f"OpenAI Bad Request: {e}" # Fallback | |
raise gr.Error(error_message) | |
except Exception as e: | |
raise gr.Error(f"An unexpected error occurred: {e}") | |
return _img_list(resp, fmt=out_fmt) | |
# ---------- Edit / Inpaint ---------- # | |
def _bytes_from_numpy(arr: np.ndarray) -> bytes: | |
"""Convert RGBA/RGB uint8 numpy array to PNG bytes.""" | |
img = Image.fromarray(arr.astype(np.uint8)) | |
out = io.BytesIO() | |
img.save(out, format="PNG") | |
return out.getvalue() | |
def _extract_mask_array(mask_value: Union[np.ndarray, Dict[str, Any], None]) -> Optional[np.ndarray]: | |
"""Handle ImageMask / ImageEditor return formats and extract a numpy mask array.""" | |
if mask_value is None: | |
return None | |
# If we already have a numpy array (ImageMask with type="numpy") | |
if isinstance(mask_value, np.ndarray): | |
mask_arr = mask_value | |
# If it's an EditorValue dict coming from ImageEditor/ImageMask with type="file" or "pil" | |
elif isinstance(mask_value, dict): | |
# Prefer the composite (all layers merged) if present | |
comp = mask_value.get("composite") | |
if comp is not None and isinstance(comp, (Image.Image, np.ndarray)): | |
mask_arr = np.array(comp) if isinstance(comp, Image.Image) else comp | |
# Fallback to the mask if present (often from ImageMask) | |
elif mask_value.get("mask") is not None and isinstance(mask_value["mask"], (Image.Image, np.ndarray)): | |
mask_arr = np.array(mask_value["mask"]) if isinstance(mask_value["mask"], Image.Image) else mask_value["mask"] | |
# Fallback to the topmost layer | |
elif mask_value.get("layers"): | |
top_layer = mask_value["layers"][-1] | |
if isinstance(top_layer, (Image.Image, np.ndarray)): | |
mask_arr = np.array(top_layer) if isinstance(top_layer, Image.Image) else top_layer | |
else: | |
return None # Cannot process layer format | |
else: | |
return None # No usable image data found in dict | |
else: | |
# Unknown format – ignore | |
return None | |
# Ensure mask_arr is a numpy array now | |
if not isinstance(mask_arr, np.ndarray): | |
return None # Should not happen after above checks, but safeguard | |
return mask_arr | |
def edit_image( | |
api_key: str, | |
image_numpy: np.ndarray, | |
mask_value: Optional[Union[np.ndarray, Dict[str, Any]]], | |
prompt: str, | |
n: int, | |
size: str, | |
quality: str, | |
out_fmt: str, | |
compression: int, | |
transparent_bg: bool, | |
): | |
"""Calls the OpenAI image edit endpoint.""" | |
if image_numpy is None: | |
raise gr.Error("Please upload an image.") | |
if not prompt: | |
raise gr.Error("Please enter an edit prompt.") | |
img_bytes = _bytes_from_numpy(image_numpy) | |
mask_bytes: Optional[bytes] = None | |
mask_numpy = _extract_mask_array(mask_value) | |
if mask_numpy is not None: | |
# Check if the mask seems empty (all black or fully transparent) | |
is_empty = False | |
if mask_numpy.ndim == 2: # Grayscale | |
is_empty = np.all(mask_numpy == 0) | |
elif mask_numpy.shape[-1] == 4: # RGBA | |
is_empty = np.all(mask_numpy[:, :, 3] == 0) | |
elif mask_numpy.shape[-1] == 3: # RGB | |
is_empty = np.all(mask_numpy == 0) | |
if is_empty: | |
gr.Warning("The provided mask appears empty (all black/transparent). The API might edit the entire image or ignore the mask.") | |
# Pass None if the mask is effectively empty, as per API docs (transparent areas are edited) | |
mask_bytes = None | |
else: | |
# Convert the mask to the format required by the API: | |
# A PNG image where TRANSPARENT areas indicate where the image should be edited. | |
# Our Gradio mask uses WHITE to indicate the edit area. | |
# So, we need to create an alpha channel where white pixels in the input mask become transparent (0), | |
# and black/other pixels become opaque (255). | |
if mask_numpy.ndim == 2: # Grayscale input mask | |
# Assume white (255) means edit -> make transparent (0 alpha) | |
# Assume black (0) means keep -> make opaque (255 alpha) | |
alpha = (mask_numpy == 0).astype(np.uint8) * 255 | |
elif mask_numpy.shape[-1] == 4: # RGBA input mask (from gr.ImageMask) | |
# Use the alpha channel directly if it exists and seems meaningful, | |
# otherwise, treat non-black RGB as edit area. | |
# gr.ImageMask often returns RGBA where painted area is white [255,255,255,255] and background is [0,0,0,0] | |
# We want the painted (white) area to be transparent in the final mask. | |
# We want the unpainted (transparent black) area to be opaque in the final mask. | |
alpha = (mask_numpy[:, :, 3] == 0).astype(np.uint8) * 255 | |
elif mask_numpy.shape[-1] == 3: # RGB input mask | |
# Assume white [255, 255, 255] means edit -> make transparent (0 alpha) | |
# Assume black [0, 0, 0] or other colors mean keep -> make opaque (255 alpha) | |
is_white = np.all(mask_numpy == [255, 255, 255], axis=-1) | |
alpha = (~is_white).astype(np.uint8) * 255 | |
else: | |
raise gr.Error("Unsupported mask format.") | |
# Create a single-channel L mode image (grayscale/alpha) for the mask | |
mask_img = Image.fromarray(alpha, mode='L') | |
# The API expects an RGBA PNG where the alpha channel defines the mask. | |
# Create a black image with the calculated alpha channel. | |
rgba_mask = Image.new("RGBA", mask_img.size, (0, 0, 0, 0)) | |
black_opaque = Image.new("L", mask_img.size, 0) # Black base | |
rgba_mask.putalpha(mask_img) # Use the calculated alpha | |
out = io.BytesIO() | |
rgba_mask.save(out, format="PNG") | |
mask_bytes = out.getvalue() | |
# Debug: Save mask locally to check | |
# rgba_mask.save("debug_mask_sent_to_api.png") | |
else: | |
gr.Info("No mask provided. The API will attempt to edit the image based on the prompt without a specific mask.") | |
mask_bytes = None # Explicitly pass None if no mask is usable | |
client = _client(api_key) | |
try: | |
common_args = _common_kwargs(prompt, n, size, quality, out_fmt, compression, transparent_bg) | |
# The edit endpoint requires the prompt | |
if "prompt" not in common_args: | |
common_args["prompt"] = prompt # Should always be there via _common_kwargs, but safeguard | |
# Ensure image and mask are passed correctly | |
api_kwargs = { | |
"image": img_bytes, | |
**common_args | |
} | |
if mask_bytes is not None: | |
api_kwargs["mask"] = mask_bytes | |
resp = client.images.edit(**api_kwargs) | |
except openai.AuthenticationError: | |
raise gr.Error("Invalid OpenAI API key.") | |
except openai.PermissionDeniedError: | |
raise gr.Error("Permission denied. Check your API key permissions or complete required verification for gpt-image-1.") | |
except openai.RateLimitError: | |
raise gr.Error("Rate limit exceeded. Please try again later.") | |
except openai.BadRequestError as e: | |
error_message = str(e) | |
try: | |
import json | |
body = json.loads(str(e.body)) | |
if isinstance(body, dict) and 'error' in body and 'message' in body['error']: | |
error_message = f"OpenAI Bad Request: {body['error']['message']}" | |
# Add specific advice based on common mask errors | |
if "mask" in error_message.lower(): | |
error_message += " (Ensure mask is a valid PNG with an alpha channel and matches the image dimensions.)" | |
elif "size" in error_message.lower(): | |
error_message += " (Ensure image and mask dimensions match and are supported.)" | |
else: | |
error_message = f"OpenAI Bad Request: {e}" | |
except: | |
error_message = f"OpenAI Bad Request: {e}" # Fallback | |
raise gr.Error(error_message) | |
except Exception as e: | |
raise gr.Error(f"An unexpected error occurred: {e}") | |
return _img_list(resp, fmt=out_fmt) | |
# ---------- Variations ---------- # | |
def variation_image( | |
api_key: str, | |
image_numpy: np.ndarray, | |
n: int, | |
size: str, | |
quality: str, | |
out_fmt: str, | |
compression: int, | |
transparent_bg: bool, | |
): | |
"""Calls the OpenAI image variations endpoint.""" | |
# NOTE: Variations are only supported for DALL-E 2 according to docs. | |
# This might fail with gpt-image-1. Consider adding a check or using DALL-E 2. | |
gr.Warning("Note: Image variations are officially supported for DALL·E 2, not gpt-image-1. This may not work as expected.") | |
if image_numpy is None: | |
raise gr.Error("Please upload an image.") | |
img_bytes = _bytes_from_numpy(image_numpy) | |
client = _client(api_key) | |
try: | |
# Prompt is None for variations | |
common_args = _common_kwargs(None, n, size, quality, out_fmt, compression, transparent_bg) | |
resp = client.images.variations( | |
image=img_bytes, | |
**common_args, | |
) | |
except openai.AuthenticationError: | |
raise gr.Error("Invalid OpenAI API key.") | |
except openai.PermissionDeniedError: | |
raise gr.Error("Permission denied. Check your API key permissions.") | |
except openai.RateLimitError: | |
raise gr.Error("Rate limit exceeded. Please try again later.") | |
except openai.BadRequestError as e: | |
error_message = str(e) | |
try: | |
import json | |
body = json.loads(str(e.body)) | |
if isinstance(body, dict) and 'error' in body and 'message' in body['error']: | |
error_message = f"OpenAI Bad Request: {body['error']['message']}" | |
if "model does not support variations" in error_message.lower(): | |
error_message += " (gpt-image-1 does not support variations, use DALL·E 2 instead)." | |
else: | |
error_message = f"OpenAI Bad Request: {e}" | |
except: | |
error_message = f"OpenAI Bad Request: {e}" # Fallback | |
raise gr.Error(error_message) | |
except Exception as e: | |
raise gr.Error(f"An unexpected error occurred: {e}") | |
return _img_list(resp, fmt=out_fmt) | |
# ---------- UI ---------- # | |
def build_ui(): | |
with gr.Blocks(title="GPT-Image-1 (BYOT)") as demo: | |
gr.Markdown("""# GPT-Image-1 Playground 🖼️🔑\nGenerate • Edit (paint mask!) • Variations""") | |
gr.Markdown( | |
"Enter your OpenAI API key below. It's used directly for API calls and **never stored**." | |
" This space uses the `gpt-image-1` model." | |
" **Note:** `gpt-image-1` may require organization verification. Variations endpoint might not work with this model (use DALL·E 2)." | |
) | |
with gr.Accordion("🔐 API key", open=False): | |
api = gr.Textbox(label="OpenAI API key", type="password", placeholder="sk-…") | |
# Common controls | |
with gr.Row(): | |
n_slider = gr.Slider(1, 4, value=1, step=1, label="Number of images (n)", info="Max 4 for this demo.") # Limit n for stability/cost | |
size = gr.Dropdown(SIZE_CHOICES, value="auto", label="Size", info="API default if 'auto'.") | |
quality = gr.Dropdown(QUALITY_CHOICES, value="auto", label="Quality", info="API default if 'auto'.") | |
with gr.Row(): | |
out_fmt = gr.Radio(FORMAT_CHOICES, value="png", label="Format", scale=1) | |
compression = gr.Slider(0, 100, value=75, step=1, label="Compression % (JPEG/WebP)", visible=False, scale=2) | |
transparent = gr.Checkbox(False, label="Transparent background (PNG/WebP only)", scale=1) | |
def _toggle_compression(fmt): | |
return gr.update(visible=fmt in {"jpeg", "webp"}) | |
out_fmt.change(_toggle_compression, inputs=out_fmt, outputs=compression) | |
common_inputs = [api, n_slider, size, quality, out_fmt, compression, transparent] | |
with gr.Tabs(): | |
# ----- Generate Tab ----- # | |
with gr.TabItem("Generate"): | |
with gr.Row(): | |
prompt_gen = gr.Textbox(label="Prompt", lines=3, placeholder="A photorealistic ginger cat astronaut on Mars", scale=4) | |
btn_gen = gr.Button("Generate 🚀", variant="primary", scale=1) | |
gallery_gen = gr.Gallery(label="Generated Images", columns=2, height="auto", preview=True) | |
btn_gen.click( | |
generate, | |
inputs=[prompt_gen] + common_inputs, # Prepend specific inputs | |
outputs=gallery_gen, | |
api_name="generate" | |
) | |
# ----- Edit Tab ----- # | |
with gr.TabItem("Edit / Inpaint"): | |
gr.Markdown("Upload an image, then **paint the area to change** in the mask canvas below (white = edit area). The API requires the mask and image to have the same dimensions.") | |
with gr.Row(): | |
img_edit = gr.Image(label="Source Image", type="numpy", height=400) | |
# Use ImageMask component for interactive painting | |
mask_canvas = gr.ImageMask( | |
label="Mask – Paint White Where Image Should Change", | |
type="numpy", # Get mask as numpy array | |
height=400 | |
) | |
with gr.Row(): | |
prompt_edit = gr.Textbox(label="Edit prompt", lines=2, placeholder="Replace the sky with a starry night", scale=4) | |
btn_edit = gr.Button("Edit 🖌️", variant="primary", scale=1) | |
gallery_edit = gr.Gallery(label="Edited Images", columns=2, height="auto", preview=True) | |
btn_edit.click( | |
edit_image, | |
inputs=[img_edit, mask_canvas, prompt_edit] + common_inputs, # Prepend specific inputs | |
outputs=gallery_edit, | |
api_name="edit" | |
) | |
# ----- Variations Tab ----- # | |
with gr.TabItem("Variations (DALL·E 2 only)"): | |
gr.Markdown("Upload an image to generate variations. **Note:** This endpoint is officially supported for DALL·E 2, not `gpt-image-1`. It likely won't work here.") | |
with gr.Row(): | |
img_var = gr.Image(label="Source Image", type="numpy", height=400, scale=4) | |
btn_var = gr.Button("Create Variations ✨", variant="primary", scale=1) | |
gallery_var = gr.Gallery(label="Variations", columns=2, height="auto", preview=True) | |
btn_var.click( | |
variation_image, | |
inputs=[img_var] + common_inputs, # Prepend specific inputs | |
outputs=gallery_var, | |
api_name="variations" | |
) | |
return demo | |
if __name__ == "__main__": | |
app = build_ui() | |
# Set share=True to create a public link (useful for Spaces) | |
# Set debug=True for more detailed logs in the console | |
app.launch(share=os.getenv("GRADIO_SHARE") == "true", debug=True) | |