ciga / app.py
kimhyunwoo's picture
Create app.py
7fe5267 verified
raw
history blame contribute delete
10.2 kB
import streamlit as st
import av # streamlit-webrtc๊ฐ€ ๋น„๋””์˜ค ํ”„๋ ˆ์ž„์„ ๋‹ค๋ฃจ๊ธฐ ์œ„ํ•ด ์‚ฌ์šฉ
import cv2
import numpy as np
from ultralytics import YOLO
from streamlit_webrtc import webrtc_streamer, VideoTransformerBase, RTCConfiguration, WebRtcMode
# --- ์„ค์ • ---
MODEL_PATH = 'trained_model.pt' # ํ•™์Šตํ•œ YOLO ๋ชจ๋ธ ํŒŒ์ผ ๊ฒฝ๋กœ
CONFIDENCE_THRESHOLD = 0.4 # ๊ฐ์ฒด ํƒ์ง€ ์ตœ์†Œ ์‹ ๋ขฐ๋„ (๋ชจ๋ธ์— ๋”ฐ๋ผ ์กฐ์ •)
SEND_ALERT_INTERVAL = 30 # ๋‹ด๋ฐฐ ํƒ์ง€ ์‹œ ๋ฐ์ดํ„ฐ ์ฑ„๋„ ๋ฉ”์‹œ์ง€๋ฅผ ๋ช‡ ํ”„๋ ˆ์ž„๋งˆ๋‹ค ๋ณด๋‚ผ์ง€ (๋„ˆ๋ฌด ์งง์œผ๋ฉด ํด๋ผ์ด์–ธํŠธ ๋ถ€๋‹ด)
# --- YOLO ๋ชจ๋ธ ๋กœ๋“œ ---
@st.cache_resource # Streamlit์˜ ์บ์‹ฑ ๊ธฐ๋Šฅ ์‚ฌ์šฉ: ์•ฑ ์‹คํ–‰ ์ค‘ ๋ชจ๋ธ์„ ํ•œ ๋ฒˆ๋งŒ ๋กœ๋“œ
def load_yolo_model(model_path):
try:
model = YOLO(model_path)
if hasattr(model, 'model') and model.model is not None:
st.success(f"YOLO ๋ชจ๋ธ ๋กœ๋“œ ์„ฑ๊ณต: {model_path}")
return model
else:
st.error(f"YOLO ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ ๋˜๋Š” ๊ฐ์ฒด ์ดˆ๊ธฐํ™” ๋ฌธ์ œ: {model_path}")
st.stop()
except FileNotFoundError:
st.error(f"์˜ค๋ฅ˜: ๋ชจ๋ธ ํŒŒ์ผ์ด ์—†์Šต๋‹ˆ๋‹ค. '{model_path}' ๊ฒฝ๋กœ๋ฅผ ํ™•์ธํ•˜์„ธ์š”.")
st.stop() # ๋ชจ๋ธ ํŒŒ์ผ ์—†์œผ๋ฉด ์•ฑ ์ค‘์ง€
except Exception as e:
st.error(f"YOLO ๋ชจ๋ธ ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
st.stop() # ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ ์‹œ ์•ฑ ์ค‘์ง€
model = load_yolo_model(MODEL_PATH)
# --- Streamlit-WebRTC๋ฅผ ์œ„ํ•œ ๋น„๋””์˜ค ๋ณ€ํ™˜ ํด๋ž˜์Šค ---
# ์ด ํด๋ž˜์Šค์˜ ์ธ์Šคํ„ด์Šค๊ฐ€ ๋น„๋””์˜ค ํ”„๋ ˆ์ž„๋งˆ๋‹ค ํ˜ธ์ถœ๋ฉ๋‹ˆ๋‹ค.
class YOLOVideoTransformer(VideoTransformerBase):
# __init__์— data_channel ์ธ์ž๋ฅผ ์ถ”๊ฐ€ํ•˜์—ฌ ํด๋ผ์ด์–ธํŠธ์™€ ํ†ต์‹ 
def __init__(self, model, confidence_thresh, send_interval, data_channel):
self.model = model
self.confidence_thresh = confidence_thresh
self.send_interval = send_interval
self._data_channel = data_channel # ํด๋ผ์ด์–ธํŠธ์™€ ํ†ต์‹ ํ•  ๋ฐ์ดํ„ฐ ์ฑ„๋„ ๊ฐ์ฒด
self.detected_in_prev_frame = False # ์ด์ „ ํ”„๋ ˆ์ž„์—์„œ ํƒ์ง€๋˜์—ˆ๋Š”์ง€ ์—ฌ๋ถ€
self.frame_counter = 0 # ํ”„๋ ˆ์ž„ ์นด์šดํ„ฐ
# ๊ฐ ๋น„๋””์˜ค ํ”„๋ ˆ์ž„์„ ์ฒ˜๋ฆฌํ•˜๋Š” ๋ฉ”์„œ๋“œ (๋น„๋™๊ธฐ ํ•จ์ˆ˜๋กœ ์ •์˜)
async def recv(self, frame: av.VideoFrame) -> av.VideoFrame:
self.frame_counter += 1
# AV ํ”„๋ ˆ์ž„์„ OpenCV(numpy) ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜
img = frame.to_ndarray(format="bgr24")
# YOLOv8 ๋ชจ๋ธ๋กœ ๊ฐ์ฒด ํƒ์ง€
# verbose=False: ์ฝ˜์†”์— ํƒ์ง€ ๊ฒฐ๊ณผ ์ถœ๋ ฅ ์•ˆ ํ•จ
results = self.model(img, conf=self.confidence_thresh, verbose=False)
cigarette_detected_in_current_frame = False
# ๊ฒฐ๊ณผ์—์„œ 'cigarette' ๊ฐ์ฒด๊ฐ€ ํƒ์ง€๋˜์—ˆ๋Š”์ง€ ํ™•์ธ
if results and len(results) > 0:
for box in results[0].boxes:
class_id = int(box.cls[0])
confidence = float(box.conf[0])
class_name = self.model.names[class_id]
if class_name == 'cigarette' and confidence >= self.confidence_thresh:
cigarette_detected_in_current_frame = True
break # ํ•˜๋‚˜๋ผ๋„ ํƒ์ง€๋˜๋ฉด ๋” ์ด์ƒ ํ™•์ธํ•  ํ•„์š” ์—†์Œ
# --- ํด๋ผ์ด์–ธํŠธ ์†Œ๋ฆฌ ์•Œ๋ฆผ ๋กœ์ง (๋ฐ์ดํ„ฐ ์ฑ„๋„ ์‚ฌ์šฉ) ---
# ๋‹ด๋ฐฐ๊ฐ€ ํ˜„์žฌ ํ”„๋ ˆ์ž„์—์„œ ํƒ์ง€๋˜์—ˆ๊ณ , ๋ฐ์ดํ„ฐ ์ฑ„๋„์ด ์—ด๋ ค ์žˆ์œผ๋ฉฐ,
# ๋ฉ”์‹œ์ง€ ์ „์†ก ๊ฐ„๊ฒฉ์— ๋„๋‹ฌํ–ˆ์„ ๋•Œ ๋ฉ”์‹œ์ง€ ์ „์†ก
if cigarette_detected_in_current_frame and self._data_channel and self._data_channel.readyState == "open":
if not self.detected_in_prev_frame or self.frame_counter % self.send_interval == 0:
# print("Sending DETECT_CIGARETTE message to client...") # ๋””๋ฒ„๊ทธ ์ถœ๋ ฅ
await self._data_channel.send("DETECT_CIGARETTE") # ํด๋ผ์ด์–ธํŠธ๋กœ ๋ฉ”์‹œ์ง€ ์ „์†ก
# ๋‹ค์Œ ํ”„๋ ˆ์ž„์„ ์œ„ํ•ด ํ˜„์žฌ ํƒ์ง€ ์ƒํƒœ ์ €์žฅ
self.detected_in_prev_frame = cigarette_detected_in_current_frame
# ํƒ์ง€๋œ ๊ฒฐ๊ณผ๋ฅผ ์ด๋ฏธ์ง€์— ํ‘œ์‹œ (๋ฐ”์šด๋”ฉ ๋ฐ•์Šค, ๋ผ๋ฒจ ๋“ฑ)
annotated_img = results[0].plot()
# ์ฒ˜๋ฆฌ๋œ ์ด๋ฏธ์ง€(numpy)๋ฅผ ๋‹ค์‹œ AV ํ”„๋ ˆ์ž„์œผ๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ ๋ฐ˜ํ™˜
return av.VideoFrame.from_ndarray(annotated_img, format="bgr24")
# --- ํด๋ผ์ด์–ธํŠธ ์ธก JavaScript ์ฝ”๋“œ (๋ฐ์ดํ„ฐ ์ฑ„๋„ ๋ฉ”์‹œ์ง€ ์ˆ˜์‹  ๋ฐ ์†Œ๋ฆฌ ์žฌ์ƒ) ---
# webrtc_streamer์˜ on_data_channel ์ธ์ž์— ์ „๋‹ฌ๋  JavaScript ํ•จ์ˆ˜ ์ •์˜
# ์ด ํ•จ์ˆ˜๋Š” data channel ๊ฐ์ฒด๋ฅผ ์ธ์ž๋กœ ๋ฐ›์Šต๋‹ˆ๋‹ค.
# ๋ฉ”์‹œ์ง€๋ฅผ ๋ฐ›์œผ๋ฉด ์›น ์˜ค๋””์˜ค API๋กœ ์‚ฌ์ธํŒŒ๋ฅผ ์ƒ์„ฑํ•˜์—ฌ ์žฌ์ƒํ•ฉ๋‹ˆ๋‹ค.
JS_CLIENT_SOUND_SCRIPT = """
(channel) => {
// ์˜ค๋””์˜ค ์ปจํ…์ŠคํŠธ ์ƒ์„ฑ (ํด๋ฆญ ๋“ฑ ์‚ฌ์šฉ์ž ์ƒํ˜ธ์ž‘์šฉ ํ›„์— ์ƒ์„ฑํ•ด์•ผ ํ•  ์ˆ˜ ์žˆ์Œ)
// webrtc_streamer ์‹œ์ž‘ ๋ฒ„ํŠผ์ด ์ด๋ฏธ ์ƒํ˜ธ์ž‘์šฉ ์—ญํ• ์„ ํ•ฉ๋‹ˆ๋‹ค.
const audioContext = new (window.AudioContext || window.webkitAudioContext)();
let lastPlayTime = 0; // ๋งˆ์ง€๋ง‰ ์†Œ๋ฆฌ ์žฌ์ƒ ์‹œ๊ฐ„ (ms)
const playCooldown = 200; // ์†Œ๋ฆฌ ์žฌ์ƒ ์ตœ์†Œ ๊ฐ„๊ฒฉ (ms)
// ์‚ฌ์ธํŒŒ ์†Œ๋ฆฌ๋ฅผ ์žฌ์ƒํ•˜๋Š” ํ•จ์ˆ˜
const playSineWaveAlert = () => {
const now = audioContext.currentTime * 1000; // ํ˜„์žฌ ์‹œ๊ฐ„์„ ๋ฐ€๋ฆฌ์ดˆ๋กœ ๋ณ€ํ™˜
if (now - lastPlayTime < playCooldown) {
// console.log("Cooldown active. Skipping sound."); // ๋””๋ฒ„๊ทธ ์ถœ๋ ฅ
return; // ์ฟจ๋‹ค์šด ์ค‘์ด๋ฉด ์žฌ์ƒํ•˜์ง€ ์•Š์Œ
}
lastPlayTime = now; // ๋งˆ์ง€๋ง‰ ์žฌ์ƒ ์‹œ๊ฐ„ ์—…๋ฐ์ดํŠธ
try {
const oscillator = audioContext.createOscillator();
const gainNode = audioContext.createGain();
oscillator.type = 'sine'; // ์‚ฌ์ธํŒŒ
oscillator.frequency.setValueAtTime(600, audioContext.currentTime); // ์ฃผํŒŒ์ˆ˜ (์˜ˆ: 600 Hz)
gainNode.gain.setValueAtTime(0.3, audioContext.currentTime); // ๋ณผ๋ฅจ (0.0 ~ 1.0)
oscillator.connect(gainNode);
gainNode.connect(audioContext.destination);
oscillator.start();
oscillator.stop(audioContext.currentTime + 0.2); // 0.2์ดˆ ์žฌ์ƒ
// console.log("Playing sine wave sound."); // ๋””๋ฒ„๊ทธ ์ถœ๋ ฅ
} catch (e) {
console.error("Error playing sine wave:", e);
}
};
// ๋ฐ์ดํ„ฐ ์ฑ„๋„๋กœ๋ถ€ํ„ฐ ๋ฉ”์‹œ์ง€๋ฅผ ์ˆ˜์‹ ํ–ˆ์„ ๋•Œ ์‹คํ–‰๋  ์ฝœ๋ฐฑ ํ•จ์ˆ˜
channel.onmessage = (event) => {
// console.log("Received message:", event.data); // ์ˆ˜์‹  ๋ฉ”์‹œ์ง€ ํ™•์ธ
if (event.data === "DETECT_CIGARETTE") {
// ์„œ๋ฒ„์—์„œ ๋‹ด๋ฐฐ ํƒ์ง€ ๋ฉ”์‹œ์ง€๋ฅผ ๋ฐ›์œผ๋ฉด ์†Œ๋ฆฌ ์žฌ์ƒ
playSineWaveAlert();
}
};
// ๋ฐ์ดํ„ฐ ์ฑ„๋„์ด ์—ด๋ ธ์„ ๋•Œ
channel.onopen = () => {
console.log("Data channel opened!");
};
// ๋ฐ์ดํ„ฐ ์ฑ„๋„์ด ๋‹ซํ˜”์„ ๋•Œ
channel.onclose = () => {
console.log("Data channel closed.");
};
// ๋ฐ์ดํ„ฐ ์ฑ„๋„ ์—๋Ÿฌ ๋ฐœ์ƒ ์‹œ
channel.onerror = (error) => {
console.error("Data channel error:", error);
};
}
"""
# --- Streamlit ์•ฑ ๋ ˆ์ด์•„์›ƒ ๊ตฌ์„ฑ ---
st.title("๐Ÿšฌ ์‹ค์‹œ๊ฐ„ ๋‹ด๋ฐฐ ํƒ์ง€ ์›น ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜ (ํด๋ผ์ด์–ธํŠธ ์†Œ๋ฆฌ)")
st.write("""
์›น์บ  ํ”ผ๋“œ๋ฅผ ํ†ตํ•ด ๋‹ด๋ฐฐ ๊ฐ์ฒด๋ฅผ ์‹ค์‹œ๊ฐ„์œผ๋กœ ํƒ์ง€ํ•˜๊ณ  ์˜์ƒ์— ํ‘œ์‹œํ•ฉ๋‹ˆ๋‹ค.
๋‹ด๋ฐฐ๊ฐ€ ํƒ์ง€๋˜๋ฉด **์‚ฌ์šฉ์ž์˜ ๋ธŒ๋ผ์šฐ์ €**์—์„œ ์•Œ๋ฆผ ์†Œ๋ฆฌ(์‚ฌ์ธํŒŒ)๊ฐ€ ์žฌ์ƒ๋ฉ๋‹ˆ๋‹ค.
**์ฃผ์˜:**
* ์ด ์•ฑ์€ ์‚ฌ์šฉ์ž์˜ ๋ธŒ๋ผ์šฐ์ € ์›น์บ  ๋ฐ ์˜ค๋””์˜ค ์žฌ์ƒ ๊ถŒํ•œ์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. ๋ธŒ๋ผ์šฐ์ € ์š”์ฒญ ์‹œ ํ—ˆ์šฉํ•ด์ฃผ์„ธ์š”.
* ๋„คํŠธ์›Œํฌ ์ƒํƒœ ๋ฐ ์ปดํ“จํ„ฐ ์„ฑ๋Šฅ์— ๋”ฐ๋ผ ์˜์ƒ ์ฒ˜๋ฆฌ์— ์ง€์—ฐ์ด ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
* `trained_model.pt` ํŒŒ์ผ์ด ์Šคํฌ๋ฆฝํŠธ ํŒŒ์ผ๊ณผ ๊ฐ™์€ ๋””๋ ‰ํ† ๋ฆฌ์— ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”.
""")
st.write("---")
st.subheader("์›น์บ  ์ŠคํŠธ๋ฆผ ๋ฐ ๋‹ด๋ฐฐ ํƒ์ง€ ๊ฒฐ๊ณผ")
# RTC ์„ค์ • (NAT ํ†ต๊ณผ๋ฅผ ์œ„ํ•ด ํ•„์š”, Google STUN ์„œ๋ฒ„ ์‚ฌ์šฉ)
# ๋Œ€๋ถ€๋ถ„์˜ ๊ฒฝ์šฐ ๊ธฐ๋ณธ ์„ค์ •์œผ๋กœ ์ถฉ๋ถ„ํ•˜๋‚˜, ๋ช…์‹œ์ ์œผ๋กœ ์„ค์ •ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
rtc_configuration = RTCConfiguration({"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]})
# Streamlit-WebRTC ์ปดํฌ๋„ŒํŠธ ์ถ”๊ฐ€
webrtc_ctx = webrtc_streamer(
key="yolo-detection-client-sound", # ๊ณ ์œ  ํ‚ค
mode=WebRtcMode.SENDRECV, # ๋น„๋””์˜ค๋ฅผ ๋ณด๋‚ด๊ณ  (SEND) ์„œ๋ฒ„์—์„œ ์ฒ˜๋ฆฌ๋œ ๋น„๋””์˜ค๋ฅผ ๋‹ค์‹œ ๋ฐ›์Œ (RECV)
video_processor_factory=lambda: YOLOVideoTransformer( # ๋น„๋””์˜ค ๋ณ€ํ™˜ ํด๋ž˜์Šค ํŒฉํ† ๋ฆฌ
model=model,
confidence_thresh=CONFIDENCE_THRESHOLD,
send_interval=SEND_ALERT_INTERVAL,
# NOTE: video_processor_factory๊ฐ€ lambda ํ•จ์ˆ˜๋กœ ์‚ฌ์šฉ๋  ๋•Œ,
# webrtc_streamer ๋‚ด๋ถ€์ ์œผ๋กœ ์ƒ์„ฑ๋œ data_channel ๊ฐ์ฒด๊ฐ€ VideoTransformer ์ธ์Šคํ„ด์Šค์— ์ „๋‹ฌ๋ฉ๋‹ˆ๋‹ค.
# ๋ช…์‹œ์ ์œผ๋กœ lambda ์ธ์ž๋กœ channel์„ ๋ฐ›์ง€ ์•Š์•„๋„ ๋ฉ๋‹ˆ๋‹ค.
# (webrtc_streamer์˜ ๊ตฌํ˜„ ๋ฐฉ์‹์— ๋”ฐ๋ผ ๋‹ค๋ฅผ ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ ๋ฌธ์„œ ํ™•์ธ ํ•„์š”)
# ์ตœ์‹  ๋ฒ„์ „์—์„œ๋Š” __init__์— data_channel=data_channel ํ˜•ํƒœ๋กœ ์ „๋‹ฌ๋จ
data_channel=None # ์ดˆ๊ธฐ๊ฐ’์€ None, webrtc_streamer๊ฐ€ ์ธ์Šคํ„ด์Šค ์ƒ์„ฑ ์‹œ ์‹ค์ œ ๊ฐ์ฒด ์ฃผ์ž…
# -> ์•„๋‹˜, lambda ํŒฉํ† ๋ฆฌ๊ฐ€ ์ธ์ž๋ฅผ ๋ฐ›๋„๋ก ๋ณ€๊ฒฝํ•ด์•ผ ํ•จ
# lambda channel: YOLOVideoTransformer(..., data_channel=channel) ์ด ๋” ์ •ํ™•ํ•จ
),
rtc_configuration=rtc_configuration,
media_stream_constraints={"video": True, "audio": False}, # ์›น์บ  ๋น„๋””์˜ค๋งŒ ์‚ฌ์šฉ
async_processing=True, # ๋น„๋””์˜ค ์ฒ˜๋ฆฌ๋ฅผ ๋น„๋™๊ธฐ๋กœ ์‹คํ–‰
on_data_channel=JS_CLIENT_SOUND_SCRIPT # ๋ฐ์ดํ„ฐ ์ฑ„๋„ ๊ด€๋ จ ํด๋ผ์ด์–ธํŠธ JS ์ฝ”๋“œ
)
# ์œ„ video_processor_factory lambda ๋ถ€๋ถ„์„ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ๋ช…์‹œ์ ์œผ๋กœ data_channel์„ ๋ฐ›๋„๋ก ์ˆ˜์ •ํ•ฉ๋‹ˆ๋‹ค.
# lambda channel: YOLOVideoTransformer(
# model=model,
# confidence_thresh=CONFIDENCE_THRESHOLD,
# send_interval=SEND_ALERT_INTERVAL,
# data_channel=channel # ๋ฐ์ดํ„ฐ ์ฑ„๋„ ๊ฐ์ฒด ์ „๋‹ฌ
# ),
st.write("---")
st.info("์›น์บ  ์ŠคํŠธ๋ฆผ์„ ์‹œ์ž‘ํ•˜๋ฉด ๋ธŒ๋ผ์šฐ์ €์—์„œ ๋‹ด๋ฐฐ ํƒ์ง€ ์‹œ ์•Œ๋ฆผ ์†Œ๋ฆฌ๊ฐ€ ์žฌ์ƒ๋ฉ๋‹ˆ๋‹ค.")