import os import sys import time import argparse import logging import numpy as np import onnxruntime as ort from transformers import AutoTokenizer from tqdm import tqdm # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) logger = logging.getLogger(__name__) class ONNXGenerationChatbot: def __init__(self, model_path, max_length=100): """ Initialize the ONNX chatbot for text generation. Args: model_path: Path to the directory containing the ONNX model and tokenizer max_length: Maximum sequence length for generation """ # Set up model paths self.model_dir = model_path self.onnx_path = os.path.join(self.model_dir, "model.onnx") self.fp32_path = os.path.join(self.model_dir, "model_fp32.onnx") # Check for model files if not os.path.exists(self.onnx_path): raise FileNotFoundError(f"ONNX model not found at {self.onnx_path}") # Get model name for prompt formatting self.model_name = os.path.basename(os.path.normpath(model_path)) logger.info(f"Using model: {self.model_name}") # Load tokenizer logger.info(f"Loading tokenizer from {self.model_dir}...") self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir, local_files_only=True) # Ensure tokenizer has necessary tokens if self.tokenizer.pad_token is None and hasattr(self.tokenizer, 'eos_token'): self.tokenizer.pad_token = self.tokenizer.eos_token # Create optimized session logger.info(f"Loading ONNX model from {self.onnx_path}...") self.session_options = ort.SessionOptions() self.session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL self.session_options.intra_op_num_threads = 4 # Adjust based on your CPU # Create session with appropriate providers providers = ['CPUExecutionProvider'] if 'CUDAExecutionProvider' in ort.get_available_providers(): logger.info("CUDA is available! Using GPU acceleration.") providers.insert(0, 'CUDAExecutionProvider') self.session = ort.InferenceSession( self.onnx_path, sess_options=self.session_options, providers=providers ) # Get input and output names from the model self.input_names = [input.name for input in self.session.get_inputs()] self.output_names = [output.name for output in self.session.get_outputs()] logger.info(f"Model inputs: {self.input_names}") logger.info(f"Model outputs: {self.output_names}") # Settings self.max_length = max_length self.stop_tokens = [self.tokenizer.eos_token_id] if self.tokenizer.eos_token_id is not None else [] # Try to add common stop tokens if they exist in the vocabulary stop_words = ["<|endoftext|>", "", "<|end|>"] for word in stop_words: try: token_id = self.tokenizer.convert_tokens_to_ids(word) if token_id not in self.stop_tokens and token_id != self.tokenizer.unk_token_id: self.stop_tokens.append(token_id) except: pass logger.info(f"Using stop tokens: {self.stop_tokens}") # Conversation history for context self.conversation_history = [] def get_prompt_template(self): """ Get the appropriate prompt template based on the model type. """ if "opt" in self.model_name.lower(): return "Human: {}\nAssistant:" elif "pythia" in self.model_name.lower(): return "USER: {}\nASSISTANT:" elif "llama" in self.model_name.lower() or "alpaca" in self.model_name.lower(): return "### Human: {}\n### Assistant:" elif "gpt2" in self.model_name.lower() or "distilgpt2" in self.model_name.lower(): return "User: {}\nBot:" else: return "Question: {}\nAnswer:" def format_prompt_with_history(self, user_message): """ Format the prompt with conversation history for better context. """ template = self.get_prompt_template() parts = template.split("{}") prefix = parts[0] suffix = parts[1] if len(parts) > 1 else "" # Include history if available (up to 3 turns) formatted_prompt = "" for i, (user, bot) in enumerate(self.conversation_history[-3:]): formatted_prompt += f"{prefix}{user}{suffix} {bot}\n\n" # Add current user message formatted_prompt += f"{prefix}{user_message}{suffix}" return formatted_prompt def run_inference_step(self, input_ids, attention_mask=None): """ Run a single inference step with the ONNX model. Args: input_ids: Token IDs of the input sequence attention_mask: Attention mask for the input sequence Returns: numpy array: Logits for the next token prediction """ # Prepare model inputs model_inputs = {} for name in self.input_names: if name == "input_ids": model_inputs[name] = input_ids elif name == "attention_mask" and attention_mask is not None: model_inputs[name] = attention_mask # Run inference outputs = self.session.run(self.output_names, model_inputs) # Return logits (assumes first output is logits) return outputs[0] def generate_text(self, prompt, max_new_tokens=50, temperature=0.7, top_k=50, top_p=0.9, repetition_penalty=1.1, do_sample=True, show_progress=True): """ Generate text using the ONNX model. Args: prompt: Text prompt to generate from max_new_tokens: Maximum number of tokens to generate temperature: Temperature for sampling (higher = more random) top_k: Number of highest probability tokens to keep for sampling top_p: Cumulative probability threshold for nucleus sampling repetition_penalty: Penalty for repeating tokens do_sample: Whether to sample from the distribution or use greedy decoding show_progress: Whether to show a progress bar during generation Returns: str: Generated text """ # Encode the prompt encoded = self.tokenizer(prompt, return_tensors="np") input_ids = encoded["input_ids"] attention_mask = encoded["attention_mask"] # Track input tokens for repetition penalty prev_tokens = input_ids[0].tolist() # Setup progress bar if requested progress = tqdm(total=max_new_tokens, desc="Generating") if show_progress else None # Generate tokens auto-regressively for _ in range(max_new_tokens): # Run inference to get next token logits logits = self.run_inference_step(input_ids, attention_mask) # Get logits for the last token next_token_logits = logits[0, -1, :] # Apply temperature scaling if temperature > 0: next_token_logits = next_token_logits / max(temperature, 1e-8) # Apply repetition penalty if repetition_penalty > 1.0: for prev_token in set(prev_tokens[-10:]): # Only consider recent tokens if prev_token < len(next_token_logits): next_token_logits[prev_token] /= repetition_penalty # Apply top-k filtering if top_k > 0: indices_to_remove = np.argsort(next_token_logits)[:-top_k] next_token_logits[indices_to_remove] = -float('inf') # Apply top-p (nucleus) filtering if 0 < top_p < 1.0: sorted_logits = np.sort(next_token_logits)[::-1] sorted_indices = np.argsort(next_token_logits)[::-1] cumulative_probs = np.cumsum(np.exp(sorted_logits) / np.sum(np.exp(sorted_logits))) # Remove tokens with cumulative probability above the threshold sorted_indices_to_remove = sorted_indices[cumulative_probs > top_p] next_token_logits[sorted_indices_to_remove] = -float('inf') # Sample from the filtered distribution or use greedy decoding if do_sample: # Apply softmax to get probabilities probs = np.exp(next_token_logits - np.max(next_token_logits)) probs = probs / np.sum(probs) # Handle NaNs if np.isnan(probs).any(): next_token_id = np.argmax(next_token_logits) else: try: # Sample from the distribution next_token_id = np.random.choice(len(probs), p=probs) except: # Fallback to greedy if sampling fails next_token_id = np.argmax(next_token_logits) else: # Greedy decoding - take highest probability token next_token_id = np.argmax(next_token_logits) # Add the chosen token to the input next_token = np.array([[next_token_id]]) input_ids = np.concatenate([input_ids, next_token], axis=1) # Update attention mask attention_mask = np.ones((1, input_ids.shape[1]), dtype=np.int64) # Add token to history for repetition penalty prev_tokens.append(int(next_token_id)) # Update progress bar if active if progress is not None: progress.update(1) # Check for stop tokens or end of text if next_token_id in self.stop_tokens: break # Also stop if we exceed max length if input_ids.shape[1] >= self.max_length: break # Close progress bar if used if progress is not None: progress.close() # Decode the full sequence generated_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True) return generated_text def extract_assistant_response(self, full_text, prompt): """ Extract just the assistant's response from the full generated text. Args: full_text: Full generated text including prompt prompt: The original prompt Returns: str: Just the assistant's response """ # Try to extract based on the prompt format template = self.get_prompt_template() response_start_marker = template.split("{}")[-1] # If the prompt is in the text, extract everything after it if prompt in full_text: after_prompt = full_text[len(prompt):] # Handle additional newlines or spaces at the beginning return after_prompt.lstrip() # If the response marker is in the text, extract everything after it if response_start_marker.strip() in full_text: parts = full_text.split(response_start_marker.strip(), 1) if len(parts) > 1: return parts[1].strip() # Fallback: return everything after the last line of the prompt prompt_last_line = prompt.strip().split('\n')[-1] if prompt_last_line in full_text: parts = full_text.split(prompt_last_line, 1) if len(parts) > 1: return parts[1].strip() # Last resort: return the whole thing return full_text def chat(self, temperature=0.7, max_new_tokens=100): """ Run an interactive chat session with the model. Args: temperature: Temperature for text generation max_new_tokens: Maximum number of tokens to generate per response """ print("\n===== ONNX Generation Chatbot =====") print(f"Model: {self.model_name}") print(f"Type 'exit' to end the conversation") print(f"Type 'reset' to clear conversation history") while True: # Get user input user_input = input("\nYou: ") # Check for exit command if user_input.lower() in ["exit", "quit", "bye"]: print("Goodbye!") break # Check for reset command if user_input.lower() == "reset": self.conversation_history = [] print("Conversation history cleared.") continue # Create prompt with history prompt = self.format_prompt_with_history(user_input) print("\nGenerating response...") # Generate text try: start_time = time.time() full_text = self.generate_text( prompt, max_new_tokens=max_new_tokens, temperature=temperature, show_progress=True ) # Extract just the assistant's response response = self.extract_assistant_response(full_text, prompt) # Clean up any trailing incomplete sentences if response and len(response) > 0: # Try to end at a sentence boundary if possible sentence_end = max( response.rfind('.'), response.rfind('!'), response.rfind('?') ) if sentence_end > len(response) * 0.5: # Only trim if we're not losing too much response = response[:sentence_end+1] # Calculate generation time gen_time = time.time() - start_time gen_speed = max_new_tokens / gen_time if gen_time > 0 else 0 # Print the response print(f"\nBot: {response}") print(f"\n[Generated {len(response)} chars in {gen_time:.2f}s ({gen_speed:.1f} tokens/sec)]") # Add to conversation history self.conversation_history.append((user_input, response)) # Keep history at a reasonable size if len(self.conversation_history) > 10: self.conversation_history = self.conversation_history[-10:] except Exception as e: logger.error(f"Error generating response: {str(e)}") print("\nBot: I encountered an error while generating a response. Let's try again.") def main(): """Run the ONNX chatbot with command line arguments.""" parser = argparse.ArgumentParser(description="Interactive ONNX Chatbot") parser.add_argument("--model", type=str, required=True, help="Path to the ONNX model directory") parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for text generation (default: 0.7)") parser.add_argument("--max_tokens", type=int, default=100, help="Maximum tokens to generate per response (default: 100)") args = parser.parse_args() try: # Create and run the chatbot chatbot = ONNXGenerationChatbot(args.model) chatbot.chat(temperature=args.temperature, max_new_tokens=args.max_tokens) except KeyboardInterrupt: print("\nExiting chatbot. Goodbye!") except Exception as e: logger.error(f"Error: {str(e)}") sys.exit(1) if __name__ == "__main__": main()