SportsAI / vitpose.py
nicolasbuitragob's picture
edits
fba6e1e
raw
history blame
4.15 kB
import torch
from rt_pose import PoseEstimationPipeline
import cv2
import supervision as sv
import numpy as np
from rt_pose import PoseEstimationPipeline, PoseEstimationOutput
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class VitPose:
def __init__(self):
self.pipeline = PoseEstimationPipeline(
object_detection_checkpoint="PekingU/rtdetr_r50vd_coco_o365",
pose_estimation_checkpoint="usyd-community/vitpose-plus-small",
device="cuda" if torch.cuda.is_available() else "cpu",
dtype=torch.bfloat16,
compile=True, # or True to get more speedup
)
self.output_video_path = None
self.video_metadata = {}
def video_to_frames(self,video):
frames = []
cap = cv2.VideoCapture(video)
self.video_metadata = {
"fps": cap.get(cv2.CAP_PROP_FPS),
"width": int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
"height": int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
}
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frames.append(frame)
return frames
def run(self,video):
frames = self.video_to_frames(video)
annotated_frames = []
for i, frame in enumerate(frames):
logger.info(f"Processing frame {i} of {len(frames)}")
output = self.pipeline(frame)
annotated_frame = self.visualize_output(frame,output)
annotated_frames.append(annotated_frame)
logger.info(f"Processed {len(annotated_frames)} frames")
return annotated_frames
def visualize_output(self,image: np.ndarray, output: PoseEstimationOutput, confidence: float = 0.3) -> np.ndarray:
"""
Visualize pose estimation output.
"""
keypoints_xy = output.keypoints_xy.float().cpu().numpy()
scores = output.scores.float().cpu().numpy()
# Supervision will not draw vertices with `0` score
# and coordinates with `(0, 0)` value
invisible_keypoints = scores < confidence
scores[invisible_keypoints] = 0
keypoints_xy[invisible_keypoints] = 0
keypoints = sv.KeyPoints(xy=keypoints_xy, confidence=scores)
_, y_min, _, y_max = output.person_boxes_xyxy.T
height = int((y_max - y_min).mean().item())
radius = max(height // 100, 4)
thickness = max(height // 200, 2)
edge_annotator = sv.EdgeAnnotator(color=sv.Color.YELLOW, thickness=thickness)
vertex_annotator = sv.VertexAnnotator(color=sv.Color.ROBOFLOW, radius=radius)
annotated_frame = image.copy()
annotated_frame = edge_annotator.annotate(annotated_frame, keypoints)
annotated_frame = vertex_annotator.annotate(annotated_frame, keypoints)
return annotated_frame
def frames_to_video(self, frames):
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
height = self.video_metadata["height"]
width = self.video_metadata["width"]
# Always ensure vertical orientation
rotate = width > height # Rotate only if the video is in landscape mode
# For the VideoWriter, we need to specify the dimensions of the output frames
if rotate:
print(f"Original dimensions: {width}x{height}, Rotated dimensions: {height}x{width}")
out = cv2.VideoWriter(self.output_video_path, fourcc, self.video_metadata["fps"], (height, width))
else:
print(f"Dimensions: {width}x{height}")
out = cv2.VideoWriter(self.output_video_path, fourcc, self.video_metadata["fps"], (width, height))
for frame in frames:
if rotate:
# Rotate landscape videos 90 degrees to make them vertical
rotated_frame = cv2.rotate(frame, cv2.ROTATE_90_COUNTERCLOCKWISE)
out.write(rotated_frame)
else:
# Already vertical, no rotation needed
out.write(frame)
out.release()