Edutube / streamlit_app.py
zhixiusue's picture
Update streamlit_app.py
1184219 verified
import streamlit as st
from transformers import AutoTokenizer, AutoModelForTokenClassification
from sentence_transformers import SentenceTransformer
from youtube_transcript_api import YouTubeTranscriptApi
from googleapiclient.discovery import build
from torch.nn.functional import cosine_similarity
import torch
# Hugging Face ๋ชจ๋ธ
MODEL_REPO = "zhixiusue/EduTubeNavigator"
# YouTube API Key
YOUTUBE_API_KEY = "AIzaSyA8SG7--MfQvWET6UOam0PVAcC5MDm4sbc"
youtube = build("youtube", "v3", developerKey=YOUTUBE_API_KEY)
# ID โ†’ ๋ผ๋ฒจ ๋งคํ•‘
token_label_map = {
0: 'O', 1: 'B-TOPIC', 2: 'I-TOPIC', 3: 'B-STYLE', 4: 'I-STYLE',
5: 'B-LENGTH', 6: 'I-LENGTH', 7: 'B-LANGUAGE', 8: 'I-LANGUAGE'
}
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
model = AutoModelForTokenClassification.from_pretrained(MODEL_REPO)
return tokenizer, model
def predict_entities(text, tokenizer, model):
tokens = list(text)
inputs = tokenizer(tokens, is_split_into_words=True, return_tensors="pt", truncation=True, max_length=128)
model.eval()
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predictions = torch.argmax(logits, dim=-1)[0]
word_ids = inputs.word_ids(batch_index=0)
entities = {}
current_entity = ""
current_type = ""
for idx, word_id in enumerate(word_ids):
if word_id is None:
continue
label = token_label_map[predictions[idx].item()]
if label.startswith("B-"):
if current_type:
entities[current_type] = current_entity
current_type = label[2:]
current_entity = tokens[word_id]
elif label.startswith("I-") and label[2:] == current_type:
current_entity += tokens[word_id]
else:
if current_type:
entities[current_type] = current_entity
current_type = ""
current_entity = ""
if current_type:
entities[current_type] = current_entity
return entities
def search_youtube_videos(query, max_results=10):
response = youtube.search().list(q=query, part="snippet", type="video", maxResults=max_results).execute()
results = []
for item in response['items']:
results.append({
'video_id': item['id']['videoId'],
'title': item['snippet']['title'],
'description': item['snippet']['description']
})
return results
def get_transcript(video_id):
try:
transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=['ko', 'en'])
return " ".join([t['text'] for t in transcript])
except:
return ""
@st.cache_resource
def load_embedder():
return SentenceTransformer('snunlp/KR-SBERT-V40K-klueNLI-augSTS')
def embed_texts(embedder, texts):
return embedder.encode(texts, convert_to_tensor=True)
def recommend_video(embedder, user_conditions, video_infos):
user_text = " ".join([v for v in user_conditions.values() if v])
user_embedding = embed_texts(embedder, [user_text])[0]
scored = []
for video in video_infos:
video_text = video['title'] + " " + video['description'] + " " + get_transcript(video['video_id'])
video_embedding = embed_texts(embedder, [video_text])[0]
score = cosine_similarity(user_embedding, video_embedding, dim=0).item()
scored.append((score, video))
return sorted(scored, reverse=True, key=lambda x: x[0])[:3]
# Streamlit UI
st.title("EduTube Navigator")
st.write("ํ•™์Šต ๋ชฉํ‘œ๋ฅผ ์ž…๋ ฅํ•˜๋ฉด ์กฐ๊ฑด์„ ์ถ”์ถœํ•˜๊ณ  ์œ ํŠœ๋ธŒ ์˜์ƒ์„ ์ถ”์ฒœํ•ฉ๋‹ˆ๋‹ค.")
user_input = st.text_input("ํ•™์Šต ๋ชฉํ‘œ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”", "๋”ฅ๋Ÿฌ๋‹์„ ์‹ค์Šต ์œ„์ฃผ๋กœ 30๋ถ„ ์•ˆ์— ๋ฐฐ์šฐ๊ณ  ์‹ถ์–ด์š”")
if st.button("์ถ”์ฒœ ์‹œ์ž‘"):
tokenizer, model = load_model()
embedder = load_embedder()
entities = predict_entities(user_input, tokenizer, model)
#st.subheader("๐Ÿ“Œ ์ถ”์ถœ๋œ ์กฐ๊ฑด")
#st.json(entities)
search_query = " ".join([v for v in entities.values() if v])
video_candidates = search_youtube_videos(search_query)
top_recommendations = recommend_video(embedder, entities, video_candidates)
st.subheader("์ถ”์ฒœ ์œ ํŠœ๋ธŒ ์˜์ƒ")
for score, video in top_recommendations:
st.markdown(f"**{video['title']}** ")
st.markdown(f"๐Ÿ”—[๋งํฌ](https://www.youtube.com/watch?v={video['video_id']})")