Zack3D's picture
Update app.py
68971bf verified
raw
history blame
16 kB
"""
Gradio Space: GPT-Image-1 – BYOT playground
Generate · Edit (paint mask!) · Variations
==========================================
Adds an **in-browser paint tool** for the edit / inpaint workflow so users can
draw the mask directly instead of uploading one.
### How mask painting works
* Upload an image.
* Use the *Mask* canvas to **paint the areas you’d like changed** (white =
editable, black = keep).
The new `gr.ImageMask` component captures your brush strokes.
* The painted mask is converted to a 1‑channel PNG and sent to the
`images.edit()` endpoint.
All other controls (size, quality, format, compression, n, background) stay the
same.
"""
from __future__ import annotations
import io
import os
from typing import List, Optional, Union, Dict, Any
import gradio as gr
import numpy as np
from PIL import Image
import openai
MODEL = "gpt-image-1"
SIZE_CHOICES = ["auto", "1024x1024", "1536x1024", "1024x1536"]
QUALITY_CHOICES = ["auto", "low", "medium", "high"]
FORMAT_CHOICES = ["png", "jpeg", "webp"]
def _client(key: str) -> openai.OpenAI:
"""Initializes the OpenAI client with the provided API key."""
api_key = key.strip() or os.getenv("OPENAI_API_KEY", "")
if not api_key:
raise gr.Error("Please enter your OpenAI API key (never stored)")
return openai.OpenAI(api_key=api_key)
def _img_list(resp, *, fmt: str) -> List[str]:
"""Return list of data URLs or direct URLs depending on API response."""
mime = f"image/{fmt}"
return [
f"data:{mime};base64,{d.b64_json}" if hasattr(d, "b64_json") and d.b64_json else d.url
for d in resp.data
]
def _common_kwargs(
prompt: Optional[str],
n: int,
size: str,
quality: str,
out_fmt: str,
compression: int,
transparent_bg: bool,
) -> Dict[str, Any]:
"""Prepare keyword arguments for Images API based on latest OpenAI spec."""
kwargs: Dict[str, Any] = dict(
model=MODEL,
n=n,
response_format="b64_json", # Request base64 to avoid potential URL expiry issues
)
# Use API defaults if 'auto' is selected
if size != "auto":
kwargs["size"] = size
if quality != "auto":
kwargs["quality"] = quality
# Prompt is optional for variations
if prompt is not None:
kwargs["prompt"] = prompt
# Output format specific settings
if out_fmt != "png": # API default is png
kwargs["output_format"] = out_fmt
# Transparency via background parameter (png & webp only)
if transparent_bg and out_fmt in {"png", "webp"}:
kwargs["background"] = "transparent"
# Compression for lossy formats (API expects integer 0-100)
if out_fmt in {"jpeg", "webp"}:
kwargs["output_compression"] = compression
return kwargs
# ---------- Generate ---------- #
def generate(
api_key: str,
prompt: str,
n: int,
size: str,
quality: str,
out_fmt: str,
compression: int,
transparent_bg: bool,
):
"""Calls the OpenAI image generation endpoint."""
if not prompt:
raise gr.Error("Please enter a prompt.")
client = _client(api_key)
try:
common_args = _common_kwargs(prompt, n, size, quality, out_fmt, compression, transparent_bg)
resp = client.images.generate(**common_args)
except openai.AuthenticationError:
raise gr.Error("Invalid OpenAI API key.")
except openai.PermissionDeniedError:
raise gr.Error("Permission denied. Check your API key permissions.")
except openai.RateLimitError:
raise gr.Error("Rate limit exceeded. Please try again later.")
except openai.BadRequestError as e:
raise gr.Error(f"OpenAI Bad Request: {e}")
except Exception as e:
raise gr.Error(f"An unexpected error occurred: {e}")
return _img_list(resp, fmt=out_fmt)
# ---------- Edit / Inpaint ---------- #
def _bytes_from_numpy(arr: np.ndarray) -> bytes:
"""Convert RGBA/RGB uint8 numpy array to PNG bytes."""
img = Image.fromarray(arr.astype(np.uint8))
out = io.BytesIO()
img.save(out, format="PNG")
return out.getvalue()
def _extract_mask_array(mask_value: Union[np.ndarray, Dict[str, Any], None]) -> Optional[np.ndarray]:
"""Handle ImageMask / ImageEditor return formats and extract a numpy mask array."""
if mask_value is None:
return None
# If we already have a numpy array (ImageMask with type="numpy")
if isinstance(mask_value, np.ndarray):
mask_arr = mask_value
# If it's an EditorValue dict coming from ImageEditor/ImageMask with type="file" or "pil"
elif isinstance(mask_value, dict):
# Prefer the composite (all layers merged) if present
comp = mask_value.get("composite")
if comp is not None and isinstance(comp, (Image.Image, np.ndarray)):
mask_arr = np.array(comp) if isinstance(comp, Image.Image) else comp
# Fallback to the mask if present (often from ImageMask)
elif mask_value.get("mask") is not None and isinstance(mask_value["mask"], (Image.Image, np.ndarray)):
mask_arr = np.array(mask_value["mask"]) if isinstance(mask_value["mask"], Image.Image) else mask_value["mask"]
# Fallback to the topmost layer
elif mask_value.get("layers"):
top_layer = mask_value["layers"][-1]
if isinstance(top_layer, (Image.Image, np.ndarray)):
mask_arr = np.array(top_layer) if isinstance(top_layer, Image.Image) else top_layer
else:
return None # Cannot process layer format
else:
return None # No usable image data found in dict
else:
# Unknown format – ignore
return None
# Ensure mask_arr is a numpy array now
if not isinstance(mask_arr, np.ndarray):
return None # Should not happen after above checks, but safeguard
return mask_arr
def edit_image(
api_key: str,
image_numpy: np.ndarray,
mask_value: Optional[Union[np.ndarray, Dict[str, Any]]],
prompt: str,
n: int,
size: str,
quality: str,
out_fmt: str,
compression: int,
transparent_bg: bool,
):
"""Calls the OpenAI image edit endpoint."""
if image_numpy is None:
raise gr.Error("Please upload an image.")
if not prompt:
raise gr.Error("Please enter an edit prompt.")
img_bytes = _bytes_from_numpy(image_numpy)
mask_bytes: Optional[bytes] = None
mask_numpy = _extract_mask_array(mask_value)
if mask_numpy is not None:
# Check if the mask seems empty (all black or fully transparent)
if np.all(mask_numpy == 0) or (mask_numpy.shape[-1] == 4 and np.all(mask_numpy[:, :, 3] == 0)):
gr.Warning("The provided mask appears empty. The entire image might be edited if no mask is applied by the API.")
# We explicitly pass None if the mask is effectively empty,
# letting the API decide how to handle it (might vary by model/version)
mask_bytes = None
else:
# Convert painted area (any non-black pixel or non-transparent pixel) to white, else black; 1‑channel alpha.
# The API expects the mask as a single alpha channel where transparency indicates the area to edit.
# White in our canvas means "edit", so this needs to become transparent in the mask sent to the API.
# Black in our canvas means "keep", so this needs to become opaque in the mask sent to the API.
if mask_numpy.ndim == 2: # Grayscale
alpha = (mask_numpy == 0).astype(np.uint8) * 255 # Black becomes opaque (255), white becomes transparent (0)
elif mask_numpy.shape[-1] == 4: # RGBA (use alpha channel)
alpha = (mask_numpy[:, :, 3] == 0).astype(np.uint8) * 255 # Transparent becomes opaque, opaque becomes transparent
elif mask_numpy.shape[-1] == 3: # RGB
# Consider any non-black pixel as the area to edit (becomes transparent)
alpha = np.all(mask_numpy == [0, 0, 0], axis=-1).astype(np.uint8) * 255
else:
raise gr.Error("Unsupported mask format.")
# Create a single-channel L mode image (grayscale) for the mask
mask_img = Image.fromarray(alpha, mode='L')
out = io.BytesIO()
mask_img.save(out, format="PNG")
mask_bytes = out.getvalue()
# Debug: Save mask locally to check
# mask_img.save("debug_mask_sent_to_api.png")
else:
gr.Warning("No mask provided or mask could not be processed. The API might edit the entire image or apply a default mask.")
mask_bytes = None # Explicitly pass None if no mask is usable
client = _client(api_key)
try:
common_args = _common_kwargs(prompt, n, size, quality, out_fmt, compression, transparent_bg)
# The edit endpoint requires the prompt
if "prompt" not in common_args:
common_args["prompt"] = prompt # Should always be there via _common_kwargs, but safeguard
resp = client.images.edit(
image=img_bytes,
mask=mask_bytes, # Pass None if no mask or empty mask
**common_args,
)
except openai.AuthenticationError:
raise gr.Error("Invalid OpenAI API key.")
except openai.PermissionDeniedError:
raise gr.Error("Permission denied. Check your API key permissions.")
except openai.RateLimitError:
raise gr.Error("Rate limit exceeded. Please try again later.")
except openai.BadRequestError as e:
# Provide more specific feedback if possible
if "mask" in str(e) and "alpha channel" in str(e):
raise gr.Error("OpenAI API Error: The mask must be a PNG image with transparency indicating the edit area. Ensure your mask was processed correctly.")
elif "size" in str(e):
raise gr.Error(f"OpenAI API Error: Image and mask size mismatch or invalid size. Ensure image is square if required by the model. Error: {e}")
else:
raise gr.Error(f"OpenAI Bad Request: {e}")
except Exception as e:
raise gr.Error(f"An unexpected error occurred: {e}")
return _img_list(resp, fmt=out_fmt)
# ---------- Variations ---------- #
def variation_image(
api_key: str,
image_numpy: np.ndarray,
n: int,
size: str,
quality: str,
out_fmt: str,
compression: int,
transparent_bg: bool,
):
"""Calls the OpenAI image variations endpoint."""
if image_numpy is None:
raise gr.Error("Please upload an image.")
img_bytes = _bytes_from_numpy(image_numpy)
client = _client(api_key)
try:
# Prompt is None for variations
common_args = _common_kwargs(None, n, size, quality, out_fmt, compression, transparent_bg)
resp = client.images.variations(
image=img_bytes,
**common_args,
)
except openai.AuthenticationError:
raise gr.Error("Invalid OpenAI API key.")
except openai.PermissionDeniedError:
raise gr.Error("Permission denied. Check your API key permissions.")
except openai.RateLimitError:
raise gr.Error("Rate limit exceeded. Please try again later.")
except openai.BadRequestError as e:
raise gr.Error(f"OpenAI Bad Request: {e}")
except Exception as e:
raise gr.Error(f"An unexpected error occurred: {e}")
return _img_list(resp, fmt=out_fmt)
# ---------- UI ---------- #
def build_ui():
with gr.Blocks(title="GPT-Image-1 (BYOT)") as demo:
gr.Markdown("""# GPT-Image-1 Playground 🖼️🔑\nGenerate • Edit (paint mask!) • Variations""")
gr.Markdown(
"Enter your OpenAI API key below. It's used directly for API calls and **never stored**."
" This space uses the `gpt-image-1` model."
)
with gr.Accordion("🔐 API key", open=False):
api = gr.Textbox(label="OpenAI API key", type="password", placeholder="sk-…")
# Common controls
with gr.Row():
n_slider = gr.Slider(1, 4, value=1, step=1, label="Number of images (n)", info="Max 4 for this demo.") # Limit n for stability/cost
size = gr.Dropdown(SIZE_CHOICES, value="auto", label="Size", info="API default if 'auto'.")
quality = gr.Dropdown(QUALITY_CHOICES, value="auto", label="Quality", info="API default if 'auto'.")
with gr.Row():
out_fmt = gr.Radio(FORMAT_CHOICES, value="png", label="Format")
compression = gr.Slider(0, 100, value=75, step=1, label="Compression % (JPEG/WebP)", visible=False)
transparent = gr.Checkbox(False, label="Transparent background (PNG/WebP only)")
def _toggle_compression(fmt):
return gr.update(visible=fmt in {"jpeg", "webp"})
out_fmt.change(_toggle_compression, inputs=out_fmt, outputs=compression)
with gr.Tabs():
# ----- Generate Tab ----- #
with gr.TabItem("Generate"):
with gr.Row():
prompt_gen = gr.Textbox(label="Prompt", lines=3, placeholder="A photorealistic ginger cat astronaut on Mars", scale=4)
btn_gen = gr.Button("Generate 🚀", variant="primary", scale=1)
gallery_gen = gr.Gallery(label="Generated Images", columns=2, height="auto", preview=True)
btn_gen.click(
generate,
inputs=[api, prompt_gen, n_slider, size, quality, out_fmt, compression, transparent],
outputs=gallery_gen,
api_name="generate"
)
# ----- Edit Tab ----- #
with gr.TabItem("Edit / Inpaint"):
gr.Markdown("Upload an image, then **paint the area to change** in the mask canvas below (white = edit).")
with gr.Row():
img_edit = gr.Image(label="Source Image", type="numpy", height=400)
# Use ImageMask component for interactive painting
mask_canvas = gr.ImageMask(
label="Mask – Paint White Where Image Should Change",
type="numpy", # Get mask as numpy array
# brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"), # Force white brush
# mask_opacity=0.7 # Adjust mask visibility on image
height=400
)
with gr.Row():
prompt_edit = gr.Textbox(label="Edit prompt", lines=2, placeholder="Replace the sky with a starry night", scale=4)
btn_edit = gr.Button("Edit 🖌️", variant="primary", scale=1)
gallery_edit = gr.Gallery(label="Edited Images", columns=2, height="auto", preview=True)
btn_edit.click(
edit_image,
inputs=[api, img_edit, mask_canvas, prompt_edit, n_slider, size, quality, out_fmt, compression, transparent],
outputs=gallery_edit,
api_name="edit"
)
# ----- Variations Tab ----- #
with gr.TabItem("Variations"):
gr.Markdown("Upload an image to generate variations.")
with gr.Row():
img_var = gr.Image(label="Source Image", type="numpy", height=400, scale=4)
btn_var = gr.Button("Create Variations ✨", variant="primary", scale=1)
gallery_var = gr.Gallery(label="Variations", columns=2, height="auto", preview=True)
btn_var.click(
variation_image,
inputs=[api, img_var, n_slider, size, quality, out_fmt, compression, transparent],
outputs=gallery_var,
api_name="variations"
)
return demo
if __name__ == "__main__":
app = build_ui()
app.launch()