import os import torch import time import gradio as gr import spaces from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer import threading from transformers import TextIteratorStreamer import threading from transformers import TextIteratorStreamer import queue class RichTextStreamer(TextIteratorStreamer): def __init__(self, tokenizer, **kwargs): super().__init__(tokenizer, **kwargs) self.token_queue = queue.Queue() def put(self, value): # Instead of just decoding here, we emit full info per token token_id = value.item() if hasattr(value, "item") else value token_str = self.tokenizer.decode([token_id], **self.decode_kwargs) is_special = token_id in self.tokenizer.all_special_ids self.token_queue.put({ "token_id": token_id, "token": token_str, "is_special": is_special }) def __iter__(self): while True: try: token_info = self.token_queue.get(timeout=self.timeout) yield token_info except queue.Empty: if self.end_of_generation.is_set(): break @spaces.GPU def chat_with_model(messages): global current_model, current_tokenizer if current_model is None or current_tokenizer is None: yield messages + [{"role": "assistant", "content": "⚠️ No model loaded."}] return pad_id = current_tokenizer.pad_token_id if pad_id is None: pad_id = current_tokenizer.unk_token_id or 0 prompt = format_prompt(messages) device = torch.device("cuda") current_model.to(device).half() inputs = current_tokenizer(prompt, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} # streamer = TextIteratorStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=False) streamer = RichTextStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=False) generation_kwargs = dict( **inputs, max_new_tokens=256, do_sample=True, streamer=streamer, eos_token_id=current_tokenizer.eos_token_id, pad_token_id=pad_id ) thread = threading.Thread(target=current_model.generate, kwargs=generation_kwargs) thread.start() output_text = "" messages = messages.copy() messages.append({"role": "assistant", "content": ""}) for token_info in streamer: token_str = token_info["token"] is_special = token_info["is_special"] output_text += token_str messages[-1]["content"] = output_text yield messages if is_special and token_info["token_id"] == current_tokenizer.eos_token_id: break current_model.to("cpu") torch.cuda.empty_cache() # Globals current_model = None current_tokenizer = None def load_model_on_selection(model_name, progress=gr.Progress(track_tqdm=False)): global current_model, current_tokenizer token = os.getenv("HF_TOKEN") progress(0, desc="Loading tokenizer...") current_tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=token) progress(0.5, desc="Loading model...") current_model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, device_map="cpu", # loaded to CPU initially use_auth_token=token ) progress(1, desc="Model ready.") return f"{model_name} loaded and ready!" # Format conversation as plain text def format_prompt(messages): prompt = "" for msg in messages: role = msg["role"] if role == "user": prompt += f"User: {msg['content'].strip()}\n" elif role == "assistant": prompt += f"Assistant: {msg['content'].strip()}\n" prompt += "Assistant:" return prompt def add_user_message(user_input, history): return "", history + [{"role": "user", "content": user_input}] # Available models model_choices = [ "meta-llama/Llama-3.2-3B-Instruct", "deepseek-ai/DeepSeek-R1-Distill-Llama-8B", "google/gemma-7b" ] # UI with gr.Blocks() as demo: gr.Markdown("## Clinical Chatbot (Streaming) — LLaMA, DeepSeek, Gemma") default_model = gr.State("meta-llama/Llama-3.2-3B-Instruct") # @spaces.GPU # def chat_with_model(messages): # global current_model, current_tokenizer # if current_model is None or current_tokenizer is None: # yield messages + [{"role": "assistant", "content": "⚠️ No model loaded."}] # return # current_model = current_model.to("cuda").half() # prompt = format_prompt(messages) # inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device) # output_ids = [] # messages = messages.copy() # messages.append({"role": "assistant", "content": ""}) # for token_id in current_model.generate( # **inputs, # max_new_tokens=256, # do_sample=True, # return_dict_in_generate=True, # output_scores=False # ).sequences[0][inputs['input_ids'].shape[-1]:]: # skip input tokens # output_ids.append(token_id.item()) # decoded = current_tokenizer.decode(output_ids, skip_special_tokens=False) # if output_ids[-1] == current_tokenizer.eos_token_id: # current_model.to("cpu") # torch.cuda.empty_cache() # return # messages[-1]["content"] = decoded # yield messages # current_model.to("cpu") # torch.cuda.empty_cache() # return with gr.Row(): model_selector = gr.Dropdown(choices=model_choices, label="Select Model") model_status = gr.Textbox(label="Model Status", interactive=False) chatbot = gr.Chatbot(label="Chat", type="messages") msg = gr.Textbox(label="Your message", placeholder="Enter clinical input...", show_label=False) clear = gr.Button("Clear") # Load default model on startup demo.load(fn=load_model_on_selection, inputs=default_model, outputs=model_status) # Load selected model manually model_selector.change(fn=load_model_on_selection, inputs=model_selector, outputs=model_status) # Submit message + stream model response msg.submit(add_user_message, [msg, chatbot], [msg, chatbot], queue=False).then( chat_with_model, chatbot, chatbot ) # Clear chat clear.click(lambda: [], None, chatbot, queue=False) demo.launch()