Zack3D's picture
Update app.py
55375ee verified
raw
history blame
8.79 kB
"""
Gradio Space: GPT‑Image‑1 – BYOT playground
Generate · Edit (paint mask!) · Variations
==========================================
Adds an **in‑browser paint tool** for the edit / inpaint workflow so users can
draw the mask directly instead of uploading one.
### How mask painting works
* Upload an image.
* Use the *Mask* canvas to **paint the areas you’d like changed** (white =
editable, black = keep).
The new `gr.ImageMask` component captures your brush strokes.
* The painted mask is converted to a 1‑channel PNG and sent to the
`images.edit()` endpoint.
All other controls (size, quality, format, compression, n, background) stay the
same.
"""
from __future__ import annotations
import io
import os
from typing import List, Optional, Union, Dict, Any
import gradio as gr
import numpy as np
from PIL import Image
import openai
MODEL = "gpt-image-1"
SIZE_CHOICES = ["auto", "1024x1024", "1536x1024", "1024x1536"]
QUALITY_CHOICES = ["auto", "low", "medium", "high"]
FORMAT_CHOICES = ["png", "jpeg", "webp"]
def _client(key: str) -> openai.OpenAI:
api_key = key.strip() or os.getenv("OPENAI_API_KEY", "")
if not api_key:
raise gr.Error("Please enter your OpenAI API key (never stored)")
return openai.OpenAI(api_key=api_key)
def _img_list(resp, *, fmt: str, transparent: bool) -> List[str]:
mime = "image/png" if fmt == "png" or transparent else f"image/{fmt}"
return [
f"data:{mime};base64,{d.b64_json}" if hasattr(d, "b64_json") else d.url
for d in resp.data
]
def _common_kwargs(
prompt: Optional[str],
n: int,
size: str,
quality: str,
out_fmt: str,
compression: int,
transparent_bg: bool,
):
kwargs = dict(
model=MODEL,
n=n,
size=size,
quality=quality,
output_format=out_fmt,
transparent_background=transparent_bg,
response_format="url" if out_fmt == "png" and not transparent_bg else "b64_json",
)
if prompt is not None:
kwargs["prompt"] = prompt
if out_fmt in {"jpeg", "webp"}:
kwargs["compression"] = f"{compression}%"
return kwargs
# ---------- Generate ---------- #
def generate(
api_key: str,
prompt: str,
n: int,
size: str,
quality: str,
out_fmt: str,
compression: int,
transparent_bg: bool,
):
client = _client(api_key)
try:
resp = client.images.generate(**_common_kwargs(prompt, n, size, quality, out_fmt, compression, transparent_bg))
except Exception as e:
raise gr.Error(f"OpenAI error: {e}")
return _img_list(resp, fmt=out_fmt, transparent=transparent_bg)
# ---------- Edit / Inpaint ---------- #
def _bytes_from_numpy(arr: np.ndarray) -> bytes:
"""Convert RGBA/RGB uint8 numpy array to PNG bytes."""
img = Image.fromarray(arr.astype(np.uint8))
out = io.BytesIO()
img.save(out, format="PNG")
return out.getvalue()
def _extract_mask_array(mask_value: Union[np.ndarray, Dict[str, Any], None]) -> Optional[np.ndarray]:
"""Handle ImageMask / ImageEditor return formats and extract a numpy mask array."""
if mask_value is None:
return None
# If we already have a numpy array (ImageMask with type="numpy")
if isinstance(mask_value, np.ndarray):
return mask_value
# If it's an EditorValue dict coming from ImageEditor/ImageMask with type="numpy"
if isinstance(mask_value, dict):
# Prefer the composite (all layers merged) if present
comp = mask_value.get("composite")
if comp is not None:
return np.asarray(comp)
# Fallback to the topmost layer
layers = mask_value.get("layers")
if layers:
return np.asarray(layers[-1])
# Unknown format – ignore
return None
def edit_image(
api_key: str,
image_numpy: np.ndarray,
mask_value: Optional[Union[np.ndarray, Dict[str, Any]]],
prompt: str,
n: int,
size: str,
quality: str,
out_fmt: str,
compression: int,
transparent_bg: bool,
):
if image_numpy is None:
raise gr.Error("Please upload an image.")
img_bytes = _bytes_from_numpy(image_numpy)
mask_bytes: Optional[bytes] = None
mask_numpy = _extract_mask_array(mask_value)
if mask_numpy is not None:
# Convert painted area (any non‑zero pixel) to white, else black; 1‑channel.
if mask_numpy.shape[-1] == 4: # RGBA (has alpha channel)
alpha = mask_numpy[:, :, 3]
else: # RGB or grayscale
alpha = np.any(mask_numpy != 0, axis=-1).astype(np.uint8) * 255
bw = np.stack([alpha] * 3, axis=-1) # 3‑channel white/black
mask_bytes = _bytes_from_numpy(bw)
client = _client(api_key)
try:
resp = client.images.edit(
image=img_bytes,
mask=mask_bytes,
**_common_kwargs(prompt, n, size, quality, out_fmt, compression, transparent_bg),
)
except Exception as e:
raise gr.Error(f"OpenAI error: {e}")
return _img_list(resp, fmt=out_fmt, transparent=transparent_bg)
# ---------- Variations ---------- #
def variation_image(
api_key: str,
image_numpy: np.ndarray,
n: int,
size: str,
quality: str,
out_fmt: str,
compression: int,
transparent_bg: bool,
):
if image_numpy is None:
raise gr.Error("Please upload an image.")
img_bytes = _bytes_from_numpy(image_numpy)
client = _client(api_key)
try:
resp = client.images.variations(
image=img_bytes,
**_common_kwargs(None, n, size, quality, out_fmt, compression, transparent_bg),
)
except Exception as e:
raise gr.Error(f"OpenAI error: {e}")
return _img_list(resp, fmt=out_fmt, transparent=transparent_bg)
# ---------- UI ---------- #
def build_ui():
with gr.Blocks(title="GPT‑Image‑1 (BYOT)") as demo:
gr.Markdown("""# GPT‑Image‑1 Playground 🖼️🔑\nGenerate • Edit (paint mask) • Variations""")
with gr.Accordion("🔐 API key", open=False):
api = gr.Textbox(label="OpenAI API key", type="password", placeholder="sk‑…")
# Common controls
n_slider = gr.Slider(1, 10, value=1, step=1, label="Number of images (n)")
size = gr.Dropdown(SIZE_CHOICES, value="auto", label="Size")
quality = gr.Dropdown(QUALITY_CHOICES, value="auto", label="Quality")
out_fmt = gr.Radio(FORMAT_CHOICES, value="png", label="Format")
compression = gr.Slider(0, 100, value=75, step=1, label="Compression (JPEG/WebP)")
transparent = gr.Checkbox(False, label="Transparent background (PNG only)")
def _toggle_compression(fmt):
return gr.update(visible=fmt in {"jpeg", "webp"})
out_fmt.change(_toggle_compression, inputs=out_fmt, outputs=compression)
with gr.Tabs():
# ----- Generate Tab ----- #
with gr.TabItem("Generate"):
prompt_gen = gr.Textbox(label="Prompt", lines=2, placeholder="A photorealistic ginger cat astronaut on Mars")
btn_gen = gr.Button("Generate 🚀")
gallery_gen = gr.Gallery(columns=2, height="auto")
btn_gen.click(
generate,
inputs=[api, prompt_gen, n_slider, size, quality, out_fmt, compression, transparent],
outputs=gallery_gen,
)
# ----- Edit Tab ----- #
with gr.TabItem("Edit / Inpaint"):
img_edit = gr.Image(label="Image", type="numpy")
mask_canvas = gr.ImageMask(label="Mask – paint white where the image should change", type="numpy")
prompt_edit = gr.Textbox(label="Edit prompt", lines=2, placeholder="Replace the sky with a starry night")
btn_edit = gr.Button("Edit 🖌️")
gallery_edit = gr.Gallery(columns=2, height="auto")
btn_edit.click(
edit_image,
inputs=[api, img_edit, mask_canvas, prompt_edit, n_slider, size, quality, out_fmt, compression, transparent],
outputs=gallery_edit,
)
# ----- Variations Tab ----- #
with gr.TabItem("Variations"):
img_var = gr.Image(label="Source image", type="numpy")
btn_var = gr.Button("Variations 🔄")
gallery_var = gr.Gallery(columns=2, height="auto")
btn_var.click(
variation_image,
inputs=[api, img_var, n_slider, size, quality, out_fmt, compression, transparent],
outputs=gallery_var,
)
return demo
demo = build_ui()
if __name__ == "__main__":
demo.launch()