retwpay commited on
Commit
480ec3f
·
verified ·
1 Parent(s): 194eca1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -22
app.py CHANGED
@@ -4,44 +4,60 @@ import numpy as np
4
  import PIL.Image
5
  from PIL import Image
6
  import random
7
- from diffusers import ControlNetModel, StableDiffusionXLPipeline, AutoencoderKL
8
- from diffusers import DDIMScheduler, EulerAncestralDiscreteScheduler
9
- import cv2
10
  import torch
11
 
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
 
14
  pipe = StableDiffusionXLPipeline.from_pretrained(
15
  "votepurchase/waiNSFWIllustrious_v110",
16
  torch_dtype=torch.float16,
 
 
17
  )
18
 
19
  pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
20
  pipe.to(device)
21
 
 
 
 
 
 
 
22
  MAX_SEED = np.iinfo(np.int32).max
23
  MAX_IMAGE_SIZE = 1216
24
-
25
 
26
  @spaces.GPU
27
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
28
-
 
 
 
29
  if randomize_seed:
30
  seed = random.randint(0, MAX_SEED)
31
 
32
- generator = torch.Generator().manual_seed(seed)
33
-
34
- output_image = pipe(
35
- prompt=prompt,
36
- negative_prompt=negative_prompt,
37
- guidance_scale=guidance_scale,
38
- num_inference_steps=num_inference_steps,
39
- width=width,
40
- height=height,
41
- generator=generator
42
- ).images[0]
43
-
44
- return output_image
 
 
 
 
 
 
45
 
46
 
47
  css = """
@@ -60,7 +76,7 @@ with gr.Blocks(css=css) as demo:
60
  label="Prompt",
61
  show_label=False,
62
  max_lines=1,
63
- placeholder="Enter your prompt",
64
  container=False,
65
  )
66
 
@@ -93,7 +109,7 @@ with gr.Blocks(css=css) as demo:
93
  minimum=256,
94
  maximum=MAX_IMAGE_SIZE,
95
  step=32,
96
- value=1024,#832,
97
  )
98
 
99
  height = gr.Slider(
@@ -101,7 +117,7 @@ with gr.Blocks(css=css) as demo:
101
  minimum=256,
102
  maximum=MAX_IMAGE_SIZE,
103
  step=32,
104
- value=1024,#1216,
105
  )
106
 
107
  with gr.Row():
@@ -121,7 +137,7 @@ with gr.Blocks(css=css) as demo:
121
  value=28,
122
  )
123
 
124
- run_button.click(#lambda x: None, inputs=None, outputs=result).then(
125
  fn=infer,
126
  inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
127
  outputs=[result]
 
4
  import PIL.Image
5
  from PIL import Image
6
  import random
7
+ from diffusers import StableDiffusionXLPipeline
8
+ from diffusers import EulerAncestralDiscreteScheduler
 
9
  import torch
10
 
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
+ # Make sure to use torch.float16 consistently throughout the pipeline
14
  pipe = StableDiffusionXLPipeline.from_pretrained(
15
  "votepurchase/waiNSFWIllustrious_v110",
16
  torch_dtype=torch.float16,
17
+ variant="fp16", # Explicitly use fp16 variant
18
+ use_safetensors=True # Use safetensors if available
19
  )
20
 
21
  pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
22
  pipe.to(device)
23
 
24
+ # Force all components to use the same dtype
25
+ pipe.text_encoder.to(torch.float16)
26
+ pipe.text_encoder_2.to(torch.float16)
27
+ pipe.vae.to(torch.float16)
28
+ pipe.unet.to(torch.float16)
29
+
30
  MAX_SEED = np.iinfo(np.int32).max
31
  MAX_IMAGE_SIZE = 1216
 
32
 
33
  @spaces.GPU
34
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
35
+ # Check and truncate prompt if too long (CLIP can only handle 77 tokens)
36
+ if len(prompt.split()) > 60: # Rough estimate to avoid exceeding token limit
37
+ print("Warning: Prompt may be too long and will be truncated by the model")
38
+
39
  if randomize_seed:
40
  seed = random.randint(0, MAX_SEED)
41
 
42
+ generator = torch.Generator(device=device).manual_seed(seed)
43
+
44
+ try:
45
+ output_image = pipe(
46
+ prompt=prompt,
47
+ negative_prompt=negative_prompt,
48
+ guidance_scale=guidance_scale,
49
+ num_inference_steps=num_inference_steps,
50
+ width=width,
51
+ height=height,
52
+ generator=generator
53
+ ).images[0]
54
+
55
+ return output_image
56
+ except RuntimeError as e:
57
+ print(f"Error during generation: {e}")
58
+ # Return a blank image with error message
59
+ error_img = Image.new('RGB', (width, height), color=(0, 0, 0))
60
+ return error_img
61
 
62
 
63
  css = """
 
76
  label="Prompt",
77
  show_label=False,
78
  max_lines=1,
79
+ placeholder="Enter your prompt (keep it under 60 words for best results)",
80
  container=False,
81
  )
82
 
 
109
  minimum=256,
110
  maximum=MAX_IMAGE_SIZE,
111
  step=32,
112
+ value=1024,
113
  )
114
 
115
  height = gr.Slider(
 
117
  minimum=256,
118
  maximum=MAX_IMAGE_SIZE,
119
  step=32,
120
+ value=1024,
121
  )
122
 
123
  with gr.Row():
 
137
  value=28,
138
  )
139
 
140
+ run_button.click(
141
  fn=infer,
142
  inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
143
  outputs=[result]