gaur3009 commited on
Commit
1e3319b
·
verified ·
1 Parent(s): 38469c3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -0
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
3
+ from PIL import Image, ImageFilter
4
+ import numpy as np
5
+ import gradio as gr
6
+ import cv2
7
+
8
+ # Load pre-trained Stable Diffusion model (frozen part)
9
+ model_id = "runwayml/stable-diffusion-v1-5"
10
+ controlnet_id = "lllyasviel/control_v11p_sd15_canny" # ControlNet for edge detection-based control
11
+
12
+ # Load ControlNet model (trainable part)
13
+ controlnet = ControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float16)
14
+
15
+ # Load Stable Diffusion pipeline with ControlNet
16
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
17
+ model_id, controlnet=controlnet, torch_dtype=torch.float16
18
+ )
19
+
20
+ # Use an efficient scheduler
21
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
22
+
23
+ # Move pipeline to GPU
24
+ device = "cuda" if torch.cuda.is_available() else "cpu"
25
+ pipe.to(device)
26
+
27
+ # Function to generate control image (edge detection using Canny filter)
28
+ def generate_control_image(input_image_path):
29
+ image = cv2.imread(input_image_path, cv2.IMREAD_GRAYSCALE)
30
+ edges = cv2.Canny(image, 100, 200) # Apply Canny edge detection
31
+ control_image = Image.fromarray(edges).convert("L")
32
+ control_image = control_image.resize((512, 512)) # Resize to match model requirements
33
+ control_image.save("control_image.jpg")
34
+ return "control_image.jpg"
35
+
36
+ # Function to apply color change
37
+ def apply_color_change(input_image, prompt):
38
+ # Save input image temporarily
39
+ input_image_path = "input_image.jpg"
40
+ input_image.save(input_image_path)
41
+
42
+ # Generate control image (edges)
43
+ control_image_path = generate_control_image(input_image_path)
44
+
45
+ # Load processed input and control images
46
+ input_image = Image.open(input_image_path).convert("RGB").resize((512, 512))
47
+ control_image = Image.open(control_image_path).convert("L")
48
+
49
+ # Generate the new image using the pipeline
50
+ generator = torch.manual_seed(42) # For reproducibility
51
+ output_image = pipe(
52
+ prompt=prompt,
53
+ image=input_image,
54
+ control_image=control_image,
55
+ generator=generator,
56
+ num_inference_steps=30
57
+ ).images[0]
58
+
59
+ output_image.save("output_color_changed.png")
60
+ return "output_color_changed.png"
61
+
62
+ # Gradio interface
63
+ def gradio_interface(input_image, prompt):
64
+ output_image_path = apply_color_change(input_image, prompt)
65
+ return output_image_path
66
+
67
+ # Launch the Gradio interface with drag and drop
68
+ interface = gr.Interface(
69
+ fn=gradio_interface,
70
+ inputs=[
71
+ gr.Image(type="pil", label="Upload your image"), # Drag and drop feature
72
+ gr.Textbox(label="Enter prompt", placeholder="e.g. A hoodie with blue and white design"),
73
+ ],
74
+ outputs=gr.Image(label="Color Changed Output"),
75
+ title="AI-Powered Clothing Color Changer",
76
+ description="Upload an image of clothing, enter a prompt, and get a redesigned color version.",
77
+ )
78
+
79
+ interface.launch()