SeedOfEvil commited on
Commit
14a0302
·
verified ·
1 Parent(s): 2d3754d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -24
app.py CHANGED
@@ -1,28 +1,35 @@
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
-
5
  import spaces
6
  from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler
7
  import torch
8
 
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  model_repo_id = "tensorart/stable-diffusion-3.5-large-TurboX"
 
11
 
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
-
17
  pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
-
19
- pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_repo_id, subfolder="scheduler", shift=5)
20
-
21
  pipe = pipe.to(device)
22
 
23
  MAX_SEED = np.iinfo(np.int32).max
24
  MAX_IMAGE_SIZE = 1024
25
 
 
 
 
 
 
 
 
 
 
 
26
  @spaces.GPU(duration=65)
27
  def infer(
28
  prompt,
@@ -38,8 +45,13 @@ def infer(
38
  if randomize_seed:
39
  seed = random.randint(0, MAX_SEED)
40
 
41
- generator = torch.Generator().manual_seed(seed)
42
 
 
 
 
 
 
43
  image = pipe(
44
  prompt=prompt,
45
  negative_prompt=negative_prompt,
@@ -52,9 +64,9 @@ def infer(
52
 
53
  return image, seed
54
 
55
-
56
  examples = [
57
- "A capybara wearing a suit holding a sign that reads Hello World",
58
  ]
59
 
60
  css = """
@@ -76,18 +88,14 @@ with gr.Blocks(css=css) as demo:
76
  placeholder="Enter your prompt",
77
  container=False,
78
  )
79
-
80
  run_button = gr.Button("Run", scale=0, variant="primary")
81
-
82
  result = gr.Image(label="Result", show_label=False)
83
-
84
  with gr.Accordion("Advanced Settings", open=False):
85
  negative_prompt = gr.Text(
86
  label="Negative prompt",
87
  max_lines=1,
88
  placeholder="Enter a negative prompt",
89
  )
90
-
91
  seed = gr.Slider(
92
  label="Seed",
93
  minimum=0,
@@ -95,18 +103,15 @@ with gr.Blocks(css=css) as demo:
95
  step=1,
96
  value=0,
97
  )
98
-
99
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
100
-
101
  with gr.Row():
102
  width = gr.Slider(
103
  label="Width",
104
  minimum=512,
105
  maximum=MAX_IMAGE_SIZE,
106
  step=32,
107
- value=1024,
108
  )
109
-
110
  height = gr.Slider(
111
  label="Height",
112
  minimum=512,
@@ -114,7 +119,6 @@ with gr.Blocks(css=css) as demo:
114
  step=32,
115
  value=1024,
116
  )
117
-
118
  with gr.Row():
119
  guidance_scale = gr.Slider(
120
  label="Guidance scale",
@@ -123,15 +127,13 @@ with gr.Blocks(css=css) as demo:
123
  step=0.1,
124
  value=1.5,
125
  )
126
-
127
  num_inference_steps = gr.Slider(
128
  label="Number of inference steps",
129
  minimum=1,
130
  maximum=50,
131
  step=1,
132
- value=8,
133
  )
134
-
135
  gr.Examples(examples=examples, inputs=[prompt], outputs=[result, seed], fn=infer, cache_examples=True, cache_mode="lazy")
136
  gr.on(
137
  triggers=[run_button.click, prompt.submit],
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
 
4
  import spaces
5
  from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler
6
  import torch
7
 
8
+ # Set device and model parameters
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  model_repo_id = "tensorart/stable-diffusion-3.5-large-TurboX"
11
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
12
 
13
+ # Load the pipeline with the specified torch_dtype and move it to the GPU
 
 
 
 
14
  pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
15
+ pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
16
+ model_repo_id, subfolder="scheduler", shift=5
17
+ )
18
  pipe = pipe.to(device)
19
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
  MAX_IMAGE_SIZE = 1024
22
 
23
+ def truncate_text(text, max_tokens=77):
24
+ """
25
+ Truncate the input text to a maximum of max_tokens using the pipeline's tokenizer.
26
+ """
27
+ if text.strip() == "":
28
+ return text
29
+ tokens = pipe.tokenizer(text, truncation=True, max_length=max_tokens, add_special_tokens=True)
30
+ truncated_text = pipe.tokenizer.decode(tokens["input_ids"], skip_special_tokens=True)
31
+ return truncated_text
32
+
33
  @spaces.GPU(duration=65)
34
  def infer(
35
  prompt,
 
45
  if randomize_seed:
46
  seed = random.randint(0, MAX_SEED)
47
 
48
+ generator = torch.Generator(device=device).manual_seed(seed)
49
 
50
+ # Explicitly truncate prompts to avoid CLIP token warnings.
51
+ prompt = truncate_text(prompt, max_tokens=77)
52
+ negative_prompt = truncate_text(negative_prompt, max_tokens=77) if negative_prompt.strip() else ""
53
+
54
+ # Generate the image using the truncated prompts.
55
  image = pipe(
56
  prompt=prompt,
57
  negative_prompt=negative_prompt,
 
64
 
65
  return image, seed
66
 
67
+ # UI Layout
68
  examples = [
69
+ "A capybara wearing a suit holding a sign that reads Hello World",
70
  ]
71
 
72
  css = """
 
88
  placeholder="Enter your prompt",
89
  container=False,
90
  )
 
91
  run_button = gr.Button("Run", scale=0, variant="primary")
 
92
  result = gr.Image(label="Result", show_label=False)
 
93
  with gr.Accordion("Advanced Settings", open=False):
94
  negative_prompt = gr.Text(
95
  label="Negative prompt",
96
  max_lines=1,
97
  placeholder="Enter a negative prompt",
98
  )
 
99
  seed = gr.Slider(
100
  label="Seed",
101
  minimum=0,
 
103
  step=1,
104
  value=0,
105
  )
 
106
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
 
107
  with gr.Row():
108
  width = gr.Slider(
109
  label="Width",
110
  minimum=512,
111
  maximum=MAX_IMAGE_SIZE,
112
  step=32,
113
+ value=1024,
114
  )
 
115
  height = gr.Slider(
116
  label="Height",
117
  minimum=512,
 
119
  step=32,
120
  value=1024,
121
  )
 
122
  with gr.Row():
123
  guidance_scale = gr.Slider(
124
  label="Guidance scale",
 
127
  step=0.1,
128
  value=1.5,
129
  )
 
130
  num_inference_steps = gr.Slider(
131
  label="Number of inference steps",
132
  minimum=1,
133
  maximum=50,
134
  step=1,
135
+ value=8,
136
  )
 
137
  gr.Examples(examples=examples, inputs=[prompt], outputs=[result, seed], fn=infer, cache_examples=True, cache_mode="lazy")
138
  gr.on(
139
  triggers=[run_button.click, prompt.submit],