import random import os import uuid from datetime import datetime import gradio as gr import numpy as np import spaces import torch from diffusers import AutoPipelineForText2Image from PIL import Image # Create permanent storage directory SAVE_DIR = "saved_images" # Gradio will handle the persistence if not os.path.exists(SAVE_DIR): os.makedirs(SAVE_DIR, exist_ok=True) device = "cuda" if torch.cuda.is_available() else "cpu" repo_id = "black-forest-labs/FLUX.1-dev" lora_id = "seawolf2357/nsfw-detection" # LoRA model print("Loading pipeline...") # Use AutoPipelineForText2Image which has better compatibility with LoRA loading pipeline = AutoPipelineForText2Image.from_pretrained( repo_id, torch_dtype=torch.bfloat16, use_safetensors=True ) pipeline = pipeline.to(device) # Try to load the LoRA with direct method (simpler approach) print("Loading LoRA weights...") try: pipeline.load_lora_weights(lora_id) print("LoRA weights loaded successfully!") lora_loaded = True except Exception as e: print(f"Could not load LoRA weights using standard method: {e}") print("Continuing without LoRA functionality.") lora_loaded = False MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 1024 def save_generated_image(image, prompt): # Generate unique filename with timestamp timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") unique_id = str(uuid.uuid4())[:8] filename = f"{timestamp}_{unique_id}.png" filepath = os.path.join(SAVE_DIR, filename) # Save the image image.save(filepath) # Save metadata metadata_file = os.path.join(SAVE_DIR, "metadata.txt") with open(metadata_file, "a", encoding="utf-8") as f: f.write(f"{filename}|{prompt}|{timestamp}\n") return filepath # Function to ensure "nsfw" and "[trigger]" are in the prompt def process_prompt(prompt): # Add "nsfw" prefix if not already present if not prompt.lower().startswith("nsfw "): prompt = "nsfw " + prompt # Add "[trigger]" suffix if not already present if not prompt.lower().endswith("[trigger]"): if prompt.endswith(" "): prompt = prompt + "[trigger]" else: prompt = prompt + " [trigger]" return prompt @spaces.GPU(duration=120) def inference( prompt: str, seed: int, randomize_seed: bool, width: int, height: int, guidance_scale: float, num_inference_steps: int, lora_scale: float, progress: gr.Progress = gr.Progress(track_tqdm=True), ): # Process the prompt to ensure it has the required format processed_prompt = process_prompt(prompt) if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator(device=device).manual_seed(seed) try: # Try with cross_attention_kwargs if LoRA was loaded successfully if lora_loaded: image = pipeline( prompt=processed_prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, width=width, height=height, generator=generator, cross_attention_kwargs={"scale": lora_scale} ).images[0] else: # Fall back to standard generation if LoRA wasn't loaded image = pipeline( prompt=processed_prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, width=width, height=height, generator=generator, ).images[0] except Exception as e: print(f"Error during inference with cross_attention_kwargs: {e}") # Fall back to standard generation without LoRA parameters image = pipeline( prompt=processed_prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, width=width, height=height, generator=generator, ).images[0] # Save the generated image filepath = save_generated_image(image, processed_prompt) # Return the image, seed, and processed prompt return image, seed, processed_prompt examples = [ "A young couple, their bodies glistening with sweat, make love in the rain, the woman" ] # Brighter custom CSS with vibrant colors custom_css = """ :root { --color-primary: #FF9E6C; --color-secondary: #FFD8A9; } footer { visibility: hidden; } .gradio-container { background: linear-gradient(to right, #FFF4E0, #FFEDDB); } .title { color: #E25822 !important; font-size: 2.5rem !important; font-weight: 700 !important; text-align: center; margin: 1rem 0; text-shadow: 2px 2px 4px rgba(0,0,0,0.1); } .subtitle { color: #2B3A67 !important; font-size: 1.2rem !important; text-align: center; margin-bottom: 2rem; } .model-description { background-color: rgba(255, 255, 255, 0.7); border-radius: 10px; padding: 20px; margin: 20px 0; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); border-left: 5px solid #E25822; } button.primary { background-color: #E25822 !important; } button:hover { transform: translateY(-2px); box-shadow: 0 5px 15px rgba(0,0,0,0.1); } """ with gr.Blocks(css=custom_css, analytics_enabled=False) as demo: gr.HTML('