joselobenitezg's picture
set tf32 matmul
5f51879
raw
history blame contribute delete
3.42 kB
import torch
import numpy as np
from PIL import Image, ImageDraw
from torchvision import transforms
from config import SAPIENS_LITE_MODELS_PATH
def load_model(task, version):
try:
model_path = SAPIENS_LITE_MODELS_PATH[task][version]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
model = torch.jit.load(model_path)
model.eval()
model.to(device)
return model, device
except KeyError as e:
print(f"Error: Tarea o versión inválida. {e}")
return None, None
def preprocess_image(image, input_shape):
img = image.resize((input_shape[2], input_shape[1]))
img = np.array(img).transpose(2, 0, 1)
img = torch.from_numpy(img).float()
img = img[[2, 1, 0], ...] # RGB to BGR
mean = torch.tensor([123.675, 116.28, 103.53]).view(3, 1, 1)
std = torch.tensor([58.395, 57.12, 57.375]).view(3, 1, 1)
img = (img - mean) / std
return img.unsqueeze(0)
def udp_decode(heatmap, img_size, heatmap_size):
# This is a simplified version. You might need to implement the full UDP decode logic
h, w = heatmap_size
keypoints = np.zeros((heatmap.shape[0], 2))
keypoint_scores = np.zeros(heatmap.shape[0])
for i in range(heatmap.shape[0]):
hm = heatmap[i]
idx = np.unravel_index(np.argmax(hm), hm.shape)
keypoints[i] = [idx[1] * img_size[1] / w, idx[0] * img_size[0] / h]
keypoint_scores[i] = hm[idx]
return keypoints, keypoint_scores
def visualize_keypoints(image, keypoints, keypoint_scores, threshold=0.3):
draw = ImageDraw.Draw(image)
for (x, y), score in zip(keypoints, keypoint_scores):
if score > threshold:
draw.ellipse([(x-2, y-2), (x+2, y+2)], fill='red', outline='red')
return image
def process_image_or_video(input_data, task='pose', version='sapiens_1b'):
model, device = load_model(task, version)
if model is None or device is None:
return None
input_shape = (3, 1024, 768)
def process_frame(frame):
if isinstance(frame, np.ndarray):
frame = Image.fromarray(frame)
if frame.mode == 'RGBA':
frame = frame.convert('RGB')
img = preprocess_image(frame, input_shape)
with torch.no_grad():
heatmap = model(img.to(device))
keypoints, keypoint_scores = udp_decode(heatmap[0].cpu().float().numpy(),
input_shape[1:],
(input_shape[1] // 4, input_shape[2] // 4))
scale_x = frame.width / input_shape[2]
scale_y = frame.height / input_shape[1]
keypoints[:, 0] *= scale_x
keypoints[:, 1] *= scale_y
pose_image = visualize_keypoints(frame, keypoints, keypoint_scores)
return pose_image
if isinstance(input_data, np.ndarray): # Video frame
return process_frame(input_data)
elif isinstance(input_data, Image.Image): # Imagen
return process_frame(input_data)
else:
print("Tipo de entrada no soportado. Por favor, proporcione una imagen PIL o un frame de video numpy.")
return None