|
import os |
|
import sys |
|
import time |
|
import argparse |
|
import logging |
|
import numpy as np |
|
import onnxruntime as ort |
|
from transformers import AutoTokenizer |
|
from tqdm import tqdm |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
datefmt='%Y-%m-%d %H:%M:%S' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
class ONNXGenerationChatbot: |
|
def __init__(self, model_path, max_length=100): |
|
""" |
|
Initialize the ONNX chatbot for text generation. |
|
|
|
Args: |
|
model_path: Path to the directory containing the ONNX model and tokenizer |
|
max_length: Maximum sequence length for generation |
|
""" |
|
|
|
self.model_dir = model_path |
|
self.onnx_path = os.path.join(self.model_dir, "model.onnx") |
|
self.fp32_path = os.path.join(self.model_dir, "model_fp32.onnx") |
|
|
|
|
|
if not os.path.exists(self.onnx_path): |
|
raise FileNotFoundError(f"ONNX model not found at {self.onnx_path}") |
|
|
|
|
|
self.model_name = os.path.basename(os.path.normpath(model_path)) |
|
logger.info(f"Using model: {self.model_name}") |
|
|
|
|
|
logger.info(f"Loading tokenizer from {self.model_dir}...") |
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir, local_files_only=True) |
|
|
|
|
|
if self.tokenizer.pad_token is None and hasattr(self.tokenizer, 'eos_token'): |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
logger.info(f"Loading ONNX model from {self.onnx_path}...") |
|
self.session_options = ort.SessionOptions() |
|
self.session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL |
|
self.session_options.intra_op_num_threads = 4 |
|
|
|
|
|
providers = ['CPUExecutionProvider'] |
|
if 'CUDAExecutionProvider' in ort.get_available_providers(): |
|
logger.info("CUDA is available! Using GPU acceleration.") |
|
providers.insert(0, 'CUDAExecutionProvider') |
|
|
|
self.session = ort.InferenceSession( |
|
self.onnx_path, |
|
sess_options=self.session_options, |
|
providers=providers |
|
) |
|
|
|
|
|
self.input_names = [input.name for input in self.session.get_inputs()] |
|
self.output_names = [output.name for output in self.session.get_outputs()] |
|
|
|
logger.info(f"Model inputs: {self.input_names}") |
|
logger.info(f"Model outputs: {self.output_names}") |
|
|
|
|
|
self.max_length = max_length |
|
self.stop_tokens = [self.tokenizer.eos_token_id] if self.tokenizer.eos_token_id is not None else [] |
|
|
|
|
|
stop_words = ["<|endoftext|>", "</s>", "<|end|>"] |
|
for word in stop_words: |
|
try: |
|
token_id = self.tokenizer.convert_tokens_to_ids(word) |
|
if token_id not in self.stop_tokens and token_id != self.tokenizer.unk_token_id: |
|
self.stop_tokens.append(token_id) |
|
except: |
|
pass |
|
|
|
logger.info(f"Using stop tokens: {self.stop_tokens}") |
|
|
|
|
|
self.conversation_history = [] |
|
|
|
def get_prompt_template(self): |
|
""" |
|
Get the appropriate prompt template based on the model type. |
|
""" |
|
if "opt" in self.model_name.lower(): |
|
return "Human: {}\nAssistant:" |
|
elif "pythia" in self.model_name.lower(): |
|
return "USER: {}\nASSISTANT:" |
|
elif "llama" in self.model_name.lower() or "alpaca" in self.model_name.lower(): |
|
return "### Human: {}\n### Assistant:" |
|
elif "gpt2" in self.model_name.lower() or "distilgpt2" in self.model_name.lower(): |
|
return "User: {}\nBot:" |
|
else: |
|
return "Question: {}\nAnswer:" |
|
|
|
def format_prompt_with_history(self, user_message): |
|
""" |
|
Format the prompt with conversation history for better context. |
|
""" |
|
template = self.get_prompt_template() |
|
parts = template.split("{}") |
|
prefix = parts[0] |
|
suffix = parts[1] if len(parts) > 1 else "" |
|
|
|
|
|
formatted_prompt = "" |
|
for i, (user, bot) in enumerate(self.conversation_history[-3:]): |
|
formatted_prompt += f"{prefix}{user}{suffix} {bot}\n\n" |
|
|
|
|
|
formatted_prompt += f"{prefix}{user_message}{suffix}" |
|
|
|
return formatted_prompt |
|
|
|
def run_inference_step(self, input_ids, attention_mask=None): |
|
""" |
|
Run a single inference step with the ONNX model. |
|
|
|
Args: |
|
input_ids: Token IDs of the input sequence |
|
attention_mask: Attention mask for the input sequence |
|
|
|
Returns: |
|
numpy array: Logits for the next token prediction |
|
""" |
|
|
|
model_inputs = {} |
|
for name in self.input_names: |
|
if name == "input_ids": |
|
model_inputs[name] = input_ids |
|
elif name == "attention_mask" and attention_mask is not None: |
|
model_inputs[name] = attention_mask |
|
|
|
|
|
outputs = self.session.run(self.output_names, model_inputs) |
|
|
|
|
|
return outputs[0] |
|
|
|
def generate_text(self, prompt, max_new_tokens=50, temperature=0.7, top_k=50, top_p=0.9, |
|
repetition_penalty=1.1, do_sample=True, show_progress=True): |
|
""" |
|
Generate text using the ONNX model. |
|
|
|
Args: |
|
prompt: Text prompt to generate from |
|
max_new_tokens: Maximum number of tokens to generate |
|
temperature: Temperature for sampling (higher = more random) |
|
top_k: Number of highest probability tokens to keep for sampling |
|
top_p: Cumulative probability threshold for nucleus sampling |
|
repetition_penalty: Penalty for repeating tokens |
|
do_sample: Whether to sample from the distribution or use greedy decoding |
|
show_progress: Whether to show a progress bar during generation |
|
|
|
Returns: |
|
str: Generated text |
|
""" |
|
|
|
encoded = self.tokenizer(prompt, return_tensors="np") |
|
input_ids = encoded["input_ids"] |
|
attention_mask = encoded["attention_mask"] |
|
|
|
|
|
prev_tokens = input_ids[0].tolist() |
|
|
|
|
|
progress = tqdm(total=max_new_tokens, desc="Generating") if show_progress else None |
|
|
|
|
|
for _ in range(max_new_tokens): |
|
|
|
logits = self.run_inference_step(input_ids, attention_mask) |
|
|
|
|
|
next_token_logits = logits[0, -1, :] |
|
|
|
|
|
if temperature > 0: |
|
next_token_logits = next_token_logits / max(temperature, 1e-8) |
|
|
|
|
|
if repetition_penalty > 1.0: |
|
for prev_token in set(prev_tokens[-10:]): |
|
if prev_token < len(next_token_logits): |
|
next_token_logits[prev_token] /= repetition_penalty |
|
|
|
|
|
if top_k > 0: |
|
indices_to_remove = np.argsort(next_token_logits)[:-top_k] |
|
next_token_logits[indices_to_remove] = -float('inf') |
|
|
|
|
|
if 0 < top_p < 1.0: |
|
sorted_logits = np.sort(next_token_logits)[::-1] |
|
sorted_indices = np.argsort(next_token_logits)[::-1] |
|
cumulative_probs = np.cumsum(np.exp(sorted_logits) / np.sum(np.exp(sorted_logits))) |
|
|
|
|
|
sorted_indices_to_remove = sorted_indices[cumulative_probs > top_p] |
|
next_token_logits[sorted_indices_to_remove] = -float('inf') |
|
|
|
|
|
if do_sample: |
|
|
|
probs = np.exp(next_token_logits - np.max(next_token_logits)) |
|
probs = probs / np.sum(probs) |
|
|
|
|
|
if np.isnan(probs).any(): |
|
next_token_id = np.argmax(next_token_logits) |
|
else: |
|
try: |
|
|
|
next_token_id = np.random.choice(len(probs), p=probs) |
|
except: |
|
|
|
next_token_id = np.argmax(next_token_logits) |
|
else: |
|
|
|
next_token_id = np.argmax(next_token_logits) |
|
|
|
|
|
next_token = np.array([[next_token_id]]) |
|
input_ids = np.concatenate([input_ids, next_token], axis=1) |
|
|
|
|
|
attention_mask = np.ones((1, input_ids.shape[1]), dtype=np.int64) |
|
|
|
|
|
prev_tokens.append(int(next_token_id)) |
|
|
|
|
|
if progress is not None: |
|
progress.update(1) |
|
|
|
|
|
if next_token_id in self.stop_tokens: |
|
break |
|
|
|
|
|
if input_ids.shape[1] >= self.max_length: |
|
break |
|
|
|
|
|
if progress is not None: |
|
progress.close() |
|
|
|
|
|
generated_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True) |
|
return generated_text |
|
|
|
def extract_assistant_response(self, full_text, prompt): |
|
""" |
|
Extract just the assistant's response from the full generated text. |
|
|
|
Args: |
|
full_text: Full generated text including prompt |
|
prompt: The original prompt |
|
|
|
Returns: |
|
str: Just the assistant's response |
|
""" |
|
|
|
template = self.get_prompt_template() |
|
response_start_marker = template.split("{}")[-1] |
|
|
|
|
|
if prompt in full_text: |
|
after_prompt = full_text[len(prompt):] |
|
|
|
|
|
return after_prompt.lstrip() |
|
|
|
|
|
if response_start_marker.strip() in full_text: |
|
parts = full_text.split(response_start_marker.strip(), 1) |
|
if len(parts) > 1: |
|
return parts[1].strip() |
|
|
|
|
|
prompt_last_line = prompt.strip().split('\n')[-1] |
|
if prompt_last_line in full_text: |
|
parts = full_text.split(prompt_last_line, 1) |
|
if len(parts) > 1: |
|
return parts[1].strip() |
|
|
|
|
|
return full_text |
|
|
|
def chat(self, temperature=0.7, max_new_tokens=100): |
|
""" |
|
Run an interactive chat session with the model. |
|
|
|
Args: |
|
temperature: Temperature for text generation |
|
max_new_tokens: Maximum number of tokens to generate per response |
|
""" |
|
print("\n===== ONNX Generation Chatbot =====") |
|
print(f"Model: {self.model_name}") |
|
print(f"Type 'exit' to end the conversation") |
|
print(f"Type 'reset' to clear conversation history") |
|
|
|
while True: |
|
|
|
user_input = input("\nYou: ") |
|
|
|
|
|
if user_input.lower() in ["exit", "quit", "bye"]: |
|
print("Goodbye!") |
|
break |
|
|
|
|
|
if user_input.lower() == "reset": |
|
self.conversation_history = [] |
|
print("Conversation history cleared.") |
|
continue |
|
|
|
|
|
prompt = self.format_prompt_with_history(user_input) |
|
print("\nGenerating response...") |
|
|
|
|
|
try: |
|
start_time = time.time() |
|
full_text = self.generate_text( |
|
prompt, |
|
max_new_tokens=max_new_tokens, |
|
temperature=temperature, |
|
show_progress=True |
|
) |
|
|
|
|
|
response = self.extract_assistant_response(full_text, prompt) |
|
|
|
|
|
if response and len(response) > 0: |
|
|
|
sentence_end = max( |
|
response.rfind('.'), |
|
response.rfind('!'), |
|
response.rfind('?') |
|
) |
|
if sentence_end > len(response) * 0.5: |
|
response = response[:sentence_end+1] |
|
|
|
|
|
gen_time = time.time() - start_time |
|
gen_speed = max_new_tokens / gen_time if gen_time > 0 else 0 |
|
|
|
|
|
print(f"\nBot: {response}") |
|
print(f"\n[Generated {len(response)} chars in {gen_time:.2f}s ({gen_speed:.1f} tokens/sec)]") |
|
|
|
|
|
self.conversation_history.append((user_input, response)) |
|
|
|
|
|
if len(self.conversation_history) > 10: |
|
self.conversation_history = self.conversation_history[-10:] |
|
|
|
except Exception as e: |
|
logger.error(f"Error generating response: {str(e)}") |
|
print("\nBot: I encountered an error while generating a response. Let's try again.") |
|
|
|
|
|
def main(): |
|
"""Run the ONNX chatbot with command line arguments.""" |
|
parser = argparse.ArgumentParser(description="Interactive ONNX Chatbot") |
|
parser.add_argument("--model", type=str, required=True, |
|
help="Path to the ONNX model directory") |
|
parser.add_argument("--temperature", type=float, default=0.7, |
|
help="Temperature for text generation (default: 0.7)") |
|
parser.add_argument("--max_tokens", type=int, default=100, |
|
help="Maximum tokens to generate per response (default: 100)") |
|
|
|
args = parser.parse_args() |
|
|
|
try: |
|
|
|
chatbot = ONNXGenerationChatbot(args.model) |
|
chatbot.chat(temperature=args.temperature, max_new_tokens=args.max_tokens) |
|
except KeyboardInterrupt: |
|
print("\nExiting chatbot. Goodbye!") |
|
except Exception as e: |
|
logger.error(f"Error: {str(e)}") |
|
sys.exit(1) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |