Spaces:
Running
Running
# app.py | |
""" | |
Gradio Space: GPT‑Image‑1 – BYOT playground | |
Generate · Edit (paint mask!) · Variations | |
========================================== | |
Adds an **in‑browser paint tool** for the edit / inpaint workflow so users can | |
draw the mask directly instead of uploading one. | |
### How mask painting works | |
* Upload an image. | |
* Use the *Mask* canvas to **paint the areas you’d like changed** (white = | |
editable, black = keep). | |
Gradio’s built‑in *sketch* tool captures your brush strokes. | |
* The painted mask is converted to a 1‑channel PNG and sent to the | |
`images.edit()` endpoint. | |
All other controls (size, quality, format, compression, n, background) stay the | |
same. | |
""" | |
from __future__ import annotations | |
import io | |
import os | |
from typing import List, Optional | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
import openai | |
MODEL = "gpt-image-1" | |
SIZE_CHOICES = ["auto", "1024x1024", "1536x1024", "1024x1536"] | |
QUALITY_CHOICES = ["auto", "low", "medium", "high"] | |
FORMAT_CHOICES = ["png", "jpeg", "webp"] | |
def _client(key: str) -> openai.OpenAI: | |
api_key = key.strip() or os.getenv("OPENAI_API_KEY", "") | |
if not api_key: | |
raise gr.Error("Please enter your OpenAI API key (never stored)") | |
return openai.OpenAI(api_key=api_key) | |
def _img_list(resp, *, fmt: str, transparent: bool) -> List[str]: | |
mime = "image/png" if fmt == "png" or transparent else f"image/{fmt}" | |
return [ | |
f"data:{mime};base64,{d.b64_json}" if hasattr(d, "b64_json") else d.url | |
for d in resp.data | |
] | |
def _common_kwargs( | |
prompt: Optional[str], | |
n: int, | |
size: str, | |
quality: str, | |
out_fmt: str, | |
compression: int, | |
transparent_bg: bool, | |
): | |
kwargs = dict( | |
model=MODEL, | |
n=n, | |
size=size, | |
quality=quality, | |
output_format=out_fmt, | |
transparent_background=transparent_bg, | |
response_format="url" if out_fmt == "png" and not transparent_bg else "b64_json", | |
) | |
if prompt is not None: | |
kwargs["prompt"] = prompt | |
if out_fmt in {"jpeg", "webp"}: | |
kwargs["compression"] = f"{compression}%" | |
return kwargs | |
# ---------- Generate ---------- # | |
def generate( | |
api_key: str, | |
prompt: str, | |
n: int, | |
size: str, | |
quality: str, | |
out_fmt: str, | |
compression: int, | |
transparent_bg: bool, | |
): | |
client = _client(api_key) | |
try: | |
resp = client.images.generate(**_common_kwargs(prompt, n, size, quality, out_fmt, compression, transparent_bg)) | |
except Exception as e: | |
raise gr.Error(f"OpenAI error: {e}") | |
return _img_list(resp, fmt=out_fmt, transparent=transparent_bg) | |
# ---------- Edit / Inpaint ---------- # | |
def _bytes_from_numpy(arr: np.ndarray) -> bytes: | |
"""Convert RGBA/RGB uint8 numpy array to PNG bytes.""" | |
img = Image.fromarray(arr.astype(np.uint8)) | |
out = io.BytesIO() | |
img.save(out, format="PNG") | |
return out.getvalue() | |
def edit_image( | |
api_key: str, | |
image_numpy: np.ndarray, | |
mask_numpy: Optional[np.ndarray], | |
prompt: str, | |
n: int, | |
size: str, | |
quality: str, | |
out_fmt: str, | |
compression: int, | |
transparent_bg: bool, | |
): | |
if image_numpy is None: | |
raise gr.Error("Please upload an image.") | |
img_bytes = _bytes_from_numpy(image_numpy) | |
mask_bytes: Optional[bytes] = None | |
if mask_numpy is not None: | |
# Convert painted area (alpha > 0) to white, else black; 1‑channel. | |
if mask_numpy.shape[-1] == 4: # RGBA from gr.Image sketch | |
alpha = mask_numpy[:, :, 3] | |
else: # RGB | |
alpha = np.any(mask_numpy != 0, axis=-1).astype(np.uint8) * 255 | |
bw = np.stack([alpha] * 3, axis=-1) # 3‑channel white/black | |
mask_bytes = _bytes_from_numpy(bw) | |
client = _client(api_key) | |
try: | |
resp = client.images.edit( | |
image=img_bytes, | |
mask=mask_bytes, | |
**_common_kwargs(prompt, n, size, quality, out_fmt, compression, transparent_bg), | |
) | |
except Exception as e: | |
raise gr.Error(f"OpenAI error: {e}") | |
return _img_list(resp, fmt=out_fmt, transparent=transparent_bg) | |
# ---------- Variations ---------- # | |
def variation_image( | |
api_key: str, | |
image_numpy: np.ndarray, | |
n: int, | |
size: str, | |
quality: str, | |
out_fmt: str, | |
compression: int, | |
transparent_bg: bool, | |
): | |
if image_numpy is None: | |
raise gr.Error("Please upload an image.") | |
img_bytes = _bytes_from_numpy(image_numpy) | |
client = _client(api_key) | |
try: | |
resp = client.images.variations( | |
image=img_bytes, | |
**_common_kwargs(None, n, size, quality, out_fmt, compression, transparent_bg), | |
) | |
except Exception as e: | |
raise gr.Error(f"OpenAI error: {e}") | |
return _img_list(resp, fmt=out_fmt, transparent=transparent_bg) | |
# ---------- UI ---------- # | |
def build_ui(): | |
with gr.Blocks(title="GPT‑Image‑1 (BYOT)") as demo: | |
gr.Markdown("""# GPT‑Image‑1 Playground 🖼️🔑\nGenerate • Edit (paint mask) • Variations""") | |
with gr.Accordion("🔐 API key", open=False): | |
api = gr.Textbox(label="OpenAI API key", type="password", placeholder="sk‑…") | |
# Common controls | |
n_slider = gr.Slider(1, 10, value=1, step=1, label="Number of images (n)") | |
size = gr.Dropdown(SIZE_CHOICES, value="auto", label="Size") | |
quality = gr.Dropdown(QUALITY_CHOICES, value="auto", label="Quality") | |
out_fmt = gr.Radio(FORMAT_CHOICES, value="png", label="Format") | |
compression = gr.Slider(0, 100, value=75, step=1, label="Compression (JPEG/WebP)") | |
transparent = gr.Checkbox(False, label="Transparent background (PNG only)") | |
def _toggle_compression(fmt): | |
return gr.update(visible=fmt in {"jpeg", "webp"}) | |
out_fmt.change(_toggle_compression, inputs=out_fmt, outputs=compression) | |
with gr.Tabs(): | |
# ----- Generate Tab ----- # | |
with gr.TabItem("Generate"): | |
prompt_gen = gr.Textbox(label="Prompt", lines=2, placeholder="A photorealistic ginger cat astronaut on Mars") | |
btn_gen = gr.Button("Generate 🚀") | |
gallery_gen = gr.Gallery(columns=2, height="auto") | |
btn_gen.click( | |
generate, | |
inputs=[api, prompt_gen, n_slider, size, quality, out_fmt, compression, transparent], | |
outputs=gallery_gen, | |
) | |
# ----- Edit Tab ----- # | |
with gr.TabItem("Edit / Inpaint"): | |
img_edit = gr.Image(label="Image", type="numpy") | |
mask_canvas = gr.Image(label="Mask – paint white where the image should change", type="numpy", tool="sketch") | |
prompt_edit = gr.Textbox(label="Edit prompt", lines=2, placeholder="Replace the sky with a starry night") | |
btn_edit = gr.Button("Edit 🖌️") | |
gallery_edit = gr.Gallery(columns=2, height="auto") | |
btn_edit.click( | |
edit_image, | |
inputs=[api, img_edit, mask_canvas, prompt_edit, n_slider, size, quality, out_fmt, compression, transparent], | |
outputs=gallery_edit, | |
) | |
# ----- Variations Tab ----- # | |
with gr.TabItem("Variations"): | |
img_var = gr.Image(label="Source image", type="numpy") | |
btn_var = gr.Button("Variations 🔄") | |
gallery_var = gr.Gallery(columns=2, height="auto") | |
btn_var.click( | |
variation_image, | |
inputs=[api, img_var, n_slider, size, quality, out_fmt, compression, transparent], | |
outputs=gallery_var, | |
) | |
return demo | |
demo = build_ui() | |
if __name__ == "__main__": | |
demo.launch() | |