nevreal commited on
Commit
78f6a44
·
verified ·
1 Parent(s): 9af81fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -35
app.py CHANGED
@@ -1,60 +1,74 @@
1
  import gradio as gr
2
- from diffusers import StableDiffusionPipeline
3
  import torch
4
 
5
  # Function to automatically switch between GPU and CPU
6
- def load_model(base_model_id, adapter_model_id=None):
7
- if torch.cuda.is_available():
8
- device = "cuda"
9
- info = "Running on GPU (CUDA) 🔥"
10
- else:
11
- device = "cpu"
12
- info = "Running on CPU 🥶"
13
-
14
- # Load the base model dynamically on the correct device
15
- pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32)
16
- pipe = pipe.to(device)
17
-
18
- # If an adapter model is provided, load and merge the adapter model
19
- if adapter_model_id:
20
- adapter_model = StableDiffusionPipeline.from_pretrained(adapter_model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32)
21
- pipe.unet.load_attn_procs(adapter_model_id) # This applies the adapter like LoRA to the model's UNet
22
- info += f" with Adapter Model: {adapter_model_id}"
23
 
24
- return pipe, info
 
 
 
 
 
25
 
26
- # Function for text-to-image generation with dynamic model ID and device info
 
 
 
 
 
 
 
 
 
 
27
  def generate_image(base_model_id, adapter_model_id, prompt):
28
  pipe, info = load_model(base_model_id, adapter_model_id)
29
- image = pipe(prompt).images[0]
30
- return image, info
 
31
 
32
- # Check device (GPU/CPU) once at the start and show it in the UI
33
- if torch.cuda.is_available():
34
- device = "cuda"
35
- info = "Running on GPU (CUDA) 🔥"
36
- else:
37
- device = "cpu"
38
- info = "Running on CPU 🥶"
39
 
40
  # Create the Gradio interface
41
  with gr.Blocks() as demo:
42
  gr.Markdown("## Custom Text-to-Image Generator with Adapter Support")
43
- gr.Markdown(f"**{info}**") # Display GPU/CPU information in the UI
44
 
45
  with gr.Row():
46
  with gr.Column():
47
- base_model_id = gr.Textbox(label="Enter Base Model ID (e.g., CompVis/stable-diffusion-v1-4)", placeholder="Base Model ID")
48
- adapter_model_id = gr.Textbox(label="Enter Adapter Model ID (optional, e.g., nevreal/vMurderDrones-Lora)", placeholder="Adapter Model ID (optional)", value="")
49
- prompt = gr.Textbox(label="Enter your prompt", placeholder="Describe the image you want to generate")
 
 
 
 
 
 
 
 
 
 
50
  generate_btn = gr.Button("Generate Image")
51
 
52
  with gr.Column():
53
  output_image = gr.Image(label="Generated Image")
54
- device_info = gr.Markdown() # To display if GPU or CPU is used and whether an adapter is applied
55
 
56
  # Link the button to the image generation function
57
- generate_btn.click(fn=generate_image, inputs=[base_model_id, adapter_model_id, prompt], outputs=[output_image, device_info])
 
 
 
 
58
 
59
  # Launch the app
60
  demo.launch()
 
1
  import gradio as gr
2
+ from diffusers import StableDiffusionPipeline, DiffusionPipeline
3
  import torch
4
 
5
  # Function to automatically switch between GPU and CPU
6
+ def load_model(base_model_id, adapter_model_id):
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+ info = f"Running on {'GPU (CUDA) 🔥' if device == 'cuda' else 'CPU 🥶'}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ try:
11
+ # Load the base model dynamically on the correct device
12
+ pipe = StableDiffusionPipeline.from_pretrained(
13
+ base_model_id,
14
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
15
+ ).to(device)
16
 
17
+ # If an adapter model is provided, load and merge the adapter model
18
+ if adapter_model_id:
19
+ adapter_pipe = DiffusionPipeline.from_pretrained(adapter_model_id)
20
+ adapter_pipe.load_lora_weights(base_model_id)
21
+ pipe = pipe.to(device)
22
+
23
+ return pipe, info
24
+ except Exception as e:
25
+ return None, f"Error loading model: {str(e)}"
26
+
27
+ # Function for text-to-image generation
28
  def generate_image(base_model_id, adapter_model_id, prompt):
29
  pipe, info = load_model(base_model_id, adapter_model_id)
30
+
31
+ if pipe is None:
32
+ return None, info
33
 
34
+ # Generate image based on the prompt
35
+ try:
36
+ image = pipe(prompt).images[0]
37
+ return image, info
38
+ except Exception as e:
39
+ return None, f"Error generating image: {str(e)}"
 
40
 
41
  # Create the Gradio interface
42
  with gr.Blocks() as demo:
43
  gr.Markdown("## Custom Text-to-Image Generator with Adapter Support")
 
44
 
45
  with gr.Row():
46
  with gr.Column():
47
+ base_model_id = gr.Textbox(
48
+ label="Enter Base Model ID (e.g., CompVis/stable-diffusion-v1-4)",
49
+ placeholder="Base Model ID"
50
+ )
51
+ adapter_model_id = gr.Textbox(
52
+ label="Enter Adapter Model ID (optional, e.g., nevreal/vMurderDrones-Lora)",
53
+ placeholder="Adapter Model ID (optional)",
54
+ value=""
55
+ )
56
+ prompt = gr.Textbox(
57
+ label="Enter your prompt",
58
+ placeholder="Describe the image you want to generate"
59
+ )
60
  generate_btn = gr.Button("Generate Image")
61
 
62
  with gr.Column():
63
  output_image = gr.Image(label="Generated Image")
64
+ device_info = gr.Markdown() # To display device info and any error messages
65
 
66
  # Link the button to the image generation function
67
+ generate_btn.click(
68
+ fn=generate_image,
69
+ inputs=[base_model_id, adapter_model_id, prompt],
70
+ outputs=[output_image, device_info]
71
+ )
72
 
73
  # Launch the app
74
  demo.launch()