Gradio / app.py
ThongCoding's picture
Update app.py
c2cc1e4 verified
raw
history blame
1.79 kB
import os
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Load model and tokenizer
def load_model():
model_name = "viet-ai/vistral-7b-chat" # Vistral của Viet-Mistral
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=os.getenv("HF_AUTH_TOKEN"))
model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=os.getenv("HF_AUTH_TOKEN"))
return model, tokenizer
# Setup and load the model
model, tokenizer = load_model()
# Generate response based on conversation history
def generate(messages):
prompt_text = ""
for message in messages:
role = message["role"]
content = message["content"]
if role == "user":
prompt_text += f"User: {content}\n"
else:
prompt_text += f"Assistant: {content}\n"
prompt_text += "Assistant: " # để chuẩn bị cho model generate tiếp
# Tokenize input prompt
inputs = tokenizer(prompt_text, return_tensors="pt")
# Generate response
with torch.no_grad():
output = model.generate(inputs.input_ids, max_length=512, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id)
response = tokenizer.decode(output[0], skip_special_tokens=True)
return response.strip()
# Gradio interface
def chatbot_interface():
with gr.Blocks() as demo:
gr.Markdown("# Chatbot sử dụng Vistral của Viet-Mistral")
chatbox = gr.Chatbot()
message = gr.Textbox(placeholder="Gửi tin nhắn...")
send_button = gr.Button("Gửi")
send_button.click(generate, inputs=message, outputs=chatbox)
return demo
# Main function to run the app
if __name__ == "__main__":
demo = chatbot_interface()
demo.launch(share=True)