mathstral_test / app.py
MarcdeFalco's picture
Fix
43b206a verified
raw
history blame contribute delete
2.77 kB
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from threading import Thread
import gradio as gr
import torch
import os
device = "cuda"
model_name = "mistralai/mathstral-7B-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name,
torch_dtype=torch.float16).to(device)
HF_TOKEN = os.environ['HF_TOKEN']
def format_prompt(message, history):
prompt = ""
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response} "
prompt += f"[INST] {message} [/INST]"
return prompt
@spaces.GPU
def generate(prompt, history,
max_new_tokens=1024,
repetition_penalty=1.2):
formatted_prompt = format_prompt(prompt, history)
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
streamer = TextIteratorStreamer(tokenizer)
generate_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
)
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
text = ''
n = len('<s>') + len(formatted_prompt)
for word in streamer:
text += word
yield text[n:]
return text[n:]
additional_inputs=[
gr.Slider(
label="Max new tokens",
value=1024,
minimum=0,
maximum=4096,
step=256,
interactive=True,
info="The maximum numbers of new tokens",
),
gr.Slider(
label="Repetition penalty",
value=1.2,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
),
]
css = """
#mkd {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML("<h1><center>Mathstral Test</center><h1>")
gr.HTML("<h3><center>Dans cette démo, vous pouvez poser des questions mathématiques et scientifiques à Mathstral. 🧮</center><h3>")
gr.ChatInterface(
generate,
additional_inputs=additional_inputs,
theme = gr.themes.Soft(),
cache_examples=False,
examples=[ [l.strip()] for l in open("exercices.md").readlines()],
chatbot = gr.Chatbot(
latex_delimiters=[
{"left" : "$$", "right": "$$", "display": True },
{"left" : "\\[", "right": "\\]", "display": True },
{"left" : "\\(", "right": "\\)", "display": False },
{"left": "$", "right": "$", "display": False }
]
)
)
demo.queue(max_size=100).launch(debug=True)