Update streamlit_app.py
Browse files- streamlit_app.py +92 -67
streamlit_app.py
CHANGED
@@ -1,21 +1,22 @@
|
|
1 |
import streamlit as st
|
2 |
from transformers import AutoTokenizer, AutoModelForTokenClassification
|
|
|
|
|
|
|
|
|
3 |
import torch
|
4 |
|
5 |
-
#
|
6 |
MODEL_REPO = "zhixiusue/EduTubeNavigator"
|
7 |
|
8 |
-
#
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
4: 'I-STYLE',
|
15 |
-
5: 'B-LENGTH',
|
16 |
-
6: 'I-LENGTH',
|
17 |
-
7: 'B-LANGUAGE',
|
18 |
-
8: 'I-LANGUAGE'
|
19 |
}
|
20 |
|
21 |
@st.cache_resource
|
@@ -24,69 +25,93 @@ def load_model():
|
|
24 |
model = AutoModelForTokenClassification.from_pretrained(MODEL_REPO)
|
25 |
return tokenizer, model
|
26 |
|
27 |
-
tokenizer, model
|
28 |
-
|
29 |
-
|
30 |
-
words = text.split()
|
31 |
-
inputs = tokenizer(words, is_split_into_words=True, return_tensors="pt", truncation=True, max_length=128)
|
32 |
-
|
33 |
model.eval()
|
34 |
with torch.no_grad():
|
35 |
outputs = model(**inputs)
|
36 |
-
|
37 |
-
|
38 |
word_ids = inputs.word_ids(batch_index=0)
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
aligned.append((word, label))
|
45 |
-
|
46 |
-
return aligned
|
47 |
-
|
48 |
-
def extract_entities(aligned_result):
|
49 |
-
entities = []
|
50 |
-
current_entity, current_text = None, ""
|
51 |
-
|
52 |
-
for word, label in aligned_result:
|
53 |
-
if label == "O":
|
54 |
-
if current_entity:
|
55 |
-
entities.append({"entity": current_entity, "text": current_text})
|
56 |
-
current_entity, current_text = None, ""
|
57 |
continue
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
current_entity =
|
64 |
-
|
65 |
-
|
66 |
-
current_text += word
|
67 |
else:
|
68 |
-
if
|
69 |
-
entities
|
70 |
-
|
71 |
-
|
72 |
-
if
|
73 |
-
entities
|
74 |
return entities
|
75 |
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
result_dict[ent["entity"]] = ent["text"]
|
90 |
|
91 |
-
st.subheader("
|
92 |
-
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
from transformers import AutoTokenizer, AutoModelForTokenClassification
|
3 |
+
from sentence_transformers import SentenceTransformer
|
4 |
+
from youtube_transcript_api import YouTubeTranscriptApi
|
5 |
+
from googleapiclient.discovery import build
|
6 |
+
from torch.nn.functional import cosine_similarity
|
7 |
import torch
|
8 |
|
9 |
+
# Hugging Face ๋ชจ๋ธ
|
10 |
MODEL_REPO = "zhixiusue/EduTubeNavigator"
|
11 |
|
12 |
+
# YouTube API Key
|
13 |
+
YOUTUBE_API_KEY = "AIzaSyA8SG7--MfQvWET6UOam0PVAcC5MDm4sbc"
|
14 |
+
youtube = build("youtube", "v3", developerKey=YOUTUBE_API_KEY)
|
15 |
+
|
16 |
+
# ID โ ๋ผ๋ฒจ ๋งคํ
|
17 |
+
token_label_map = {
|
18 |
+
0: 'O', 1: 'B-TOPIC', 2: 'I-TOPIC', 3: 'B-STYLE', 4: 'I-STYLE',
|
19 |
+
5: 'B-LENGTH', 6: 'I-LENGTH', 7: 'B-LANGUAGE', 8: 'I-LANGUAGE'
|
|
|
|
|
|
|
20 |
}
|
21 |
|
22 |
@st.cache_resource
|
|
|
25 |
model = AutoModelForTokenClassification.from_pretrained(MODEL_REPO)
|
26 |
return tokenizer, model
|
27 |
|
28 |
+
def predict_entities(text, tokenizer, model):
|
29 |
+
tokens = list(text)
|
30 |
+
inputs = tokenizer(tokens, is_split_into_words=True, return_tensors="pt", truncation=True, max_length=128)
|
|
|
|
|
|
|
31 |
model.eval()
|
32 |
with torch.no_grad():
|
33 |
outputs = model(**inputs)
|
34 |
+
logits = outputs.logits
|
35 |
+
predictions = torch.argmax(logits, dim=-1)[0]
|
36 |
word_ids = inputs.word_ids(batch_index=0)
|
37 |
+
entities = {}
|
38 |
+
current_entity = ""
|
39 |
+
current_type = ""
|
40 |
+
for idx, word_id in enumerate(word_ids):
|
41 |
+
if word_id is None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
continue
|
43 |
+
label = token_label_map[predictions[idx].item()]
|
44 |
+
if label.startswith("B-"):
|
45 |
+
if current_type:
|
46 |
+
entities[current_type] = current_entity
|
47 |
+
current_type = label[2:]
|
48 |
+
current_entity = tokens[word_id]
|
49 |
+
elif label.startswith("I-") and label[2:] == current_type:
|
50 |
+
current_entity += tokens[word_id]
|
|
|
51 |
else:
|
52 |
+
if current_type:
|
53 |
+
entities[current_type] = current_entity
|
54 |
+
current_type = ""
|
55 |
+
current_entity = ""
|
56 |
+
if current_type:
|
57 |
+
entities[current_type] = current_entity
|
58 |
return entities
|
59 |
|
60 |
+
def search_youtube_videos(query, max_results=10):
|
61 |
+
response = youtube.search().list(q=query, part="snippet", type="video", maxResults=max_results).execute()
|
62 |
+
results = []
|
63 |
+
for item in response['items']:
|
64 |
+
results.append({
|
65 |
+
'video_id': item['id']['videoId'],
|
66 |
+
'title': item['snippet']['title'],
|
67 |
+
'description': item['snippet']['description']
|
68 |
+
})
|
69 |
+
return results
|
70 |
+
|
71 |
+
def get_transcript(video_id):
|
72 |
+
try:
|
73 |
+
transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=['ko', 'en'])
|
74 |
+
return " ".join([t['text'] for t in transcript])
|
75 |
+
except:
|
76 |
+
return ""
|
77 |
|
78 |
+
@st.cache_resource
|
79 |
+
def load_embedder():
|
80 |
+
return SentenceTransformer('snunlp/KR-SBERT-V40K-klueNLI-augSTS')
|
81 |
+
|
82 |
+
def embed_texts(embedder, texts):
|
83 |
+
return embedder.encode(texts, convert_to_tensor=True)
|
84 |
+
|
85 |
+
def recommend_video(embedder, user_conditions, video_infos):
|
86 |
+
user_text = " ".join([v for v in user_conditions.values() if v])
|
87 |
+
user_embedding = embed_texts(embedder, [user_text])[0]
|
88 |
+
|
89 |
+
scored = []
|
90 |
+
for video in video_infos:
|
91 |
+
video_text = video['title'] + " " + video['description'] + " " + get_transcript(video['video_id'])
|
92 |
+
video_embedding = embed_texts(embedder, [video_text])[0]
|
93 |
+
score = cosine_similarity(user_embedding, video_embedding, dim=0).item()
|
94 |
+
scored.append((score, video))
|
95 |
+
return sorted(scored, reverse=True, key=lambda x: x[0])[:3]
|
96 |
+
|
97 |
+
# Streamlit UI
|
98 |
+
st.title("๐ฏ ํ์ต ์กฐ๊ฑด ๊ธฐ๋ฐ ์ ํ๋ธ ์ถ์ฒ๊ธฐ")
|
99 |
+
st.write("ํ์ต ๋ชฉํ๋ฅผ ์
๋ ฅํ๋ฉด ์กฐ๊ฑด์ ์ถ์ถํ๊ณ ์ ํ๋ธ ์์์ ์ถ์ฒํฉ๋๋ค.")
|
100 |
+
|
101 |
+
user_input = st.text_input("๐ฌ ํ์ต ๋ชฉํ๋ฅผ ์
๋ ฅํ์ธ์", "๋ฅ๋ฌ๋์ ์ค์ต ์์ฃผ๋ก 30๋ถ ์์ ๋ฐฐ์ฐ๊ณ ์ถ์ด์")
|
102 |
+
if st.button("๐ ์ถ์ฒ ์์"):
|
103 |
+
tokenizer, model = load_model()
|
104 |
+
embedder = load_embedder()
|
105 |
+
|
106 |
+
entities = predict_entities(user_input, tokenizer, model)
|
107 |
+
st.subheader("๐ ์ถ์ถ๋ ์กฐ๊ฑด")
|
108 |
+
st.json(entities)
|
109 |
|
110 |
+
search_query = " ".join([v for v in entities.values() if v])
|
111 |
+
video_candidates = search_youtube_videos(search_query)
|
112 |
+
top_recommendations = recommend_video(embedder, entities, video_candidates)
|
|
|
113 |
|
114 |
+
st.subheader("๐บ ์ถ์ฒ ์ ํ๋ธ ์์")
|
115 |
+
for score, video in top_recommendations:
|
116 |
+
st.markdown(f"**{video['title']}** ")
|
117 |
+
st.markdown(f"๐ [๋งํฌ](https://www.youtube.com/watch?v={video['video_id']})")
|