Zack3D's picture
Create app.py
c164914 verified
raw
history blame
7.84 kB
# app.py
"""
Gradio Space: GPT‑Image‑1 – BYOT playground
Generate · Edit (paint mask!) · Variations
==========================================
Adds an **in‑browser paint tool** for the edit / inpaint workflow so users can
draw the mask directly instead of uploading one.
### How mask painting works
* Upload an image.
* Use the *Mask* canvas to **paint the areas you’d like changed** (white =
editable, black = keep).
Gradio’s built‑in *sketch* tool captures your brush strokes.
* The painted mask is converted to a 1‑channel PNG and sent to the
`images.edit()` endpoint.
All other controls (size, quality, format, compression, n, background) stay the
same.
"""
from __future__ import annotations
import io
import os
from typing import List, Optional
import gradio as gr
import numpy as np
from PIL import Image
import openai
MODEL = "gpt-image-1"
SIZE_CHOICES = ["auto", "1024x1024", "1536x1024", "1024x1536"]
QUALITY_CHOICES = ["auto", "low", "medium", "high"]
FORMAT_CHOICES = ["png", "jpeg", "webp"]
def _client(key: str) -> openai.OpenAI:
api_key = key.strip() or os.getenv("OPENAI_API_KEY", "")
if not api_key:
raise gr.Error("Please enter your OpenAI API key (never stored)")
return openai.OpenAI(api_key=api_key)
def _img_list(resp, *, fmt: str, transparent: bool) -> List[str]:
mime = "image/png" if fmt == "png" or transparent else f"image/{fmt}"
return [
f"data:{mime};base64,{d.b64_json}" if hasattr(d, "b64_json") else d.url
for d in resp.data
]
def _common_kwargs(
prompt: Optional[str],
n: int,
size: str,
quality: str,
out_fmt: str,
compression: int,
transparent_bg: bool,
):
kwargs = dict(
model=MODEL,
n=n,
size=size,
quality=quality,
output_format=out_fmt,
transparent_background=transparent_bg,
response_format="url" if out_fmt == "png" and not transparent_bg else "b64_json",
)
if prompt is not None:
kwargs["prompt"] = prompt
if out_fmt in {"jpeg", "webp"}:
kwargs["compression"] = f"{compression}%"
return kwargs
# ---------- Generate ---------- #
def generate(
api_key: str,
prompt: str,
n: int,
size: str,
quality: str,
out_fmt: str,
compression: int,
transparent_bg: bool,
):
client = _client(api_key)
try:
resp = client.images.generate(**_common_kwargs(prompt, n, size, quality, out_fmt, compression, transparent_bg))
except Exception as e:
raise gr.Error(f"OpenAI error: {e}")
return _img_list(resp, fmt=out_fmt, transparent=transparent_bg)
# ---------- Edit / Inpaint ---------- #
def _bytes_from_numpy(arr: np.ndarray) -> bytes:
"""Convert RGBA/RGB uint8 numpy array to PNG bytes."""
img = Image.fromarray(arr.astype(np.uint8))
out = io.BytesIO()
img.save(out, format="PNG")
return out.getvalue()
def edit_image(
api_key: str,
image_numpy: np.ndarray,
mask_numpy: Optional[np.ndarray],
prompt: str,
n: int,
size: str,
quality: str,
out_fmt: str,
compression: int,
transparent_bg: bool,
):
if image_numpy is None:
raise gr.Error("Please upload an image.")
img_bytes = _bytes_from_numpy(image_numpy)
mask_bytes: Optional[bytes] = None
if mask_numpy is not None:
# Convert painted area (alpha > 0) to white, else black; 1‑channel.
if mask_numpy.shape[-1] == 4: # RGBA from gr.Image sketch
alpha = mask_numpy[:, :, 3]
else: # RGB
alpha = np.any(mask_numpy != 0, axis=-1).astype(np.uint8) * 255
bw = np.stack([alpha] * 3, axis=-1) # 3‑channel white/black
mask_bytes = _bytes_from_numpy(bw)
client = _client(api_key)
try:
resp = client.images.edit(
image=img_bytes,
mask=mask_bytes,
**_common_kwargs(prompt, n, size, quality, out_fmt, compression, transparent_bg),
)
except Exception as e:
raise gr.Error(f"OpenAI error: {e}")
return _img_list(resp, fmt=out_fmt, transparent=transparent_bg)
# ---------- Variations ---------- #
def variation_image(
api_key: str,
image_numpy: np.ndarray,
n: int,
size: str,
quality: str,
out_fmt: str,
compression: int,
transparent_bg: bool,
):
if image_numpy is None:
raise gr.Error("Please upload an image.")
img_bytes = _bytes_from_numpy(image_numpy)
client = _client(api_key)
try:
resp = client.images.variations(
image=img_bytes,
**_common_kwargs(None, n, size, quality, out_fmt, compression, transparent_bg),
)
except Exception as e:
raise gr.Error(f"OpenAI error: {e}")
return _img_list(resp, fmt=out_fmt, transparent=transparent_bg)
# ---------- UI ---------- #
def build_ui():
with gr.Blocks(title="GPT‑Image‑1 (BYOT)") as demo:
gr.Markdown("""# GPT‑Image‑1 Playground 🖼️🔑\nGenerate • Edit (paint mask) • Variations""")
with gr.Accordion("🔐 API key", open=False):
api = gr.Textbox(label="OpenAI API key", type="password", placeholder="sk‑…")
# Common controls
n_slider = gr.Slider(1, 10, value=1, step=1, label="Number of images (n)")
size = gr.Dropdown(SIZE_CHOICES, value="auto", label="Size")
quality = gr.Dropdown(QUALITY_CHOICES, value="auto", label="Quality")
out_fmt = gr.Radio(FORMAT_CHOICES, value="png", label="Format")
compression = gr.Slider(0, 100, value=75, step=1, label="Compression (JPEG/WebP)")
transparent = gr.Checkbox(False, label="Transparent background (PNG only)")
def _toggle_compression(fmt):
return gr.update(visible=fmt in {"jpeg", "webp"})
out_fmt.change(_toggle_compression, inputs=out_fmt, outputs=compression)
with gr.Tabs():
# ----- Generate Tab ----- #
with gr.TabItem("Generate"):
prompt_gen = gr.Textbox(label="Prompt", lines=2, placeholder="A photorealistic ginger cat astronaut on Mars")
btn_gen = gr.Button("Generate 🚀")
gallery_gen = gr.Gallery(columns=2, height="auto")
btn_gen.click(
generate,
inputs=[api, prompt_gen, n_slider, size, quality, out_fmt, compression, transparent],
outputs=gallery_gen,
)
# ----- Edit Tab ----- #
with gr.TabItem("Edit / Inpaint"):
img_edit = gr.Image(label="Image", type="numpy")
mask_canvas = gr.Image(label="Mask – paint white where the image should change", type="numpy", tool="sketch")
prompt_edit = gr.Textbox(label="Edit prompt", lines=2, placeholder="Replace the sky with a starry night")
btn_edit = gr.Button("Edit 🖌️")
gallery_edit = gr.Gallery(columns=2, height="auto")
btn_edit.click(
edit_image,
inputs=[api, img_edit, mask_canvas, prompt_edit, n_slider, size, quality, out_fmt, compression, transparent],
outputs=gallery_edit,
)
# ----- Variations Tab ----- #
with gr.TabItem("Variations"):
img_var = gr.Image(label="Source image", type="numpy")
btn_var = gr.Button("Variations 🔄")
gallery_var = gr.Gallery(columns=2, height="auto")
btn_var.click(
variation_image,
inputs=[api, img_var, n_slider, size, quality, out_fmt, compression, transparent],
outputs=gallery_var,
)
return demo
demo = build_ui()
if __name__ == "__main__":
demo.launch()