aixsatoshi's picture
Update app.py
57f7053 verified
raw
history blame contribute delete
5.06 kB
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import gradio as gr
from threading import Thread
model_id = "hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="sequential",
offload_folder="offload", # オフロードフォルダの指定
offload_state_dict=True # 必要に応じてstate_dictをオフロード
)
TITLE = "<h1><center>Meta-Llama-3.1-70B-Instruct-AWQ-INT4 Chat webui</center></h1>"
DESCRIPTION = """
<h3>MODEL: <a href="https://hf.co/hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4">Meta-Llama-3.1-70B-Instruct-AWQ-INT4</a></h3>
<center>
<p>This model is designed for conversational interactions.</p>
</center>
"""
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
h3 {
text-align: center;
}
.chatbox .messages .message.user {
background-color: #e1f5fe;
}
.chatbox .messages .message.bot {
background-color: #eeeeee;
}
"""
@spaces.GPU(duration=120)
def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
print(f'Message: {message}')
print(f'History: {history}')
conversation = []
for prompt, answer in history:
conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(input_ids, return_tensors="pt").to(0)
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
inputs,
streamer=streamer,
top_k=top_k,
top_p=top_p,
repetition_penalty=penalty,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
eos_token_id=[128001, 128009],
)
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
chatbot = gr.Chatbot(height=500)
with gr.Blocks(css=CSS) as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
gr.ChatInterface(
fn=stream_chat,
chatbot=chatbot,
fill_height=True,
theme="soft",
retry_btn=None,
undo_btn="Delete Previous",
clear_btn="Clear",
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.8,
label="Temperature",
render=False,
),
gr.Slider(
minimum=128,
maximum=4096,
step=1,
value=1024,
label="Max new tokens",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.8,
label="top_p",
render=False,
),
gr.Slider(
minimum=1,
maximum=20,
step=1,
value=20,
label="top_k",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.2,
label="Repetition penalty",
render=False,
),
],
examples=[
["Explain Deep Learning as a pirate."],
["Give me five ideas for a child's summer science project."],
["Provide advice for writing a script for a puzzle game."],
["Create a tutorial for building a breakout game using markdown."],
["超能力を持つ主人公のSF物語のシナリオを考えてください。伏線の設定、テーマやログラインを理論的に使用してください"],
["子供の夏休みの自由研究のための、5つのアイデアと、その手法を簡潔に教えてください。"],
["パズルゲームのスクリプト作成のためにアドバイスお願いします"],
["マークダウン記法にて、ブロック崩しのゲーム作成の教科書作成してください"],
["お笑いのトンチ大会のお題を考えてください"],
["日本語の慣用句、ことわざについての試験問題を考えてください"],
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()