File size: 2,548 Bytes
c2a8649
6031dc7
 
 
 
1888310
6031dc7
 
 
c2a8649
6031dc7
 
1888310
 
 
6031dc7
1888310
 
 
 
 
 
6031dc7
 
234658e
1888310
 
 
 
 
 
 
 
6031dc7
1888310
 
6031dc7
 
c2a8649
6031dc7
 
1888310
 
c2a8649
 
1888310
 
6031dc7
c2a8649
 
 
6031dc7
1888310
6031dc7
 
1888310
c2a8649
6031dc7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import gradio as gr
from diffusers import StableDiffusionPipeline
import torch

# Function to automatically switch between GPU and CPU
def load_model(base_model_id, adapter_model_id=None):
    if torch.cuda.is_available():
        device = "cuda"
        info = "Running on GPU (CUDA)"
    else:
        device = "cpu"
        info = "Running on CPU"

    # Load the base model dynamically on the correct device
    pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32)
    pipe = pipe.to(device)

    # If an adapter model is provided, load and merge the adapter model
    if adapter_model_id:
        adapter_model = StableDiffusionPipeline.from_pretrained(adapter_model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32)
        pipe.unet.load_attn_procs(adapter_model_id)  # This applies the adapter like LoRA to the model's UNet
        info += f" with Adapter Model: {adapter_model_id}"
    
    return pipe, info


if torch.cuda.is_available():
        device = "cuda"
        info = "Running on GPU (CUDA) 🔥"
    else:
        device = "cpu"
        info = "Running on CPU 🥶"

# Function for text-to-image generation with dynamic model ID and device info
def generate_image(base_model_id, adapter_model_id, prompt):
    pipe, info = load_model(base_model_id, adapter_model_id)
    image = pipe(prompt).images[0]
    return image, info

# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("## Custom Text-to-Image Generator with Adapter Support")
    gr.Markdown(f"{info}")
    with gr.Row():
        with gr.Column():
            base_model_id = gr.Textbox(label="Enter Base Model ID (e.g., CompVis/stable-diffusion-v1-4)", placeholder="Base Model ID")
            adapter_model_id = gr.Textbox(label="Enter Adapter Model ID (optional, e.g., nevreal/vMurderDrones-Lora)", placeholder="Adapter Model ID (optional)", value="")
            prompt = gr.Textbox(label="Enter your prompt", placeholder="Describe the image you want to generate")
            generate_btn = gr.Button("Generate Image")
        
        with gr.Column():
            output_image = gr.Image(label="Generated Image")
            device_info = gr.Markdown()  # To display if GPU or CPU is used and whether an adapter is applied
    
    # Link the button to the image generation function
    generate_btn.click(fn=generate_image, inputs=[base_model_id, adapter_model_id, prompt], outputs=[output_image, device_info])

# Launch the app
demo.launch()