Spaces:
Sleeping
Sleeping
import streamlit as st | |
import av # streamlit-webrtc๊ฐ ๋น๋์ค ํ๋ ์์ ๋ค๋ฃจ๊ธฐ ์ํด ์ฌ์ฉ | |
import cv2 | |
import numpy as np | |
from ultralytics import YOLO | |
from streamlit_webrtc import webrtc_streamer, VideoTransformerBase, RTCConfiguration, WebRtcMode | |
# --- ์ค์ --- | |
MODEL_PATH = 'trained_model.pt' # ํ์ตํ YOLO ๋ชจ๋ธ ํ์ผ ๊ฒฝ๋ก | |
CONFIDENCE_THRESHOLD = 0.4 # ๊ฐ์ฒด ํ์ง ์ต์ ์ ๋ขฐ๋ (๋ชจ๋ธ์ ๋ฐ๋ผ ์กฐ์ ) | |
SEND_ALERT_INTERVAL = 30 # ๋ด๋ฐฐ ํ์ง ์ ๋ฐ์ดํฐ ์ฑ๋ ๋ฉ์์ง๋ฅผ ๋ช ํ๋ ์๋ง๋ค ๋ณด๋ผ์ง (๋๋ฌด ์งง์ผ๋ฉด ํด๋ผ์ด์ธํธ ๋ถ๋ด) | |
# --- YOLO ๋ชจ๋ธ ๋ก๋ --- | |
# Streamlit์ ์บ์ฑ ๊ธฐ๋ฅ ์ฌ์ฉ: ์ฑ ์คํ ์ค ๋ชจ๋ธ์ ํ ๋ฒ๋ง ๋ก๋ | |
def load_yolo_model(model_path): | |
try: | |
model = YOLO(model_path) | |
if hasattr(model, 'model') and model.model is not None: | |
st.success(f"YOLO ๋ชจ๋ธ ๋ก๋ ์ฑ๊ณต: {model_path}") | |
return model | |
else: | |
st.error(f"YOLO ๋ชจ๋ธ ๋ก๋ ์คํจ ๋๋ ๊ฐ์ฒด ์ด๊ธฐํ ๋ฌธ์ : {model_path}") | |
st.stop() | |
except FileNotFoundError: | |
st.error(f"์ค๋ฅ: ๋ชจ๋ธ ํ์ผ์ด ์์ต๋๋ค. '{model_path}' ๊ฒฝ๋ก๋ฅผ ํ์ธํ์ธ์.") | |
st.stop() # ๋ชจ๋ธ ํ์ผ ์์ผ๋ฉด ์ฑ ์ค์ง | |
except Exception as e: | |
st.error(f"YOLO ๋ชจ๋ธ ๋ก๋ ์ค ์ค๋ฅ ๋ฐ์: {e}") | |
st.stop() # ๋ชจ๋ธ ๋ก๋ ์คํจ ์ ์ฑ ์ค์ง | |
model = load_yolo_model(MODEL_PATH) | |
# --- Streamlit-WebRTC๋ฅผ ์ํ ๋น๋์ค ๋ณํ ํด๋์ค --- | |
# ์ด ํด๋์ค์ ์ธ์คํด์ค๊ฐ ๋น๋์ค ํ๋ ์๋ง๋ค ํธ์ถ๋ฉ๋๋ค. | |
class YOLOVideoTransformer(VideoTransformerBase): | |
# __init__์ data_channel ์ธ์๋ฅผ ์ถ๊ฐํ์ฌ ํด๋ผ์ด์ธํธ์ ํต์ | |
def __init__(self, model, confidence_thresh, send_interval, data_channel): | |
self.model = model | |
self.confidence_thresh = confidence_thresh | |
self.send_interval = send_interval | |
self._data_channel = data_channel # ํด๋ผ์ด์ธํธ์ ํต์ ํ ๋ฐ์ดํฐ ์ฑ๋ ๊ฐ์ฒด | |
self.detected_in_prev_frame = False # ์ด์ ํ๋ ์์์ ํ์ง๋์๋์ง ์ฌ๋ถ | |
self.frame_counter = 0 # ํ๋ ์ ์นด์ดํฐ | |
# ๊ฐ ๋น๋์ค ํ๋ ์์ ์ฒ๋ฆฌํ๋ ๋ฉ์๋ (๋น๋๊ธฐ ํจ์๋ก ์ ์) | |
async def recv(self, frame: av.VideoFrame) -> av.VideoFrame: | |
self.frame_counter += 1 | |
# AV ํ๋ ์์ OpenCV(numpy) ์ด๋ฏธ์ง๋ก ๋ณํ | |
img = frame.to_ndarray(format="bgr24") | |
# YOLOv8 ๋ชจ๋ธ๋ก ๊ฐ์ฒด ํ์ง | |
# verbose=False: ์ฝ์์ ํ์ง ๊ฒฐ๊ณผ ์ถ๋ ฅ ์ ํจ | |
results = self.model(img, conf=self.confidence_thresh, verbose=False) | |
cigarette_detected_in_current_frame = False | |
# ๊ฒฐ๊ณผ์์ 'cigarette' ๊ฐ์ฒด๊ฐ ํ์ง๋์๋์ง ํ์ธ | |
if results and len(results) > 0: | |
for box in results[0].boxes: | |
class_id = int(box.cls[0]) | |
confidence = float(box.conf[0]) | |
class_name = self.model.names[class_id] | |
if class_name == 'cigarette' and confidence >= self.confidence_thresh: | |
cigarette_detected_in_current_frame = True | |
break # ํ๋๋ผ๋ ํ์ง๋๋ฉด ๋ ์ด์ ํ์ธํ ํ์ ์์ | |
# --- ํด๋ผ์ด์ธํธ ์๋ฆฌ ์๋ฆผ ๋ก์ง (๋ฐ์ดํฐ ์ฑ๋ ์ฌ์ฉ) --- | |
# ๋ด๋ฐฐ๊ฐ ํ์ฌ ํ๋ ์์์ ํ์ง๋์๊ณ , ๋ฐ์ดํฐ ์ฑ๋์ด ์ด๋ ค ์์ผ๋ฉฐ, | |
# ๋ฉ์์ง ์ ์ก ๊ฐ๊ฒฉ์ ๋๋ฌํ์ ๋ ๋ฉ์์ง ์ ์ก | |
if cigarette_detected_in_current_frame and self._data_channel and self._data_channel.readyState == "open": | |
if not self.detected_in_prev_frame or self.frame_counter % self.send_interval == 0: | |
# print("Sending DETECT_CIGARETTE message to client...") # ๋๋ฒ๊ทธ ์ถ๋ ฅ | |
await self._data_channel.send("DETECT_CIGARETTE") # ํด๋ผ์ด์ธํธ๋ก ๋ฉ์์ง ์ ์ก | |
# ๋ค์ ํ๋ ์์ ์ํด ํ์ฌ ํ์ง ์ํ ์ ์ฅ | |
self.detected_in_prev_frame = cigarette_detected_in_current_frame | |
# ํ์ง๋ ๊ฒฐ๊ณผ๋ฅผ ์ด๋ฏธ์ง์ ํ์ (๋ฐ์ด๋ฉ ๋ฐ์ค, ๋ผ๋ฒจ ๋ฑ) | |
annotated_img = results[0].plot() | |
# ์ฒ๋ฆฌ๋ ์ด๋ฏธ์ง(numpy)๋ฅผ ๋ค์ AV ํ๋ ์์ผ๋ก ๋ณํํ์ฌ ๋ฐํ | |
return av.VideoFrame.from_ndarray(annotated_img, format="bgr24") | |
# --- ํด๋ผ์ด์ธํธ ์ธก JavaScript ์ฝ๋ (๋ฐ์ดํฐ ์ฑ๋ ๋ฉ์์ง ์์ ๋ฐ ์๋ฆฌ ์ฌ์) --- | |
# webrtc_streamer์ on_data_channel ์ธ์์ ์ ๋ฌ๋ JavaScript ํจ์ ์ ์ | |
# ์ด ํจ์๋ data channel ๊ฐ์ฒด๋ฅผ ์ธ์๋ก ๋ฐ์ต๋๋ค. | |
# ๋ฉ์์ง๋ฅผ ๋ฐ์ผ๋ฉด ์น ์ค๋์ค API๋ก ์ฌ์ธํ๋ฅผ ์์ฑํ์ฌ ์ฌ์ํฉ๋๋ค. | |
JS_CLIENT_SOUND_SCRIPT = """ | |
(channel) => { | |
// ์ค๋์ค ์ปจํ ์คํธ ์์ฑ (ํด๋ฆญ ๋ฑ ์ฌ์ฉ์ ์ํธ์์ฉ ํ์ ์์ฑํด์ผ ํ ์ ์์) | |
// webrtc_streamer ์์ ๋ฒํผ์ด ์ด๋ฏธ ์ํธ์์ฉ ์ญํ ์ ํฉ๋๋ค. | |
const audioContext = new (window.AudioContext || window.webkitAudioContext)(); | |
let lastPlayTime = 0; // ๋ง์ง๋ง ์๋ฆฌ ์ฌ์ ์๊ฐ (ms) | |
const playCooldown = 200; // ์๋ฆฌ ์ฌ์ ์ต์ ๊ฐ๊ฒฉ (ms) | |
// ์ฌ์ธํ ์๋ฆฌ๋ฅผ ์ฌ์ํ๋ ํจ์ | |
const playSineWaveAlert = () => { | |
const now = audioContext.currentTime * 1000; // ํ์ฌ ์๊ฐ์ ๋ฐ๋ฆฌ์ด๋ก ๋ณํ | |
if (now - lastPlayTime < playCooldown) { | |
// console.log("Cooldown active. Skipping sound."); // ๋๋ฒ๊ทธ ์ถ๋ ฅ | |
return; // ์ฟจ๋ค์ด ์ค์ด๋ฉด ์ฌ์ํ์ง ์์ | |
} | |
lastPlayTime = now; // ๋ง์ง๋ง ์ฌ์ ์๊ฐ ์ ๋ฐ์ดํธ | |
try { | |
const oscillator = audioContext.createOscillator(); | |
const gainNode = audioContext.createGain(); | |
oscillator.type = 'sine'; // ์ฌ์ธํ | |
oscillator.frequency.setValueAtTime(600, audioContext.currentTime); // ์ฃผํ์ (์: 600 Hz) | |
gainNode.gain.setValueAtTime(0.3, audioContext.currentTime); // ๋ณผ๋ฅจ (0.0 ~ 1.0) | |
oscillator.connect(gainNode); | |
gainNode.connect(audioContext.destination); | |
oscillator.start(); | |
oscillator.stop(audioContext.currentTime + 0.2); // 0.2์ด ์ฌ์ | |
// console.log("Playing sine wave sound."); // ๋๋ฒ๊ทธ ์ถ๋ ฅ | |
} catch (e) { | |
console.error("Error playing sine wave:", e); | |
} | |
}; | |
// ๋ฐ์ดํฐ ์ฑ๋๋ก๋ถํฐ ๋ฉ์์ง๋ฅผ ์์ ํ์ ๋ ์คํ๋ ์ฝ๋ฐฑ ํจ์ | |
channel.onmessage = (event) => { | |
// console.log("Received message:", event.data); // ์์ ๋ฉ์์ง ํ์ธ | |
if (event.data === "DETECT_CIGARETTE") { | |
// ์๋ฒ์์ ๋ด๋ฐฐ ํ์ง ๋ฉ์์ง๋ฅผ ๋ฐ์ผ๋ฉด ์๋ฆฌ ์ฌ์ | |
playSineWaveAlert(); | |
} | |
}; | |
// ๋ฐ์ดํฐ ์ฑ๋์ด ์ด๋ ธ์ ๋ | |
channel.onopen = () => { | |
console.log("Data channel opened!"); | |
}; | |
// ๋ฐ์ดํฐ ์ฑ๋์ด ๋ซํ์ ๋ | |
channel.onclose = () => { | |
console.log("Data channel closed."); | |
}; | |
// ๋ฐ์ดํฐ ์ฑ๋ ์๋ฌ ๋ฐ์ ์ | |
channel.onerror = (error) => { | |
console.error("Data channel error:", error); | |
}; | |
} | |
""" | |
# --- Streamlit ์ฑ ๋ ์ด์์ ๊ตฌ์ฑ --- | |
st.title("๐ฌ ์ค์๊ฐ ๋ด๋ฐฐ ํ์ง ์น ์ ํ๋ฆฌ์ผ์ด์ (ํด๋ผ์ด์ธํธ ์๋ฆฌ)") | |
st.write(""" | |
์น์บ ํผ๋๋ฅผ ํตํด ๋ด๋ฐฐ ๊ฐ์ฒด๋ฅผ ์ค์๊ฐ์ผ๋ก ํ์งํ๊ณ ์์์ ํ์ํฉ๋๋ค. | |
๋ด๋ฐฐ๊ฐ ํ์ง๋๋ฉด **์ฌ์ฉ์์ ๋ธ๋ผ์ฐ์ **์์ ์๋ฆผ ์๋ฆฌ(์ฌ์ธํ)๊ฐ ์ฌ์๋ฉ๋๋ค. | |
**์ฃผ์:** | |
* ์ด ์ฑ์ ์ฌ์ฉ์์ ๋ธ๋ผ์ฐ์ ์น์บ ๋ฐ ์ค๋์ค ์ฌ์ ๊ถํ์ด ํ์ํฉ๋๋ค. ๋ธ๋ผ์ฐ์ ์์ฒญ ์ ํ์ฉํด์ฃผ์ธ์. | |
* ๋คํธ์ํฌ ์ํ ๋ฐ ์ปดํจํฐ ์ฑ๋ฅ์ ๋ฐ๋ผ ์์ ์ฒ๋ฆฌ์ ์ง์ฐ์ด ๋ฐ์ํ ์ ์์ต๋๋ค. | |
* `trained_model.pt` ํ์ผ์ด ์คํฌ๋ฆฝํธ ํ์ผ๊ณผ ๊ฐ์ ๋๋ ํ ๋ฆฌ์ ์๋์ง ํ์ธํ์ธ์. | |
""") | |
st.write("---") | |
st.subheader("์น์บ ์คํธ๋ฆผ ๋ฐ ๋ด๋ฐฐ ํ์ง ๊ฒฐ๊ณผ") | |
# RTC ์ค์ (NAT ํต๊ณผ๋ฅผ ์ํด ํ์, Google STUN ์๋ฒ ์ฌ์ฉ) | |
# ๋๋ถ๋ถ์ ๊ฒฝ์ฐ ๊ธฐ๋ณธ ์ค์ ์ผ๋ก ์ถฉ๋ถํ๋, ๋ช ์์ ์ผ๋ก ์ค์ ํ ์ ์์ต๋๋ค. | |
rtc_configuration = RTCConfiguration({"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}) | |
# Streamlit-WebRTC ์ปดํฌ๋ํธ ์ถ๊ฐ | |
webrtc_ctx = webrtc_streamer( | |
key="yolo-detection-client-sound", # ๊ณ ์ ํค | |
mode=WebRtcMode.SENDRECV, # ๋น๋์ค๋ฅผ ๋ณด๋ด๊ณ (SEND) ์๋ฒ์์ ์ฒ๋ฆฌ๋ ๋น๋์ค๋ฅผ ๋ค์ ๋ฐ์ (RECV) | |
video_processor_factory=lambda: YOLOVideoTransformer( # ๋น๋์ค ๋ณํ ํด๋์ค ํฉํ ๋ฆฌ | |
model=model, | |
confidence_thresh=CONFIDENCE_THRESHOLD, | |
send_interval=SEND_ALERT_INTERVAL, | |
# NOTE: video_processor_factory๊ฐ lambda ํจ์๋ก ์ฌ์ฉ๋ ๋, | |
# webrtc_streamer ๋ด๋ถ์ ์ผ๋ก ์์ฑ๋ data_channel ๊ฐ์ฒด๊ฐ VideoTransformer ์ธ์คํด์ค์ ์ ๋ฌ๋ฉ๋๋ค. | |
# ๋ช ์์ ์ผ๋ก lambda ์ธ์๋ก channel์ ๋ฐ์ง ์์๋ ๋ฉ๋๋ค. | |
# (webrtc_streamer์ ๊ตฌํ ๋ฐฉ์์ ๋ฐ๋ผ ๋ค๋ฅผ ์ ์์ผ๋ฏ๋ก ๋ฌธ์ ํ์ธ ํ์) | |
# ์ต์ ๋ฒ์ ์์๋ __init__์ data_channel=data_channel ํํ๋ก ์ ๋ฌ๋จ | |
data_channel=None # ์ด๊ธฐ๊ฐ์ None, webrtc_streamer๊ฐ ์ธ์คํด์ค ์์ฑ ์ ์ค์ ๊ฐ์ฒด ์ฃผ์ | |
# -> ์๋, lambda ํฉํ ๋ฆฌ๊ฐ ์ธ์๋ฅผ ๋ฐ๋๋ก ๋ณ๊ฒฝํด์ผ ํจ | |
# lambda channel: YOLOVideoTransformer(..., data_channel=channel) ์ด ๋ ์ ํํจ | |
), | |
rtc_configuration=rtc_configuration, | |
media_stream_constraints={"video": True, "audio": False}, # ์น์บ ๋น๋์ค๋ง ์ฌ์ฉ | |
async_processing=True, # ๋น๋์ค ์ฒ๋ฆฌ๋ฅผ ๋น๋๊ธฐ๋ก ์คํ | |
on_data_channel=JS_CLIENT_SOUND_SCRIPT # ๋ฐ์ดํฐ ์ฑ๋ ๊ด๋ จ ํด๋ผ์ด์ธํธ JS ์ฝ๋ | |
) | |
# ์ video_processor_factory lambda ๋ถ๋ถ์ ๋ค์๊ณผ ๊ฐ์ด ๋ช ์์ ์ผ๋ก data_channel์ ๋ฐ๋๋ก ์์ ํฉ๋๋ค. | |
# lambda channel: YOLOVideoTransformer( | |
# model=model, | |
# confidence_thresh=CONFIDENCE_THRESHOLD, | |
# send_interval=SEND_ALERT_INTERVAL, | |
# data_channel=channel # ๋ฐ์ดํฐ ์ฑ๋ ๊ฐ์ฒด ์ ๋ฌ | |
# ), | |
st.write("---") | |
st.info("์น์บ ์คํธ๋ฆผ์ ์์ํ๋ฉด ๋ธ๋ผ์ฐ์ ์์ ๋ด๋ฐฐ ํ์ง ์ ์๋ฆผ ์๋ฆฌ๊ฐ ์ฌ์๋ฉ๋๋ค.") |