Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -34,13 +34,22 @@ def _img_list(resp) -> List[Union[np.ndarray, str]]:
|
|
34 |
Decode base64 images into numpy arrays (for Gradio) or pass URL strings directly.
|
35 |
"""
|
36 |
imgs: List[Union[np.ndarray, str]] = []
|
|
|
|
|
|
|
|
|
37 |
for d in resp.data:
|
38 |
if hasattr(d, "b64_json") and d.b64_json:
|
39 |
-
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
42 |
elif getattr(d, "url", None):
|
43 |
imgs.append(d.url)
|
|
|
|
|
44 |
return imgs
|
45 |
|
46 |
|
@@ -49,24 +58,38 @@ def _common_kwargs(
|
|
49 |
n: int,
|
50 |
size: str,
|
51 |
quality: str,
|
52 |
-
out_fmt: str,
|
53 |
-
compression: int,
|
54 |
-
transparent_bg: bool,
|
55 |
) -> Dict[str, Any]:
|
56 |
"""Prepare keyword args for OpenAI Images API."""
|
57 |
kwargs: Dict[str, Any] = {
|
58 |
"model": MODEL,
|
59 |
"n": n,
|
|
|
|
|
60 |
}
|
61 |
if size != "auto":
|
62 |
kwargs["size"] = size
|
|
|
|
|
63 |
if quality != "auto":
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
if prompt is not None:
|
66 |
kwargs["prompt"] = prompt
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
70 |
return kwargs
|
71 |
|
72 |
|
@@ -77,45 +100,98 @@ def convert_to_format(
|
|
77 |
) -> np.ndarray:
|
78 |
"""
|
79 |
Convert a PIL numpy array to target_fmt (JPEG/WebP) and return as numpy array.
|
|
|
80 |
"""
|
|
|
|
|
|
|
|
|
81 |
img = Image.fromarray(img_array.astype(np.uint8))
|
82 |
buf = io.BytesIO()
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
|
89 |
def _format_openai_error(e: Exception) -> str:
|
|
|
90 |
error_message = f"An error occurred: {type(e).__name__}"
|
91 |
details = ""
|
92 |
if hasattr(e, 'body') and e.body:
|
93 |
try:
|
94 |
-
|
|
|
95 |
if isinstance(body, dict) and 'error' in body and isinstance(body['error'], dict) and 'message' in body['error']:
|
96 |
details = body['error']['message']
|
97 |
-
elif isinstance(body, dict) and 'message' in body:
|
98 |
details = body['message']
|
99 |
-
|
|
|
|
|
|
|
100 |
details = str(e.body)
|
101 |
-
|
|
|
|
|
|
|
102 |
details = e.message
|
|
|
103 |
if details:
|
104 |
error_message = f"OpenAI API Error: {details}"
|
|
|
|
|
|
|
|
|
|
|
105 |
if isinstance(e, openai.AuthenticationError):
|
106 |
-
error_message = "Invalid OpenAI API key. Please check your key."
|
107 |
elif isinstance(e, openai.PermissionDeniedError):
|
108 |
prefix = "Permission Denied."
|
109 |
-
if "organization verification" in details.lower():
|
110 |
-
prefix += " Your organization may need verification to use this feature/model."
|
|
|
|
|
111 |
error_message = f"{prefix} Details: {details}" if details else prefix
|
112 |
elif isinstance(e, openai.RateLimitError):
|
113 |
-
error_message = "Rate limit exceeded. Please wait and try again later."
|
114 |
elif isinstance(e, openai.BadRequestError):
|
115 |
error_message = f"OpenAI Bad Request: {details or str(e)}"
|
116 |
-
if
|
117 |
-
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
return error_message
|
120 |
|
121 |
|
@@ -128,86 +204,217 @@ def generate(
|
|
128 |
quality: str,
|
129 |
out_fmt: str,
|
130 |
compression: int,
|
131 |
-
transparent_bg: bool,
|
132 |
):
|
133 |
if not prompt:
|
134 |
raise gr.Error("Please enter a prompt.")
|
135 |
try:
|
136 |
client = _client(api_key)
|
|
|
137 |
common_args = _common_kwargs(prompt, n, size, quality, out_fmt, compression, transparent_bg)
|
|
|
|
|
|
|
138 |
resp = client.images.generate(**common_args)
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
except (openai.APIError, openai.OpenAIError) as e:
|
|
|
144 |
raise gr.Error(_format_openai_error(e))
|
145 |
except Exception as e:
|
146 |
print(f"Unexpected error during generation: {type(e).__name__}: {e}")
|
|
|
|
|
147 |
raise gr.Error("An unexpected application error occurred. Please check logs.")
|
148 |
|
149 |
|
150 |
# ---------- Edit / Inpaint ---------- #
|
151 |
-
def _bytes_from_numpy(arr: np.ndarray) -> bytes:
|
|
|
152 |
img = Image.fromarray(arr.astype(np.uint8))
|
153 |
buf = io.BytesIO()
|
154 |
-
img.save(buf, format=
|
155 |
return buf.getvalue()
|
156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
|
158 |
def _extract_mask_array(mask_value: Union[np.ndarray, Dict[str, Any], None]) -> Optional[np.ndarray]:
|
|
|
159 |
if mask_value is None:
|
|
|
160 |
return None
|
|
|
|
|
|
|
161 |
if isinstance(mask_value, dict):
|
162 |
mask_array = mask_value.get("mask")
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
return None
|
168 |
|
169 |
|
170 |
def edit_image(
|
171 |
api_key: str,
|
172 |
image_numpy: Optional[np.ndarray],
|
173 |
-
|
174 |
prompt: str,
|
175 |
n: int,
|
176 |
size: str,
|
177 |
quality: str,
|
178 |
out_fmt: str,
|
179 |
compression: int,
|
180 |
-
transparent_bg: bool,
|
181 |
):
|
182 |
if image_numpy is None:
|
183 |
raise gr.Error("Please upload an image.")
|
184 |
if not prompt:
|
185 |
raise gr.Error("Please enter an edit prompt.")
|
186 |
|
187 |
-
|
188 |
-
|
189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
|
191 |
-
# (Mask handling code unchanged)
|
192 |
if mask_numpy is not None:
|
193 |
-
|
194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
|
196 |
try:
|
197 |
client = _client(api_key)
|
|
|
198 |
common_args = _common_kwargs(prompt, n, size, quality, out_fmt, compression, transparent_bg)
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
except (openai.APIError, openai.OpenAIError) as e:
|
|
|
208 |
raise gr.Error(_format_openai_error(e))
|
209 |
except Exception as e:
|
210 |
print(f"Unexpected error during edit: {type(e).__name__}: {e}")
|
|
|
|
|
211 |
raise gr.Error("An unexpected application error occurred. Please check logs.")
|
212 |
|
213 |
|
@@ -217,80 +424,163 @@ def variation_image(
|
|
217 |
image_numpy: Optional[np.ndarray],
|
218 |
n: int,
|
219 |
size: str,
|
220 |
-
quality: str,
|
221 |
out_fmt: str,
|
222 |
compression: int,
|
223 |
-
transparent_bg: bool,
|
224 |
):
|
225 |
-
|
|
|
226 |
if image_numpy is None:
|
227 |
raise gr.Error("Please upload an image.")
|
228 |
|
229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
try:
|
231 |
client = _client(api_key)
|
232 |
-
|
|
|
|
|
|
|
|
|
|
|
233 |
if size != "auto":
|
234 |
var_args["size"] = size
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
except (openai.APIError, openai.OpenAIError) as e:
|
241 |
-
|
|
|
|
|
|
|
|
|
|
|
242 |
except Exception as e:
|
243 |
print(f"Unexpected error during variation: {type(e).__name__}: {e}")
|
|
|
|
|
244 |
raise gr.Error("An unexpected application error occurred. Please check logs.")
|
245 |
|
246 |
|
247 |
# ---------- UI ---------- #
|
248 |
def build_ui():
|
249 |
-
with gr.Blocks(title="
|
250 |
-
gr.Markdown("""#
|
251 |
-
|
|
|
|
|
|
|
252 |
api = gr.Textbox(label="OpenAI API key", type="password", placeholder="sk-...")
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
|
254 |
with gr.Row():
|
255 |
n_slider = gr.Slider(1, 4, value=1, step=1, label="Number of images (n)")
|
256 |
-
size = gr.Dropdown(SIZE_CHOICES, value="auto", label="Size")
|
257 |
-
quality = gr.Dropdown(QUALITY_CHOICES, value="auto", label="Quality")
|
258 |
with gr.Row():
|
259 |
-
out_fmt = gr.Radio(FORMAT_CHOICES, value="png", label="Output Format")
|
260 |
-
compression = gr.Slider(0, 100, value=75, step=1, label="Compression % (JPEG/WebP)", visible=False)
|
261 |
-
|
|
|
|
|
262 |
|
263 |
def _toggle_compression(fmt):
|
264 |
return gr.update(visible=fmt in {"jpeg", "webp"})
|
265 |
out_fmt.change(_toggle_compression, inputs=out_fmt, outputs=compression)
|
266 |
|
|
|
267 |
common_controls = [n_slider, size, quality, out_fmt, compression, transparent]
|
268 |
|
269 |
with gr.Tabs():
|
270 |
with gr.TabItem("Generate"):
|
271 |
-
prompt_gen = gr.Textbox(label="Prompt", lines=3, placeholder="A photorealistic..." )
|
272 |
-
btn_gen = gr.Button("Generate 🚀")
|
273 |
-
gallery_gen = gr.Gallery(columns=2, height="auto")
|
274 |
-
|
|
|
|
|
|
|
|
|
275 |
|
276 |
with gr.TabItem("Edit / Inpaint"):
|
277 |
-
gr.Markdown("Upload an image,
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
gr.
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
return demo
|
292 |
|
293 |
|
294 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
295 |
app = build_ui()
|
296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
Decode base64 images into numpy arrays (for Gradio) or pass URL strings directly.
|
35 |
"""
|
36 |
imgs: List[Union[np.ndarray, str]] = []
|
37 |
+
if not resp or not hasattr(resp, 'data'):
|
38 |
+
print("Warning: Response object missing or has no 'data' attribute.")
|
39 |
+
return imgs # Return empty list if response is invalid
|
40 |
+
|
41 |
for d in resp.data:
|
42 |
if hasattr(d, "b64_json") and d.b64_json:
|
43 |
+
try:
|
44 |
+
data = base64.b64decode(d.b64_json)
|
45 |
+
img = Image.open(io.BytesIO(data))
|
46 |
+
imgs.append(np.array(img))
|
47 |
+
except Exception as decode_err:
|
48 |
+
print(f"Error decoding base64 image: {decode_err}")
|
49 |
elif getattr(d, "url", None):
|
50 |
imgs.append(d.url)
|
51 |
+
else:
|
52 |
+
print(f"Warning: Response item has neither b64_json nor url: {d}")
|
53 |
return imgs
|
54 |
|
55 |
|
|
|
58 |
n: int,
|
59 |
size: str,
|
60 |
quality: str,
|
61 |
+
out_fmt: str, # Note: out_fmt is used *after* generation for conversion, not directly in API call
|
62 |
+
compression: int, # Note: compression is used *after* generation for conversion
|
63 |
+
transparent_bg: bool, # Note: transparent_bg is used *after* generation for conversion if not directly supported
|
64 |
) -> Dict[str, Any]:
|
65 |
"""Prepare keyword args for OpenAI Images API."""
|
66 |
kwargs: Dict[str, Any] = {
|
67 |
"model": MODEL,
|
68 |
"n": n,
|
69 |
+
# Always request PNG for maximum quality/editability before potential conversion
|
70 |
+
"response_format": "b64_json", # Request base64 to handle locally
|
71 |
}
|
72 |
if size != "auto":
|
73 |
kwargs["size"] = size
|
74 |
+
# DALL-E 3 uses 'quality': 'standard' or 'hd'. DALL-E 2 doesn't have quality.
|
75 |
+
# Adapt this based on the actual model's capabilities. Assuming 'hd' for 'high'.
|
76 |
if quality != "auto":
|
77 |
+
# Map your choices to OpenAI's expected values if needed
|
78 |
+
# Example for DALL-E 3:
|
79 |
+
# if quality == "high": kwargs["quality"] = "hd"
|
80 |
+
# elif quality == "medium": kwargs["quality"] = "standard" # or omit
|
81 |
+
# For now, pass it directly, but be aware it might not be supported by MODEL
|
82 |
+
kwargs["quality"] = quality
|
83 |
+
|
84 |
if prompt is not None:
|
85 |
kwargs["prompt"] = prompt
|
86 |
+
|
87 |
+
# Note: Background removal is not a standard parameter for generate/edit/variation.
|
88 |
+
# This would typically be a post-processing step or require a specific model/API.
|
89 |
+
# If transparent_bg is True, you might need to handle it after receiving the image.
|
90 |
+
# if transparent_bg and out_fmt in {"png", "webp"}:
|
91 |
+
# kwargs["background"] = "transparent" # This parameter is hypothetical
|
92 |
+
|
93 |
return kwargs
|
94 |
|
95 |
|
|
|
100 |
) -> np.ndarray:
|
101 |
"""
|
102 |
Convert a PIL numpy array to target_fmt (JPEG/WebP) and return as numpy array.
|
103 |
+
Handles PNG pass-through.
|
104 |
"""
|
105 |
+
if target_fmt.lower() == "png":
|
106 |
+
# No conversion needed if already PNG (assuming input from b64 is effectively PNG)
|
107 |
+
return img_array
|
108 |
+
|
109 |
img = Image.fromarray(img_array.astype(np.uint8))
|
110 |
buf = io.BytesIO()
|
111 |
+
save_kwargs = {}
|
112 |
+
fmt_upper = target_fmt.upper()
|
113 |
+
|
114 |
+
if fmt_upper in ["JPEG", "WEBP"]:
|
115 |
+
save_kwargs["quality"] = quality
|
116 |
+
# Handle transparency for WebP
|
117 |
+
if fmt_upper == "WEBP":
|
118 |
+
# Check if image has alpha channel
|
119 |
+
if img.mode == 'RGBA' or 'A' in img.getbands():
|
120 |
+
pass # WebP supports transparency inherently
|
121 |
+
else:
|
122 |
+
# If no alpha, don't need special handling unless forcing transparency loss
|
123 |
+
pass
|
124 |
+
# Handle transparency loss for JPEG
|
125 |
+
elif fmt_upper == "JPEG":
|
126 |
+
if img.mode == 'RGBA' or img.mode == 'LA' or (img.mode == 'P' and 'transparency' in img.info):
|
127 |
+
# Convert to RGB, losing transparency. Default white background.
|
128 |
+
img = img.convert('RGB')
|
129 |
+
|
130 |
+
try:
|
131 |
+
img.save(buf, format=fmt_upper, **save_kwargs)
|
132 |
+
buf.seek(0)
|
133 |
+
img2 = Image.open(buf)
|
134 |
+
return np.array(img2)
|
135 |
+
except Exception as e:
|
136 |
+
print(f"Error during image conversion to {target_fmt}: {e}")
|
137 |
+
# Fallback to returning original array on conversion error
|
138 |
+
return img_array
|
139 |
|
140 |
|
141 |
def _format_openai_error(e: Exception) -> str:
|
142 |
+
"""Formats OpenAI API errors into user-friendly messages."""
|
143 |
error_message = f"An error occurred: {type(e).__name__}"
|
144 |
details = ""
|
145 |
if hasattr(e, 'body') and e.body:
|
146 |
try:
|
147 |
+
# Try parsing as JSON first
|
148 |
+
body = json.loads(str(e.body))
|
149 |
if isinstance(body, dict) and 'error' in body and isinstance(body['error'], dict) and 'message' in body['error']:
|
150 |
details = body['error']['message']
|
151 |
+
elif isinstance(body, dict) and 'message' in body: # Sometimes the message is top-level
|
152 |
details = body['message']
|
153 |
+
else:
|
154 |
+
details = str(e.body) # Fallback if structure is unexpected
|
155 |
+
except json.JSONDecodeError:
|
156 |
+
# If body is not JSON, use its string representation
|
157 |
details = str(e.body)
|
158 |
+
except Exception:
|
159 |
+
# Catch any other parsing errors
|
160 |
+
details = str(e.body)
|
161 |
+
elif hasattr(e, 'message') and e.message: # Fallback for older error structures
|
162 |
details = e.message
|
163 |
+
|
164 |
if details:
|
165 |
error_message = f"OpenAI API Error: {details}"
|
166 |
+
else: # Keep the generic message if no details found
|
167 |
+
error_message = f"An OpenAI API error occurred: {str(e)}"
|
168 |
+
|
169 |
+
|
170 |
+
# Specific error type handling
|
171 |
if isinstance(e, openai.AuthenticationError):
|
172 |
+
error_message = "Invalid OpenAI API key. Please check your key and ensure it's active."
|
173 |
elif isinstance(e, openai.PermissionDeniedError):
|
174 |
prefix = "Permission Denied."
|
175 |
+
if details and "organization verification" in details.lower():
|
176 |
+
prefix += " Your organization may need verification or payment method update to use this feature/model."
|
177 |
+
elif details and "quota" in details.lower():
|
178 |
+
prefix += " You might have exceeded your usage quota."
|
179 |
error_message = f"{prefix} Details: {details}" if details else prefix
|
180 |
elif isinstance(e, openai.RateLimitError):
|
181 |
+
error_message = "Rate limit exceeded. Please wait and try again later, or check your usage limits."
|
182 |
elif isinstance(e, openai.BadRequestError):
|
183 |
error_message = f"OpenAI Bad Request: {details or str(e)}"
|
184 |
+
if details:
|
185 |
+
if "mask" in details.lower(): error_message += " (Check mask format/dimensions/transparency)"
|
186 |
+
if "size" in details.lower(): error_message += " (Check image/mask dimensions or requested size compatibility)"
|
187 |
+
if "model does not support variations" in details.lower(): error_message += f" ({MODEL} does not support variations)."
|
188 |
+
if "unsupported file format" in details.lower() or "unsupported mimetype" in details.lower(): error_message += " (Ensure input image is PNG, JPG, or WEBP)"
|
189 |
+
if "prompt" in details.lower() and "policy" in details.lower(): error_message += " (Prompt may violate OpenAI's safety policies)"
|
190 |
+
elif isinstance(e, openai.APIConnectionError):
|
191 |
+
error_message = "Could not connect to OpenAI. Please check your network connection."
|
192 |
+
elif isinstance(e, openai.InternalServerError):
|
193 |
+
error_message = "OpenAI server error. Please try again later."
|
194 |
+
|
195 |
return error_message
|
196 |
|
197 |
|
|
|
204 |
quality: str,
|
205 |
out_fmt: str,
|
206 |
compression: int,
|
207 |
+
transparent_bg: bool, # Note: Transparency handled post-generation if needed
|
208 |
):
|
209 |
if not prompt:
|
210 |
raise gr.Error("Please enter a prompt.")
|
211 |
try:
|
212 |
client = _client(api_key)
|
213 |
+
# Request b64_json for local processing/conversion
|
214 |
common_args = _common_kwargs(prompt, n, size, quality, out_fmt, compression, transparent_bg)
|
215 |
+
common_args["response_format"] = "b64_json" # Ensure we get base64
|
216 |
+
|
217 |
+
print(f"Generating images with args: {common_args}") # Debug print
|
218 |
resp = client.images.generate(**common_args)
|
219 |
+
imgs_np = _img_list(resp) # Should be list of numpy arrays
|
220 |
+
|
221 |
+
# Post-generation conversion
|
222 |
+
final_imgs = []
|
223 |
+
for img_np in imgs_np:
|
224 |
+
if isinstance(img_np, np.ndarray):
|
225 |
+
# Apply transparency removal or format conversion here if needed
|
226 |
+
# Note: True transparency generation isn't standard. This handles format conversion.
|
227 |
+
# If transparent_bg was intended for background removal, that needs a separate model/tool.
|
228 |
+
converted_img = convert_to_format(img_np, out_fmt, compression)
|
229 |
+
final_imgs.append(converted_img)
|
230 |
+
else:
|
231 |
+
# If we somehow got a URL (shouldn't with b64_json), append it directly
|
232 |
+
final_imgs.append(img_np)
|
233 |
+
|
234 |
+
if not final_imgs:
|
235 |
+
raise gr.Error("Failed to generate or process images. Check logs.")
|
236 |
+
|
237 |
+
return final_imgs
|
238 |
except (openai.APIError, openai.OpenAIError) as e:
|
239 |
+
print(f"OpenAI API Error during generation: {type(e).__name__}: {e}")
|
240 |
raise gr.Error(_format_openai_error(e))
|
241 |
except Exception as e:
|
242 |
print(f"Unexpected error during generation: {type(e).__name__}: {e}")
|
243 |
+
import traceback
|
244 |
+
traceback.print_exc() # Print full traceback for unexpected errors
|
245 |
raise gr.Error("An unexpected application error occurred. Please check logs.")
|
246 |
|
247 |
|
248 |
# ---------- Edit / Inpaint ---------- #
|
249 |
+
def _bytes_from_numpy(arr: np.ndarray, format: str = "PNG") -> bytes:
|
250 |
+
"""Converts numpy array to bytes in the specified format."""
|
251 |
img = Image.fromarray(arr.astype(np.uint8))
|
252 |
buf = io.BytesIO()
|
253 |
+
img.save(buf, format=format)
|
254 |
return buf.getvalue()
|
255 |
|
256 |
+
def _ensure_rgba_for_mask(mask_array: np.ndarray) -> np.ndarray:
|
257 |
+
"""Ensures the mask is RGBA, converting grayscale/RGB if necessary."""
|
258 |
+
if mask_array.ndim == 2: # Grayscale
|
259 |
+
# Convert grayscale to RGBA: White areas (255) become transparent (alpha=0), others opaque black
|
260 |
+
alpha = np.where(mask_array == 255, 0, 255).astype(np.uint8)
|
261 |
+
rgb = np.zeros((*mask_array.shape, 3), dtype=np.uint8) # Black RGB
|
262 |
+
rgba = np.dstack((rgb, alpha))
|
263 |
+
return rgba
|
264 |
+
elif mask_array.ndim == 3:
|
265 |
+
if mask_array.shape[2] == 3: # RGB
|
266 |
+
# Assume white RGB (255, 255, 255) means transparent for Gradio mask
|
267 |
+
is_white = np.all(mask_array == [255, 255, 255], axis=2)
|
268 |
+
alpha = np.where(is_white, 0, 255).astype(np.uint8)
|
269 |
+
rgba = np.dstack((mask_array, alpha))
|
270 |
+
return rgba
|
271 |
+
elif mask_array.shape[2] == 4: # Already RGBA
|
272 |
+
# Ensure correct interpretation: 0 alpha means transparent (area to edit)
|
273 |
+
# Gradio ImageMask often uses white paint on transparent bg.
|
274 |
+
# We need alpha=0 for transparent areas (edit target).
|
275 |
+
# If alpha channel is mostly 255 (opaque), invert it assuming white paint = transparent target.
|
276 |
+
alpha_channel = mask_array[:, :, 3]
|
277 |
+
if np.mean(alpha_channel) > 128: # Heuristic: if mostly opaque
|
278 |
+
print("Inverting mask alpha channel based on heuristic (mostly opaque).")
|
279 |
+
mask_array[:, :, 3] = 255 - alpha_channel
|
280 |
+
return mask_array # Assume it's correctly formatted otherwise
|
281 |
+
raise ValueError("Unsupported mask format/dimensions")
|
282 |
+
|
283 |
|
284 |
def _extract_mask_array(mask_value: Union[np.ndarray, Dict[str, Any], None]) -> Optional[np.ndarray]:
|
285 |
+
"""Extracts the mask numpy array from Gradio's ImageMask output."""
|
286 |
if mask_value is None:
|
287 |
+
print("Mask input is None.")
|
288 |
return None
|
289 |
+
# Gradio ImageMask output is often a dict {'image': ndarray, 'mask': ndarray}
|
290 |
+
# Or sometimes just the mask ndarray directly depending on version/setup
|
291 |
+
mask_array = None
|
292 |
if isinstance(mask_value, dict):
|
293 |
mask_array = mask_value.get("mask")
|
294 |
+
print(f"Extracted mask from dict: type={type(mask_array)}, shape={getattr(mask_array, 'shape', 'N/A')}")
|
295 |
+
elif isinstance(mask_value, np.ndarray):
|
296 |
+
mask_array = mask_value
|
297 |
+
print(f"Received mask as ndarray directly: shape={mask_array.shape}")
|
298 |
+
|
299 |
+
if isinstance(mask_array, np.ndarray):
|
300 |
+
# Add basic validation
|
301 |
+
if mask_array.ndim < 2 or mask_array.ndim > 3:
|
302 |
+
print(f"Warning: Unexpected mask dimensions: {mask_array.ndim}")
|
303 |
+
return None
|
304 |
+
if mask_array.size == 0:
|
305 |
+
print("Warning: Received empty mask array.")
|
306 |
+
return None
|
307 |
+
print(f"Successfully extracted mask array, shape: {mask_array.shape}, dtype: {mask_array.dtype}, min/max: {np.min(mask_array)}/{np.max(mask_array)}")
|
308 |
+
return mask_array
|
309 |
+
|
310 |
+
print(f"Could not extract ndarray mask from input type: {type(mask_value)}")
|
311 |
return None
|
312 |
|
313 |
|
314 |
def edit_image(
|
315 |
api_key: str,
|
316 |
image_numpy: Optional[np.ndarray],
|
317 |
+
mask_input: Optional[Union[np.ndarray, Dict[str, Any]]], # Renamed for clarity
|
318 |
prompt: str,
|
319 |
n: int,
|
320 |
size: str,
|
321 |
quality: str,
|
322 |
out_fmt: str,
|
323 |
compression: int,
|
324 |
+
transparent_bg: bool, # Note: Transparency handled post-generation if needed
|
325 |
):
|
326 |
if image_numpy is None:
|
327 |
raise gr.Error("Please upload an image.")
|
328 |
if not prompt:
|
329 |
raise gr.Error("Please enter an edit prompt.")
|
330 |
|
331 |
+
# Convert source image to PNG bytes
|
332 |
+
try:
|
333 |
+
img_bytes = _bytes_from_numpy(image_numpy, format="PNG")
|
334 |
+
# --- FIX: Provide image data as a tuple (filename, bytes, mimetype) ---
|
335 |
+
image_tuple: Tuple[str, bytes, str] = ("image.png", img_bytes, "image/png")
|
336 |
+
print(f"Prepared source image: {image_tuple[0]}, size={len(image_tuple[1])} bytes, type={image_tuple[2]}")
|
337 |
+
except Exception as e:
|
338 |
+
print(f"Error converting source image to bytes: {e}")
|
339 |
+
raise gr.Error("Failed to process source image.")
|
340 |
+
|
341 |
+
mask_tuple: Optional[Tuple[str, bytes, str]] = None
|
342 |
+
mask_numpy = _extract_mask_array(mask_input)
|
343 |
|
|
|
344 |
if mask_numpy is not None:
|
345 |
+
try:
|
346 |
+
# Ensure mask matches image dimensions (OpenAI requires this)
|
347 |
+
if image_numpy.shape[:2] != mask_numpy.shape[:2]:
|
348 |
+
raise gr.Error(f"Mask dimensions ({mask_numpy.shape[:2]}) must match image dimensions ({image_numpy.shape[:2]}). Please repaint the mask.")
|
349 |
+
|
350 |
+
# Convert mask to RGBA PNG bytes as required by OpenAI API
|
351 |
+
# The API expects a PNG where transparent pixels (alpha=0) indicate the area to edit.
|
352 |
+
mask_rgba = _ensure_rgba_for_mask(mask_numpy)
|
353 |
+
mask_bytes = _bytes_from_numpy(mask_rgba, format="PNG")
|
354 |
+
|
355 |
+
# --- FIX: Provide mask data as a tuple ---
|
356 |
+
mask_tuple = ("mask.png", mask_bytes, "image/png")
|
357 |
+
print(f"Prepared mask: {mask_tuple[0]}, size={len(mask_tuple[1])} bytes, type={mask_tuple[2]}")
|
358 |
+
|
359 |
+
except ValueError as e:
|
360 |
+
print(f"Error processing mask: {e}")
|
361 |
+
raise gr.Error(f"Failed to process mask: {e}")
|
362 |
+
except Exception as e:
|
363 |
+
print(f"Error converting mask to bytes: {e}")
|
364 |
+
raise gr.Error("Failed to process mask.")
|
365 |
+
else:
|
366 |
+
# If no mask is provided, it's an 'edit' without inpainting (DALL-E 2 supported this, DALL-E 3 might interpret differently)
|
367 |
+
# The API might require a mask for the /edit endpoint. Check API docs for the specific model.
|
368 |
+
# For DALL-E 2, omitting mask was allowed. Let's assume it might work or fail gracefully.
|
369 |
+
print("No valid mask provided or extracted. Proceeding without mask.")
|
370 |
+
# raise gr.Error("Please paint a mask to indicate the edit area.") # Uncomment if mask is strictly required
|
371 |
|
372 |
try:
|
373 |
client = _client(api_key)
|
374 |
+
# Get common args, ensure response format is b64_json
|
375 |
common_args = _common_kwargs(prompt, n, size, quality, out_fmt, compression, transparent_bg)
|
376 |
+
common_args["response_format"] = "b64_json" # Ensure we get base64
|
377 |
+
|
378 |
+
# Prepare final API arguments
|
379 |
+
api_kwargs = {
|
380 |
+
"image": image_tuple,
|
381 |
+
**common_args
|
382 |
+
}
|
383 |
+
if mask_tuple is not None:
|
384 |
+
api_kwargs["mask"] = mask_tuple
|
385 |
+
else:
|
386 |
+
# If mask is omitted, remove prompt from common_args if the API treats it like variations?
|
387 |
+
# DALL-E 2 /edit without mask needed no prompt. DALL-E 3 might differ.
|
388 |
+
# Let's keep the prompt for now. The API error will tell us if it's wrong.
|
389 |
+
pass
|
390 |
+
# api_kwargs.pop("prompt", None) # Consider this if API complains about prompt without mask
|
391 |
+
|
392 |
+
print(f"Editing image with args: { {k: v if k not in ['image', 'mask'] else (v[0], f'{len(v[1])} bytes', v[2]) for k, v in api_kwargs.items()} }") # Debug print
|
393 |
+
resp = client.images.edit(**api_kwargs) # Call the edit endpoint
|
394 |
+
|
395 |
+
imgs_np = _img_list(resp) # Should be list of numpy arrays
|
396 |
+
|
397 |
+
# Post-generation conversion
|
398 |
+
final_imgs = []
|
399 |
+
for img_np in imgs_np:
|
400 |
+
if isinstance(img_np, np.ndarray):
|
401 |
+
converted_img = convert_to_format(img_np, out_fmt, compression)
|
402 |
+
final_imgs.append(converted_img)
|
403 |
+
else:
|
404 |
+
final_imgs.append(img_np) # Append URL if received
|
405 |
+
|
406 |
+
if not final_imgs:
|
407 |
+
raise gr.Error("Failed to edit or process images. Check logs.")
|
408 |
+
|
409 |
+
return final_imgs
|
410 |
+
|
411 |
except (openai.APIError, openai.OpenAIError) as e:
|
412 |
+
print(f"OpenAI API Error during edit: {type(e).__name__}: {e}")
|
413 |
raise gr.Error(_format_openai_error(e))
|
414 |
except Exception as e:
|
415 |
print(f"Unexpected error during edit: {type(e).__name__}: {e}")
|
416 |
+
import traceback
|
417 |
+
traceback.print_exc()
|
418 |
raise gr.Error("An unexpected application error occurred. Please check logs.")
|
419 |
|
420 |
|
|
|
424 |
image_numpy: Optional[np.ndarray],
|
425 |
n: int,
|
426 |
size: str,
|
427 |
+
quality: str, # Note: Quality may not be supported by variations endpoint
|
428 |
out_fmt: str,
|
429 |
compression: int,
|
430 |
+
transparent_bg: bool, # Note: Transparency handled post-generation if needed
|
431 |
):
|
432 |
+
# Explicit warning as gpt-image-1 is likely not the correct model for variations
|
433 |
+
gr.Warning(f"Note: Image Variations are officially supported for DALL·E 2. Using model '{MODEL}' may fail or produce unexpected results.")
|
434 |
if image_numpy is None:
|
435 |
raise gr.Error("Please upload an image.")
|
436 |
|
437 |
+
# Convert source image to PNG bytes
|
438 |
+
try:
|
439 |
+
img_bytes = _bytes_from_numpy(image_numpy, format="PNG")
|
440 |
+
# --- FIX: Provide image data as a tuple ---
|
441 |
+
image_tuple: Tuple[str, bytes, str] = ("image.png", img_bytes, "image/png")
|
442 |
+
print(f"Prepared source image for variation: {image_tuple[0]}, size={len(image_tuple[1])} bytes, type={image_tuple[2]}")
|
443 |
+
except Exception as e:
|
444 |
+
print(f"Error converting source image to bytes for variation: {e}")
|
445 |
+
raise gr.Error("Failed to process source image.")
|
446 |
+
|
447 |
try:
|
448 |
client = _client(api_key)
|
449 |
+
# Prepare args for variations endpoint
|
450 |
+
var_args: Dict[str, Any] = {
|
451 |
+
"model": MODEL, # Use the selected model, though it might fail
|
452 |
+
"n": n,
|
453 |
+
"response_format": "b64_json" # Request base64
|
454 |
+
}
|
455 |
if size != "auto":
|
456 |
var_args["size"] = size
|
457 |
+
# Quality parameter is generally NOT supported for variations
|
458 |
+
# if quality != "auto":
|
459 |
+
# var_args["quality"] = quality # This will likely cause an error
|
460 |
+
|
461 |
+
print(f"Creating variations with args: { {k: v if k != 'image' else (v[0], f'{len(v[1])} bytes', v[2]) for k, v in {**var_args, 'image': image_tuple}.items()} }") # Debug print
|
462 |
+
|
463 |
+
# Pass the tuple to the image parameter
|
464 |
+
resp = client.images.create_variation(image=image_tuple, **var_args)
|
465 |
+
|
466 |
+
imgs_np = _img_list(resp) # Should be list of numpy arrays
|
467 |
+
|
468 |
+
# Post-generation conversion
|
469 |
+
final_imgs = []
|
470 |
+
for img_np in imgs_np:
|
471 |
+
if isinstance(img_np, np.ndarray):
|
472 |
+
converted_img = convert_to_format(img_np, out_fmt, compression)
|
473 |
+
final_imgs.append(converted_img)
|
474 |
+
else:
|
475 |
+
final_imgs.append(img_np) # Append URL if received
|
476 |
+
|
477 |
+
if not final_imgs:
|
478 |
+
raise gr.Error("Failed to create variations or process images. Check logs.")
|
479 |
+
|
480 |
+
return final_imgs
|
481 |
+
|
482 |
except (openai.APIError, openai.OpenAIError) as e:
|
483 |
+
print(f"OpenAI API Error during variation: {type(e).__name__}: {e}")
|
484 |
+
err_msg = _format_openai_error(e)
|
485 |
+
# Add specific check for variation incompatibility
|
486 |
+
if isinstance(e, openai.BadRequestError) and ("model does not support variations" in err_msg.lower() or "not supported" in err_msg.lower()):
|
487 |
+
raise gr.Error(f"As warned, the selected model ('{MODEL}') does not support the variations endpoint. Try using 'dall-e-2'.")
|
488 |
+
raise gr.Error(err_msg)
|
489 |
except Exception as e:
|
490 |
print(f"Unexpected error during variation: {type(e).__name__}: {e}")
|
491 |
+
import traceback
|
492 |
+
traceback.print_exc()
|
493 |
raise gr.Error("An unexpected application error occurred. Please check logs.")
|
494 |
|
495 |
|
496 |
# ---------- UI ---------- #
|
497 |
def build_ui():
|
498 |
+
with gr.Blocks(title="OpenAI Image Playground (BYOK)") as demo:
|
499 |
+
gr.Markdown(f"""# OpenAI Image Playground 🖼️🔑
|
500 |
+
Generate • Edit • Variations (using your own API key)
|
501 |
+
**Selected Model:** `{MODEL}` (Ensure your key has access)
|
502 |
+
""")
|
503 |
+
with gr.Accordion("🔐 API key & Model Info", open=False):
|
504 |
api = gr.Textbox(label="OpenAI API key", type="password", placeholder="sk-...")
|
505 |
+
gr.Markdown(f"""
|
506 |
+
* **Model:** `{MODEL}` is configured in the code. This might be a placeholder; official models are typically `dall-e-3` or `dall-e-2`.
|
507 |
+
* **Variations:** Officially only supported by `dall-e-2`. Using other models here will likely fail.
|
508 |
+
* **Edit/Inpainting:** Requires a model supporting the `/images/edits` endpoint (like `dall-e-2`).
|
509 |
+
* **Size/Quality:** Options shown may not be supported by all models. Check OpenAI documentation for `{MODEL}` if it's a real model. DALL-E 3 uses `quality` ('standard'/'hd'), DALL-E 2 does not.
|
510 |
+
""")
|
511 |
|
512 |
with gr.Row():
|
513 |
n_slider = gr.Slider(1, 4, value=1, step=1, label="Number of images (n)")
|
514 |
+
size = gr.Dropdown(SIZE_CHOICES, value="auto", label="Size (if supported)", info="Set target size. 'auto' uses model default.")
|
515 |
+
quality = gr.Dropdown(QUALITY_CHOICES, value="auto", label="Quality (if supported)", info="'auto' uses model default. DALL-E 3: 'standard'/'hd'. DALL-E 2 ignores this.")
|
516 |
with gr.Row():
|
517 |
+
out_fmt = gr.Radio(FORMAT_CHOICES, value="png", label="Output Format", info="Format for viewing/downloading generated images.")
|
518 |
+
compression = gr.Slider(0, 100, value=75, step=1, label="Compression % (JPEG/WebP)", visible=False, info="Lower value = smaller file, lower quality.")
|
519 |
+
# Transparency generation is complex; this checkbox is mainly for format support.
|
520 |
+
# Actual transparency depends on model/post-processing.
|
521 |
+
transparent = gr.Checkbox(False, label="Transparent background (PNG/WebP only)", info="Request transparency if model supports, or save PNG/WebP with alpha if generated.", visible=False) # Hidden for now as it's not directly controllable via API param
|
522 |
|
523 |
def _toggle_compression(fmt):
|
524 |
return gr.update(visible=fmt in {"jpeg", "webp"})
|
525 |
out_fmt.change(_toggle_compression, inputs=out_fmt, outputs=compression)
|
526 |
|
527 |
+
# Combine common controls for easier passing to functions
|
528 |
common_controls = [n_slider, size, quality, out_fmt, compression, transparent]
|
529 |
|
530 |
with gr.Tabs():
|
531 |
with gr.TabItem("Generate"):
|
532 |
+
prompt_gen = gr.Textbox(label="Prompt", lines=3, placeholder="A photorealistic image of..." )
|
533 |
+
btn_gen = gr.Button("Generate 🚀", variant="primary")
|
534 |
+
gallery_gen = gr.Gallery(label="Generated Images", columns=2, height="auto", preview=True)
|
535 |
+
# Clear outputs on new click
|
536 |
+
inputs_gen = [api, prompt_gen] + common_controls
|
537 |
+
prompt_gen.submit(generate, inputs=inputs_gen, outputs=gallery_gen)
|
538 |
+
btn_gen.click(generate, inputs=inputs_gen, outputs=gallery_gen)
|
539 |
+
|
540 |
|
541 |
with gr.TabItem("Edit / Inpaint"):
|
542 |
+
gr.Markdown("Upload an image, **paint white** over the area you want the AI to change, then provide an edit prompt.")
|
543 |
+
with gr.Row():
|
544 |
+
img_edit_src = gr.Image(type="numpy", label="Source Image", height=400, tool="select")
|
545 |
+
# Use ImageMask tool for painting
|
546 |
+
mask_canvas = gr.ImageMask(type="numpy", label="Mask – Paint Area to Edit (White)", height=400, brush_radius=20)
|
547 |
+
# Link source image to mask canvas background
|
548 |
+
# img_edit_src.change(lambda x: x, inputs=img_edit_src, outputs=mask_canvas) # This might auto-clear mask, check Gradio docs if needed
|
549 |
+
|
550 |
+
prompt_edit = gr.Textbox(label="Edit prompt", lines=2, placeholder="Example: Make the cat wear a wizard hat")
|
551 |
+
btn_edit = gr.Button("Edit Image 🖌️", variant="primary")
|
552 |
+
gallery_edit = gr.Gallery(label="Edited Images", columns=2, height="auto", preview=True)
|
553 |
+
|
554 |
+
# Define inputs for the edit function
|
555 |
+
inputs_edit = [api, img_edit_src, mask_canvas, prompt_edit] + common_controls
|
556 |
+
prompt_edit.submit(edit_image, inputs=inputs_edit, outputs=gallery_edit)
|
557 |
+
btn_edit.click(edit_image, inputs=inputs_edit, outputs=gallery_edit)
|
558 |
+
|
559 |
+
|
560 |
+
with gr.TabItem("Variations (DALL·E 2 only)"):
|
561 |
+
gr.Markdown("Upload an image to generate variations. **Warning:** This endpoint is officially supported only by DALL·E 2.")
|
562 |
+
img_var_src = gr.Image(type="numpy", label="Source Image", height=400)
|
563 |
+
btn_var = gr.Button("Create Variations ✨", variant="primary")
|
564 |
+
gallery_var = gr.Gallery(label="Variations", columns=2, height="auto", preview=True)
|
565 |
+
|
566 |
+
# Define inputs for the variation function
|
567 |
+
inputs_var = [api, img_var_src] + common_controls
|
568 |
+
# Variations don't use prompt, quality typically ignored
|
569 |
+
btn_var.click(variation_image, inputs=inputs_var, outputs=gallery_var)
|
570 |
+
|
571 |
return demo
|
572 |
|
573 |
|
574 |
if __name__ == "__main__":
|
575 |
+
# For debugging purposes, you can preload an API key from env vars
|
576 |
+
# Make sure to handle security appropriately if deploying publicly
|
577 |
+
# api_key_env = os.getenv("OPENAI_API_KEY")
|
578 |
+
|
579 |
app = build_ui()
|
580 |
+
# Launch the Gradio app
|
581 |
+
app.launch(
|
582 |
+
share=os.getenv("GRADIO_SHARE") == "true",
|
583 |
+
debug=os.getenv("GRADIO_DEBUG") == "true",
|
584 |
+
server_name="0.0.0.0" # Bind to all interfaces for Docker compatibility
|
585 |
+
# auth=("user", "password") # Add basic auth if needed for sharing
|
586 |
+
)
|