ThongCoding commited on
Commit
c2cc1e4
·
verified ·
1 Parent(s): e603c7b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -38
app.py CHANGED
@@ -1,43 +1,55 @@
1
  import os
2
- import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
4
- import spaces
5
  import gradio as gr
 
 
6
 
7
- @spaces.GPU
8
  def load_model():
9
- model_id = "microsoft/phi-2"
10
- access_token = os.environ.get("HF_AUTH_TOKEN")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- tokenizer = AutoTokenizer.from_pretrained(model_id, token=access_token)
13
- model = AutoModelForCausalLM.from_pretrained(
14
- model_id,
15
- device_map="auto",
16
- torch_dtype=torch.float16,
17
- use_auth_token=access_token
18
- )
19
- streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
20
- return model, tokenizer, streamer
21
-
22
- model, tokenizer, streamer = load_model()
23
-
24
- def generate(prompt, history):
25
- messages = [{"role": "user", "content": prompt}]
26
- prompt_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
27
-
28
- inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)
29
- output = model.generate(
30
- **inputs,
31
- max_new_tokens=512,
32
- do_sample=True,
33
- temperature=0.8,
34
- top_p=0.95,
35
- top_k=50,
36
- streamer=streamer
37
- )
38
- decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)
39
- # Tách phần phản hồi ra khỏi prompt
40
- response = decoded_output.split(prompt_text)[-1].strip()
41
- return response
42
-
43
- gr.ChatInterface(generate, title="💬 Chatbot Phi-2").launch()
 
1
  import os
 
 
 
2
  import gradio as gr
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import torch
5
 
6
+ # Load model and tokenizer
7
  def load_model():
8
+ model_name = "viet-ai/vistral-7b-chat" # Vistral của Viet-Mistral
9
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=os.getenv("HF_AUTH_TOKEN"))
10
+ model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=os.getenv("HF_AUTH_TOKEN"))
11
+ return model, tokenizer
12
+
13
+ # Setup and load the model
14
+ model, tokenizer = load_model()
15
+
16
+ # Generate response based on conversation history
17
+ def generate(messages):
18
+ prompt_text = ""
19
+ for message in messages:
20
+ role = message["role"]
21
+ content = message["content"]
22
+ if role == "user":
23
+ prompt_text += f"User: {content}\n"
24
+ else:
25
+ prompt_text += f"Assistant: {content}\n"
26
+ prompt_text += "Assistant: " # để chuẩn bị cho model generate tiếp
27
+
28
+ # Tokenize input prompt
29
+ inputs = tokenizer(prompt_text, return_tensors="pt")
30
+
31
+ # Generate response
32
+ with torch.no_grad():
33
+ output = model.generate(inputs.input_ids, max_length=512, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id)
34
+
35
+ response = tokenizer.decode(output[0], skip_special_tokens=True)
36
 
37
+ return response.strip()
38
+
39
+ # Gradio interface
40
+ def chatbot_interface():
41
+ with gr.Blocks() as demo:
42
+ gr.Markdown("# Chatbot sử dụng Vistral của Viet-Mistral")
43
+ chatbox = gr.Chatbot()
44
+ message = gr.Textbox(placeholder="Gửi tin nhắn...")
45
+ send_button = gr.Button("Gửi")
46
+
47
+ send_button.click(generate, inputs=message, outputs=chatbox)
48
+
49
+ return demo
50
+
51
+ # Main function to run the app
52
+ if __name__ == "__main__":
53
+ demo = chatbot_interface()
54
+ demo.launch(share=True)
55
+