Azure99's picture
Update app.py
3f63d09 verified
raw
history blame contribute delete
3.14 kB
from threading import Thread
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
MAX_NEW_TOKENS = 2048
MODEL_NAME = "Azure99/Blossom-V6-32B-AWQ"
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
def get_input_ids(inst, history):
conversation = []
for user, assistant in history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": inst})
return tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
@spaces.GPU
def chat(inst, history, temperature, top_p, repetition_penalty):
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
input_ids = get_input_ids(inst, history)
generation_kwargs = dict(input_ids=input_ids,
streamer=streamer, do_sample=True, max_new_tokens=MAX_NEW_TOKENS,
temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty)
Thread(target=model.generate, kwargs=generation_kwargs).start()
outputs = ""
for new_text in streamer:
outputs += new_text
yield outputs
additional_inputs = [
gr.Slider(
label="Temperature",
value=0.5,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Controls randomness in choosing words.",
),
gr.Slider(
label="Top-P",
value=0.85,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Picks words until their combined probability is at least top_p.",
),
gr.Slider(
label="Repetition penalty",
value=1.05,
minimum=1.0,
maximum=1.2,
step=0.01,
interactive=True,
info="Repetition Penalty: Controls how much repetition is penalized.",
)
]
gr.ChatInterface(chat,
chatbot=gr.Chatbot(show_label=False, height=500, show_copy_button=True, render_markdown=True),
textbox=gr.Textbox(placeholder="", container=False, scale=7),
title="Blossom-V6-32B-AWQ Demo",
description='Hello, I am Blossom, an open source conversational large language model.🌠'
'<a href="https://github.com/Azure99/BlossomLM">GitHub</a>',
theme="soft",
examples=[["Hello"], ["What is MBTI"], ["用Python实现二分查找"],
["为switch写一篇小红书种草文案,带上emoji"]],
cache_examples=False,
additional_inputs=additional_inputs,
additional_inputs_accordion=gr.Accordion(label="Config", open=True),
clear_btn="🗑️Clear",
undo_btn="↩️Undo",
retry_btn="🔄Retry",
submit_btn="➡️Submit",
).queue().launch()