|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForTokenClassification |
|
from sentence_transformers import SentenceTransformer |
|
from youtube_transcript_api import YouTubeTranscriptApi |
|
from googleapiclient.discovery import build |
|
from torch.nn.functional import cosine_similarity |
|
import torch |
|
|
|
|
|
MODEL_REPO = "zhixiusue/EduTubeNavigator" |
|
|
|
|
|
YOUTUBE_API_KEY = "AIzaSyA8SG7--MfQvWET6UOam0PVAcC5MDm4sbc" |
|
youtube = build("youtube", "v3", developerKey=YOUTUBE_API_KEY) |
|
|
|
|
|
token_label_map = { |
|
0: 'O', 1: 'B-TOPIC', 2: 'I-TOPIC', 3: 'B-STYLE', 4: 'I-STYLE', |
|
5: 'B-LENGTH', 6: 'I-LENGTH', 7: 'B-LANGUAGE', 8: 'I-LANGUAGE' |
|
} |
|
|
|
@st.cache_resource |
|
def load_model(): |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO) |
|
model = AutoModelForTokenClassification.from_pretrained(MODEL_REPO) |
|
return tokenizer, model |
|
|
|
def predict_entities(text, tokenizer, model): |
|
tokens = list(text) |
|
inputs = tokenizer(tokens, is_split_into_words=True, return_tensors="pt", truncation=True, max_length=128) |
|
model.eval() |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
predictions = torch.argmax(logits, dim=-1)[0] |
|
word_ids = inputs.word_ids(batch_index=0) |
|
entities = {} |
|
current_entity = "" |
|
current_type = "" |
|
for idx, word_id in enumerate(word_ids): |
|
if word_id is None: |
|
continue |
|
label = token_label_map[predictions[idx].item()] |
|
if label.startswith("B-"): |
|
if current_type: |
|
entities[current_type] = current_entity |
|
current_type = label[2:] |
|
current_entity = tokens[word_id] |
|
elif label.startswith("I-") and label[2:] == current_type: |
|
current_entity += tokens[word_id] |
|
else: |
|
if current_type: |
|
entities[current_type] = current_entity |
|
current_type = "" |
|
current_entity = "" |
|
if current_type: |
|
entities[current_type] = current_entity |
|
return entities |
|
|
|
def search_youtube_videos(query, max_results=10): |
|
response = youtube.search().list(q=query, part="snippet", type="video", maxResults=max_results).execute() |
|
results = [] |
|
for item in response['items']: |
|
results.append({ |
|
'video_id': item['id']['videoId'], |
|
'title': item['snippet']['title'], |
|
'description': item['snippet']['description'] |
|
}) |
|
return results |
|
|
|
def get_transcript(video_id): |
|
try: |
|
transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=['ko', 'en']) |
|
return " ".join([t['text'] for t in transcript]) |
|
except: |
|
return "" |
|
|
|
@st.cache_resource |
|
def load_embedder(): |
|
return SentenceTransformer('snunlp/KR-SBERT-V40K-klueNLI-augSTS') |
|
|
|
def embed_texts(embedder, texts): |
|
return embedder.encode(texts, convert_to_tensor=True) |
|
|
|
def recommend_video(embedder, user_conditions, video_infos): |
|
user_text = " ".join([v for v in user_conditions.values() if v]) |
|
user_embedding = embed_texts(embedder, [user_text])[0] |
|
|
|
scored = [] |
|
for video in video_infos: |
|
video_text = video['title'] + " " + video['description'] + " " + get_transcript(video['video_id']) |
|
video_embedding = embed_texts(embedder, [video_text])[0] |
|
score = cosine_similarity(user_embedding, video_embedding, dim=0).item() |
|
scored.append((score, video)) |
|
return sorted(scored, reverse=True, key=lambda x: x[0])[:3] |
|
|
|
|
|
st.title("EduTube Navigator") |
|
st.write("ํ์ต ๋ชฉํ๋ฅผ ์
๋ ฅํ๋ฉด ์กฐ๊ฑด์ ์ถ์ถํ๊ณ ์ ํ๋ธ ์์์ ์ถ์ฒํฉ๋๋ค.") |
|
|
|
user_input = st.text_input("ํ์ต ๋ชฉํ๋ฅผ ์
๋ ฅํ์ธ์", "๋ฅ๋ฌ๋์ ์ค์ต ์์ฃผ๋ก 30๋ถ ์์ ๋ฐฐ์ฐ๊ณ ์ถ์ด์") |
|
if st.button("์ถ์ฒ ์์"): |
|
tokenizer, model = load_model() |
|
embedder = load_embedder() |
|
|
|
entities = predict_entities(user_input, tokenizer, model) |
|
|
|
|
|
|
|
search_query = " ".join([v for v in entities.values() if v]) |
|
video_candidates = search_youtube_videos(search_query) |
|
top_recommendations = recommend_video(embedder, entities, video_candidates) |
|
|
|
st.subheader("์ถ์ฒ ์ ํ๋ธ ์์") |
|
for score, video in top_recommendations: |
|
st.markdown(f"**{video['title']}** ") |
|
st.markdown(f"๐[๋งํฌ](https://www.youtube.com/watch?v={video['video_id']})") |
|
|