merterbak's picture
Create app.py
450d1bc verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread
import time
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.chat_template = "{% for message in messages %}<|im_start|>{{ message.role }}\n{{ message.content }}<|im_end|>\n{% endfor %}<|im_start|>assistant\n"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32,
device_map="cpu",
low_cpu_mem_usage=True
).to('cpu')
class deepstreamer(TextIteratorStreamer):
def __init__(self, tokenizer):
super().__init__(tokenizer, skip_prompt=True, skip_special_tokens=True)
self.token_count = 0
self.start_time = None
def put(self, value):
if self.start_time is None:
self.start_time = time.time()
self.token_count += 1
return super().put(value)
def get_tps(self):
if self.start_time is None:
return 0
return self.token_count / (time.time() - self.start_time)
def format_response(text, tps=None): #token per second
return f"{text}\n\n**Tokens per second:** {tps:.2f}" if tps else text
def chat_response(message, history, max_tokens=512):
messages = []
for human, assistant in history:
messages.append({"role": "user", "content": human})
messages.append({"role": "assistant", "content": assistant})
messages.append({"role": "user", "content": message})
formatted_input = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer(formatted_input, return_tensors="pt").to('cpu')
streamer = deepstreamer(tokenizer)
generation_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_tokens,
do_sample=True,
temperature=0.7,
top_p=0.9,
eos_token_id=tokenizer.eos_token_id
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
partial_response = ""
try:
for token in streamer:
partial_response += token
yield format_response(partial_response, streamer.get_tps())
final_tps = streamer.token_count / (time.time() - streamer.start_time)
yield format_response(partial_response, final_tps)
finally:
thread.join()
demo = gr.ChatInterface(
fn=chat_response,
title="DeepSeek-R1-Distill-Qwen-1.5B on CPU",
description="Runnig on CPU so expect less tokens",
examples=[
"Discuss the future of renewable energy",
"What's the history of the Roman Empire?",
"What's the capital of China?",
"Tell me a fun fact about space"
]
)
if __name__ == "__main__":
demo.queue().launch()