File size: 16,456 Bytes
16ffc97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
import os
import sys
import time
import argparse
import logging
import numpy as np
import onnxruntime as ort
from transformers import AutoTokenizer
from tqdm import tqdm

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)

class ONNXGenerationChatbot:
    def __init__(self, model_path, max_length=100):
        """
        Initialize the ONNX chatbot for text generation.
        
        Args:
            model_path: Path to the directory containing the ONNX model and tokenizer
            max_length: Maximum sequence length for generation
        """
        # Set up model paths
        self.model_dir = model_path
        self.onnx_path = os.path.join(self.model_dir, "model.onnx")
        self.fp32_path = os.path.join(self.model_dir, "model_fp32.onnx")
        
        # Check for model files
        if not os.path.exists(self.onnx_path):
            raise FileNotFoundError(f"ONNX model not found at {self.onnx_path}")
        
        # Get model name for prompt formatting
        self.model_name = os.path.basename(os.path.normpath(model_path))
        logger.info(f"Using model: {self.model_name}")
        
        # Load tokenizer
        logger.info(f"Loading tokenizer from {self.model_dir}...")
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir, local_files_only=True)
        
        # Ensure tokenizer has necessary tokens
        if self.tokenizer.pad_token is None and hasattr(self.tokenizer, 'eos_token'):
            self.tokenizer.pad_token = self.tokenizer.eos_token
            
        # Create optimized session
        logger.info(f"Loading ONNX model from {self.onnx_path}...")
        self.session_options = ort.SessionOptions()
        self.session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
        self.session_options.intra_op_num_threads = 4  # Adjust based on your CPU
        
        # Create session with appropriate providers
        providers = ['CPUExecutionProvider']
        if 'CUDAExecutionProvider' in ort.get_available_providers():
            logger.info("CUDA is available! Using GPU acceleration.")
            providers.insert(0, 'CUDAExecutionProvider')
        
        self.session = ort.InferenceSession(
            self.onnx_path,
            sess_options=self.session_options,
            providers=providers
        )
        
        # Get input and output names from the model
        self.input_names = [input.name for input in self.session.get_inputs()]
        self.output_names = [output.name for output in self.session.get_outputs()]
        
        logger.info(f"Model inputs: {self.input_names}")
        logger.info(f"Model outputs: {self.output_names}")
        
        # Settings
        self.max_length = max_length
        self.stop_tokens = [self.tokenizer.eos_token_id] if self.tokenizer.eos_token_id is not None else []
        
        # Try to add common stop tokens if they exist in the vocabulary
        stop_words = ["<|endoftext|>", "</s>", "<|end|>"]
        for word in stop_words:
            try:
                token_id = self.tokenizer.convert_tokens_to_ids(word)
                if token_id not in self.stop_tokens and token_id != self.tokenizer.unk_token_id:
                    self.stop_tokens.append(token_id)
            except:
                pass
        
        logger.info(f"Using stop tokens: {self.stop_tokens}")
        
        # Conversation history for context
        self.conversation_history = []
    
    def get_prompt_template(self):
        """
        Get the appropriate prompt template based on the model type.
        """
        if "opt" in self.model_name.lower():
            return "Human: {}\nAssistant:"
        elif "pythia" in self.model_name.lower():
            return "USER: {}\nASSISTANT:"
        elif "llama" in self.model_name.lower() or "alpaca" in self.model_name.lower():
            return "### Human: {}\n### Assistant:"
        elif "gpt2" in self.model_name.lower() or "distilgpt2" in self.model_name.lower():
            return "User: {}\nBot:"
        else:
            return "Question: {}\nAnswer:"
    
    def format_prompt_with_history(self, user_message):
        """
        Format the prompt with conversation history for better context.
        """
        template = self.get_prompt_template()
        parts = template.split("{}")
        prefix = parts[0]
        suffix = parts[1] if len(parts) > 1 else ""
        
        # Include history if available (up to 3 turns)
        formatted_prompt = ""
        for i, (user, bot) in enumerate(self.conversation_history[-3:]):
            formatted_prompt += f"{prefix}{user}{suffix} {bot}\n\n"
            
        # Add current user message
        formatted_prompt += f"{prefix}{user_message}{suffix}"
        
        return formatted_prompt
    
    def run_inference_step(self, input_ids, attention_mask=None):
        """
        Run a single inference step with the ONNX model.
        
        Args:
            input_ids: Token IDs of the input sequence
            attention_mask: Attention mask for the input sequence
            
        Returns:
            numpy array: Logits for the next token prediction
        """
        # Prepare model inputs
        model_inputs = {}
        for name in self.input_names:
            if name == "input_ids":
                model_inputs[name] = input_ids
            elif name == "attention_mask" and attention_mask is not None:
                model_inputs[name] = attention_mask
        
        # Run inference
        outputs = self.session.run(self.output_names, model_inputs)
        
        # Return logits (assumes first output is logits)
        return outputs[0]
        
    def generate_text(self, prompt, max_new_tokens=50, temperature=0.7, top_k=50, top_p=0.9, 
                     repetition_penalty=1.1, do_sample=True, show_progress=True):
        """
        Generate text using the ONNX model.
        
        Args:
            prompt: Text prompt to generate from
            max_new_tokens: Maximum number of tokens to generate
            temperature: Temperature for sampling (higher = more random)
            top_k: Number of highest probability tokens to keep for sampling
            top_p: Cumulative probability threshold for nucleus sampling
            repetition_penalty: Penalty for repeating tokens
            do_sample: Whether to sample from the distribution or use greedy decoding
            show_progress: Whether to show a progress bar during generation
            
        Returns:
            str: Generated text
        """
        # Encode the prompt
        encoded = self.tokenizer(prompt, return_tensors="np")
        input_ids = encoded["input_ids"]
        attention_mask = encoded["attention_mask"]
        
        # Track input tokens for repetition penalty
        prev_tokens = input_ids[0].tolist()
        
        # Setup progress bar if requested
        progress = tqdm(total=max_new_tokens, desc="Generating") if show_progress else None
        
        # Generate tokens auto-regressively
        for _ in range(max_new_tokens):
            # Run inference to get next token logits
            logits = self.run_inference_step(input_ids, attention_mask)
            
            # Get logits for the last token
            next_token_logits = logits[0, -1, :]
            
            # Apply temperature scaling
            if temperature > 0:
                next_token_logits = next_token_logits / max(temperature, 1e-8)
            
            # Apply repetition penalty
            if repetition_penalty > 1.0:
                for prev_token in set(prev_tokens[-10:]):  # Only consider recent tokens
                    if prev_token < len(next_token_logits):
                        next_token_logits[prev_token] /= repetition_penalty
            
            # Apply top-k filtering
            if top_k > 0:
                indices_to_remove = np.argsort(next_token_logits)[:-top_k]
                next_token_logits[indices_to_remove] = -float('inf')
            
            # Apply top-p (nucleus) filtering
            if 0 < top_p < 1.0:
                sorted_logits = np.sort(next_token_logits)[::-1]
                sorted_indices = np.argsort(next_token_logits)[::-1]
                cumulative_probs = np.cumsum(np.exp(sorted_logits) / np.sum(np.exp(sorted_logits)))
                
                # Remove tokens with cumulative probability above the threshold
                sorted_indices_to_remove = sorted_indices[cumulative_probs > top_p]
                next_token_logits[sorted_indices_to_remove] = -float('inf')
            
            # Sample from the filtered distribution or use greedy decoding
            if do_sample:
                # Apply softmax to get probabilities
                probs = np.exp(next_token_logits - np.max(next_token_logits))
                probs = probs / np.sum(probs)
                
                # Handle NaNs
                if np.isnan(probs).any():
                    next_token_id = np.argmax(next_token_logits)
                else:
                    try:
                        # Sample from the distribution
                        next_token_id = np.random.choice(len(probs), p=probs)
                    except:
                        # Fallback to greedy if sampling fails
                        next_token_id = np.argmax(next_token_logits)
            else:
                # Greedy decoding - take highest probability token
                next_token_id = np.argmax(next_token_logits)
            
            # Add the chosen token to the input
            next_token = np.array([[next_token_id]])
            input_ids = np.concatenate([input_ids, next_token], axis=1)
            
            # Update attention mask
            attention_mask = np.ones((1, input_ids.shape[1]), dtype=np.int64)
            
            # Add token to history for repetition penalty
            prev_tokens.append(int(next_token_id))
            
            # Update progress bar if active
            if progress is not None:
                progress.update(1)
            
            # Check for stop tokens or end of text
            if next_token_id in self.stop_tokens:
                break
                
            # Also stop if we exceed max length
            if input_ids.shape[1] >= self.max_length:
                break
        
        # Close progress bar if used
        if progress is not None:
            progress.close()
        
        # Decode the full sequence
        generated_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
        return generated_text
    
    def extract_assistant_response(self, full_text, prompt):
        """
        Extract just the assistant's response from the full generated text.
        
        Args:
            full_text: Full generated text including prompt
            prompt: The original prompt
            
        Returns:
            str: Just the assistant's response
        """
        # Try to extract based on the prompt format
        template = self.get_prompt_template()
        response_start_marker = template.split("{}")[-1]
        
        # If the prompt is in the text, extract everything after it
        if prompt in full_text:
            after_prompt = full_text[len(prompt):]
            
            # Handle additional newlines or spaces at the beginning
            return after_prompt.lstrip()
        
        # If the response marker is in the text, extract everything after it
        if response_start_marker.strip() in full_text:
            parts = full_text.split(response_start_marker.strip(), 1)
            if len(parts) > 1:
                return parts[1].strip()
        
        # Fallback: return everything after the last line of the prompt
        prompt_last_line = prompt.strip().split('\n')[-1]
        if prompt_last_line in full_text:
            parts = full_text.split(prompt_last_line, 1)
            if len(parts) > 1:
                return parts[1].strip()
        
        # Last resort: return the whole thing
        return full_text
    
    def chat(self, temperature=0.7, max_new_tokens=100):
        """
        Run an interactive chat session with the model.
        
        Args:
            temperature: Temperature for text generation
            max_new_tokens: Maximum number of tokens to generate per response
        """
        print("\n===== ONNX Generation Chatbot =====")
        print(f"Model: {self.model_name}")
        print(f"Type 'exit' to end the conversation")
        print(f"Type 'reset' to clear conversation history")
        
        while True:
            # Get user input
            user_input = input("\nYou: ")
            
            # Check for exit command
            if user_input.lower() in ["exit", "quit", "bye"]:
                print("Goodbye!")
                break
                
            # Check for reset command
            if user_input.lower() == "reset":
                self.conversation_history = []
                print("Conversation history cleared.")
                continue
            
            # Create prompt with history
            prompt = self.format_prompt_with_history(user_input)
            print("\nGenerating response...")
            
            # Generate text
            try:
                start_time = time.time()
                full_text = self.generate_text(
                    prompt, 
                    max_new_tokens=max_new_tokens,
                    temperature=temperature,
                    show_progress=True
                )
                
                # Extract just the assistant's response
                response = self.extract_assistant_response(full_text, prompt)
                
                # Clean up any trailing incomplete sentences
                if response and len(response) > 0:
                    # Try to end at a sentence boundary if possible
                    sentence_end = max(
                        response.rfind('.'), 
                        response.rfind('!'), 
                        response.rfind('?')
                    )
                    if sentence_end > len(response) * 0.5:  # Only trim if we're not losing too much
                        response = response[:sentence_end+1]
                
                # Calculate generation time
                gen_time = time.time() - start_time
                gen_speed = max_new_tokens / gen_time if gen_time > 0 else 0
                
                # Print the response
                print(f"\nBot: {response}")
                print(f"\n[Generated {len(response)} chars in {gen_time:.2f}s ({gen_speed:.1f} tokens/sec)]")
                
                # Add to conversation history
                self.conversation_history.append((user_input, response))
                
                # Keep history at a reasonable size
                if len(self.conversation_history) > 10:
                    self.conversation_history = self.conversation_history[-10:]
                    
            except Exception as e:
                logger.error(f"Error generating response: {str(e)}")
                print("\nBot: I encountered an error while generating a response. Let's try again.")


def main():
    """Run the ONNX chatbot with command line arguments."""
    parser = argparse.ArgumentParser(description="Interactive ONNX Chatbot")
    parser.add_argument("--model", type=str, required=True,
                       help="Path to the ONNX model directory")
    parser.add_argument("--temperature", type=float, default=0.7,
                       help="Temperature for text generation (default: 0.7)")
    parser.add_argument("--max_tokens", type=int, default=100,
                       help="Maximum tokens to generate per response (default: 100)")
    
    args = parser.parse_args()
    
    try:
        # Create and run the chatbot
        chatbot = ONNXGenerationChatbot(args.model)
        chatbot.chat(temperature=args.temperature, max_new_tokens=args.max_tokens)
    except KeyboardInterrupt:
        print("\nExiting chatbot. Goodbye!")
    except Exception as e:
        logger.error(f"Error: {str(e)}")
        sys.exit(1)


if __name__ == "__main__":
    main()