1 commited on
Commit
7b7e62e
·
1 Parent(s): 65dcc19

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -4
app.py CHANGED
@@ -14,14 +14,34 @@
14
 
15
  import dataclasses
16
  import json
 
 
17
  from pathlib import Path
18
 
19
  import gradio as gr
20
  import torch
21
  import spaces
 
 
 
22
 
23
  from uno.flux.pipeline import UNOPipeline
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def get_examples(examples_dir: str = "assets/examples") -> list:
26
  examples = Path(examples_dir)
27
  ans = []
@@ -54,6 +74,7 @@ def create_demo(
54
  device: str = "cuda" if torch.cuda.is_available() else "cpu",
55
  offload: bool = False,
56
  ):
 
57
  pipeline = UNOPipeline(model_type, device, offload, only_lora=True, lora_rank=512)
58
  pipeline.gradio_generate = spaces.GPU(duratioin=120)(pipeline.gradio_generate)
59
 
@@ -229,11 +250,117 @@ def create_demo(
229
  ],
230
  )
231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  return demo
233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  if __name__ == "__main__":
235
  from typing import Literal
236
-
237
  from transformers import HfArgumentParser
238
 
239
  @dataclasses.dataclass
@@ -245,10 +372,17 @@ if __name__ == "__main__":
245
  metadata={"help": "If True, sequantial offload the models(ae, dit, text encoder) to CPU if not used."}
246
  )
247
  port: int = 7860
 
248
 
249
  parser = HfArgumentParser([AppArgs])
250
- args_tuple = parser.parse_args_into_dataclasses() # type: tuple[AppArgs]
251
  args = args_tuple[0]
252
-
 
253
  demo = create_demo(args.name, args.device, args.offload)
254
- demo.launch(server_port=args.port, ssr_mode=False)
 
 
 
 
 
 
14
 
15
  import dataclasses
16
  import json
17
+ import base64
18
+ import io
19
  from pathlib import Path
20
 
21
  import gradio as gr
22
  import torch
23
  import spaces
24
+ from PIL import Image as PILImage
25
+ from fastapi import FastAPI, Body
26
+ from fastapi.middleware.cors import CORSMiddleware
27
 
28
  from uno.flux.pipeline import UNOPipeline
29
 
30
+ # 创建FastAPI应用
31
+ app = FastAPI()
32
+
33
+ # 添加CORS中间件允许跨域请求
34
+ app.add_middleware(
35
+ CORSMiddleware,
36
+ allow_origins=["*"],
37
+ allow_credentials=True,
38
+ allow_methods=["*"],
39
+ allow_headers=["*"],
40
+ )
41
+
42
+ # 设置全局pipeline变量
43
+ pipeline = None
44
+
45
  def get_examples(examples_dir: str = "assets/examples") -> list:
46
  examples = Path(examples_dir)
47
  ans = []
 
74
  device: str = "cuda" if torch.cuda.is_available() else "cpu",
75
  offload: bool = False,
76
  ):
77
+ global pipeline
78
  pipeline = UNOPipeline(model_type, device, offload, only_lora=True, lora_rank=512)
79
  pipeline.gradio_generate = spaces.GPU(duratioin=120)(pipeline.gradio_generate)
80
 
 
250
  ],
251
  )
252
 
253
+ # 添加API文档
254
+ with gr.Accordion("API Documentation", open=False):
255
+ gr.Markdown("""
256
+ ### API Usage
257
+
258
+ You can use the following endpoint to generate images programmatically:
259
+
260
+ **Endpoint:** `/api/generate`
261
+
262
+ **Method:** POST
263
+
264
+ **Request Body:**
265
+ ```json
266
+ {
267
+ "prompt": "your text prompt",
268
+ "image_refs": ["base64_encoded_image1", "base64_encoded_image2", ...],
269
+ "width": 512,
270
+ "height": 512,
271
+ "guidance": 4.0,
272
+ "num_steps": 25,
273
+ "seed": -1
274
+ }
275
+ ```
276
+
277
+ **Response:**
278
+ ```json
279
+ {
280
+ "image": "base64_encoded_generated_image"
281
+ }
282
+ ```
283
+
284
+ **Example JavaScript Usage:**
285
+ ```javascript
286
+ async function generateImage() {
287
+ const response = await fetch('/api/generate', {
288
+ method: 'POST',
289
+ headers: {
290
+ 'Content-Type': 'application/json',
291
+ },
292
+ body: JSON.stringify({
293
+ prompt: "handsome woman in the city",
294
+ image_refs: [],
295
+ width: 512,
296
+ height: 512
297
+ }),
298
+ });
299
+
300
+ const data = await response.json();
301
+ const imgElement = document.getElementById('generatedImage');
302
+ imgElement.src = `data:image/png;base64,${data.image}`;
303
+ }
304
+ ```
305
+ """)
306
+
307
  return demo
308
 
309
+ # 创建API端点
310
+ @app.post("/api/generate")
311
+ async def generate_image(
312
+ prompt: str = Body(...),
313
+ width: int = Body(512),
314
+ height: int = Body(512),
315
+ guidance: float = Body(4.0),
316
+ num_steps: int = Body(25),
317
+ seed: int = Body(-1),
318
+ image_refs: list = Body([])
319
+ ):
320
+ global pipeline
321
+ # 处理参考图像
322
+ ref_images = []
323
+ for i in range(min(4, len(image_refs))):
324
+ if image_refs[i]:
325
+ try:
326
+ # 解码base64图像
327
+ if isinstance(image_refs[i], str) and "base64" in image_refs[i]:
328
+ # 移除数据URL前缀
329
+ if "," in image_refs[i]:
330
+ img_data = image_refs[i].split(",")[1]
331
+ else:
332
+ img_data = image_refs[i]
333
+
334
+ img_data = base64.b64decode(img_data)
335
+ ref_img = PILImage.open(io.BytesIO(img_data))
336
+ ref_images.append(ref_img)
337
+ else:
338
+ ref_images.append(None)
339
+ except:
340
+ ref_images.append(None)
341
+ else:
342
+ ref_images.append(None)
343
+
344
+ # 填充至4张图像
345
+ while len(ref_images) < 4:
346
+ ref_images.append(None)
347
+
348
+ # 调用模型生成图像
349
+ result_image, _ = pipeline.gradio_generate(
350
+ prompt, width, height, guidance, num_steps, seed,
351
+ ref_images[0], ref_images[1], ref_images[2], ref_images[3]
352
+ )
353
+
354
+ # 将结果图像编码为base64
355
+ buffered = io.BytesIO()
356
+ result_image.save(buffered, format="PNG")
357
+ img_str = base64.b64encode(buffered.getvalue()).decode()
358
+
359
+ return {"image": img_str}
360
+
361
  if __name__ == "__main__":
362
  from typing import Literal
363
+ import uvicorn
364
  from transformers import HfArgumentParser
365
 
366
  @dataclasses.dataclass
 
372
  metadata={"help": "If True, sequantial offload the models(ae, dit, text encoder) to CPU if not used."}
373
  )
374
  port: int = 7860
375
+ host: str = "0.0.0.0"
376
 
377
  parser = HfArgumentParser([AppArgs])
378
+ args_tuple = parser.parse_args_into_dataclasses() # type: tuple[AppArgs]
379
  args = args_tuple[0]
380
+
381
+ # 创建Gradio demo
382
  demo = create_demo(args.name, args.device, args.offload)
383
+
384
+ # 挂载Gradio接口到FastAPI应用
385
+ app = gr.mount_gradio_app(app, demo, path="/")
386
+
387
+ # 使用uvicorn启动FastAPI应用
388
+ uvicorn.run(app, host=args.host, port=args.port)