Spaces:
Sleeping
Sleeping
from ultralytics import YOLO | |
import torch | |
import tensorflow as tf | |
import time | |
import os | |
import logging | |
from pathlib import Path | |
from typing import List | |
import av | |
import cv2 | |
import numpy as np | |
import streamlit as st | |
from streamlit_webrtc import WebRtcMode, webrtc_streamer | |
from utils.download import download_file | |
from utils.turn import get_ice_servers | |
from mtcnn import MTCNN # Import MTCNN for face detection | |
from PIL import Image, ImageDraw # Import PIL for image processing | |
from transformers import pipeline # Import Hugging Face transformers pipeline | |
import requests | |
from io import BytesIO # Import for handling byte streams | |
# CHANGE CODE BELOW HERE, USE TO REPLACE WITH YOUR WANTED ANALYSIS. | |
# Update below string to set display title of analysis | |
# Default title - "Facial Sentiment Analysis" | |
ANALYSIS_TITLE = "YOLO-8 Object Detection Analysis" | |
# Load the YOLOv8 model | |
model = YOLO("yolov8n.pt") | |
# CHANGE THE CONTENTS OF THIS FUNCTION, USE TO REPLACE WITH YOUR WANTED ANALYSIS. | |
# | |
# Set analysis results in img_container and result queue for display | |
# img_container["input"] - holds the input frame contents - of type np.ndarray | |
# img_container["analyzed"] - holds the analyzed frame with any added annotations - of type np.ndarray | |
# img_container["analysis_time"] - holds how long the analysis has taken in miliseconds | |
# result_queue - holds the analysis metadata results - of type dictionary | |
def analyze_frame(frame: np.ndarray): | |
start_time = time.time() # Start timing the analysis | |
img_container["input"] = frame # Store the input frame | |
frame = frame.copy() # Create a copy of the frame to modify | |
# Run YOLOv8 tracking on the frame, persisting tracks between frames | |
results = model.track(frame, persist=True) | |
# Initialize a list to store Detection objects | |
detections = [] | |
object_counter = 1 | |
# Iterate over the detected boxes | |
for box in results[0].boxes: | |
detection = {} | |
# Extract class id, label, score, and bounding box coordinates | |
class_id = int(box.cls) | |
detection["id"] = object_counter | |
detection["label"] = model.names[class_id] | |
detection["score"] = float(box.conf) | |
detection["box_coords"] = [round(value.item(), 2) | |
for value in box.xyxy.flatten()] | |
detections.append(detection) | |
object_counter += 1 | |
# Visualize the results on the frame | |
frame = results[0].plot() | |
end_time = time.time() # End timing the analysis | |
execution_time_ms = round( | |
(end_time - start_time) * 1000, 2 | |
) # Calculate execution time in milliseconds | |
# Store the execution time | |
img_container["analysis_time"] = execution_time_ms | |
# store the detections | |
img_container["detections"] = detections | |
img_container["analyzed"] = frame # Store the analyzed frame | |
return # End of the function | |
# | |
# | |
# DO NOT TOUCH THE BELOW CODE (NOT NEEDED) | |
# | |
# | |
# Suppress FFmpeg logs | |
os.environ["FFMPEG_LOG_LEVEL"] = "quiet" | |
# Suppress TensorFlow or PyTorch progress bars | |
tf.get_logger().setLevel("ERROR") | |
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" | |
# Suppress PyTorch logs | |
logging.getLogger().setLevel(logging.WARNING) | |
torch.set_num_threads(1) | |
logging.getLogger("torch").setLevel(logging.ERROR) | |
# Suppress Streamlit logs using the logging module | |
logging.getLogger("streamlit").setLevel(logging.ERROR) | |
# Container to hold image data and analysis results | |
img_container = {"input": None, "analyzed": None, | |
"analysis_time": None, "detections": None} | |
# Logger for debugging and information | |
logger = logging.getLogger(__name__) | |
# Callback function to process video frames | |
# This function is called for each video frame in the WebRTC stream. | |
# It converts the frame to a numpy array in RGB format, analyzes the frame, | |
# and returns the original frame. | |
def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame: | |
# Convert frame to numpy array in RGB format | |
img = frame.to_ndarray(format="rgb24") | |
analyze_frame(img) # Analyze the frame | |
return frame # Return the original frame | |
# Get ICE servers for WebRTC | |
ice_servers = get_ice_servers() | |
# Streamlit UI configuration | |
st.set_page_config(layout="wide") | |
# Custom CSS for the Streamlit page | |
st.markdown( | |
""" | |
<style> | |
.main { | |
padding: 2rem; | |
} | |
h1, h2, h3 { | |
font-family: 'Arial', sans-serif; | |
} | |
h1 { | |
font-weight: 700; | |
font-size: 2.5rem; | |
} | |
h2 { | |
font-weight: 600; | |
font-size: 2rem; | |
} | |
h3 { | |
font-weight: 500; | |
font-size: 1.5rem; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
# Streamlit page title and subtitle | |
st.title("Computer Vision Playground") | |
# Add a link to the README file | |
st.markdown( | |
""" | |
<div style="text-align: left;"> | |
<p>See the <a href="https://huggingface.co/spaces/eusholli/sentiment-analyzer/blob/main/README.md" | |
target="_blank">README</a> to learn how to use this code to help you start your computer vision exploration.</p> | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
st.subheader(ANALYSIS_TITLE) | |
# Columns for input and output streams | |
col1, col2 = st.columns(2) | |
with col1: | |
st.header("Input Stream") | |
st.subheader("input") | |
# WebRTC streamer to get video input from the webcam | |
webrtc_ctx = webrtc_streamer( | |
key="input-webcam", | |
mode=WebRtcMode.SENDRECV, | |
rtc_configuration=ice_servers, | |
video_frame_callback=video_frame_callback, | |
media_stream_constraints={"video": True, "audio": False}, | |
async_processing=True, | |
) | |
# File uploader for images | |
st.subheader("Upload an Image") | |
uploaded_file = st.file_uploader( | |
"Choose an image...", type=["jpg", "jpeg", "png"]) | |
# Text input for image URL | |
st.subheader("Or Enter Image URL") | |
image_url = st.text_input("Image URL") | |
# File uploader for videos | |
st.subheader("Upload a Video") | |
uploaded_video = st.file_uploader( | |
"Choose a video...", type=["mp4", "avi", "mov", "mkv"] | |
) | |
# Text input for video URL | |
st.subheader("Or Enter Video Download URL") | |
video_url = st.text_input("Video URL") | |
# Streamlit footer | |
st.markdown( | |
""" | |
<div style="text-align: center; margin-top: 2rem;"> | |
<p>If you want to set up your own computer vision playground see <a href="https://huggingface.co/spaces/eusholli/computer-vision-playground/blob/main/README.md" target="_blank">here</a>.</p> | |
</div> | |
""", | |
unsafe_allow_html=True | |
) | |
# Function to initialize the analysis UI | |
# This function sets up the placeholders and UI elements in the analysis section. | |
# It creates placeholders for input and output frames, analysis time, and detected labels. | |
def analysis_init(): | |
global analysis_time, show_labels, labels_placeholder, input_placeholder, output_placeholder | |
with col2: | |
st.header("Analysis") | |
st.subheader("Input Frame") | |
input_placeholder = st.empty() # Placeholder for input frame | |
st.subheader("Output Frame") | |
output_placeholder = st.empty() # Placeholder for output frame | |
analysis_time = st.empty() # Placeholder for analysis time | |
show_labels = st.checkbox( | |
"Show the detected labels", value=True | |
) # Checkbox to show/hide labels | |
labels_placeholder = st.empty() # Placeholder for labels | |
# Function to publish frames and results to the Streamlit UI | |
# This function retrieves the latest frames and results from the global container and result queue, | |
# and updates the placeholders in the Streamlit UI with the current input frame, analyzed frame, analysis time, and detected labels. | |
def publish_frame(): | |
img = img_container["input"] | |
if img is None: | |
return | |
input_placeholder.image(img, channels="RGB") # Display the input frame | |
analyzed = img_container["analyzed"] | |
if analyzed is None: | |
return | |
# Display the analyzed frame | |
output_placeholder.image(analyzed, channels="RGB") | |
time = img_container["analysis_time"] | |
if time is None: | |
return | |
# Display the analysis time | |
analysis_time.text(f"Analysis Time: {time} ms") | |
detections = img_container["detections"] | |
if detections is None: | |
return | |
if show_labels: | |
labels_placeholder.table( | |
detections | |
) # Display labels if the checkbox is checked | |
# If the WebRTC streamer is playing, initialize and publish frames | |
if webrtc_ctx.state.playing: | |
analysis_init() # Initialize the analysis UI | |
while True: | |
publish_frame() # Publish the frames and results | |
time.sleep(0.1) # Delay to control frame rate | |
# If an image is uploaded or a URL is provided, process the image | |
if uploaded_file is not None or image_url: | |
analysis_init() # Initialize the analysis UI | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file) # Open the uploaded image | |
img = np.array(image.convert("RGB")) # Convert the image to RGB format | |
else: | |
response = requests.get(image_url) # Download the image from the URL | |
# Open the downloaded image | |
image = Image.open(BytesIO(response.content)) | |
img = np.array(image.convert("RGB")) # Convert the image to RGB format | |
analyze_frame(img) # Analyze the image | |
publish_frame() # Publish the results | |
# Function to process video files | |
# This function reads frames from a video file, analyzes each frame for face detection and sentiment analysis, | |
# and updates the Streamlit UI with the current input frame, analyzed frame, and detected labels. | |
def process_video(video_path): | |
cap = cv2.VideoCapture(video_path) # Open the video file | |
while cap.isOpened(): | |
ret, frame = cap.read() # Read a frame from the video | |
if not ret: | |
break # Exit the loop if no more frames are available | |
# Display the current frame as the input frame | |
input_placeholder.image(frame) | |
analyze_frame( | |
frame | |
) # Analyze the frame for face detection and sentiment analysis | |
publish_frame() # Publish the results | |
cap.release() # Release the video capture object | |
# If a video is uploaded or a URL is provided, process the video | |
if uploaded_video is not None or video_url: | |
analysis_init() # Initialize the analysis UI | |
if uploaded_video is not None: | |
video_path = uploaded_video.name # Get the name of the uploaded video | |
with open(video_path, "wb") as f: | |
# Save the uploaded video to a file | |
f.write(uploaded_video.getbuffer()) | |
else: | |
# Download the video from the URL | |
video_path = download_file(video_url) | |
process_video(video_path) # Process the video | |