Rishi Desai commited on
Commit
8308bbd
·
1 Parent(s): 632672e
Files changed (3) hide show
  1. .env +0 -0
  2. main.py +43 -0
  3. utils.py +190 -0
.env ADDED
File without changes
main.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from utils import crop_face, upscale_image
4
+
5
+ def parse_args():
6
+ parser = argparse.ArgumentParser(description='Face Enhancement Tool')
7
+ parser.add_argument('--input', type=str, required=True, help='Path to the input image')
8
+ parser.add_argument('--crop', action='store_true', help='Whether to crop the image')
9
+ parser.add_argument('--upscale', action='store_true', help='Whether to upscale the image')
10
+ parser.add_argument('--output', type=str, required=True, help='Path to save the output image')
11
+ args = parser.parse_args()
12
+
13
+ # Validate input file exists
14
+ if not os.path.exists(args.input):
15
+ parser.error(f"Input file does not exist: {args.input}")
16
+
17
+ # Validate output directory exists
18
+ output_dir = os.path.dirname(args.output)
19
+ if output_dir and not os.path.exists(output_dir):
20
+ parser.error(f"Output directory does not exist: {output_dir}")
21
+
22
+ return args
23
+
24
+ def main():
25
+ args = parse_args()
26
+ print(f"Processing image: {args.input}")
27
+ print(f"Crop enabled: {args.crop}")
28
+ print(f"Upscale enabled: {args.upscale}")
29
+ print(f"Output will be saved to: {args.output}")
30
+
31
+ face_image = args.input
32
+ if args.crop:
33
+ crop_face(args.input, "./scratch/cropped_face.png")
34
+ face_image = "./scratch/cropped_face.png"
35
+
36
+ if args.upscale:
37
+ upscale_image(face_image, "./scratch/upscaled_face.png")
38
+ face_image = "./scratch/upscaled_face.png"
39
+
40
+
41
+
42
+ if __name__ == "__main__":
43
+ main()
utils.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import sys
6
+ import cv2
7
+ import base64
8
+ import aiohttp
9
+ from fal import Client as FalClient
10
+ sys.path.append('./ComfyUI_AutoCropFaces')
11
+ from dotenv import load_dotenv
12
+ load_dotenv()
13
+ from Pytorch_Retinaface.pytorch_retinaface import Pytorch_RetinaFace
14
+ from transformers import AutoProcessor, AutoModelForCausalLM
15
+ from transformers import CLIPProcessor, CLIPModel
16
+ import gc
17
+
18
+
19
+ CACHE_DIR = '/workspace/huggingface_cache'
20
+
21
+ os.environ["HF_HOME"] = CACHE_DIR
22
+ os.makedirs(CACHE_DIR, exist_ok=True)
23
+
24
+ device = "cuda"
25
+
26
+ def clear_cuda_memory():
27
+ """Aggressively clear CUDA memory"""
28
+ gc.collect()
29
+ torch.cuda.empty_cache()
30
+ torch.cuda.synchronize()
31
+
32
+
33
+ def load_vision_models():
34
+ print("Loading CLIP and Florence models...")
35
+ # Load CLIP
36
+ clip_model = CLIPModel.from_pretrained(
37
+ "openai/clip-vit-large-patch14",
38
+ cache_dir=CACHE_DIR
39
+ ).to(device)
40
+ clip_processor = CLIPProcessor.from_pretrained(
41
+ "openai/clip-vit-large-patch14",
42
+ cache_dir=CACHE_DIR
43
+ )
44
+
45
+ # Load Florence
46
+ florence_model = AutoModelForCausalLM.from_pretrained(
47
+ "microsoft/Florence-2-large",
48
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
49
+ trust_remote_code=True,
50
+ cache_dir=CACHE_DIR
51
+ ).to(device)
52
+ florence_processor = AutoProcessor.from_pretrained(
53
+ "microsoft/Florence-2-large",
54
+ trust_remote_code=True,
55
+ cache_dir=CACHE_DIR
56
+ )
57
+
58
+ return {
59
+ 'clip_model': clip_model,
60
+ 'clip_processor': clip_processor,
61
+ 'florence_model': florence_model,
62
+ 'florence_processor': florence_processor,
63
+ }
64
+
65
+
66
+ def generate_caption(image):
67
+ vision_models = load_vision_models()
68
+
69
+ # Ensure the image is a PIL Image
70
+ if not isinstance(image, Image.Image):
71
+ image = Image.fromarray(image)
72
+
73
+ # Convert the image to RGB if it has an alpha channel
74
+ if image.mode == 'RGBA':
75
+ image = image.convert('RGB')
76
+
77
+ prompt = "<DETAILED_CAPTION>"
78
+ inputs = vision_models['florence_processor'](
79
+ text=prompt,
80
+ images=image,
81
+ return_tensors="pt"
82
+ ).to(device, torch.float16 if torch.cuda.is_available() else torch.float32)
83
+
84
+ generated_ids = vision_models['florence_model'].generate(
85
+ input_ids=inputs["input_ids"],
86
+ pixel_values=inputs["pixel_values"],
87
+ max_new_tokens=1024,
88
+ num_beams=3,
89
+ do_sample=False
90
+ )
91
+ generated_text = vision_models['florence_processor'].batch_decode(generated_ids, skip_special_tokens=True)[0]
92
+ parsed_answer = vision_models['florence_processor'].post_process_generation(
93
+ generated_text, task="<DETAILED_CAPTION>",
94
+ image_size=(image.width, image.height)
95
+ )
96
+
97
+ clear_cuda_memory()
98
+ return parsed_answer['<DETAILED_CAPTION>']
99
+
100
+
101
+ def crop_face(image_path, output_dir, output_name, scale_factor=4.0):
102
+ image = Image.open(image_path).convert("RGB")
103
+
104
+ img_raw = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
105
+ img_raw = img_raw.astype(np.float32)
106
+
107
+ rf = Pytorch_RetinaFace(
108
+ cfg='mobile0.25',
109
+ pretrained_path='./weights/mobilenet0.25_Final.pth',
110
+ confidence_threshold=0.02,
111
+ nms_threshold=0.4,
112
+ vis_thres=0.6
113
+ )
114
+
115
+ dets = rf.detect_faces(img_raw)
116
+ print("Dets: ", dets)
117
+
118
+ # Instead of asserting, handle multiple faces gracefully
119
+ if len(dets) == 0:
120
+ print("No faces detected!")
121
+ return False
122
+
123
+ # If multiple faces detected, use the one with highest confidence
124
+ if len(dets) > 1:
125
+ print(f"Warning: {len(dets)} faces detected, using the one with highest confidence")
126
+ # Assuming dets is a list of [bbox, landmark, score] and we want to sort by score
127
+ dets = sorted(dets, key=lambda x: x[2], reverse=True) # Sort by confidence score
128
+ # Just keep the highest confidence detection
129
+ dets = [dets[0]]
130
+
131
+ # Pass the scale_factor to center_and_crop_rescale for adjustable crop size
132
+ try:
133
+ # Unpack the tuple correctly - the function returns (cropped_imgs, bbox_infos)
134
+ cropped_imgs, bbox_infos = rf.center_and_crop_rescale(img_raw, dets, shift_factor=0.45, scale_factor=scale_factor)
135
+
136
+ # Check if we got any cropped images
137
+ if not cropped_imgs or len(cropped_imgs) == 0:
138
+ print("No cropped images returned")
139
+ return False
140
+
141
+ # Use the first cropped face image directly - it's not nested
142
+ img_to_save = cropped_imgs[0]
143
+
144
+ os.makedirs(output_dir, exist_ok=True)
145
+ cv2.imwrite(os.path.join(output_dir, output_name), img_to_save)
146
+ print(f"Saved: {output_name}")
147
+ return True
148
+
149
+ except Exception as e:
150
+ print(f"Error during face cropping: {e}")
151
+ return False
152
+
153
+ async def upscale_image(image_path, output_path):
154
+ """Upscale an image using fal.ai's RealESRGAN model"""
155
+ fal_client = FalClient()
156
+
157
+ # Read and encode the image
158
+ with open(image_path, "rb") as image_file:
159
+ encoded_image = base64.b64encode(image_file.read()).decode('utf-8')
160
+ data_uri = f"data:image/jpeg;base64,{encoded_image}"
161
+
162
+ try:
163
+ # Submit the upscaling request
164
+ handler = await fal_client.submit_async(
165
+ "fal-ai/real-esrgan",
166
+ arguments={
167
+ "image_url": data_uri,
168
+ "scale": 2,
169
+ "model": "RealESRGAN_x4plus",
170
+ "output_format": "png",
171
+ "face": True
172
+ },
173
+ )
174
+ result = await handler.get()
175
+
176
+ # Download and save the upscaled image
177
+ image_url = result['image_url']
178
+ async with aiohttp.ClientSession() as session:
179
+ async with session.get(image_url) as response:
180
+ if response.status == 200:
181
+ with open(output_path, 'wb') as f:
182
+ f.write(await response.read())
183
+ return True
184
+ else:
185
+ print(f"Failed to download upscaled image: {response.status}")
186
+ return False
187
+
188
+ except Exception as e:
189
+ print(f"Error during upscaling: {e}")
190
+ return False