Spaces:
Paused
Paused
import torch | |
from rt_pose import PoseEstimationPipeline | |
import cv2 | |
import supervision as sv | |
import numpy as np | |
from rt_pose import PoseEstimationPipeline, PoseEstimationOutput | |
import logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class VitPose: | |
def __init__(self): | |
self.pipeline = PoseEstimationPipeline( | |
object_detection_checkpoint="PekingU/rtdetr_r50vd_coco_o365", | |
pose_estimation_checkpoint="usyd-community/vitpose-plus-small", | |
device="cuda" if torch.cuda.is_available() else "cpu", | |
dtype=torch.bfloat16, | |
compile=True, # or True to get more speedup | |
) | |
self.output_video_path = None | |
self.video_metadata = {} | |
def video_to_frames(self,video): | |
frames = [] | |
cap = cv2.VideoCapture(video) | |
self.video_metadata = { | |
"fps": cap.get(cv2.CAP_PROP_FPS), | |
"width": int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), | |
"height": int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), | |
} | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
frames.append(frame) | |
return frames | |
def run(self,video): | |
frames = self.video_to_frames(video) | |
annotated_frames = [] | |
for i, frame in enumerate(frames): | |
logger.info(f"Processing frame {i} of {len(frames)}") | |
output = self.pipeline(frame) | |
annotated_frame = self.visualize_output(frame,output) | |
annotated_frames.append(annotated_frame) | |
logger.info(f"Processed {len(annotated_frames)} frames") | |
return annotated_frames | |
def visualize_output(self,image: np.ndarray, output: PoseEstimationOutput, confidence: float = 0.3) -> np.ndarray: | |
""" | |
Visualize pose estimation output. | |
""" | |
keypoints_xy = output.keypoints_xy.float().cpu().numpy() | |
scores = output.scores.float().cpu().numpy() | |
# Supervision will not draw vertices with `0` score | |
# and coordinates with `(0, 0)` value | |
invisible_keypoints = scores < confidence | |
scores[invisible_keypoints] = 0 | |
keypoints_xy[invisible_keypoints] = 0 | |
keypoints = sv.KeyPoints(xy=keypoints_xy, confidence=scores) | |
_, y_min, _, y_max = output.person_boxes_xyxy.T | |
height = int((y_max - y_min).mean().item()) | |
radius = max(height // 100, 4) | |
thickness = max(height // 200, 2) | |
edge_annotator = sv.EdgeAnnotator(color=sv.Color.YELLOW, thickness=thickness) | |
vertex_annotator = sv.VertexAnnotator(color=sv.Color.ROBOFLOW, radius=radius) | |
annotated_frame = image.copy() | |
annotated_frame = edge_annotator.annotate(annotated_frame, keypoints) | |
annotated_frame = vertex_annotator.annotate(annotated_frame, keypoints) | |
return annotated_frame | |
def frames_to_video(self, frames): | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
height = self.video_metadata["height"] | |
width = self.video_metadata["width"] | |
# Always ensure vertical orientation | |
rotate = width > height # Rotate only if the video is in landscape mode | |
# For the VideoWriter, we need to specify the dimensions of the output frames | |
if rotate: | |
print(f"Original dimensions: {width}x{height}, Rotated dimensions: {height}x{width}") | |
out = cv2.VideoWriter(self.output_video_path, fourcc, self.video_metadata["fps"], (height, width)) | |
else: | |
print(f"Dimensions: {width}x{height}") | |
out = cv2.VideoWriter(self.output_video_path, fourcc, self.video_metadata["fps"], (width, height)) | |
for frame in frames: | |
if rotate: | |
# Rotate landscape videos 90 degrees to make them vertical | |
rotated_frame = cv2.rotate(frame, cv2.ROTATE_90_COUNTERCLOCKWISE) | |
out.write(rotated_frame) | |
else: | |
# Already vertical, no rotation needed | |
out.write(frame) | |
out.release() |