sagar007 commited on
Commit
55ca942
·
verified ·
1 Parent(s): 16296e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +206 -88
app.py CHANGED
@@ -1,107 +1,225 @@
1
  import os
2
  import torch
3
- #import
4
  import gradio as gr
5
- from tqdm import tqdm
6
  from PIL import Image
7
- import torch.nn.functional as F
8
- from torchvision import transforms as tfms
9
- from transformers import CLIPTextModel, CLIPTokenizer, logging
10
  from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel, DiffusionPipeline
11
-
12
- torch_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
13
- if "mps" == torch_device: os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1"
14
-
15
- # Load the pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  model_path = "CompVis/stable-diffusion-v1-4"
17
- sd_pipeline = DiffusionPipeline.from_pretrained(
18
- model_path,
19
- low_cpu_mem_usage=True,
20
- torch_dtype=torch.float32
21
- ).to(torch_device)
22
-
23
- # Load textual inversions
24
- sd_pipeline.load_textual_inversion("sd-concepts-library/illustration-style")
25
- sd_pipeline.load_textual_inversion("sd-concepts-library/line-art")
26
- sd_pipeline.load_textual_inversion("sd-concepts-library/hitokomoru-style-nao")
27
- sd_pipeline.load_textual_inversion("sd-concepts-library/style-of-marc-allante")
28
- sd_pipeline.load_textual_inversion("sd-concepts-library/midjourney-style")
29
- sd_pipeline.load_textual_inversion("sd-concepts-library/hanfu-anime-style")
30
- sd_pipeline.load_textual_inversion("sd-concepts-library/birb-style")
31
-
32
- # Update style token dictionary
33
- style_token_dict = {
34
- "Illustration Style": '<illustration-style>',
35
- "Line Art":'<line-art>',
36
- "Hitokomoru Style":'<hitokomoru-style-nao>',
37
- "Marc Allante": '<Marc_Allante>',
38
- "Midjourney":'<midjourney-style>',
39
- "Hanfu Anime": '<hanfu-anime-style>',
40
- "Birb Style": '<birb-style>'
41
- }
42
-
43
-
44
- def set_timesteps(scheduler, num_inference_steps):
45
- scheduler.set_timesteps(num_inference_steps)
46
- scheduler.timesteps = scheduler.timesteps.to(torch.float32)
47
-
48
- def pil_to_latent(input_im):
49
- with torch.no_grad():
50
- latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling
51
- return 0.18215 * latent.latent_dist.sample()
52
-
53
- def latents_to_pil(latents):
54
- latents = (1 / 0.18215) * latents
55
- with torch.no_grad():
56
- image = vae.decode(latents).sample
57
- image = (image / 2 + 0.5).clamp(0, 1)
58
- image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
59
- images = (image * 255).round().astype("uint8")
60
- pil_images = [Image.fromarray(image) for image in images]
61
- return pil_images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  def generate_with_pipeline(prompt, num_inference_steps, guidance_scale, seed):
 
64
  generator = torch.Generator(device=torch_device).manual_seed(seed)
65
- image = sd_pipeline(
66
- prompt,
67
- num_inference_steps=num_inference_steps,
68
- guidance_scale=guidance_scale,
69
- generator=generator
70
- ).images[0]
71
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  def inference(text, style, inference_step, guidance_scale, seed, guidance_method, loss_scale):
74
- prompt = text + " " + style_token_dict[style]
75
-
76
- # Generate image with pipeline
 
 
 
 
 
 
 
 
 
 
77
  image_pipeline = generate_with_pipeline(prompt, inference_step, guidance_scale, seed)
78
 
79
- # For the guided image, we'll need to implement a custom pipeline or modify the existing one
80
- # This is a placeholder and would need to be implemented
81
- image_guide = image_pipeline # This should be replaced with actual guided generation
 
 
 
 
 
82
 
83
  return image_pipeline, image_guide
84
 
85
- title = "Generative with Textual Inversion"
86
- description = "A simple Gradio interface to infer Stable Diffusion and generate images with different art styles"
 
 
 
 
 
87
  examples = [
88
- ["A majestic castle on a floating island", 'Illustration Style', 20, 7.5, 42, 'Grayscale', 200],
89
- ["A cyberpunk cityscape at night", 'Midjourney', 25, 8.0, 123, 'Contrast', 300]
 
 
90
  ]
91
 
92
- demo = gr.Interface(inference,
93
- inputs = [gr.Textbox(label="Prompt", type="text"),
94
- gr.Dropdown(label="Style", choices=list(style_token_dict.keys()), value="Illustration Style"),
95
- gr.Slider(1, 50, 10, step = 1, label="Inference steps"),
96
- gr.Slider(1, 10, 7.5, step = 0.1, label="Guidance scale"),
97
- gr.Slider(0, 10000, 42, step = 1, label="Seed"),
98
- gr.Dropdown(label="Guidance method", choices=['Grayscale', 'Bright', 'Contrast',
99
- 'Symmetry', 'Saturation'], value="Grayscale"),
100
- gr.Slider(100, 10000, 200, step = 100, label="Loss scale")],
101
- outputs= [gr.Image(width=512, height=512, label="Generated art"),
102
- gr.Image(width=512, height=512, label="Generated art with guidance")],
103
- title=title,
104
- description=description,
105
- examples=examples)
106
-
107
- demo.launch()
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import torch
3
+ #import # Unnecessary import
4
  import gradio as gr
5
+ # from tqdm import tqdm # Not used directly in the simplified inference function
6
  from PIL import Image
7
+ import torch.nn.functional as F # Not used directly in the simplified inference function
8
+ from torchvision import transforms as tfms # Not used directly in the simplified inference function
9
+ # from transformers import CLIPTextModel, CLIPTokenizer, logging # Not used directly
10
  from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel, DiffusionPipeline
11
+ import warnings
12
+
13
+ # Suppress specific warnings if needed (optional)
14
+ # logging.set_verbosity_error()
15
+ warnings.filterwarnings("ignore", category=FutureWarning) # Example: Ignore FutureWarnings
16
+
17
+ # --- Device Setup ---
18
+ if torch.cuda.is_available():
19
+ torch_device = "cuda"
20
+ print("Using CUDA (GPU)")
21
+ elif torch.backends.mps.is_available():
22
+ torch_device = "mps"
23
+ os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1"
24
+ print("Using MPS (Apple Silicon GPU)")
25
+ else:
26
+ torch_device = "cpu"
27
+ print("Using CPU")
28
+
29
+ # --- Configuration ---
30
  model_path = "CompVis/stable-diffusion-v1-4"
31
+ # Use float16 for faster inference and lower memory on CUDA
32
+ use_fp16 = torch_device == "cuda"
33
+ dtype = torch.float16 if use_fp16 else torch.float32
34
+ print(f"Using dtype: {dtype}")
35
+
36
+
37
+ # --- Load the Pipeline ---
38
+ print(f"Loading Stable Diffusion pipeline from {model_path}...")
39
+ try:
40
+ sd_pipeline = DiffusionPipeline.from_pretrained(
41
+ model_path,
42
+ # revision="fp16" if use_fp16 else "main", # Use fp16 revision if available and using fp16
43
+ torch_dtype=dtype,
44
+ # low_cpu_mem_usage=True, # Useful for large models, might slightly slow down loading
45
+ ).to(torch_device)
46
+ print("Pipeline loaded successfully.")
47
+ except Exception as e:
48
+ print(f"Error loading pipeline: {e}")
49
+ print("Ensure you have enough RAM/VRAM and are authenticated if required (huggingface-cli login).")
50
+ exit() # Exit if pipeline fails to load
51
+
52
+
53
+ # --- Enable xformers (Optional Speed/Memory Optimization) ---
54
+ try:
55
+ import xformers
56
+ sd_pipeline.enable_xformers_memory_efficient_attention()
57
+ print("xFormers enabled for memory efficient attention.")
58
+ except ImportError:
59
+ print("xFormers not found. For potential speedup, install with: pip install xformers")
60
+
61
+
62
+ # --- Load Textual Inversions ---
63
+ print("Loading textual inversions...")
64
+ try:
65
+ # Define paths or URLs - using Hugging Face Hub concepts library paths
66
+ inversions = {
67
+ "illustration-style": "sd-concepts-library/illustration-style",
68
+ "line-art": "sd-concepts-library/line-art",
69
+ "hitokomoru-style-nao": "sd-concepts-library/hitokomoru-style-nao",
70
+ "style-of-marc-allante": "sd-concepts-library/style-of-marc-allante", # Placeholder name likely needs adjustment
71
+ "midjourney-style": "sd-concepts-library/midjourney-style",
72
+ "hanfu-anime-style": "sd-concepts-library/hanfu-anime-style",
73
+ "birb-style": "sd-concepts-library/birb-style",
74
+ }
75
+ for name, path in inversions.items():
76
+ print(f" Loading: {name} ({path})")
77
+ sd_pipeline.load_textual_inversion(path) # Assumes weights are downloaded or accessible
78
+
79
+ print("Textual inversions loaded.")
80
+
81
+ # Update style token dictionary based on loaded concepts
82
+ # Ensure the placeholder names match the actual token learned during TI training
83
+ style_token_dict = {
84
+ "Illustration Style": '<illustration-style>',
85
+ "Line Art":'<line-art>',
86
+ "Hitokomoru Style":'<hitokomoru-style-nao>',
87
+ "Marc Allante": '<style-of-marc-allante>', # Corrected placeholder based on repo name convention
88
+ "Midjourney":'<midjourney-style>',
89
+ "Hanfu Anime": '<hanfu-anime-style>',
90
+ "Birb Style": '<birb-style>'
91
+ }
92
+
93
+ except Exception as e:
94
+ print(f"Error loading textual inversions: {e}")
95
+ print("Please ensure the concepts exist and paths are correct.")
96
+ # Continue without textual inversions or exit, depending on desired behavior
97
+ style_token_dict = {"Default": ""} # Fallback
98
+
99
+
100
+ # --- Helper functions (Keep for potential future use with custom guidance) ---
101
+ # Note: These are not used in the current simplified 'generate_with_pipeline' approach
102
+
103
+ # def set_timesteps(scheduler, num_inference_steps):
104
+ # scheduler.set_timesteps(num_inference_steps)
105
+ # scheduler.timesteps = scheduler.timesteps.to(torch.float32)
106
+
107
+ # def pil_to_latent(vae, input_im):
108
+ # # VAE is part of sd_pipeline.vae
109
+ # transform = tfms.Compose([
110
+ # tfms.ToTensor(),
111
+ # tfms.Normalize([0.5], [0.5]) # Important: Normalize to [-1, 1]
112
+ # ])
113
+ # with torch.no_grad():
114
+ # # Ensure image is RGB
115
+ # if input_im.mode != "RGB":
116
+ # input_im = input_im.convert("RGB")
117
+ # image = transform(input_im).unsqueeze(0).to(torch_device, dtype=vae.dtype)
118
+ # latent = vae.encode(image) # Note scaling
119
+ # return 0.18215 * latent.latent_dist.sample() # Magic number from SD
120
+
121
+ # def latents_to_pil(vae, latents):
122
+ # # VAE is part of sd_pipeline.vae
123
+ # latents = (1 / 0.18215) * latents # Reverse magic number
124
+ # with torch.no_grad():
125
+ # image = vae.decode(latents).sample
126
+ # image = (image / 2 + 0.5).clamp(0, 1) # Denormalize
127
+ # image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
128
+ # images = (image * 255).round().astype("uint8")
129
+ # pil_images = [Image.fromarray(image) for image in images]
130
+ # return pil_images
131
+
132
+ # --- Generation Functions ---
133
 
134
  def generate_with_pipeline(prompt, num_inference_steps, guidance_scale, seed):
135
+ """Generates an image using the main Diffusers pipeline."""
136
  generator = torch.Generator(device=torch_device).manual_seed(seed)
137
+ try:
138
+ # Offload VAE if low VRAM causes issues (will slow down inference)
139
+ # sd_pipeline.enable_vae_slicing() # Alternative memory saving
140
+ # sd_pipeline.enable_model_cpu_offload() # If really low on VRAM
141
+
142
+ print(f"\nGenerating with: Prompt='{prompt}', Steps={num_inference_steps}, Scale={guidance_scale}, Seed={seed}")
143
+ with torch.autocast(torch_device, enabled=use_fp16): # Use autocast for fp16
144
+ image = sd_pipeline(
145
+ prompt,
146
+ num_inference_steps=num_inference_steps,
147
+ guidance_scale=guidance_scale,
148
+ generator=generator
149
+ ).images[0]
150
+ print("Generation complete.")
151
+ return image
152
+ except Exception as e:
153
+ print(f"Error during generation: {e}")
154
+ # Return a placeholder or raise error
155
+ return Image.new('RGB', (512, 512), color = 'grey') # Placeholder grey image
156
+
157
+
158
+ # --- Main Inference Function for Gradio ---
159
 
160
  def inference(text, style, inference_step, guidance_scale, seed, guidance_method, loss_scale):
161
+ """
162
+ Gradio interface function. Currently generates one image using the standard pipeline.
163
+ The guidance method and loss scale parameters are placeholders for future implementation.
164
+ """
165
+ if style in style_token_dict:
166
+ style_token = style_token_dict[style]
167
+ # Handle potential empty token for 'Default' or errors
168
+ prompt = f"{text} {style_token}".strip()
169
+ else:
170
+ print(f"Warning: Style '{style}' not found in token dictionary. Using prompt without style token.")
171
+ prompt = text
172
+
173
+ # Generate image with the standard pipeline
174
  image_pipeline = generate_with_pipeline(prompt, inference_step, guidance_scale, seed)
175
 
176
+ # --- Placeholder for Guided Image Generation ---
177
+ # The code for custom guidance (Grayscale, Contrast, etc.) would go here.
178
+ # This typically involves a custom diffusion loop, calculating losses based on the
179
+ # guidance method, and modifying the latents at each step. It's significantly
180
+ # more complex and computationally intensive than the standard pipeline call.
181
+ # For now, we just return the same image as a placeholder for the second output.
182
+ print(f"Guidance method '{guidance_method}' and Loss Scale '{loss_scale}' are currently placeholders.")
183
+ image_guide = image_pipeline # Placeholder
184
 
185
  return image_pipeline, image_guide
186
 
187
+ # --- Gradio Interface Definition ---
188
+ title = "Generative Art with Textual Inversion Styles"
189
+ description = """
190
+ A Gradio interface to generate images using Stable Diffusion v1.4 with Textual Inversion styles.
191
+ Select a style, enter a prompt, and adjust generation parameters.
192
+ *Note:* The 'Generated art with guidance' output currently shows the same image as the first. Custom guidance logic (Grayscale, Contrast, etc.) is not yet implemented. Using lower inference steps speeds up generation. Enable the queue if timeouts occur.
193
+ """
194
  examples = [
195
+ ["A majestic castle on a floating island, detailed, fantasy art", 'Illustration Style', 25, 7.5, 1001, 'Grayscale', 200],
196
+ ["A cyberpunk cityscape at night, neon lights, rain, cinematic", 'Midjourney', 30, 8.0, 42, 'Contrast', 300],
197
+ ["Portrait of a woman in traditional chinese dress, anime style", "Hanfu Anime", 30, 7.0, 1234, 'Saturation', 250],
198
+ ["Cute cartoon bird character sitting on a branch", "Birb Style", 20, 7.5, 5678, 'Symmetry', 150]
199
  ]
200
 
201
+ demo = gr.Interface(
202
+ inference,
203
+ inputs = [
204
+ gr.Textbox(label="Prompt", info="Describe the image you want to create.", type="text"),
205
+ gr.Dropdown(label="Style", info="Select an art style (requires loaded textual inversion).", choices=list(style_token_dict.keys()), value="Illustration Style"),
206
+ gr.Slider(10, 50, 25, step = 1, label="Inference steps", info="More steps can improve detail but take longer."), # Default 25 steps
207
+ gr.Slider(1.0, 15.0, 7.5, step = 0.1, label="Guidance scale (CFG)", info="How strongly the prompt guides the image. Higher values follow prompt more closely."),
208
+ gr.Slider(0, 100000, 42, step = 1, label="Seed", info="Same seed + prompt = same image. 0 for random."),
209
+ gr.Dropdown(label="Guidance method (Placeholder)", info="Custom guidance method (Not implemented yet).", choices=['Grayscale', 'Bright', 'Contrast', 'Symmetry', 'Saturation'], value="Grayscale"),
210
+ gr.Slider(100, 10000, 200, step = 100, label="Loss scale (Placeholder)", info="Strength of custom guidance (Not implemented yet).")
211
+ ],
212
+ outputs= [
213
+ gr.Image(width=512, height=512, label="Generated Art (Standard Pipeline)"),
214
+ gr.Image(width=512, height=512, label="Generated Art with Guidance (Placeholder)")
215
+ ],
216
+ title=title,
217
+ description=description,
218
+ examples=examples,
219
+ allow_flagging='never' # Disable flagging if not needed
220
+ )
221
+
222
+ # --- Launch the Interface with Queue Enabled ---
223
+ # Use .queue() to handle long inference times and prevent timeouts
224
+ print("Launching Gradio interface...")
225
+ demo.queue().launch(share=False) # Set share=True to get a public link (use with caution)