File size: 16,400 Bytes
c345460
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db62cdf
 
 
 
 
c345460
db62cdf
c345460
 
 
 
057450b
c345460
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f66cd6
c345460
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
057450b
c345460
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
import gradio as gr
import torch
from diffusers import StableDiffusionPipeline, StableDiffusionXLImg2ImgPipeline, AutoPipelineForText2Image
from diffusers.utils import load_image
from PIL import Image
import time
import random
import os
import gc # Garbage collector
import logging

# --- Configuration ---

# Setup basic logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Ensure CPU is used
DEVICE = "cpu"
TORCH_DTYPE = torch.float32 # float16/bfloat16 not practical on CPU

# Model definitions
# We need to know the base model for LoRAs and compatible IP-Adapters
MODEL_CONFIG = {
    "BlaireSilver13/youtube-thumbnail": {
        "repo_id": "BlaireSilver13/youtube-thumbnail",
        "is_lora": True,
        "lora_filename": "FLUX-youtube-thumbnails.safetensors", 
        "base_model": "black-forest-labs/FLUX.1-dev",
        "pipeline_class": AutoPipelineForText2Image, 
        "ip_adapter_repo": "h94/IP-Adapter", 
        "ip_adapter_weights": "ip-adapter_sd15.bin",
        "ip_adapter_image_encoder": "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" 
    },
    "itzzdeep/youtube-thumbnails-sdxl-lora": {
        "repo_id": "itzzdeep/youtube-thumbnails-sdxl-lora",
        "is_lora": True,
        "lora_filename": "youtube-thumbnails-sdxl-lora.safetensors", 
        "base_model": "stabilityai/stable-diffusion-xl-base-1.0",
        "pipeline_class": AutoPipelineForText2Image, # Handles SDXL loading better
        "ip_adapter_repo": "h94/IP-Adapter", # SDXL IP-Adapter repo
        "ip_adapter_weights": "ip-adapter-plus_sdxl_vit-h.bin", # SDXL weights
        "ip_adapter_image_encoder": "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" # Usually the same encoder repo
    },
    "justmalhar/flux-thumbnails-v3": {
        "repo_id": "justmalhar/flux-thumbnails-v3",
        "is_lora": False, # Assuming this is a full SD 1.5 fine-tune based on common practice
        "base_model": None,
        "pipeline_class": StableDiffusionPipeline,
        "ip_adapter_repo": "h94/IP-Adapter",
        "ip_adapter_weights": "ip-adapter_sd15.bin",
        "ip_adapter_image_encoder": "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
    },
    "saq1b/mrbeast-thumbnail-style": {
        "repo_id": "saq1b/mrbeast-thumbnail-style",
        "is_lora": True, # This is typically a LoRA
        "lora_filename": None, # Auto-detect or specify e.g., "pytorch_lora_weights.safetensors"
        "base_model": "runwayml/stable-diffusion-v1-5", # Common base for SD 1.5 LoRAs
        "pipeline_class": StableDiffusionPipeline,
        "ip_adapter_repo": "h94/IP-Adapter",
        "ip_adapter_weights": "ip-adapter_sd15.bin",
        "ip_adapter_image_encoder": "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
    }
}

AVAILABLE_MODELS = list(MODEL_CONFIG.keys())

# Global variable to potentially hold the pipeline to avoid reloading *if memory allows*
# NOTE: On restricted CPU environments, it's SAFER to load inside the function.
# Set to None initially. Let's load dynamically inside the function for safety.
# current_pipeline = None
# current_model_key = None

# --- Helper Functions ---

def cleanup_memory():
    """Attempts to free GPU memory (less relevant for CPU but good practice)."""
    logger.info("Attempting to clean up memory...")
    try:
        # If a pipeline exists globally (if we change strategy), unload it
        # global current_pipeline, current_model_key
        # if current_pipeline is not None:
        #     logger.info(f"Unloading model {current_model_key} from memory.")
        #     del current_pipeline
        #     current_pipeline = None
        #     current_model_key = None

        gc.collect()
        if torch.cuda.is_available(): # Only run cuda cache empty if cuda is present
             torch.cuda.empty_cache()
        logger.info("Memory cleanup potentially done.")
    except Exception as e:
        logger.error(f"Error during memory cleanup: {e}")


# --- Main Generation Function ---

def generate_thumbnail(
    model_key: str,
    prompt: str,
    negative_prompt: str,
    reference_image_pil: Image.Image | None, # Gradio provides PIL image
    num_inference_steps: int,
    guidance_scale: float,
    seed: int,
    ip_adapter_scale: float,
    progress=gr.Progress()
):
    """Generates an image using the selected model, IP-Adapter, and settings."""
    start_time = time.time()
    debug_log = f"--- Generation Log ({time.strftime('%Y-%m-%d %H:%M:%S')}) ---\n"
    debug_log += f"Selected Model Key: {model_key}\n"
    debug_log += f"Prompt: {prompt}\n"
    debug_log += f"Negative Prompt: {negative_prompt}\n"
    debug_log += f"Steps: {num_inference_steps}, CFG Scale: {guidance_scale}\n"
    debug_log += f"Seed: {seed}\n"
    debug_log += f"Reference Image Provided: {'Yes' if reference_image_pil else 'No'}\n"
    debug_log += f"IP Adapter Scale: {ip_adapter_scale}\n"
    debug_log += f"Device: {DEVICE}, Dtype: {TORCH_DTYPE}\n\n"

    pipeline = None # Ensure pipeline is defined in this scope

    try:
        if not model_key:
            raise ValueError("No model selected.")

        config = MODEL_CONFIG[model_key]
        repo_id = config["repo_id"]
        is_lora = config["is_lora"]
        base_model = config["base_model"]
        pipeline_class = config["pipeline_class"]
        ip_adapter_repo = config["ip_adapter_repo"]
        ip_adapter_weights = config["ip_adapter_weights"]
        # ip_adapter_image_encoder = config["ip_adapter_image_encoder"] # Encoder loaded via IP-Adapter itself usually

        # --- Model Loading ---
        load_start_time = time.time()
        debug_log += f"[{time.time() - start_time:.2f}s] Cleaning up memory before loading...\n"
        progress(0.1, desc="Cleaning up memory...")
        cleanup_memory() # Attempt cleanup before loading new model

        debug_log += f"[{time.time() - start_time:.2f}s] Loading model: {'LoRA ' + repo_id if is_lora else repo_id}...\n"
        progress(0.2, desc=f"Loading {'LoRA ' + repo_id if is_lora else repo_id}...")

        model_load_id = base_model if is_lora else repo_id
        debug_log += f"[{time.time() - start_time:.2f}s] Base/Model ID for pipeline: {model_load_id}\n"

        pipeline = pipeline_class.from_pretrained(
            model_load_id,
            torch_dtype=TORCH_DTYPE,
            # Add any specific args needed for the pipeline class if necessary
            # safety_checker=None, # Disable safety checker if needed/causes issues on CPU
            # requires_safety_checker=False,
        )
        pipeline.to(DEVICE)
        debug_log += f"[{time.time() - start_time:.2f}s] Base pipeline loaded onto {DEVICE}.\n"

        if is_lora:
            lora_load_start = time.time()
            debug_log += f"[{time.time() - start_time:.2f}s] Loading LoRA weights from {repo_id}...\n"
            progress(0.4, desc=f"Loading LoRA {repo_id}...")
            try:
                lora_filename = config.get("lora_filename") # Get specific filename if provided
                if lora_filename:
                    debug_log += f"[{time.time() - start_time:.2f}s] Using specified LoRA filename: {lora_filename}\n"
                    pipeline.load_lora_weights(repo_id, weight_name=lora_filename, torch_dtype=TORCH_DTYPE)
                else:
                    # Let diffusers try to auto-detect standard names like .safetensors or .bin
                    debug_log += f"[{time.time() - start_time:.2f}s] Attempting auto-detection of LoRA filename.\n"
                    pipeline.load_lora_weights(repo_id, torch_dtype=TORCH_DTYPE)

                # When using LoRA with diffusers >= 0.22, explicitly fuse *or* set adapters
                # pipeline.fuse_lora() # Fuse creates a new pipeline state (might use more memory)
                pipeline.set_adapters(pipeline.get_active_adapters(), adapter_weights=1.0) # Recommended for flexibility
                debug_log += f"[{time.time() - start_time:.2f}s] LoRA weights loaded and adapters set in {time.time() - lora_load_start:.2f}s.\n"

            except Exception as e:
                 debug_log += f"[{time.time() - start_time:.2f}s] ERROR loading LoRA: {e}. Check LoRA repo structure/filename.\n"
                 # Decide whether to continue without LoRA or raise error
                 raise ValueError(f"Failed to load LoRA weights for {repo_id}: {e}")

        # --- IP Adapter Loading ---
        if reference_image_pil and ip_adapter_scale > 0:
            ip_load_start = time.time()
            debug_log += f"[{time.time() - start_time:.2f}s] Loading IP-Adapter: {ip_adapter_repo} ({ip_adapter_weights})...\n"
            progress(0.6, desc="Loading IP-Adapter...")
            try:
                # Ensure the pipeline has the load_ip_adapter method
                if not hasattr(pipeline, "load_ip_adapter"):
                     raise AttributeError("The current pipeline class does not support load_ip_adapter. Check diffusers version or pipeline type.")

                pipeline.load_ip_adapter(
                    ip_adapter_repo,
                    subfolder="models", # Common subfolder, adjust if needed
                    weight_name=ip_adapter_weights,
                    # image_encoder_folder=ip_adapter_image_encoder # Let diffusers handle encoder loading usually
                )
                pipeline.set_ip_adapter_scale(ip_adapter_scale)
                debug_log += f"[{time.time() - start_time:.2f}s] IP-Adapter loaded and scale set ({ip_adapter_scale}) in {time.time() - ip_load_start:.2f}s.\n"
                # Prepare the image for IP-Adapter (often just needs to be a PIL image)
                ip_image = reference_image_pil.convert("RGB")
                debug_log += f"[{time.time() - start_time:.2f}s] Reference image prepared for IP-Adapter.\n"

            except Exception as e:
                debug_log += f"[{time.time() - start_time:.2f}s] WARNING: Failed to load IP-Adapter: {e}. Proceeding without image guidance.\n"
                ip_image = None
                ip_adapter_scale = 0 # Effectively disable it if loading failed
                pipeline.set_ip_adapter_scale(0) # Ensure scale is 0
        else:
            ip_image = None
            if hasattr(pipeline, "set_ip_adapter_scale"):
                pipeline.set_ip_adapter_scale(0) # Ensure scale is 0 if no image/scale=0
            debug_log += f"[{time.time() - start_time:.2f}s] No reference image provided or IP Adapter scale is 0. Skipping IP-Adapter loading.\n"


        debug_log += f"[{time.time() - start_time:.2f}s] Total Model & IP-Adapter Loading time: {time.time() - load_start_time:.2f}s\n"


        # --- Generation ---
        gen_start_time = time.time()
        debug_log += f"[{time.time() - start_time:.2f}s] Starting generation...\n"
        progress(0.7, desc="Generating image...")

        # Handle seed
        if seed == -1:
            seed = random.randint(0, 2**32 - 1)
            debug_log += f"[{time.time() - start_time:.2f}s] Using random seed: {seed}\n"
        generator = torch.Generator(device=DEVICE).manual_seed(seed)

        # Prepare arguments for pipeline call
        pipeline_args = {
            "prompt": prompt,
            "negative_prompt": negative_prompt,
            "num_inference_steps": num_inference_steps,
            "guidance_scale": guidance_scale,
            "generator": generator,
        }

        # Add IP-Adapter image if it's loaded and ready
        if ip_image is not None and hasattr(pipeline, "set_ip_adapter_scale") and ip_adapter_scale > 0:
             pipeline_args["ip_adapter_image"] = ip_image
             # Scale was set earlier with set_ip_adapter_scale
             debug_log += f"[{time.time() - start_time:.2f}s] Passing reference image to pipeline with IP scale {ip_adapter_scale}.\n"
        else:
             debug_log += f"[{time.time() - start_time:.2f}s] Not passing reference image to pipeline.\n"


        # Run inference
        with torch.inference_mode(): # More modern than no_grad for inference
            output_image = pipeline(**pipeline_args).images[0]

        gen_end_time = time.time()
        debug_log += f"[{time.time() - start_time:.2f}s] Generation finished in {gen_end_time - gen_start_time:.2f}s.\n"

        # --- Cleanup ---
        debug_log += f"[{time.time() - start_time:.2f}s] Unloading model from memory (CPU strategy)...\n"
        progress(0.95, desc="Cleaning up...")
        del pipeline # Explicitly delete pipeline
        cleanup_memory() # Call garbage collection

        total_time = time.time() - start_time
        debug_log += f"\n--- Total time: {total_time:.2f} seconds ---\n"

        return output_image, debug_log

    except Exception as e:
        logger.exception(f"Error during generation for model {model_key}") # Log full traceback
        error_time = time.time() - start_time
        debug_log += f"\n\n!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n"
        debug_log += f"ERROR occurred after {error_time:.2f}s:\n{e}\n"
        debug_log += f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n"
        # Try cleanup even on error
        if 'pipeline' in locals() and pipeline is not None:
            del pipeline
        cleanup_memory()
        # Return None for image, and the log containing the error
        return None, debug_log


# --- Gradio Interface ---

css = """
#warning {
    background-color: #FFCCCB; /* Light red */
    padding: 10px;
    border-radius: 5px;
    text-align: center;
    font-weight: bold;
}
#debug_log_area textarea {
    font-family: monospace;
    font-size: 10px; /* Smaller font for logs */
    white-space: pre-wrap; /* Wrap long lines */
    word-wrap: break-word; /* Break words if necessary */
}
"""

with gr.Blocks(css=css) as demo:
    gr.Markdown("# YouTube Thumbnail Generator with IP-Adapter")
    gr.Markdown(
        "Select a thumbnail model, provide a text prompt, and optionally upload a reference image "
        "to guide the generation using IP-Adapter."
    )
    gr.HTML("<div id='warning'>⚠️ Warning: Inference on CPU is VERY SLOW (minutes per image, especially SDXL models). Please be patient.</div>")

    with gr.Row():
        with gr.Column(scale=1):
            model_dropdown = gr.Dropdown(
                label="Select Thumbnail Model",
                choices=AVAILABLE_MODELS,
                value=AVAILABLE_MODELS[0] if AVAILABLE_MODELS else None,
            )
            prompt_input = gr.Textbox(label="Prompt", lines=3, placeholder="e.g., Epic landscape, dramatic lighting, YouTube thumbnail style")
            negative_prompt_input = gr.Textbox(label="Negative Prompt", lines=2, placeholder="e.g., blurry, low quality, text, signature, watermark")
            reference_image_input = gr.Image(label="Reference Image (for IP-Adapter)", type="pil", sources=["upload"])

            with gr.Accordion("Advanced Settings", open=False):
                steps_slider = gr.Slider(label="Inference Steps", minimum=10, maximum=100, value=30, step=1)
                cfg_slider = gr.Slider(label="Guidance Scale (CFG)", minimum=1.0, maximum=20.0, value=7.0, step=0.5)
                ip_adapter_scale_slider = gr.Slider(label="IP-Adapter Scale", minimum=0.0, maximum=1.5, value=0.6, step=0.05,
                                                    info="Strength of the reference image influence (0 = disabled).")
                seed_input = gr.Number(label="Seed", value=-1, precision=0, info="-1 for random seed")

            generate_button = gr.Button("Generate Thumbnail", variant="primary")

        with gr.Column(scale=1):
            output_image = gr.Image(label="Generated Thumbnail", type="pil")
            debug_output = gr.Textbox(label="Debug Log", lines=20, interactive=False, elem_id="debug_log_area")

    generate_button.click(
        fn=generate_thumbnail,
        inputs=[
            model_dropdown,
            prompt_input,
            negative_prompt_input,
            reference_image_input,
            steps_slider,
            cfg_slider,
            seed_input,
            ip_adapter_scale_slider
        ],
        outputs=[output_image, debug_output]
    )

# --- Launch ---
if __name__ == "__main__":
    logger.info("Starting Gradio App...")
    # Queueing is important for handling multiple users on Spaces, even if slow
    demo.queue().launch(debug=True) # debug=True provides Gradio debug info in console