File size: 3,716 Bytes
628d1d2
 
 
5708eb0
628d1d2
 
5708eb0
628d1d2
 
 
 
 
 
 
 
5708eb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
628d1d2
 
 
 
 
84736e5
 
 
 
628d1d2
 
 
5708eb0
 
 
 
 
 
 
 
 
 
 
 
 
 
628d1d2
 
5708eb0
 
 
 
628d1d2
5708eb0
19e83aa
5708eb0
 
628d1d2
 
5708eb0
 
 
 
628d1d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84736e5
628d1d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5708eb0
628d1d2
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import streamlit as st
import numpy as np
import json
from sentence_transformers import SentenceTransformer, util
import time

st.set_page_config(initial_sidebar_state="collapsed")
# データを読み込む
with open("data/qa_data.json", "r", encoding="utf-8") as f:
    data = json.load(f)

questions = [item["question"] for item in data]
answers = [item["answer"] for item in data]


# Cache model ở level app
@st.cache_resource
def load_model():
    return SentenceTransformer("pkshatech/GLuCoSE-base-ja")


# Cache embeddings data
@st.cache_data
def load_embeddings():
    return (
        np.load("data/question_embeddings.npy"),
        np.load("data/answer_embeddings.npy"),
    )


# Load model và embeddings một lần
model = load_model()
question_embeddings, answer_embeddings = load_embeddings()

# サイドバー設定
with st.sidebar.expander("⚙️ 設定", expanded=False):
    threshold_q = st.slider("質問の類似度しきい値", 0.0, 1.0, 0.7, 0.01)
    threshold_a = st.slider("回答の類似度しきい値", 0.0, 1.0, 0.65, 0.01)
    
    if st.button("新しいチャット", use_container_width=True):
        st.session_state.messages = []
        st.rerun()


def search_answer(user_input):
    """Tìm kiếm câu trả lời sử dụng cosine similarity"""
    # Encode với batch_size và show_progress_bar=False để tăng tốc
    user_embedding = model.encode(
        [user_input],
        convert_to_numpy=True,
        batch_size=1,
        show_progress_bar=False,
        normalize_embeddings=True,  # Pre-normalize để tăng tốc cosine similarity
    )

    # Tính cosine similarity với câu hỏi
    cos_scores_q = util.cos_sim(user_embedding, question_embeddings)[0]
    best_q_idx = np.argmax(cos_scores_q)
    score_q = cos_scores_q[best_q_idx]

    if score_q >= threshold_q:
        return (
            answers[best_q_idx].replace("\n", "  \n"),
            f"質問にマッチ ({score_q:.2f})",
        )

    # Tính cosine similarity với câu trả lời
    cos_scores_a = util.cos_sim(user_embedding, answer_embeddings)[0]
    best_a_idx = np.argmax(cos_scores_a)
    score_a = cos_scores_a[best_a_idx]

    if score_a >= threshold_a:
        return (
            answers[best_a_idx].replace("\n", "  \n"),
            f"回答にマッチ ({score_a:.2f})",
        )

    return "申し訳ありませんが、ご質問の答えを見つけることができませんでした。もう少し詳しく説明していただけますか?", "一致なし"


def stream_response(response):
    """レスポンスをストリーム表示する(文字単位)"""
    for char in response:
        if char == "\n":
            # Replace newline with markdown line break
            yield "  \n"
        else:
            yield char
        time.sleep(0.05)


# Streamlitチャットインターフェース
st.title("🤖 よくある質問チャットボット")

if "messages" not in st.session_state:
    st.session_state.messages = []

for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

if user_input := st.chat_input("質問を入力してください:"):
    st.session_state.messages.append({"role": "user", "content": user_input})
    with st.chat_message("user"):
        st.markdown(user_input)

    with st.spinner("考え中... お待ちください。"):
        answer, info = search_answer(user_input)
        print(info)

    with st.chat_message("assistant"):
        response_placeholder = st.empty()
        response_placeholder.write_stream(stream_response(answer))

    st.session_state.messages.append({"role": "assistant", "content": answer})