Zack3D commited on
Commit
c164914
·
verified ·
1 Parent(s): 8945cd4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +236 -0
app.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ """
3
+ Gradio Space: GPT‑Image‑1 – BYOT playground
4
+ Generate · Edit (paint mask!) · Variations
5
+ ==========================================
6
+ Adds an **in‑browser paint tool** for the edit / inpaint workflow so users can
7
+ draw the mask directly instead of uploading one.
8
+
9
+ ### How mask painting works
10
+ * Upload an image.
11
+ * Use the *Mask* canvas to **paint the areas you’d like changed** (white =
12
+ editable, black = keep).
13
+ Gradio’s built‑in *sketch* tool captures your brush strokes.
14
+ * The painted mask is converted to a 1‑channel PNG and sent to the
15
+ `images.edit()` endpoint.
16
+
17
+ All other controls (size, quality, format, compression, n, background) stay the
18
+ same.
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import io
24
+ import os
25
+ from typing import List, Optional
26
+
27
+ import gradio as gr
28
+ import numpy as np
29
+ from PIL import Image
30
+ import openai
31
+
32
+ MODEL = "gpt-image-1"
33
+ SIZE_CHOICES = ["auto", "1024x1024", "1536x1024", "1024x1536"]
34
+ QUALITY_CHOICES = ["auto", "low", "medium", "high"]
35
+ FORMAT_CHOICES = ["png", "jpeg", "webp"]
36
+
37
+
38
+ def _client(key: str) -> openai.OpenAI:
39
+ api_key = key.strip() or os.getenv("OPENAI_API_KEY", "")
40
+ if not api_key:
41
+ raise gr.Error("Please enter your OpenAI API key (never stored)")
42
+ return openai.OpenAI(api_key=api_key)
43
+
44
+
45
+ def _img_list(resp, *, fmt: str, transparent: bool) -> List[str]:
46
+ mime = "image/png" if fmt == "png" or transparent else f"image/{fmt}"
47
+ return [
48
+ f"data:{mime};base64,{d.b64_json}" if hasattr(d, "b64_json") else d.url
49
+ for d in resp.data
50
+ ]
51
+
52
+
53
+ def _common_kwargs(
54
+ prompt: Optional[str],
55
+ n: int,
56
+ size: str,
57
+ quality: str,
58
+ out_fmt: str,
59
+ compression: int,
60
+ transparent_bg: bool,
61
+ ):
62
+ kwargs = dict(
63
+ model=MODEL,
64
+ n=n,
65
+ size=size,
66
+ quality=quality,
67
+ output_format=out_fmt,
68
+ transparent_background=transparent_bg,
69
+ response_format="url" if out_fmt == "png" and not transparent_bg else "b64_json",
70
+ )
71
+ if prompt is not None:
72
+ kwargs["prompt"] = prompt
73
+ if out_fmt in {"jpeg", "webp"}:
74
+ kwargs["compression"] = f"{compression}%"
75
+ return kwargs
76
+
77
+
78
+ # ---------- Generate ---------- #
79
+
80
+ def generate(
81
+ api_key: str,
82
+ prompt: str,
83
+ n: int,
84
+ size: str,
85
+ quality: str,
86
+ out_fmt: str,
87
+ compression: int,
88
+ transparent_bg: bool,
89
+ ):
90
+ client = _client(api_key)
91
+ try:
92
+ resp = client.images.generate(**_common_kwargs(prompt, n, size, quality, out_fmt, compression, transparent_bg))
93
+ except Exception as e:
94
+ raise gr.Error(f"OpenAI error: {e}")
95
+ return _img_list(resp, fmt=out_fmt, transparent=transparent_bg)
96
+
97
+
98
+ # ---------- Edit / Inpaint ---------- #
99
+
100
+ def _bytes_from_numpy(arr: np.ndarray) -> bytes:
101
+ """Convert RGBA/RGB uint8 numpy array to PNG bytes."""
102
+ img = Image.fromarray(arr.astype(np.uint8))
103
+ out = io.BytesIO()
104
+ img.save(out, format="PNG")
105
+ return out.getvalue()
106
+
107
+
108
+ def edit_image(
109
+ api_key: str,
110
+ image_numpy: np.ndarray,
111
+ mask_numpy: Optional[np.ndarray],
112
+ prompt: str,
113
+ n: int,
114
+ size: str,
115
+ quality: str,
116
+ out_fmt: str,
117
+ compression: int,
118
+ transparent_bg: bool,
119
+ ):
120
+ if image_numpy is None:
121
+ raise gr.Error("Please upload an image.")
122
+ img_bytes = _bytes_from_numpy(image_numpy)
123
+
124
+ mask_bytes: Optional[bytes] = None
125
+ if mask_numpy is not None:
126
+ # Convert painted area (alpha > 0) to white, else black; 1‑channel.
127
+ if mask_numpy.shape[-1] == 4: # RGBA from gr.Image sketch
128
+ alpha = mask_numpy[:, :, 3]
129
+ else: # RGB
130
+ alpha = np.any(mask_numpy != 0, axis=-1).astype(np.uint8) * 255
131
+ bw = np.stack([alpha] * 3, axis=-1) # 3‑channel white/black
132
+ mask_bytes = _bytes_from_numpy(bw)
133
+
134
+ client = _client(api_key)
135
+ try:
136
+ resp = client.images.edit(
137
+ image=img_bytes,
138
+ mask=mask_bytes,
139
+ **_common_kwargs(prompt, n, size, quality, out_fmt, compression, transparent_bg),
140
+ )
141
+ except Exception as e:
142
+ raise gr.Error(f"OpenAI error: {e}")
143
+ return _img_list(resp, fmt=out_fmt, transparent=transparent_bg)
144
+
145
+
146
+ # ---------- Variations ---------- #
147
+
148
+ def variation_image(
149
+ api_key: str,
150
+ image_numpy: np.ndarray,
151
+ n: int,
152
+ size: str,
153
+ quality: str,
154
+ out_fmt: str,
155
+ compression: int,
156
+ transparent_bg: bool,
157
+ ):
158
+ if image_numpy is None:
159
+ raise gr.Error("Please upload an image.")
160
+ img_bytes = _bytes_from_numpy(image_numpy)
161
+ client = _client(api_key)
162
+ try:
163
+ resp = client.images.variations(
164
+ image=img_bytes,
165
+ **_common_kwargs(None, n, size, quality, out_fmt, compression, transparent_bg),
166
+ )
167
+ except Exception as e:
168
+ raise gr.Error(f"OpenAI error: {e}")
169
+ return _img_list(resp, fmt=out_fmt, transparent=transparent_bg)
170
+
171
+
172
+ # ---------- UI ---------- #
173
+
174
+ def build_ui():
175
+ with gr.Blocks(title="GPT‑Image‑1 (BYOT)") as demo:
176
+ gr.Markdown("""# GPT‑Image‑1 Playground 🖼️🔑\nGenerate • Edit (paint mask) • Variations""")
177
+
178
+ with gr.Accordion("🔐 API key", open=False):
179
+ api = gr.Textbox(label="OpenAI API key", type="password", placeholder="sk‑…")
180
+
181
+ # Common controls
182
+ n_slider = gr.Slider(1, 10, value=1, step=1, label="Number of images (n)")
183
+ size = gr.Dropdown(SIZE_CHOICES, value="auto", label="Size")
184
+ quality = gr.Dropdown(QUALITY_CHOICES, value="auto", label="Quality")
185
+ out_fmt = gr.Radio(FORMAT_CHOICES, value="png", label="Format")
186
+ compression = gr.Slider(0, 100, value=75, step=1, label="Compression (JPEG/WebP)")
187
+ transparent = gr.Checkbox(False, label="Transparent background (PNG only)")
188
+
189
+ def _toggle_compression(fmt):
190
+ return gr.update(visible=fmt in {"jpeg", "webp"})
191
+
192
+ out_fmt.change(_toggle_compression, inputs=out_fmt, outputs=compression)
193
+
194
+ with gr.Tabs():
195
+ # ----- Generate Tab ----- #
196
+ with gr.TabItem("Generate"):
197
+ prompt_gen = gr.Textbox(label="Prompt", lines=2, placeholder="A photorealistic ginger cat astronaut on Mars")
198
+ btn_gen = gr.Button("Generate 🚀")
199
+ gallery_gen = gr.Gallery(columns=2, height="auto")
200
+ btn_gen.click(
201
+ generate,
202
+ inputs=[api, prompt_gen, n_slider, size, quality, out_fmt, compression, transparent],
203
+ outputs=gallery_gen,
204
+ )
205
+
206
+ # ----- Edit Tab ----- #
207
+ with gr.TabItem("Edit / Inpaint"):
208
+ img_edit = gr.Image(label="Image", type="numpy")
209
+ mask_canvas = gr.Image(label="Mask – paint white where the image should change", type="numpy", tool="sketch")
210
+ prompt_edit = gr.Textbox(label="Edit prompt", lines=2, placeholder="Replace the sky with a starry night")
211
+ btn_edit = gr.Button("Edit 🖌️")
212
+ gallery_edit = gr.Gallery(columns=2, height="auto")
213
+ btn_edit.click(
214
+ edit_image,
215
+ inputs=[api, img_edit, mask_canvas, prompt_edit, n_slider, size, quality, out_fmt, compression, transparent],
216
+ outputs=gallery_edit,
217
+ )
218
+
219
+ # ----- Variations Tab ----- #
220
+ with gr.TabItem("Variations"):
221
+ img_var = gr.Image(label="Source image", type="numpy")
222
+ btn_var = gr.Button("Variations 🔄")
223
+ gallery_var = gr.Gallery(columns=2, height="auto")
224
+ btn_var.click(
225
+ variation_image,
226
+ inputs=[api, img_var, n_slider, size, quality, out_fmt, compression, transparent],
227
+ outputs=gallery_var,
228
+ )
229
+
230
+ return demo
231
+
232
+
233
+ demo = build_ui()
234
+
235
+ if __name__ == "__main__":
236
+ demo.launch()