cai-qi commited on
Commit
107bed2
·
verified ·
1 Parent(s): 1be5da8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +833 -4
app.py CHANGED
@@ -1,7 +1,836 @@
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import random
4
+ import time
5
+ import traceback
6
+ from io import BytesIO
7
+ import io
8
+ import base64
9
+ from openai import OpenAI
10
+ import uuid
11
+ import requests
12
  import gradio as gr
13
+ import requests
14
+ from PIL import Image, PngImagePlugin
15
 
16
+ # Set up logging
17
+ logging.basicConfig(
18
+ level=logging.INFO,
19
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
20
+ )
21
+ logger = logging.getLogger(__name__)
22
 
23
+ # API Configuration
24
+ API_TOKEN = os.environ.get("HIDREAM_API_TOKEN")
25
+ API_REQUEST_URL = os.environ.get("API_REQUEST_URL")
26
+ API_RESULT_URL = os.environ.get("API_RESULT_URL")
27
+ API_IMAGE_URL = os.environ.get("API_IMAGE_URL")
28
+ API_VERSION = os.environ.get("API_VERSION")
29
+ API_MODEL_NAME = os.environ.get("API_MODEL_NAME")
30
+ OSS_IMAGE_BUCKET = os.environ.get("OSS_IMAGE_BUCKET")
31
+ OSS_MEDIA_BUCKET = os.environ.get("OSS_MEDIA_BUCKET")
32
+ OSS_TOKEN_URL = os.environ.get("OSS_TOKEN_URL")
33
+ MAX_RETRY_COUNT = int(os.environ.get("MAX_RETRY_COUNT", "3"))
34
+ POLL_INTERVAL = float(os.environ.get("POLL_INTERVAL", "1"))
35
+ MAX_POLL_TIME = int(os.environ.get("MAX_POLL_TIME", "300"))
36
+
37
+ def get_oss_token(is_image=True, prefix="p_"):
38
+ head = {
39
+ "Cookie": os.environ.get("OSS_AUTH_COOKIE", "")
40
+ }
41
+ if is_image:
42
+ filename = f"p_{uuid.uuid4()}" if prefix == "p_" else f"j_{uuid.uuid4()}"
43
+ bucket = OSS_IMAGE_BUCKET
44
+ else:
45
+ filename = f"{uuid.uuid4()}.mp4"
46
+ bucket = OSS_MEDIA_BUCKET
47
+ token_url = f"{OSS_TOKEN_URL}{bucket}?filename={filename}"
48
+ req = requests.get(token_url, headers=head)
49
+ if req.status_code == 200 and req.json()["code"] == 0:
50
+ return req.json()["result"], filename
51
+ else:
52
+ print(req.status_code, req.text)
53
+ return None, None
54
+
55
+ def upload_to_gcs(signed_url: str, file_io, is_image=True):
56
+ if is_image:
57
+ headers = {
58
+ "Content-Type": "image/png", # ensure content-type matches the signed url
59
+ }
60
+ else:
61
+ headers = {
62
+ "Content-Type": "video/mp4", # ensure content-type matches the signed url
63
+ }
64
+ # with open(file_path, "rb") as f:
65
+ # response = requests.put(signed_url, data=f, headers=headers)
66
+ response = requests.put(signed_url, data=file_io, headers=headers)
67
+ if response.status_code == 200:
68
+ print("✅ Upload success")
69
+ else:
70
+ print(f"❌ Upload failed, status code: {response.status_code}, response content: {response.text}")
71
+
72
+ # Instruction refinement prompt
73
+ INSTRUCTION_PROMPT = """Your Role: You are an analytical assistant. Your task is to process a source image and a corresponding editing instruction, assuming the instruction accurately describes a desired transformation. You will 1) describe the source image, 2) output the editing instruction (potentially refined for clarity based on the source image context), and 3) describe the *imagined* result of applying that instruction.
74
+
75
+ Input:
76
+ 1. Source Image: The original 'before' image.
77
+ 2. Source Instruction: A text instruction describing the edit to be performed on the Source Image. You *must assume* this instruction is accurate and feasible for the purpose of this task.
78
+
79
+ Task Breakdown:
80
+ 1. **Describe Source Image:** Generate a description (e.g., key subject, setting) of the Source Image by analyzing it. This will be the first line of your output.
81
+
82
+ 2. **Output Editing Instruction:** This step determines the second line of your output.
83
+ * **Assumption:** The provided Source Instruction *accurately* describes the desired edit.
84
+ * **Goal:** Output a concise, single-line instruction based on the Source Instruction.
85
+ * **Refinement based on Source Image:** While the Source Instruction is assumed correct, analyze the Source Image to see if the instruction needs refinement for specificity. If the Source Image contains multiple similar objects and the Source Instruction is potentially ambiguous (e.g., "change the car color" when there are three cars), refine the instruction to be specific, using positional qualifiers (e.g., 'the left car', 'the bird on the top branch'), size ('the smaller dog', 'the largest building'), or other distinguishing visual features apparent in the Source Image. If the Source Instruction is already specific or if there's no ambiguity in the Source Image context, you can use it directly or with minor phrasing adjustments for naturalness. The *core meaning* of the Source Instruction must be preserved.
86
+ * **Output:** Present the resulting specific, single-line instruction as the second line.
87
+
88
+ 3. **Describe Imagined Target Image:** Based *only* on the Source Image description (Line 1) and the Editing Instruction (Line 2), generate a description of the *imagined outcome*.
89
+ * Describe the scene from Line 1 *as if* the instruction from Line 2 has been successfully applied. Conceptualize the result of the edit on the source description.
90
+ * This description must be purely a logical prediction based on applying the instruction (Line 2) to the description in Line 1. Do *not* invent details not implied by the instruction or observed in the source image beyond the specified edit. This will be the third line of your output.
91
+
92
+ Output Format:
93
+ * Your response *must* consist of exactly three lines.
94
+ * Do not include any other explanations, comments, introductory phrases, labels (like "Line 1:"), or formatting.
95
+ * Your output should be in English.
96
+
97
+ Line 1: [Description of the Source Image]
98
+ Line 2: [The specific, single-line editing instruction based on the Source Instruction and Source Image context]
99
+ Line 3: [Description of the Imagined Target Image based on Lines 1 & 2]
100
+
101
+ Now, please generate the three-line output based on the Source Image and the Source Instruction: {source_instruction}
102
+ """
103
+
104
+ def filter_response(src_instruction):
105
+ try:
106
+ src_instruction = src_instruction.strip().split("\n")
107
+ src_instruction = [k.strip() for k in src_instruction if k.strip()]
108
+ src_instruction = [k for k in src_instruction if len(k) > 0]
109
+ if len(src_instruction) != 3:
110
+ return ""
111
+ instruction = src_instruction[1]
112
+ target_description = src_instruction[2]
113
+ instruction = instruction.strip().strip(".")
114
+ inst_format = "Editing Instruction: {}. Target Image Description: {}"
115
+ return inst_format.format(instruction, target_description)
116
+ except:
117
+ return ""
118
+
119
+ def refine_instruction(src_image, src_instruction):
120
+ MAX_TOKENS_RESPONSE = 500 # Limit response tokens as output format is structured
121
+ client = OpenAI()
122
+ src_image = src_image.convert("RGB")
123
+ src_image_buffer = io.BytesIO()
124
+ src_image.save(src_image_buffer, format="JPEG")
125
+ src_image_buffer.seek(0)
126
+ src_base64 = base64.b64encode(src_image_buffer.read()).decode('utf-8')
127
+ encoded_str = f"data:image/jpeg;base64,{src_base64}"
128
+ image_content = [
129
+ {"type": "image_url", "image_url": {"url": encoded_str,}},
130
+ ]
131
+ instruction_text = INSTRUCTION_PROMPT.format(source_instruction=src_instruction)
132
+ message_content = [
133
+ {"type": "text", "text": instruction_text},
134
+ *image_content # Unpack the list of image dictionaries
135
+ ]
136
+ completion = client.chat.completions.create(
137
+ model="gpt-4o",
138
+ messages=[
139
+ {"role": "system", "content": "You are a professional digital artist."},
140
+ {"role": "user", "content": message_content}
141
+ ],
142
+ max_tokens=MAX_TOKENS_RESPONSE, # Good practice to set max tokens
143
+ temperature=0.2 # Lower temperature for more deterministic output
144
+ )
145
+ evaluation_result = completion.choices[0].message.content
146
+ refined_instruction = filter_response(evaluation_result)
147
+ return refined_instruction
148
+
149
+ # Resolution options
150
+ ASPECT_RATIO_OPTIONS = ["1:1", "3:4", "4:3", "9:16", "16:9"]
151
+
152
+ # Log configuration details
153
+ logger.info(f"API configuration loaded: REQUEST_URL={API_REQUEST_URL}, RESULT_URL={API_RESULT_URL}, VERSION={API_VERSION}, MODEL={API_MODEL_NAME}")
154
+ logger.info(f"OSS configuration: IMAGE_BUCKET={OSS_IMAGE_BUCKET}, MEDIA_BUCKET={OSS_MEDIA_BUCKET}, TOKEN_URL={OSS_TOKEN_URL}")
155
+ logger.info(f"Retry configuration: MAX_RETRY_COUNT={MAX_RETRY_COUNT}, POLL_INTERVAL={POLL_INTERVAL}s, MAX_POLL_TIME={MAX_POLL_TIME}s")
156
+
157
+
158
+ class APIError(Exception):
159
+ """Custom exception for API-related errors"""
160
+ pass
161
+
162
+
163
+ def create_request(prompt, image, guidance_scale=5.0, image_guidance_scale=4.0, seed=-1):
164
+ """
165
+ Create an image editing request to the API.
166
+
167
+ Args:
168
+ prompt (str): Text prompt describing the image edit
169
+ image (PIL.Image): Input image to edit
170
+ guidance_scale (float): Strength of instruction following
171
+ image_guidance_scale (float): Strength of image preservation
172
+ seed (int): Seed for reproducibility, -1 for random
173
+
174
+ Returns:
175
+ tuple: (task_id, seed) - Task ID if successful and the seed used
176
+
177
+ Raises:
178
+ APIError: If the API request fails
179
+ """
180
+ logger.info(f"Starting create_request with prompt='{prompt[:50]}...', guidance_scale={guidance_scale}, image_guidance_scale={image_guidance_scale}, seed={seed}")
181
+ image_io = io.BytesIO()
182
+ image = image.convert("RGB")
183
+ image.save(image_io, format="PNG")
184
+ image_io.seek(0)
185
+ token_url, filename = get_oss_token(is_image=True)
186
+ upload_to_gcs(token_url, image_io, is_image=True)
187
+
188
+ if not prompt or not prompt.strip():
189
+ logger.error("Empty prompt provided to create_request")
190
+ raise ValueError("Prompt cannot be empty")
191
+
192
+ if not image:
193
+ logger.error("No image provided to create_request")
194
+ raise ValueError("Image cannot be empty")
195
+
196
+ # Generate random seed if not provided
197
+ if seed == -1:
198
+ seed = random.randint(1, 1000000)
199
+ logger.info(f"Generated random seed: {seed}")
200
+
201
+ # Validate seed
202
+ try:
203
+ seed = int(seed)
204
+ if seed < -1 or seed > 1000000:
205
+ logger.info(f"Invalid seed value: {seed}, forcing to 8888")
206
+ seed = 8888
207
+ except (TypeError, ValueError) as e:
208
+ logger.error(f"Seed validation failed: {str(e)}")
209
+ raise ValueError(f"Seed must be an integer but got {seed}")
210
+
211
+ headers = {
212
+ "Authorization": f"Bearer {API_TOKEN}",
213
+ "X-accept-language": "en",
214
+ "X-source": "api",
215
+ "Content-Type": "application/json",
216
+ }
217
+
218
+ generate_data = {
219
+ "module": "image_edit",
220
+ "images": [filename, ],
221
+ "prompt": prompt,
222
+ "params": {
223
+ "seed": seed,
224
+ "custom_params": {
225
+ "sample_steps": 28,
226
+ "guidance_scale": guidance_scale,
227
+ "image_guidance_scale": image_guidance_scale
228
+ },
229
+ },
230
+ "version": API_VERSION,
231
+ }
232
+
233
+ retry_count = 0
234
+ while retry_count < MAX_RETRY_COUNT:
235
+ try:
236
+ logger.info(f"Sending API request [attempt {retry_count+1}/{MAX_RETRY_COUNT}] for prompt: '{prompt[:50]}...'")
237
+ response = requests.post(API_REQUEST_URL, json=generate_data, headers=headers, timeout=10)
238
+
239
+ # Log response status code
240
+ logger.info(f"API request response status: {response.status_code}")
241
+
242
+ response.raise_for_status()
243
+
244
+ result = response.json()
245
+ if not result or "result" not in result:
246
+ logger.error(f"Invalid API response format: {str(result)}")
247
+ raise APIError(f"Invalid response format from API when sending request: {str(result)}")
248
+
249
+ task_id = result.get("result", {}).get("task_id")
250
+ if not task_id:
251
+ logger.error(f"No task ID in API response: {str(result)}")
252
+ raise APIError(f"No task ID returned from API: {str(result)}")
253
+
254
+ logger.info(f"Successfully created task with ID: {task_id}, seed: {seed}")
255
+ return task_id, seed
256
+
257
+ except requests.exceptions.Timeout:
258
+ retry_count += 1
259
+ logger.warning(f"Request timed out. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
260
+ time.sleep(1)
261
+
262
+ except requests.exceptions.HTTPError as e:
263
+ status_code = e.response.status_code
264
+ error_message = f"HTTP error {status_code}"
265
+
266
+ try:
267
+ error_detail = e.response.json()
268
+ error_message += f": {error_detail}"
269
+ logger.error(f"API response error content: {error_detail}")
270
+ except:
271
+ logger.error(f"Could not parse API error response as JSON. Raw content: {e.response.content[:500]}")
272
+
273
+ if status_code == 401:
274
+ logger.error(f"Authentication failed with API token. Status code: {status_code}")
275
+ raise APIError("Authentication failed. Please check your API token.")
276
+ elif status_code == 429:
277
+ retry_count += 1
278
+ wait_time = min(2 ** retry_count, 10) # Exponential backoff
279
+ logger.warning(f"Rate limit exceeded. Waiting {wait_time}s before retry ({retry_count}/{MAX_RETRY_COUNT})...")
280
+ time.sleep(wait_time)
281
+ elif 400 <= status_code < 500:
282
+ try:
283
+ error_detail = e.response.json()
284
+ error_message += f": {error_detail.get('message', 'Client error')}"
285
+ except:
286
+ pass
287
+ logger.error(f"Client error: {error_message}, Prompt: '{prompt[:50]}...', Status: {status_code}")
288
+ raise APIError(error_message)
289
+ else:
290
+ retry_count += 1
291
+ logger.warning(f"Server error: {error_message}. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
292
+ time.sleep(1)
293
+
294
+ except requests.exceptions.RequestException as e:
295
+ logger.error(f"Request error: {str(e)}")
296
+ logger.debug(f"Request error details: {traceback.format_exc()}")
297
+ raise APIError(f"Failed to connect to API: {str(e)}")
298
+
299
+ except Exception as e:
300
+ logger.error(f"Unexpected error in create_request: {str(e)}")
301
+ logger.error(f"Full traceback: {traceback.format_exc()}")
302
+ raise APIError(f"Unexpected error: {str(e)}")
303
+
304
+ logger.error(f"Failed to create request after {MAX_RETRY_COUNT} retries for prompt: '{prompt[:50]}...'")
305
+ raise APIError(f"Failed after {MAX_RETRY_COUNT} retries")
306
+
307
+
308
+ def get_results(task_id):
309
+ """
310
+ Check the status of an image generation task.
311
+
312
+ Args:
313
+ task_id (str): The task ID to check
314
+
315
+ Returns:
316
+ dict: Task result information
317
+
318
+ Raises:
319
+ APIError: If the API request fails
320
+ """
321
+ logger.debug(f"Checking status for task ID: {task_id}")
322
+
323
+ if not task_id:
324
+ logger.error("Empty task ID provided to get_results")
325
+ raise ValueError("Task ID cannot be empty")
326
+
327
+ url = f"{API_RESULT_URL}?task_id={task_id}"
328
+ headers = {
329
+ "Authorization": f"Bearer {API_TOKEN}",
330
+ "X-accept-language": "en",
331
+ }
332
+
333
+ try:
334
+ response = requests.get(url, headers=headers, timeout=10)
335
+ logger.debug(f"Status check response code: {response.status_code}")
336
+
337
+ response.raise_for_status()
338
+ result = response.json()
339
+
340
+ if not result or "result" not in result:
341
+ logger.warning(f"Invalid response format from API when checking task {task_id}: {str(result)}")
342
+ raise APIError(f"Invalid response format from API when checking task {task_id}: {str(result)}")
343
+
344
+ return result
345
+
346
+ except requests.exceptions.Timeout:
347
+ logger.warning(f"Request timed out when checking task {task_id}")
348
+ return None
349
+
350
+ except requests.exceptions.HTTPError as e:
351
+ status_code = e.response.status_code
352
+ logger.warning(f"HTTP error {status_code} when checking task {task_id}")
353
+
354
+ try:
355
+ error_content = e.response.json()
356
+ logger.error(f"Error response content: {error_content}")
357
+ except:
358
+ logger.error(f"Could not parse error response as JSON. Raw content: {e.response.content[:500]}")
359
+
360
+ if status_code == 401:
361
+ logger.error(f"Authentication failed when checking task {task_id}")
362
+ raise APIError(f"Authentication failed. Please check your API token when checking task {task_id}")
363
+ elif 400 <= status_code < 500:
364
+ try:
365
+ error_detail = e.response.json()
366
+ error_message = f"HTTP error {status_code}: {error_detail.get('message', 'Client error')}"
367
+ except:
368
+ error_message = f"HTTP error {status_code}"
369
+ logger.error(error_message)
370
+ return None
371
+ else:
372
+ logger.warning(f"Server error {status_code} when checking task {task_id}")
373
+ return None
374
+
375
+ except requests.exceptions.RequestException as e:
376
+ logger.warning(f"Network error when checking task {task_id}: {str(e)}")
377
+ logger.debug(f"Network error details: {traceback.format_exc()}")
378
+ return None
379
+
380
+ except Exception as e:
381
+ logger.error(f"Unexpected error when checking task {task_id}: {str(e)}")
382
+ logger.error(f"Full traceback: {traceback.format_exc()}")
383
+ return None
384
+
385
+
386
+ def download_image(image_url):
387
+ """
388
+ Download an image from a URL and return it as a PIL Image.
389
+ Converts WebP to PNG format while preserving original image data.
390
+
391
+ Args:
392
+ image_url (str): URL of the image
393
+
394
+ Returns:
395
+ PIL.Image: Downloaded image object converted to PNG format
396
+
397
+ Raises:
398
+ APIError: If the download fails
399
+ """
400
+ logger.info(f"Starting download_image from URL: {image_url}")
401
+
402
+ if not image_url:
403
+ logger.error("Empty image URL provided to download_image")
404
+ raise ValueError("Image URL cannot be empty when downloading image")
405
+
406
+ retry_count = 0
407
+ while retry_count < MAX_RETRY_COUNT:
408
+ try:
409
+ logger.info(f"Downloading image [attempt {retry_count+1}/{MAX_RETRY_COUNT}] from {image_url}")
410
+ response = requests.get(image_url, timeout=15)
411
+
412
+ logger.debug(f"Image download response status: {response.status_code}, Content-Type: {response.headers.get('Content-Type')}, Content-Length: {response.headers.get('Content-Length')}")
413
+
414
+ response.raise_for_status()
415
+
416
+ # Open the image from response content
417
+ image = Image.open(BytesIO(response.content))
418
+ logger.info(f"Image opened successfully. Format: {image.format}, Size: {image.size[0]}x{image.size[1]}, Mode: {image.mode}")
419
+
420
+ # Get original metadata before conversion
421
+ original_metadata = {}
422
+ for key, value in image.info.items():
423
+ if isinstance(key, str) and isinstance(value, str):
424
+ original_metadata[key] = value
425
+
426
+ logger.debug(f"Original image metadata: {original_metadata}")
427
+
428
+ # Convert to PNG regardless of original format (WebP, JPEG, etc.)
429
+ if image.format != 'PNG':
430
+ logger.info(f"Converting image from {image.format} to PNG format")
431
+ png_buffer = BytesIO()
432
+
433
+ # If the image has an alpha channel, preserve it, otherwise convert to RGB
434
+ if 'A' in image.getbands():
435
+ logger.debug("Preserving alpha channel in image conversion")
436
+ image_to_save = image
437
+ else:
438
+ logger.debug("Converting image to RGB mode")
439
+ image_to_save = image.convert('RGB')
440
+
441
+ image_to_save.save(png_buffer, format='PNG')
442
+ png_buffer.seek(0)
443
+ image = Image.open(png_buffer)
444
+ logger.debug(f"Image converted to PNG. New size: {image.size[0]}x{image.size[1]}, Mode: {image.mode}")
445
+
446
+ # Preserve original metadata
447
+ for key, value in original_metadata.items():
448
+ image.info[key] = value
449
+ logger.debug("Original metadata preserved in converted image")
450
+
451
+ logger.info(f"Successfully downloaded and processed image: {image.size[0]}x{image.size[1]}")
452
+ return image
453
+
454
+ except requests.exceptions.Timeout:
455
+ retry_count += 1
456
+ logger.warning(f"Download timed out. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
457
+ time.sleep(1)
458
+
459
+ except requests.exceptions.HTTPError as e:
460
+ status_code = e.response.status_code
461
+ logger.error(f"HTTP error {status_code} when downloading image from {image_url}")
462
+
463
+ try:
464
+ error_content = e.response.text[:500]
465
+ logger.error(f"Error response content: {error_content}")
466
+ except:
467
+ logger.error("Could not read error response content")
468
+
469
+ if 400 <= status_code < 500:
470
+ error_message = f"HTTP error {status_code} when downloading image"
471
+ logger.error(error_message)
472
+ raise APIError(error_message)
473
+ else:
474
+ retry_count += 1
475
+ logger.warning(f"Server error {status_code}. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
476
+ time.sleep(1)
477
+
478
+ except requests.exceptions.RequestException as e:
479
+ retry_count += 1
480
+ logger.warning(f"Network error during image download: {str(e)}. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
481
+ logger.debug(f"Network error details: {traceback.format_exc()}")
482
+ time.sleep(1)
483
+
484
+ except Exception as e:
485
+ logger.error(f"Error processing image from {image_url}: {str(e)}")
486
+ logger.error(f"Full traceback: {traceback.format_exc()}")
487
+ raise APIError(f"Failed to process image: {str(e)}")
488
+
489
+ logger.error(f"Failed to download image from {image_url} after {MAX_RETRY_COUNT} retries")
490
+ raise APIError(f"Failed to download image after {MAX_RETRY_COUNT} retries")
491
+
492
+
493
+ def add_metadata_to_image(image, metadata):
494
+ """
495
+ Add metadata to a PIL image.
496
+
497
+ Args:
498
+ image (PIL.Image): The image to add metadata to
499
+ metadata (dict): Metadata to add to the image
500
+
501
+ Returns:
502
+ PIL.Image: Image with metadata
503
+ """
504
+ logger.debug(f"Adding metadata to image: {metadata}")
505
+
506
+ if not image:
507
+ logger.error("Null image provided to add_metadata_to_image")
508
+ return None
509
+ try:
510
+ # Get any existing metadata
511
+ existing_metadata = {}
512
+ for key, value in image.info.items():
513
+ if isinstance(key, str) and isinstance(value, str):
514
+ existing_metadata[key] = value
515
+
516
+ logger.debug(f"Existing image metadata: {existing_metadata}")
517
+
518
+ # Merge with new metadata (new values override existing ones)
519
+ all_metadata = {**existing_metadata, **metadata}
520
+ logger.debug(f"Combined metadata: {all_metadata}")
521
+
522
+ # Create a new metadata dictionary for PNG
523
+ meta = PngImagePlugin.PngInfo()
524
+
525
+ # Add each metadata item
526
+ for key, value in all_metadata.items():
527
+ meta.add_text(key, str(value))
528
+
529
+ # Save with metadata to a buffer
530
+ buffer = BytesIO()
531
+ image.save(buffer, format='PNG', pnginfo=meta)
532
+ logger.debug("Image saved to buffer with metadata")
533
+
534
+ # Reload the image from the buffer
535
+ buffer.seek(0)
536
+ result_image = Image.open(buffer)
537
+ logger.debug("Image reloaded from buffer with metadata")
538
+ return result_image
539
+
540
+ except Exception as e:
541
+ logger.error(f"Failed to add metadata to image: {str(e)}")
542
+ logger.error(f"Full traceback: {traceback.format_exc()}")
543
+ return image # Return original image if metadata addition fails
544
+
545
+ # Create Gradio interface
546
+ def create_ui():
547
+ logger.info("Creating Gradio UI")
548
+ with gr.Blocks(title="HiDream-E1-Full Image Editor", theme=gr.themes.Base()) as demo:
549
+ with gr.Row(equal_height=True):
550
+ with gr.Column(scale=1):
551
+ gr.Markdown("""
552
+ # HiDream-E1-Full Image Editor
553
+
554
+ Edit images using natural language instructions with state-of-the-art AI [🤗 HuggingFace](https://huggingface.co/HiDream-ai/HiDream-E1-Full) | [GitHub](https://github.com/HiDream-ai/HiDream-E1) | [Twitter](https://x.com/vivago_ai)
555
+
556
+ <span style="color: #FF5733; font-weight: bold">For more features and to experience the full capabilities of our product, please visit [https://vivago.ai/](https://vivago.ai/).</span>
557
+ """)
558
+
559
+ with gr.Row():
560
+ # Input column
561
+ with gr.Column(scale=1):
562
+ input_image = gr.Image(
563
+ type="pil",
564
+ label="Input Image",
565
+ height=400,
566
+ show_download_button=True,
567
+ show_label=True,
568
+ scale=1,
569
+ container=True,
570
+ image_mode="RGB"
571
+ )
572
+
573
+ instruction = gr.Textbox(
574
+ label="Editing Instruction",
575
+ placeholder="e.g., convert the image into a Ghibli style",
576
+ lines=3
577
+ )
578
+
579
+ gr.Markdown("""
580
+ <div style="padding: 8px; margin-bottom: 10px; background-color: #E2F0FF; border-left: 5px solid #2E86DE; color: #2C3E50;">
581
+ <strong>Note:</strong> For optimal results, we recommend using the <strong>Refine Instruction</strong> button which formats your input into:
582
+ <br><em>"Editing Instruction: [your instruction]. Target Image Description: [expected result]"</em>
583
+ </div>
584
+ """)
585
+
586
+ with gr.Row():
587
+ refine_btn = gr.Button("Refine Instruction")
588
+ generate_btn = gr.Button("Generate", variant="primary", size="lg")
589
+
590
+ with gr.Accordion("Advanced Settings", open=True):
591
+ gr.Markdown("""
592
+ <div style="padding: 8px; margin: 15px 0; background-color: #FFF3CD; border-left: 5px solid #FFDD57; color: #856404;">
593
+ <strong>Important:</strong> Adjust these parameters based on your editing needs:
594
+ <ul>
595
+ <li>For style changes, use higher image preservation strength (e.g., 3.0-4.0)</li>
596
+ <li>For local edits like adding, deleting, replacing elements, use lower image preservation strength (e.g., 2.0-3.0)</li>
597
+ </ul>
598
+ </div>
599
+ """)
600
+ with gr.Row():
601
+ guidance_scale = gr.Slider(
602
+ minimum=1.0,
603
+ maximum=10.0,
604
+ step=0.1,
605
+ value=5.0,
606
+ label="Instruction Following Strength"
607
+ )
608
+ image_guidance_scale = gr.Slider(
609
+ minimum=1.0,
610
+ maximum=10.0,
611
+ step=0.1,
612
+ value=3.0,
613
+ label="Image Preservation Strength"
614
+ )
615
+
616
+ seed = gr.Number(
617
+ label="Seed (use -1 for random)",
618
+ value=82706,
619
+ precision=0
620
+ )
621
+
622
+
623
+
624
+ progress = gr.Progress(track_tqdm=False)
625
+
626
+ # Output column
627
+ with gr.Column(scale=1):
628
+ output_image = gr.Image(
629
+ label="Generated Image",
630
+ type="pil",
631
+ height=400,
632
+ interactive=False,
633
+ show_download_button=True,
634
+ scale=1,
635
+ container=True,
636
+ image_mode="RGB"
637
+ )
638
+
639
+ with gr.Accordion("Image Information", open=False):
640
+ image_info = gr.JSON(label="Details")
641
+
642
+ def refine_instruction_ui(image, instruction):
643
+ if not image or not instruction:
644
+ return instruction
645
+ try:
646
+ refined = refine_instruction(image, instruction)
647
+ if len(refined) > 0:
648
+ return refined
649
+ else:
650
+ logger.warning("Instruction refinement service returned empty result")
651
+ gr.Warning("Instruction refinement service is currently not working. Please try again later.")
652
+ return instruction
653
+ except Exception as e:
654
+ logger.error(f"Error refining instruction: {str(e)}")
655
+ gr.Warning("Instruction refinement service is currently not working. Please try again later.")
656
+ return instruction
657
+
658
+ # Generate function with progress updates
659
+ def generate_with_progress(image, instruction, seed, guidance_scale, image_guidance_scale, progress=gr.Progress()):
660
+ logger.info(f"Starting image generation with instruction='{instruction[:50]}...', seed={seed}")
661
+
662
+ try:
663
+ if not image:
664
+ logger.error("No image provided in UI")
665
+ return None, None
666
+
667
+ if not instruction.strip():
668
+ logger.error("Empty instruction provided in UI")
669
+ return None, None
670
+
671
+ # Create request
672
+ logger.info("Creating API request")
673
+ task_id, used_seed = create_request(
674
+ prompt=instruction,
675
+ image=image,
676
+ guidance_scale=guidance_scale,
677
+ image_guidance_scale=image_guidance_scale,
678
+ seed=seed
679
+ )
680
+
681
+ # Poll for results
682
+ start_time = time.time()
683
+ last_completion_ratio = 0
684
+ progress(0, desc="Initializing...")
685
+ logger.info(f"Starting to poll for results for task ID: {task_id}")
686
+
687
+ while time.time() - start_time < MAX_POLL_TIME:
688
+ result = get_results(task_id)
689
+ if not result:
690
+ time.sleep(POLL_INTERVAL)
691
+ continue
692
+
693
+ sub_results = result.get("result", {}).get("sub_task_results", [])
694
+ if not sub_results:
695
+ time.sleep(POLL_INTERVAL)
696
+ continue
697
+
698
+ status = sub_results[0].get("task_status")
699
+ logger.debug(f"Task status for ID {task_id}: {status}")
700
+
701
+ # Get and display completion ratio
702
+ completion_ratio = sub_results[0].get('task_completion', 0) * 100
703
+ if completion_ratio != last_completion_ratio:
704
+ # Only update UI when completion ratio changes
705
+ last_completion_ratio = completion_ratio
706
+ progress(completion_ratio / 100, desc=f"Generating image")
707
+ logger.info(f"Generation progress - Task ID: {task_id}, Completion: {completion_ratio:.1f}%")
708
+
709
+ # Check task status
710
+ if status == 1: # Success
711
+ logger.info(f"Task completed successfully - Task ID: {task_id}")
712
+ progress(1.0, desc="Generation complete")
713
+ image_name = sub_results[0].get("image")
714
+ if not image_name:
715
+ logger.error(f"No image name in successful response. Response: {sub_results[0]}")
716
+ return None, None
717
+
718
+ image_url = f"{API_IMAGE_URL}{image_name}.png"
719
+ logger.info(f"Downloading image - Task ID: {task_id}, URL: {image_url}")
720
+ image = download_image(image_url)
721
+
722
+ if image:
723
+ # Add metadata to the image
724
+ logger.info(f"Adding metadata to image - Task ID: {task_id}")
725
+ metadata = {
726
+ "prompt": instruction,
727
+ "seed": str(used_seed),
728
+ "model": API_MODEL_NAME,
729
+ "guidance_scale": str(guidance_scale),
730
+ "image_guidance_scale": str(image_guidance_scale),
731
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
732
+ "generated_by": "HiDream-E1-Full Editor"
733
+ }
734
+
735
+ image_with_metadata = add_metadata_to_image(image, metadata)
736
+
737
+ # Create info for display
738
+ info = {
739
+ "model": API_MODEL_NAME,
740
+ "prompt": instruction,
741
+ "seed": used_seed,
742
+ "guidance_scale": guidance_scale,
743
+ "image_guidance_scale": image_guidance_scale,
744
+ "generated_at": time.strftime("%Y-%m-%d %H:%M:%S")
745
+ }
746
+
747
+ logger.info(f"Image generation complete - Task ID: {task_id}")
748
+ return image_with_metadata, info
749
+ else:
750
+ logger.error(f"Failed to download image - Task ID: {task_id}, URL: {image_url}")
751
+ return None, None
752
+
753
+ elif status in {3, 4}: # Failed or Canceled
754
+ error_msg = sub_results[0].get("task_error", "Unknown error")
755
+ logger.error(f"Task failed - Task ID: {task_id}, Status: {status}, Error: {error_msg}")
756
+ return None, None
757
+
758
+ time.sleep(POLL_INTERVAL)
759
+
760
+ logger.error(f"Timeout waiting for task completion - Task ID: {task_id}, Max time: {MAX_POLL_TIME}s")
761
+ return None, None
762
+
763
+ except Exception as e:
764
+ logger.error(f"Error during image generation: {str(e)}")
765
+ logger.error(f"Full traceback: {traceback.format_exc()}")
766
+ return None, None
767
+
768
+ # Set up event handlers
769
+ refine_btn.click(
770
+ fn=refine_instruction_ui,
771
+ inputs=[input_image, instruction],
772
+ outputs=[instruction]
773
+ )
774
+
775
+ generate_btn.click(
776
+ fn=generate_with_progress,
777
+ inputs=[input_image, instruction, seed, guidance_scale, image_guidance_scale],
778
+ outputs=[output_image, image_info]
779
+ )
780
+
781
+ # Define a combined function to refine instruction and then generate image
782
+ def refine_and_generate(image, instruction, seed, guidance_scale, image_guidance_scale, progress=gr.Progress()):
783
+ try:
784
+ # First refine the instruction
785
+ if not image or not instruction:
786
+ return None, None, instruction
787
+
788
+ logger.info(f"Refining instruction: '{instruction[:50]}...'")
789
+ refined_instruction = refine_instruction_ui(image, instruction)
790
+
791
+ if not refined_instruction or refined_instruction.strip() == "":
792
+ logger.warning("Instruction refinement failed, using original instruction")
793
+ refined_instruction = instruction
794
+ gr.Warning("Instruction refinement failed, using original instruction instead.")
795
+ else:
796
+ logger.info(f"Instruction refined to: '{refined_instruction[:50]}...'")
797
+
798
+ # Then generate with the refined instruction
799
+ progress(0.2, desc="Instruction refined, generating image...")
800
+ generated_image, image_info = generate_with_progress(image, refined_instruction, seed, guidance_scale, image_guidance_scale, progress)
801
+ return generated_image, image_info, refined_instruction
802
+ except Exception as e:
803
+ logger.error(f"Error in refine_and_generate: {str(e)}")
804
+ logger.error(f"Full traceback: {traceback.format_exc()}")
805
+ gr.Warning(f"An error occurred during processing: {str(e)}")
806
+ return None, None, instruction
807
+
808
+ # Examples
809
+ gr.Examples(
810
+ examples=[
811
+ ["assets/test_1.png", "convert the image into a Ghibli style",3, 5, 4],
812
+ ["assets/test_1.png", "change the image into Disney Pixar style",3, 5, 4],
813
+ ["assets/test_1.png", "turn to sketch style",3, 5, 4],
814
+ ["assets/test_1.png", "add a sunglasses to the girl",3, 5, 2],
815
+ ["assets/test_1.png", "change the background to a sunset",3, 5, 2],
816
+ ["assets/test_2.jpg", "convert this image into a ink sketch image",3, 5, 2],
817
+ ["assets/test_2.jpg", "add butterfly'",3, 5, 2],
818
+ ["assets/test_2.jpg", "remove the wooden sign'",3, 5, 2],
819
+ ],
820
+ inputs=[input_image, instruction, seed, guidance_scale, image_guidance_scale],
821
+ outputs=[output_image, image_info, instruction],
822
+ fn=refine_and_generate,
823
+ cache_examples=True,
824
+ cache_mode = "lazy"
825
+ )
826
+
827
+ logger.info("Gradio UI created successfully")
828
+ return demo
829
+
830
+ # Launch app
831
+ if __name__ == "__main__":
832
+ logger.info("Starting HiDream-E1-Full Image Generator application")
833
+ demo = create_ui()
834
+ logger.info("Launching Gradio interface with queue")
835
+ demo.queue(max_size=50, default_concurrency_limit=8).launch(show_api=False)
836
+ logger.info("Application shutdown")