File size: 1,803 Bytes
a96c72f
c585309
75eb9ca
a96c72f
6d7b830
 
22986a6
 
75eb9ca
a96c72f
6d7b830
a96c72f
6d7b830
 
a96c72f
9000ced
 
a96c72f
75eb9ca
 
56843e1
9000ced
 
75eb9ca
 
 
9000ced
0cf9b7f
9000ced
 
e22ba0b
9000ced
0cf9b7f
9417eab
0cf9b7f
 
14e602c
0cf9b7f
 
 
 
6d7b830
 
75eb9ca
 
 
56843e1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import streamlit as st
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch

@st.cache_resource
def load_model():
    tokenizer = AutoTokenizer.from_pretrained("google/mt5-base", padding_side="left", use_fast=False)
    model = AutoModelForSeq2SeqLM.from_pretrained("google/mt5-base")
    return tokenizer, model

st.title("Український Чат-бот")

if "history" not in st.session_state:
    st.session_state.history = []

if "user_input" not in st.session_state:
    st.session_state.user_input = ""

tokenizer, model = load_model()

def send_message():
    if st.session_state.user_input:
        inputs = tokenizer(st.session_state.history + [st.session_state.user_input], return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            outputs = model.generate(**inputs, max_length=100)
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        st.session_state.history.extend([st.session_state.user_input, response])
        st.session_state.user_input = ""

def update_user_input():
    st.session_state.user_input = st.session_state.temp_user_input

st.text_input("Ви:", key="temp_user_input", on_change=update_user_input)

if st.button("Надіслати"):
    send_message()

# Обробка натискання Enter
if st.session_state.get("temp_user_input") and st.session_state.get("last_input", "") != st.session_state.get("temp_user_input"):
    st.session_state["last_input"] = st.session_state["temp_user_input"]
    send_message()

if st.session_state.history:
    for i in range(0, len(st.session_state.history), 2):
        st.write(f"Ви: {st.session_state.history[i]}")
        if i + 1 < len(st.session_state.history):
            st.write(f"Бот: {st.session_state.history[i+1]}")