File size: 1,594 Bytes
a96c72f c585309 75eb9ca a96c72f 6d7b830 c585309 75eb9ca a96c72f 6d7b830 a96c72f 6d7b830 a96c72f 9000ced a96c72f 75eb9ca 56843e1 9000ced 75eb9ca 9000ced 56843e1 6d7b830 75eb9ca 56843e1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import streamlit as st
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained("google/mt5-small", padding_side="left")
model = AutoModelForSeq2SeqLM.from_pretrained("google/mt5-small")
return tokenizer, model
st.title("Український Чат-бот")
if "history" not in st.session_state:
st.session_state.history = []
if "user_input" not in st.session_state:
st.session_state.user_input = ""
tokenizer, model = load_model()
def send_message():
if st.session_state.user_input:
inputs = tokenizer(st.session_state.history + [st.session_state.user_input], return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model.generate(**inputs, max_length=100)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
st.session_state.history.extend([st.session_state.user_input, response])
st.session_state.user_input = "" # clear the input field after sending
def update_user_input():
st.session_state.user_input = st.session_state.temp_user_input # update the user input on change
st.text_input("Ви:", key="temp_user_input", on_change=update_user_input)
if st.button("Надіслати"):
send_message()
if st.session_state.history:
for i in range(0, len(st.session_state.history), 2):
st.write(f"Ви: {st.session_state.history[i]}")
if i + 1 < len(st.session_state.history):
st.write(f"Бот: {st.session_state.history[i+1]}") |