OniXinO commited on
Commit
75eb9ca
·
1 Parent(s): 6d7b830

3rd attempt of init pre-release

Browse files
Files changed (1) hide show
  1. app.py +24 -12
app.py CHANGED
@@ -1,10 +1,12 @@
1
  import streamlit as st
2
- from transformers import pipeline
 
3
 
4
  @st.cache_resource
5
  def load_model():
6
- chatbot = pipeline("conversational", model="facebook/blenderbot-400M-distill")
7
- return chatbot
 
8
 
9
  st.title("Український Чат-бот")
10
 
@@ -13,14 +15,24 @@ if "history" not in st.session_state:
13
 
14
  user_input = st.text_input("Ви:", "")
15
 
16
- if st.button("Надіслати"):
17
- chatbot = load_model()
18
- response = chatbot(st.session_state.history + [{"role": "user", "content": user_input}])
19
- st.session_state.history.extend([{"role": "user", "content": user_input}, {"role": "assistant", "content": response.generated_responses[0]}])
 
 
 
 
 
 
20
 
21
  if st.session_state.history:
22
- for message in st.session_state.history:
23
- if message["role"] == "user":
24
- st.write(f"Ви: {message['content']}")
25
- else:
26
- st.write(f"Бот: {message['content']}")
 
 
 
 
 
1
  import streamlit as st
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
 
5
  @st.cache_resource
6
  def load_model():
7
+ tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
8
+ model = AutoModelForCausalLM.from_pretrained("facebook/blenderbot-400M-distill")
9
+ return tokenizer, model
10
 
11
  st.title("Український Чат-бот")
12
 
 
15
 
16
  user_input = st.text_input("Ви:", "")
17
 
18
+ tokenizer, model = load_model()
19
+
20
+ if st.button("Надіслати") or st.session_state.get("enter_pressed", False):
21
+ st.session_state.enter_pressed = False
22
+ if user_input:
23
+ inputs = tokenizer(st.session_state.history + [user_input], return_tensors="pt")
24
+ with torch.no_grad():
25
+ outputs = model.generate(**inputs, max_length=100)
26
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
27
+ st.session_state.history.extend([user_input, response])
28
 
29
  if st.session_state.history:
30
+ for i in range(0, len(st.session_state.history), 2):
31
+ st.write(f"Ви: {st.session_state.history[i]}")
32
+ if i + 1 < len(st.session_state.history):
33
+ st.write(f"Бот: {st.session_state.history[i+1]}")
34
+
35
+ def set_enter_pressed():
36
+ st.session_state.enter_pressed = True
37
+
38
+ st.text_input("Ви:", key="user_input_enter", on_change=set_enter_pressed)