File size: 2,528 Bytes
b7b138d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from transformers import AutoModel, AutoTokenizer
import streamlit as st


st.set_page_config(
    page_title="ChatGLM2-6b ๆผ”็คบ",
    page_icon=":robot:",
    layout='wide'
)


@st.cache_resource
def get_model():
    tokenizer = AutoTokenizer.from_pretrained(r"/mnt/workspace/chatglm2-6b", trust_remote_code=True)
    model = AutoModel.from_pretrained(r"/mnt/workspace/chatglm2-6b", trust_remote_code=True).cuda()
    # ๅคšๆ˜พๅกๆ”ฏๆŒ๏ผŒไฝฟ็”จไธ‹้ขไธค่กŒไปฃๆ›ฟไธŠ้ขไธ€่กŒ๏ผŒๅฐ†num_gpusๆ”นไธบไฝ ๅฎž้™…็š„ๆ˜พๅกๆ•ฐ้‡
    # from utils import load_model_on_gpus
    # model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2)
    model = model.eval()
    return tokenizer, model


tokenizer, model = get_model()

st.title("ChatGLM2-6B")

max_length = st.sidebar.slider(
    'max_length', 0, 32768, 8192, step=1
)
top_p = st.sidebar.slider(
    'top_p', 0.0, 1.0, 0.8, step=0.01
)
temperature = st.sidebar.slider(
    'temperature', 0.0, 1.0, 0.8, step=0.01
)

if 'history' not in st.session_state:
    st.session_state.history = []

if 'past_key_values' not in st.session_state:
    st.session_state.past_key_values = None

for i, (query, response) in enumerate(st.session_state.history):
    with st.chat_message(name="user", avatar="user"):
        st.markdown(query)
    with st.chat_message(name="assistant", avatar="assistant"):
        st.markdown(response)
with st.chat_message(name="user", avatar="user"):
    input_placeholder = st.empty()
with st.chat_message(name="assistant", avatar="assistant"):
    message_placeholder = st.empty()

prompt_text = st.text_area(label="็”จๆˆทๅ‘ฝไปค่พ“ๅ…ฅ",
                           height=100,
                           placeholder="่ฏทๅœจ่ฟ™ๅ„ฟ่พ“ๅ…ฅๆ‚จ็š„ๅ‘ฝไปค")

button = st.button("ๅ‘้€", key="predict")

if button:
    input_placeholder.markdown(prompt_text)
    history, past_key_values = st.session_state.history, st.session_state.past_key_values
    for response, history, past_key_values in model.stream_chat(tokenizer, prompt_text, history,
                                                                past_key_values=past_key_values,
                                                                max_length=max_length, top_p=top_p,
                                                                temperature=temperature,
                                                                return_past_key_values=True):
        message_placeholder.markdown(response)

    st.session_state.history = history
    st.session_state.past_key_values = past_key_values