Spaces:
Sleeping
Sleeping
improve cache model
Browse files- app.py +47 -21
- data/answer_embeddings.npy +3 -0
- data/faiss_answer.index +0 -0
- data/faiss_question.index +0 -0
- data/question_embeddings.npy +3 -0
- preprocess.py +6 -14
app.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
import streamlit as st
|
2 |
-
import faiss
|
3 |
import numpy as np
|
4 |
import json
|
5 |
-
from sentence_transformers import SentenceTransformer
|
6 |
import time
|
7 |
|
|
|
8 |
# データを読み込む
|
9 |
with open("data/qa_data.json", "r", encoding="utf-8") as f:
|
10 |
data = json.load(f)
|
@@ -12,15 +12,27 @@ with open("data/qa_data.json", "r", encoding="utf-8") as f:
|
|
12 |
questions = [item["question"] for item in data]
|
13 |
answers = [item["answer"] for item in data]
|
14 |
|
15 |
-
# 埋め込みモデルをロード
|
16 |
-
model = SentenceTransformer("pkshatech/GLuCoSE-base-ja")
|
17 |
|
18 |
-
#
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
# サイドバー設定
|
23 |
-
st.set_page_config(initial_sidebar_state="collapsed")
|
24 |
with st.sidebar.expander("⚙️ 設定", expanded=False):
|
25 |
threshold_q = st.slider("質問の類似度しきい値", 0.0, 1.0, 0.7, 0.01)
|
26 |
threshold_a = st.slider("回答の類似度しきい値", 0.0, 1.0, 0.65, 0.01)
|
@@ -31,24 +43,37 @@ with st.sidebar.expander("⚙️ 設定", expanded=False):
|
|
31 |
|
32 |
|
33 |
def search_answer(user_input):
|
34 |
-
"""
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
if score_q >= threshold_q:
|
42 |
-
|
43 |
-
|
|
|
|
|
44 |
|
45 |
-
#
|
46 |
-
|
47 |
-
|
|
|
48 |
|
49 |
if score_a >= threshold_a:
|
50 |
-
|
51 |
-
|
|
|
|
|
52 |
|
53 |
return "申し訳ありませんが、ご質問の答えを見つけることができませんでした。もう少し詳しく説明していただけますか?", "一致なし"
|
54 |
|
@@ -81,6 +106,7 @@ if user_input := st.chat_input("質問を入力してください:"):
|
|
81 |
|
82 |
with st.spinner("考え中... お待ちください。"):
|
83 |
answer, info = search_answer(user_input)
|
|
|
84 |
|
85 |
with st.chat_message("assistant"):
|
86 |
response_placeholder = st.empty()
|
|
|
1 |
import streamlit as st
|
|
|
2 |
import numpy as np
|
3 |
import json
|
4 |
+
from sentence_transformers import SentenceTransformer, util
|
5 |
import time
|
6 |
|
7 |
+
st.set_page_config(initial_sidebar_state="collapsed")
|
8 |
# データを読み込む
|
9 |
with open("data/qa_data.json", "r", encoding="utf-8") as f:
|
10 |
data = json.load(f)
|
|
|
12 |
questions = [item["question"] for item in data]
|
13 |
answers = [item["answer"] for item in data]
|
14 |
|
|
|
|
|
15 |
|
16 |
+
# Cache model ở level app
|
17 |
+
@st.cache_resource
|
18 |
+
def load_model():
|
19 |
+
return SentenceTransformer("pkshatech/GLuCoSE-base-ja")
|
20 |
+
|
21 |
+
|
22 |
+
# Cache embeddings data
|
23 |
+
@st.cache_data
|
24 |
+
def load_embeddings():
|
25 |
+
return (
|
26 |
+
np.load("data/question_embeddings.npy"),
|
27 |
+
np.load("data/answer_embeddings.npy"),
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
# Load model và embeddings một lần
|
32 |
+
model = load_model()
|
33 |
+
question_embeddings, answer_embeddings = load_embeddings()
|
34 |
|
35 |
# サイドバー設定
|
|
|
36 |
with st.sidebar.expander("⚙️ 設定", expanded=False):
|
37 |
threshold_q = st.slider("質問の類似度しきい値", 0.0, 1.0, 0.7, 0.01)
|
38 |
threshold_a = st.slider("回答の類似度しきい値", 0.0, 1.0, 0.65, 0.01)
|
|
|
43 |
|
44 |
|
45 |
def search_answer(user_input):
|
46 |
+
"""Tìm kiếm câu trả lời sử dụng cosine similarity"""
|
47 |
+
# Encode với batch_size và show_progress_bar=False để tăng tốc
|
48 |
+
user_embedding = model.encode(
|
49 |
+
[user_input],
|
50 |
+
convert_to_numpy=True,
|
51 |
+
batch_size=1,
|
52 |
+
show_progress_bar=False,
|
53 |
+
normalize_embeddings=True, # Pre-normalize để tăng tốc cosine similarity
|
54 |
+
)
|
55 |
+
|
56 |
+
# Tính cosine similarity với câu hỏi
|
57 |
+
cos_scores_q = util.cos_sim(user_embedding, question_embeddings)[0]
|
58 |
+
best_q_idx = np.argmax(cos_scores_q)
|
59 |
+
score_q = cos_scores_q[best_q_idx]
|
60 |
|
61 |
if score_q >= threshold_q:
|
62 |
+
return (
|
63 |
+
answers[best_q_idx].replace("\n", " \n"),
|
64 |
+
f"質問にマッチ ({score_q:.2f})",
|
65 |
+
)
|
66 |
|
67 |
+
# Tính cosine similarity với câu trả lời
|
68 |
+
cos_scores_a = model.util.cos_sim(user_embedding, answer_embeddings)[0]
|
69 |
+
best_a_idx = np.argmax(cos_scores_a)
|
70 |
+
score_a = cos_scores_a[best_a_idx]
|
71 |
|
72 |
if score_a >= threshold_a:
|
73 |
+
return (
|
74 |
+
answers[best_a_idx].replace("\n", " \n"),
|
75 |
+
f"回答にマッチ ({score_a:.2f})",
|
76 |
+
)
|
77 |
|
78 |
return "申し訳ありませんが、ご質問の答えを見つけることができませんでした。もう少し詳しく説明していただけますか?", "一致なし"
|
79 |
|
|
|
106 |
|
107 |
with st.spinner("考え中... お待ちください。"):
|
108 |
answer, info = search_answer(user_input)
|
109 |
+
print(info)
|
110 |
|
111 |
with st.chat_message("assistant"):
|
112 |
response_placeholder = st.empty()
|
data/answer_embeddings.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:925632dc69ab4df0223970df60cc9054dd46e2958e597e6998514bd3b33fc703
|
3 |
+
size 67712
|
data/faiss_answer.index
DELETED
Binary file (67.6 kB)
|
|
data/faiss_question.index
DELETED
Binary file (67.6 kB)
|
|
data/question_embeddings.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:45127b6e8615f93324b2debb37305b93d3963c5f91f054f9c56def8cd00c1ca5
|
3 |
+
size 67712
|
preprocess.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import json
|
2 |
-
import faiss
|
3 |
import numpy as np
|
4 |
from sentence_transformers import SentenceTransformer
|
5 |
|
@@ -14,20 +13,13 @@ answers = [item["answer"] for item in data]
|
|
14 |
model = SentenceTransformer("pkshatech/GLuCoSE-base-ja")
|
15 |
|
16 |
# Tạo embedding cho câu hỏi và câu trả lời
|
17 |
-
question_embeddings = model.encode(questions)
|
18 |
-
answer_embeddings = model.encode(answers)
|
19 |
|
20 |
-
# Lưu
|
21 |
-
|
22 |
-
|
23 |
-
index_a = faiss.IndexFlatL2(dim)
|
24 |
-
|
25 |
-
index_q.add(np.array(question_embeddings).astype(np.float32))
|
26 |
-
index_a.add(np.array(answer_embeddings).astype(np.float32))
|
27 |
-
|
28 |
-
faiss.write_index(index_q, "faiss_question.index")
|
29 |
-
faiss.write_index(index_a, "faiss_answer.index")
|
30 |
|
31 |
# Lưu dữ liệu gốc
|
32 |
-
with open("qa_data.json", "w", encoding="utf-8") as f:
|
33 |
json.dump(data, f, ensure_ascii=False, indent=2)
|
|
|
1 |
import json
|
|
|
2 |
import numpy as np
|
3 |
from sentence_transformers import SentenceTransformer
|
4 |
|
|
|
13 |
model = SentenceTransformer("pkshatech/GLuCoSE-base-ja")
|
14 |
|
15 |
# Tạo embedding cho câu hỏi và câu trả lời
|
16 |
+
question_embeddings = model.encode(questions, convert_to_numpy=True)
|
17 |
+
answer_embeddings = model.encode(answers, convert_to_numpy=True)
|
18 |
|
19 |
+
# Lưu embedding dưới dạng numpy array
|
20 |
+
np.save("data/question_embeddings.npy", question_embeddings)
|
21 |
+
np.save("data/answer_embeddings.npy", answer_embeddings)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
# Lưu dữ liệu gốc
|
24 |
+
with open("data/qa_data.json", "w", encoding="utf-8") as f:
|
25 |
json.dump(data, f, ensure_ascii=False, indent=2)
|