seawolf2357 commited on
Commit
b7251fb
·
verified ·
1 Parent(s): 448c2fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -166
app.py CHANGED
@@ -6,18 +6,9 @@ import gradio as gr
6
  import numpy as np
7
  import spaces
8
  import torch
9
- from diffusers import DiffusionPipeline
10
  from PIL import Image
11
 
12
- # Make sure PEFT is installed
13
- try:
14
- import peft
15
- except ImportError:
16
- import subprocess
17
- print("Installing PEFT library...")
18
- subprocess.check_call(["pip", "install", "peft"])
19
- import peft
20
-
21
  # Create permanent storage directory
22
  SAVE_DIR = "saved_images" # Gradio will handle the persistence
23
  if not os.path.exists(SAVE_DIR):
@@ -25,17 +16,28 @@ if not os.path.exists(SAVE_DIR):
25
 
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
  repo_id = "black-forest-labs/FLUX.1-dev"
28
- adapter_id = "seawolf2357/nsfw-detection" # Changed to Renoir model
29
 
30
- # Initialize pipeline
31
  print("Loading pipeline...")
32
- # Use DiffusionPipeline instead of FluxPipeline to ensure proper LoRA compatibility
33
- pipeline = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16)
34
- print("Loading LoRA weights...")
35
- # Add low_cpu_mem_usage=False to avoid PEFT assign=True incompatibility issue
36
- pipeline.load_lora_weights(adapter_id, low_cpu_mem_usage=False)
 
37
  pipeline = pipeline.to(device)
38
 
 
 
 
 
 
 
 
 
 
 
 
39
  MAX_SEED = np.iinfo(np.int32).max
40
  MAX_IMAGE_SIZE = 1024
41
 
@@ -56,28 +58,6 @@ def save_generated_image(image, prompt):
56
 
57
  return filepath
58
 
59
- def load_generated_images():
60
- if not os.path.exists(SAVE_DIR):
61
- return []
62
-
63
- # Load all images from the directory
64
- image_files = [os.path.join(SAVE_DIR, f) for f in os.listdir(SAVE_DIR)
65
- if f.endswith(('.png', '.jpg', '.jpeg', '.webp'))]
66
- # Sort by creation time (newest first)
67
- image_files.sort(key=lambda x: os.path.getctime(x), reverse=True)
68
- return image_files
69
-
70
- def load_predefined_images():
71
- predefined_images = [
72
- "assets/r1.webp",
73
- "assets/r2.webp",
74
- "assets/r3.webp",
75
- "assets/r4.webp",
76
- "assets/r5.webp",
77
- "assets/r6.webp",
78
- ]
79
- return predefined_images
80
-
81
  # Function to ensure "nsfw" and "[trigger]" are in the prompt
82
  def process_prompt(prompt):
83
  # Add "nsfw" prefix if not already present
@@ -112,21 +92,31 @@ def inference(
112
  seed = random.randint(0, MAX_SEED)
113
  generator = torch.Generator(device=device).manual_seed(seed)
114
 
115
- # Use joint_attention_kwargs to control LoRA scale
116
- # (FluxPipeline may use a different parameter name but attempt both)
117
  try:
118
- image = pipeline(
119
- prompt=processed_prompt,
120
- guidance_scale=guidance_scale,
121
- num_inference_steps=num_inference_steps,
122
- width=width,
123
- height=height,
124
- generator=generator,
125
- joint_attention_kwargs={"scale": lora_scale},
126
- ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
127
  except Exception as e:
128
- # If the above fails, try with cross_attention_kwargs which is more common
129
- print(f"First attempt failed with: {e}, trying alternative method...")
130
  image = pipeline(
131
  prompt=processed_prompt,
132
  guidance_scale=guidance_scale,
@@ -134,17 +124,17 @@ def inference(
134
  width=width,
135
  height=height,
136
  generator=generator,
137
- cross_attention_kwargs={"scale": lora_scale},
138
  ).images[0]
139
 
140
  # Save the generated image
141
  filepath = save_generated_image(image, processed_prompt)
142
 
143
- # Return the image, seed, and updated gallery
144
- return image, seed, processed_prompt, load_generated_images()
145
-
146
- examples = "A young couple, their bodies glistening with sweat, make love in the rain, the woman"
147
 
 
 
 
148
 
149
  # Brighter custom CSS with vibrant colors
150
  custom_css = """
@@ -187,125 +177,80 @@ button:hover {
187
  transform: translateY(-2px);
188
  box-shadow: 0 5px 15px rgba(0,0,0,0.1);
189
  }
190
- .tabs {
191
- margin-top: 20px;
192
- }
193
- .gallery {
194
- background-color: rgba(255, 255, 255, 0.5);
195
- border-radius: 10px;
196
- padding: 10px;
197
- }
198
  """
199
 
200
  with gr.Blocks(css=custom_css, analytics_enabled=False) as demo:
201
  gr.HTML('<div class="title">NSFW Detection STUDIO</div>')
202
 
203
- # Model description with the requested content
204
-
205
- with gr.Tabs(elem_classes="tabs") as tabs:
206
- with gr.Tab("Generation"):
207
- with gr.Column(elem_id="col-container"):
208
- with gr.Row():
209
- prompt = gr.Text(
210
- label="Prompt",
211
- show_label=False,
212
- max_lines=1,
213
- placeholder="Enter your prompt (nsfw and [trigger] will be added automatically)",
214
- container=False,
215
- )
216
- run_button = gr.Button("Generate", variant="primary", scale=0)
217
-
218
- result = gr.Image(label="Result", show_label=False)
219
- processed_prompt_display = gr.Textbox(label="Processed Prompt", show_label=True)
220
-
221
- with gr.Accordion("Advanced Settings", open=False):
222
- seed = gr.Slider(
223
- label="Seed",
224
- minimum=0,
225
- maximum=MAX_SEED,
226
- step=1,
227
- value=42,
228
- )
229
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
230
-
231
- with gr.Row():
232
- width = gr.Slider(
233
- label="Width",
234
- minimum=256,
235
- maximum=MAX_IMAGE_SIZE,
236
- step=32,
237
- value=1024,
238
- )
239
- height = gr.Slider(
240
- label="Height",
241
- minimum=256,
242
- maximum=MAX_IMAGE_SIZE,
243
- step=32,
244
- value=768,
245
- )
246
 
247
- with gr.Row():
248
- guidance_scale = gr.Slider(
249
- label="Guidance scale",
250
- minimum=0.0,
251
- maximum=10.0,
252
- step=0.1,
253
- value=3.5,
254
- )
255
- num_inference_steps = gr.Slider(
256
- label="Number of inference steps",
257
- minimum=1,
258
- maximum=50,
259
- step=1,
260
- value=30,
261
- )
262
- lora_scale = gr.Slider(
263
- label="LoRA scale",
264
- minimum=0.0,
265
- maximum=1.0,
266
- step=0.1,
267
- value=1.0,
268
- )
269
 
270
- gr.Examples(
271
- examples=examples,
272
- inputs=[prompt],
273
- outputs=[result, seed, processed_prompt_display],
274
- )
275
-
276
- with gr.Tab("Gallery"):
277
- gallery_header = gr.Markdown("### Your Generated Images")
278
- generated_gallery = gr.Gallery(
279
- label="Generated Images",
280
- columns=3,
281
- show_label=False,
282
- value=load_generated_images(),
283
- elem_id="generated_gallery",
284
- elem_classes="gallery",
285
- height="auto"
286
  )
287
- refresh_btn = gr.Button("🔄 Refresh Gallery", variant="primary")
288
 
289
- # Add sample gallery section at the bottom
290
- gr.Markdown("### Pierre-Auguste Renoir Style Examples")
291
- predefined_gallery = gr.Gallery(
292
- label="Sample Images",
293
- columns=3,
294
- rows=2,
295
- show_label=False,
296
- value=load_predefined_images(),
297
- elem_classes="gallery"
298
- )
 
 
 
 
 
299
 
300
- # Event handlers
301
- def refresh_gallery():
302
- return load_generated_images()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
- refresh_btn.click(
305
- fn=refresh_gallery,
306
- inputs=None,
307
- outputs=generated_gallery,
308
- )
309
 
310
  gr.on(
311
  triggers=[run_button.click, prompt.submit],
@@ -320,7 +265,7 @@ with gr.Blocks(css=custom_css, analytics_enabled=False) as demo:
320
  num_inference_steps,
321
  lora_scale,
322
  ],
323
- outputs=[result, seed, processed_prompt_display, generated_gallery],
324
  )
325
 
326
  demo.queue()
 
6
  import numpy as np
7
  import spaces
8
  import torch
9
+ from diffusers import AutoPipelineForText2Image
10
  from PIL import Image
11
 
 
 
 
 
 
 
 
 
 
12
  # Create permanent storage directory
13
  SAVE_DIR = "saved_images" # Gradio will handle the persistence
14
  if not os.path.exists(SAVE_DIR):
 
16
 
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
  repo_id = "black-forest-labs/FLUX.1-dev"
19
+ lora_id = "seawolf2357/nsfw-detection" # LoRA model
20
 
 
21
  print("Loading pipeline...")
22
+ # Use AutoPipelineForText2Image which has better compatibility with LoRA loading
23
+ pipeline = AutoPipelineForText2Image.from_pretrained(
24
+ repo_id,
25
+ torch_dtype=torch.bfloat16,
26
+ use_safetensors=True
27
+ )
28
  pipeline = pipeline.to(device)
29
 
30
+ # Try to load the LoRA with direct method (simpler approach)
31
+ print("Loading LoRA weights...")
32
+ try:
33
+ pipeline.load_lora_weights(lora_id)
34
+ print("LoRA weights loaded successfully!")
35
+ lora_loaded = True
36
+ except Exception as e:
37
+ print(f"Could not load LoRA weights using standard method: {e}")
38
+ print("Continuing without LoRA functionality.")
39
+ lora_loaded = False
40
+
41
  MAX_SEED = np.iinfo(np.int32).max
42
  MAX_IMAGE_SIZE = 1024
43
 
 
58
 
59
  return filepath
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  # Function to ensure "nsfw" and "[trigger]" are in the prompt
62
  def process_prompt(prompt):
63
  # Add "nsfw" prefix if not already present
 
92
  seed = random.randint(0, MAX_SEED)
93
  generator = torch.Generator(device=device).manual_seed(seed)
94
 
 
 
95
  try:
96
+ # Try with cross_attention_kwargs if LoRA was loaded successfully
97
+ if lora_loaded:
98
+ image = pipeline(
99
+ prompt=processed_prompt,
100
+ guidance_scale=guidance_scale,
101
+ num_inference_steps=num_inference_steps,
102
+ width=width,
103
+ height=height,
104
+ generator=generator,
105
+ cross_attention_kwargs={"scale": lora_scale}
106
+ ).images[0]
107
+ else:
108
+ # Fall back to standard generation if LoRA wasn't loaded
109
+ image = pipeline(
110
+ prompt=processed_prompt,
111
+ guidance_scale=guidance_scale,
112
+ num_inference_steps=num_inference_steps,
113
+ width=width,
114
+ height=height,
115
+ generator=generator,
116
+ ).images[0]
117
  except Exception as e:
118
+ print(f"Error during inference with cross_attention_kwargs: {e}")
119
+ # Fall back to standard generation without LoRA parameters
120
  image = pipeline(
121
  prompt=processed_prompt,
122
  guidance_scale=guidance_scale,
 
124
  width=width,
125
  height=height,
126
  generator=generator,
 
127
  ).images[0]
128
 
129
  # Save the generated image
130
  filepath = save_generated_image(image, processed_prompt)
131
 
132
+ # Return the image, seed, and processed prompt
133
+ return image, seed, processed_prompt
 
 
134
 
135
+ examples = [
136
+ "A young couple, their bodies glistening with sweat, make love in the rain, the woman"
137
+ ]
138
 
139
  # Brighter custom CSS with vibrant colors
140
  custom_css = """
 
177
  transform: translateY(-2px);
178
  box-shadow: 0 5px 15px rgba(0,0,0,0.1);
179
  }
 
 
 
 
 
 
 
 
180
  """
181
 
182
  with gr.Blocks(css=custom_css, analytics_enabled=False) as demo:
183
  gr.HTML('<div class="title">NSFW Detection STUDIO</div>')
184
 
185
+ # Main generation interface
186
+ with gr.Column(elem_id="col-container"):
187
+ with gr.Row():
188
+ prompt = gr.Text(
189
+ label="Prompt",
190
+ show_label=False,
191
+ max_lines=1,
192
+ placeholder="Enter your prompt (nsfw and [trigger] will be added automatically)",
193
+ container=False,
194
+ )
195
+ run_button = gr.Button("Generate", variant="primary", scale=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
+ result = gr.Image(label="Result", show_label=False)
198
+ processed_prompt_display = gr.Textbox(label="Processed Prompt", show_label=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
+ with gr.Accordion("Advanced Settings", open=False):
201
+ seed = gr.Slider(
202
+ label="Seed",
203
+ minimum=0,
204
+ maximum=MAX_SEED,
205
+ step=1,
206
+ value=42,
 
 
 
 
 
 
 
 
 
207
  )
208
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
209
 
210
+ with gr.Row():
211
+ width = gr.Slider(
212
+ label="Width",
213
+ minimum=256,
214
+ maximum=MAX_IMAGE_SIZE,
215
+ step=32,
216
+ value=1024,
217
+ )
218
+ height = gr.Slider(
219
+ label="Height",
220
+ minimum=256,
221
+ maximum=MAX_IMAGE_SIZE,
222
+ step=32,
223
+ value=768,
224
+ )
225
 
226
+ with gr.Row():
227
+ guidance_scale = gr.Slider(
228
+ label="Guidance scale",
229
+ minimum=0.0,
230
+ maximum=10.0,
231
+ step=0.1,
232
+ value=3.5,
233
+ )
234
+ num_inference_steps = gr.Slider(
235
+ label="Number of inference steps",
236
+ minimum=1,
237
+ maximum=50,
238
+ step=1,
239
+ value=30,
240
+ )
241
+ lora_scale = gr.Slider(
242
+ label="LoRA scale",
243
+ minimum=0.0,
244
+ maximum=1.0,
245
+ step=0.1,
246
+ value=1.0,
247
+ )
248
 
249
+ gr.Examples(
250
+ examples=examples,
251
+ inputs=[prompt],
252
+ outputs=[result, seed, processed_prompt_display],
253
+ )
254
 
255
  gr.on(
256
  triggers=[run_button.click, prompt.submit],
 
265
  num_inference_steps,
266
  lora_scale,
267
  ],
268
+ outputs=[result, seed, processed_prompt_display],
269
  )
270
 
271
  demo.queue()