onnx-models / old_scripts /test_chat.py
agoor97's picture
Upload folder using huggingface_hub
16ffc97 verified
import os
import sys
import time
import argparse
import logging
import numpy as np
import onnxruntime as ort
from transformers import AutoTokenizer
from tqdm import tqdm
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
class ONNXGenerationChatbot:
def __init__(self, model_path, max_length=100):
"""
Initialize the ONNX chatbot for text generation.
Args:
model_path: Path to the directory containing the ONNX model and tokenizer
max_length: Maximum sequence length for generation
"""
# Set up model paths
self.model_dir = model_path
self.onnx_path = os.path.join(self.model_dir, "model.onnx")
self.fp32_path = os.path.join(self.model_dir, "model_fp32.onnx")
# Check for model files
if not os.path.exists(self.onnx_path):
raise FileNotFoundError(f"ONNX model not found at {self.onnx_path}")
# Get model name for prompt formatting
self.model_name = os.path.basename(os.path.normpath(model_path))
logger.info(f"Using model: {self.model_name}")
# Load tokenizer
logger.info(f"Loading tokenizer from {self.model_dir}...")
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir, local_files_only=True)
# Ensure tokenizer has necessary tokens
if self.tokenizer.pad_token is None and hasattr(self.tokenizer, 'eos_token'):
self.tokenizer.pad_token = self.tokenizer.eos_token
# Create optimized session
logger.info(f"Loading ONNX model from {self.onnx_path}...")
self.session_options = ort.SessionOptions()
self.session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
self.session_options.intra_op_num_threads = 4 # Adjust based on your CPU
# Create session with appropriate providers
providers = ['CPUExecutionProvider']
if 'CUDAExecutionProvider' in ort.get_available_providers():
logger.info("CUDA is available! Using GPU acceleration.")
providers.insert(0, 'CUDAExecutionProvider')
self.session = ort.InferenceSession(
self.onnx_path,
sess_options=self.session_options,
providers=providers
)
# Get input and output names from the model
self.input_names = [input.name for input in self.session.get_inputs()]
self.output_names = [output.name for output in self.session.get_outputs()]
logger.info(f"Model inputs: {self.input_names}")
logger.info(f"Model outputs: {self.output_names}")
# Settings
self.max_length = max_length
self.stop_tokens = [self.tokenizer.eos_token_id] if self.tokenizer.eos_token_id is not None else []
# Try to add common stop tokens if they exist in the vocabulary
stop_words = ["<|endoftext|>", "</s>", "<|end|>"]
for word in stop_words:
try:
token_id = self.tokenizer.convert_tokens_to_ids(word)
if token_id not in self.stop_tokens and token_id != self.tokenizer.unk_token_id:
self.stop_tokens.append(token_id)
except:
pass
logger.info(f"Using stop tokens: {self.stop_tokens}")
# Conversation history for context
self.conversation_history = []
def get_prompt_template(self):
"""
Get the appropriate prompt template based on the model type.
"""
if "opt" in self.model_name.lower():
return "Human: {}\nAssistant:"
elif "pythia" in self.model_name.lower():
return "USER: {}\nASSISTANT:"
elif "llama" in self.model_name.lower() or "alpaca" in self.model_name.lower():
return "### Human: {}\n### Assistant:"
elif "gpt2" in self.model_name.lower() or "distilgpt2" in self.model_name.lower():
return "User: {}\nBot:"
else:
return "Question: {}\nAnswer:"
def format_prompt_with_history(self, user_message):
"""
Format the prompt with conversation history for better context.
"""
template = self.get_prompt_template()
parts = template.split("{}")
prefix = parts[0]
suffix = parts[1] if len(parts) > 1 else ""
# Include history if available (up to 3 turns)
formatted_prompt = ""
for i, (user, bot) in enumerate(self.conversation_history[-3:]):
formatted_prompt += f"{prefix}{user}{suffix} {bot}\n\n"
# Add current user message
formatted_prompt += f"{prefix}{user_message}{suffix}"
return formatted_prompt
def run_inference_step(self, input_ids, attention_mask=None):
"""
Run a single inference step with the ONNX model.
Args:
input_ids: Token IDs of the input sequence
attention_mask: Attention mask for the input sequence
Returns:
numpy array: Logits for the next token prediction
"""
# Prepare model inputs
model_inputs = {}
for name in self.input_names:
if name == "input_ids":
model_inputs[name] = input_ids
elif name == "attention_mask" and attention_mask is not None:
model_inputs[name] = attention_mask
# Run inference
outputs = self.session.run(self.output_names, model_inputs)
# Return logits (assumes first output is logits)
return outputs[0]
def generate_text(self, prompt, max_new_tokens=50, temperature=0.7, top_k=50, top_p=0.9,
repetition_penalty=1.1, do_sample=True, show_progress=True):
"""
Generate text using the ONNX model.
Args:
prompt: Text prompt to generate from
max_new_tokens: Maximum number of tokens to generate
temperature: Temperature for sampling (higher = more random)
top_k: Number of highest probability tokens to keep for sampling
top_p: Cumulative probability threshold for nucleus sampling
repetition_penalty: Penalty for repeating tokens
do_sample: Whether to sample from the distribution or use greedy decoding
show_progress: Whether to show a progress bar during generation
Returns:
str: Generated text
"""
# Encode the prompt
encoded = self.tokenizer(prompt, return_tensors="np")
input_ids = encoded["input_ids"]
attention_mask = encoded["attention_mask"]
# Track input tokens for repetition penalty
prev_tokens = input_ids[0].tolist()
# Setup progress bar if requested
progress = tqdm(total=max_new_tokens, desc="Generating") if show_progress else None
# Generate tokens auto-regressively
for _ in range(max_new_tokens):
# Run inference to get next token logits
logits = self.run_inference_step(input_ids, attention_mask)
# Get logits for the last token
next_token_logits = logits[0, -1, :]
# Apply temperature scaling
if temperature > 0:
next_token_logits = next_token_logits / max(temperature, 1e-8)
# Apply repetition penalty
if repetition_penalty > 1.0:
for prev_token in set(prev_tokens[-10:]): # Only consider recent tokens
if prev_token < len(next_token_logits):
next_token_logits[prev_token] /= repetition_penalty
# Apply top-k filtering
if top_k > 0:
indices_to_remove = np.argsort(next_token_logits)[:-top_k]
next_token_logits[indices_to_remove] = -float('inf')
# Apply top-p (nucleus) filtering
if 0 < top_p < 1.0:
sorted_logits = np.sort(next_token_logits)[::-1]
sorted_indices = np.argsort(next_token_logits)[::-1]
cumulative_probs = np.cumsum(np.exp(sorted_logits) / np.sum(np.exp(sorted_logits)))
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = sorted_indices[cumulative_probs > top_p]
next_token_logits[sorted_indices_to_remove] = -float('inf')
# Sample from the filtered distribution or use greedy decoding
if do_sample:
# Apply softmax to get probabilities
probs = np.exp(next_token_logits - np.max(next_token_logits))
probs = probs / np.sum(probs)
# Handle NaNs
if np.isnan(probs).any():
next_token_id = np.argmax(next_token_logits)
else:
try:
# Sample from the distribution
next_token_id = np.random.choice(len(probs), p=probs)
except:
# Fallback to greedy if sampling fails
next_token_id = np.argmax(next_token_logits)
else:
# Greedy decoding - take highest probability token
next_token_id = np.argmax(next_token_logits)
# Add the chosen token to the input
next_token = np.array([[next_token_id]])
input_ids = np.concatenate([input_ids, next_token], axis=1)
# Update attention mask
attention_mask = np.ones((1, input_ids.shape[1]), dtype=np.int64)
# Add token to history for repetition penalty
prev_tokens.append(int(next_token_id))
# Update progress bar if active
if progress is not None:
progress.update(1)
# Check for stop tokens or end of text
if next_token_id in self.stop_tokens:
break
# Also stop if we exceed max length
if input_ids.shape[1] >= self.max_length:
break
# Close progress bar if used
if progress is not None:
progress.close()
# Decode the full sequence
generated_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
return generated_text
def extract_assistant_response(self, full_text, prompt):
"""
Extract just the assistant's response from the full generated text.
Args:
full_text: Full generated text including prompt
prompt: The original prompt
Returns:
str: Just the assistant's response
"""
# Try to extract based on the prompt format
template = self.get_prompt_template()
response_start_marker = template.split("{}")[-1]
# If the prompt is in the text, extract everything after it
if prompt in full_text:
after_prompt = full_text[len(prompt):]
# Handle additional newlines or spaces at the beginning
return after_prompt.lstrip()
# If the response marker is in the text, extract everything after it
if response_start_marker.strip() in full_text:
parts = full_text.split(response_start_marker.strip(), 1)
if len(parts) > 1:
return parts[1].strip()
# Fallback: return everything after the last line of the prompt
prompt_last_line = prompt.strip().split('\n')[-1]
if prompt_last_line in full_text:
parts = full_text.split(prompt_last_line, 1)
if len(parts) > 1:
return parts[1].strip()
# Last resort: return the whole thing
return full_text
def chat(self, temperature=0.7, max_new_tokens=100):
"""
Run an interactive chat session with the model.
Args:
temperature: Temperature for text generation
max_new_tokens: Maximum number of tokens to generate per response
"""
print("\n===== ONNX Generation Chatbot =====")
print(f"Model: {self.model_name}")
print(f"Type 'exit' to end the conversation")
print(f"Type 'reset' to clear conversation history")
while True:
# Get user input
user_input = input("\nYou: ")
# Check for exit command
if user_input.lower() in ["exit", "quit", "bye"]:
print("Goodbye!")
break
# Check for reset command
if user_input.lower() == "reset":
self.conversation_history = []
print("Conversation history cleared.")
continue
# Create prompt with history
prompt = self.format_prompt_with_history(user_input)
print("\nGenerating response...")
# Generate text
try:
start_time = time.time()
full_text = self.generate_text(
prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
show_progress=True
)
# Extract just the assistant's response
response = self.extract_assistant_response(full_text, prompt)
# Clean up any trailing incomplete sentences
if response and len(response) > 0:
# Try to end at a sentence boundary if possible
sentence_end = max(
response.rfind('.'),
response.rfind('!'),
response.rfind('?')
)
if sentence_end > len(response) * 0.5: # Only trim if we're not losing too much
response = response[:sentence_end+1]
# Calculate generation time
gen_time = time.time() - start_time
gen_speed = max_new_tokens / gen_time if gen_time > 0 else 0
# Print the response
print(f"\nBot: {response}")
print(f"\n[Generated {len(response)} chars in {gen_time:.2f}s ({gen_speed:.1f} tokens/sec)]")
# Add to conversation history
self.conversation_history.append((user_input, response))
# Keep history at a reasonable size
if len(self.conversation_history) > 10:
self.conversation_history = self.conversation_history[-10:]
except Exception as e:
logger.error(f"Error generating response: {str(e)}")
print("\nBot: I encountered an error while generating a response. Let's try again.")
def main():
"""Run the ONNX chatbot with command line arguments."""
parser = argparse.ArgumentParser(description="Interactive ONNX Chatbot")
parser.add_argument("--model", type=str, required=True,
help="Path to the ONNX model directory")
parser.add_argument("--temperature", type=float, default=0.7,
help="Temperature for text generation (default: 0.7)")
parser.add_argument("--max_tokens", type=int, default=100,
help="Maximum tokens to generate per response (default: 100)")
args = parser.parse_args()
try:
# Create and run the chatbot
chatbot = ONNXGenerationChatbot(args.model)
chatbot.chat(temperature=args.temperature, max_new_tokens=args.max_tokens)
except KeyboardInterrupt:
print("\nExiting chatbot. Goodbye!")
except Exception as e:
logger.error(f"Error: {str(e)}")
sys.exit(1)
if __name__ == "__main__":
main()