fenfan commited on
Commit
0a6db9d
·
verified ·
1 Parent(s): 3994a93

fix: update app.py to fit zero gpu

Browse files
Files changed (1) hide show
  1. app.py +60 -81
app.py CHANGED
@@ -11,96 +11,75 @@
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
-
15
- import dataclasses
16
-
17
  import gradio as gr
18
  import torch
19
  import spaces
20
 
21
  from uno.flux.pipeline import UNOPipeline
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- def create_demo(
25
- model_type: str,
26
- device: str = "cuda" if torch.cuda.is_available() else "cpu",
27
- offload: bool = False,
28
- ):
29
- pipeline = UNOPipeline(model_type, device, offload, only_lora=True, lora_rank=512)
30
- pipeline.__call__ = spaces.GPU(duratioin=120)(pipeline.__call__)
31
 
32
- with gr.Blocks() as demo:
33
- gr.Markdown(f"# UNO by UNO team")
34
- with gr.Row():
35
- with gr.Column():
36
- prompt = gr.Textbox(label="Prompt", value="handsome woman in the city")
37
- with gr.Row():
38
- image_prompt1 = gr.Image(label="ref img1", visible=True, interactive=True, type="pil")
39
- image_prompt2 = gr.Image(label="ref img2", visible=True, interactive=True, type="pil")
40
- image_prompt3 = gr.Image(label="ref img3", visible=True, interactive=True, type="pil")
41
- image_prompt4 = gr.Image(label="ref img4", visible=True, interactive=True, type="pil")
42
 
43
- with gr.Row():
44
- with gr.Column():
45
- ref_long_side = gr.Slider(128, 512, 512, step=16, label="Long side of Ref Images")
46
- with gr.Column():
47
- gr.Markdown("📌 **The recommended ref scale** is related to the ref img number.\n")
48
- gr.Markdown(" 1->512 / 2->320 / 3...n->256")
49
 
50
- with gr.Row():
51
- with gr.Column():
52
- width = gr.Slider(512, 2048, 512, step=16, label="Gneration Width")
53
- height = gr.Slider(512, 2048, 512, step=16, label="Gneration Height")
54
- with gr.Column():
55
- gr.Markdown("📌 The model trained on 512x512 resolution.\n")
56
- gr.Markdown(
57
- "The size closer to 512 is more stable,"
58
- " and the higher size gives a better visual effect but is less stable"
59
- )
60
-
61
- with gr.Accordion("Generation Options", open=False):
62
- with gr.Row():
63
- num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps")
64
- guidance = gr.Slider(1.0, 5.0, 4.0, step=0.1, label="Guidance", interactive=True)
65
- seed = gr.Number(-1, label="Seed (-1 for random)")
66
-
67
- generate_btn = gr.Button("Generate")
68
-
69
- with gr.Column():
70
- output_image = gr.Image(label="Generated Image")
71
- download_btn = gr.File(label="Download full-resolution", type="filepath", interactive=False)
72
-
73
-
74
- inputs = [
75
- prompt, width, height, guidance, num_steps,
76
- seed, ref_long_side, image_prompt1, image_prompt2, image_prompt3, image_prompt4
77
- ]
78
- generate_btn.click(
79
- fn=pipeline.gradio_generate,
80
- inputs=inputs,
81
- outputs=[output_image, download_btn],
82
- )
83
-
84
- return demo
85
-
86
- if __name__ == "__main__":
87
- from typing import Literal
88
-
89
- from transformers import HfArgumentParser
90
-
91
- @dataclasses.dataclass
92
- class AppArgs:
93
- name: Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] = "flux-dev"
94
- device: Literal["cuda", "cpu"] = "cuda" if torch.cuda.is_available() else "cpu"
95
- offload: bool = dataclasses.field(
96
- default=False,
97
- metadata={"help": "If True, sequantial offload the models(ae, dit, text encoder) to CPU if not used."}
98
  )
99
- port: int = 7860
100
-
101
- parser = HfArgumentParser([AppArgs])
102
- args_tuple = parser.parse_args_into_dataclasses() # type: tuple[AppArgs]
103
- args = args_tuple[0]
104
 
105
- demo = create_demo(args.name, args.device, args.offload)
106
- demo.launch(server_port=args.port)
 
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
 
 
 
14
  import gradio as gr
15
  import torch
16
  import spaces
17
 
18
  from uno.flux.pipeline import UNOPipeline
19
 
20
+ model_type = "flux-dev"
21
+ offload = False
22
+ device = "cuda"
23
+
24
+ pipeline = UNOPipeline(model_type, device, offload, only_lora=True, lora_rank=512)
25
+
26
+
27
+ ## it seems must use decorator can be trigger zero GPU
28
+ ## not work by mannualy decorate by fn = spaces.GPU(duration=120)(fn)
29
+ @spaces.GPU(duration=120)
30
+ def generate_callback(*args, **kwargs):
31
+ return pipeline.gradio_generate(*args, **kwargs)
32
+
33
+ with gr.Blocks() as demo:
34
+ gr.Markdown(f"# UNO by UNO team")
35
+ with gr.Row():
36
+ with gr.Column():
37
+ prompt = gr.Textbox(label="Prompt", value="handsome woman in the city")
38
+ with gr.Row():
39
+ image_prompt1 = gr.Image(label="ref img1", visible=True, interactive=True, type="pil")
40
+ image_prompt2 = gr.Image(label="ref img2", visible=True, interactive=True, type="pil")
41
+ image_prompt3 = gr.Image(label="ref img3", visible=True, interactive=True, type="pil")
42
+ image_prompt4 = gr.Image(label="ref img4", visible=True, interactive=True, type="pil")
43
+
44
+ with gr.Row():
45
+ with gr.Column():
46
+ ref_long_side = gr.Slider(128, 512, 512, step=16, label="Long side of Ref Images")
47
+ with gr.Column():
48
+ gr.Markdown("📌 **The recommended ref scale** is related to the ref img number.\n")
49
+ gr.Markdown(" 1->512 / 2->320 / 3...n->256")
50
+
51
+ with gr.Row():
52
+ with gr.Column():
53
+ width = gr.Slider(512, 2048, 512, step=16, label="Gneration Width")
54
+ height = gr.Slider(512, 2048, 512, step=16, label="Gneration Height")
55
+ with gr.Column():
56
+ gr.Markdown("📌 The model trained on 512x512 resolution.\n")
57
+ gr.Markdown(
58
+ "The size closer to 512 is more stable,"
59
+ " and the higher size gives a better visual effect but is less stable"
60
+ )
61
+
62
+ with gr.Accordion("Generation Options", open=False):
63
+ with gr.Row():
64
+ num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps")
65
+ guidance = gr.Slider(1.0, 5.0, 4.0, step=0.1, label="Guidance", interactive=True)
66
+ seed = gr.Number(-1, label="Seed (-1 for random)")
67
 
68
+ generate_btn = gr.Button("Generate")
 
 
 
 
 
 
69
 
70
+ with gr.Column():
71
+ output_image = gr.Image(label="Generated Image")
72
+ download_btn = gr.File(label="Download full-resolution", type="filepath", interactive=False)
 
 
 
 
 
 
 
73
 
 
 
 
 
 
 
74
 
75
+ inputs = [
76
+ prompt, width, height, guidance, num_steps,
77
+ seed, ref_long_side, image_prompt1, image_prompt2, image_prompt3, image_prompt4
78
+ ]
79
+ generate_btn.click(
80
+ fn=generate_callback,
81
+ inputs=inputs,
82
+ outputs=[output_image, download_btn],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  )
 
 
 
 
 
84
 
85
+ demo.launch()