Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,238 Bytes
8308bbd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
import os
import torch
import numpy as np
from PIL import Image
import sys
import cv2
import base64
import aiohttp
from fal import Client as FalClient
sys.path.append('./ComfyUI_AutoCropFaces')
from dotenv import load_dotenv
load_dotenv()
from Pytorch_Retinaface.pytorch_retinaface import Pytorch_RetinaFace
from transformers import AutoProcessor, AutoModelForCausalLM
from transformers import CLIPProcessor, CLIPModel
import gc
CACHE_DIR = '/workspace/huggingface_cache'
os.environ["HF_HOME"] = CACHE_DIR
os.makedirs(CACHE_DIR, exist_ok=True)
device = "cuda"
def clear_cuda_memory():
"""Aggressively clear CUDA memory"""
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
def load_vision_models():
print("Loading CLIP and Florence models...")
# Load CLIP
clip_model = CLIPModel.from_pretrained(
"openai/clip-vit-large-patch14",
cache_dir=CACHE_DIR
).to(device)
clip_processor = CLIPProcessor.from_pretrained(
"openai/clip-vit-large-patch14",
cache_dir=CACHE_DIR
)
# Load Florence
florence_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-large",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
trust_remote_code=True,
cache_dir=CACHE_DIR
).to(device)
florence_processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-large",
trust_remote_code=True,
cache_dir=CACHE_DIR
)
return {
'clip_model': clip_model,
'clip_processor': clip_processor,
'florence_model': florence_model,
'florence_processor': florence_processor,
}
def generate_caption(image):
vision_models = load_vision_models()
# Ensure the image is a PIL Image
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
# Convert the image to RGB if it has an alpha channel
if image.mode == 'RGBA':
image = image.convert('RGB')
prompt = "<DETAILED_CAPTION>"
inputs = vision_models['florence_processor'](
text=prompt,
images=image,
return_tensors="pt"
).to(device, torch.float16 if torch.cuda.is_available() else torch.float32)
generated_ids = vision_models['florence_model'].generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
num_beams=3,
do_sample=False
)
generated_text = vision_models['florence_processor'].batch_decode(generated_ids, skip_special_tokens=True)[0]
parsed_answer = vision_models['florence_processor'].post_process_generation(
generated_text, task="<DETAILED_CAPTION>",
image_size=(image.width, image.height)
)
clear_cuda_memory()
return parsed_answer['<DETAILED_CAPTION>']
def crop_face(image_path, output_dir, output_name, scale_factor=4.0):
image = Image.open(image_path).convert("RGB")
img_raw = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
img_raw = img_raw.astype(np.float32)
rf = Pytorch_RetinaFace(
cfg='mobile0.25',
pretrained_path='./weights/mobilenet0.25_Final.pth',
confidence_threshold=0.02,
nms_threshold=0.4,
vis_thres=0.6
)
dets = rf.detect_faces(img_raw)
print("Dets: ", dets)
# Instead of asserting, handle multiple faces gracefully
if len(dets) == 0:
print("No faces detected!")
return False
# If multiple faces detected, use the one with highest confidence
if len(dets) > 1:
print(f"Warning: {len(dets)} faces detected, using the one with highest confidence")
# Assuming dets is a list of [bbox, landmark, score] and we want to sort by score
dets = sorted(dets, key=lambda x: x[2], reverse=True) # Sort by confidence score
# Just keep the highest confidence detection
dets = [dets[0]]
# Pass the scale_factor to center_and_crop_rescale for adjustable crop size
try:
# Unpack the tuple correctly - the function returns (cropped_imgs, bbox_infos)
cropped_imgs, bbox_infos = rf.center_and_crop_rescale(img_raw, dets, shift_factor=0.45, scale_factor=scale_factor)
# Check if we got any cropped images
if not cropped_imgs or len(cropped_imgs) == 0:
print("No cropped images returned")
return False
# Use the first cropped face image directly - it's not nested
img_to_save = cropped_imgs[0]
os.makedirs(output_dir, exist_ok=True)
cv2.imwrite(os.path.join(output_dir, output_name), img_to_save)
print(f"Saved: {output_name}")
return True
except Exception as e:
print(f"Error during face cropping: {e}")
return False
async def upscale_image(image_path, output_path):
"""Upscale an image using fal.ai's RealESRGAN model"""
fal_client = FalClient()
# Read and encode the image
with open(image_path, "rb") as image_file:
encoded_image = base64.b64encode(image_file.read()).decode('utf-8')
data_uri = f"data:image/jpeg;base64,{encoded_image}"
try:
# Submit the upscaling request
handler = await fal_client.submit_async(
"fal-ai/real-esrgan",
arguments={
"image_url": data_uri,
"scale": 2,
"model": "RealESRGAN_x4plus",
"output_format": "png",
"face": True
},
)
result = await handler.get()
# Download and save the upscaled image
image_url = result['image_url']
async with aiohttp.ClientSession() as session:
async with session.get(image_url) as response:
if response.status == 200:
with open(output_path, 'wb') as f:
f.write(await response.read())
return True
else:
print(f"Failed to download upscaled image: {response.status}")
return False
except Exception as e:
print(f"Error during upscaling: {e}")
return False
|