File size: 17,238 Bytes
47fc4b4
 
1989065
686de7f
 
1f3f2ad
 
686de7f
 
 
 
 
 
1b17a7b
1989065
686de7f
 
 
1989065
686de7f
 
 
 
1989065
9f952bc
2805605
686de7f
2805605
686de7f
2805605
686de7f
 
 
 
 
1f3f2ad
686de7f
1989065
 
 
686de7f
 
 
 
 
 
 
 
 
5d51eb7
686de7f
1989065
686de7f
 
1989065
686de7f
 
1b70f99
 
1989065
686de7f
 
 
1b70f99
686de7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b70f99
686de7f
 
1989065
 
1b70f99
88430cf
686de7f
 
9f3cddf
 
1b70f99
 
686de7f
1b70f99
 
 
686de7f
 
 
 
1b70f99
 
 
686de7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b17a7b
686de7f
1b70f99
686de7f
1b70f99
 
686de7f
 
1b70f99
686de7f
 
 
 
 
 
 
1b70f99
686de7f
 
 
 
 
 
1b70f99
 
 
1b17a7b
686de7f
 
1b70f99
686de7f
 
 
1b70f99
686de7f
 
 
 
1b70f99
686de7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b70f99
 
686de7f
 
1b70f99
686de7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b70f99
686de7f
 
9f3cddf
686de7f
 
 
 
 
 
 
88430cf
686de7f
88430cf
 
1b70f99
 
686de7f
 
1b70f99
 
 
 
686de7f
1b70f99
686de7f
 
 
 
 
 
75300a2
1b70f99
686de7f
 
 
88430cf
 
 
686de7f
75300a2
686de7f
 
 
1b70f99
686de7f
 
 
 
 
 
 
 
 
 
1b70f99
686de7f
 
1b70f99
686de7f
 
 
 
 
1b70f99
686de7f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
import gradio as gr
import logging
import time
from datetime import datetime
from typing import List, Optional, Tuple
import random
import nltk
# nltk.download('punkt') # Ensure punkt is downloaded if needed
from nltk.tokenize import sent_tokenize
import io
# from joblib import dump, load # Not used currently, commented out

# Import Hugging Face libraries
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
from datasets import load_dataset # Added for dataset loading

# Import ML/Data libraries
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

# Standard libraries
from concurrent.futures import ThreadPoolExecutor # Still useful for embedding generation

# Configure logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__) # Use __name__ for logger

# Download NLTK data (optional, might not be strictly needed depending on chunking)
# try:
#     nltk.download('punkt', quiet=True)
# except Exception as e:
#     logger.warning(f"Failed to download NLTK data: {e}")

# --- Configuration ---
class Config:
    MODEL_NAME = "microsoft/DialoGPT-medium"
    EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
    MAX_TOKENS_RESPONSE = 150  # Max tokens for the generated response part
    MAX_TOKENS_INPUT = 800    # Max tokens allowed for context + query (adjust based on model limits)
    SIMILARITY_THRESHOLD = 0.3 # Adjusted threshold, tune as needed
    CHUNK_SIZE = 300          # Smaller chunk size might be better for dataset entries
    MAX_WORKERS = 5           # For parallel embedding generation
    DATASET_NAME = "acecalisto3/sspnc" # Hugging Face Dataset ID
    DATASET_SPLIT = "train"     # Which split of the dataset to use
    TEXT_COLUMNS = ["Subject", "Body"] # Columns containing text to index
    SOURCE_INFO_COLUMNS = ["Subject", "Date"] # Columns to use for source attribution

# --- Data Structures ---
class ResourceItem:
    def __init__(self, source_id: str, content: str, resource_type: str):
        self.source_id = source_id # Changed 'url' to 'source_id' for clarity
        self.content = content
        self.type = resource_type
        self.embedding = None # Overall embedding (optional now, as we use chunk embeddings)
        self.chunks = []
        self.chunk_embeddings = []

    def __str__(self):
        return f"ResourceItem(type={self.type}, source_id={self.source_id}, content_length={len(self.content)})"

    def create_chunks(self, chunk_size=Config.CHUNK_SIZE):
        """Split content into overlapping chunks using sentence tokenization for better boundaries"""
        if not self.content:
            logger.warning(f"Content is empty for source_id: {self.source_id}. Skipping chunk creation.")
            return

        try:
            sentences = sent_tokenize(self.content)
        except LookupError:
            logger.warning("NLTK 'punkt' tokenizer not found. Falling back to simple whitespace splitting. Consider running nltk.download('punkt')")
            # Fallback to word splitting if sentence tokenization fails
            words = self.content.split()
            overlap = chunk_size // 4
            for i in range(0, len(words), chunk_size - overlap):
                chunk = ' '.join(words[i : i + chunk_size])
                if chunk:
                    self.chunks.append(chunk)
            return
        except Exception as e:
             logger.error(f"Error during sentence tokenization for {self.source_id}: {e}. Skipping chunk creation.")
             return


        current_chunk = ""
        overlap_sentences = max(1, chunk_size // 100) # Overlap a few sentences
        last_sentences = []

        for sentence in sentences:
            # If adding the next sentence exceeds chunk size (considering words approx)
            if len((current_chunk + " " + sentence).split()) > chunk_size:
                if current_chunk: # Add the completed chunk
                    self.chunks.append(current_chunk.strip())
                # Start new chunk with overlap
                current_chunk = " ".join(last_sentences) + " " + sentence
            else:
                current_chunk += " " + sentence

            # Keep track of last sentences for overlap
            last_sentences.append(sentence)
            if len(last_sentences) > overlap_sentences:
                last_sentences.pop(0)

        # Add the last remaining chunk
        if current_chunk.strip():
            self.chunks.append(current_chunk.strip())

        if not self.chunks:
             logger.warning(f"No chunks created for source_id: {self.source_id}. Content might be too short or tokenization failed.")


# --- Chatbot Core Logic ---
class SchoolChatbot:
    def __init__(self):
        logger.info("Initializing SchoolChatbot...")
        self.setup_models()
        self.resources: List[ResourceItem] = []
        self.load_and_index_dataset() # Changed from crawl_and_index_resources

    def setup_models(self):
        try:
            logger.info("Setting up models...")
            # Consider adding device mapping if GPU is available: device_map="auto"
            self.tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_NAME)
            self.model = AutoModelForCausalLM.from_pretrained(Config.MODEL_NAME)
            self.embedding_model = SentenceTransformer(Config.EMBEDDING_MODEL)
            # Ensure tokenizer has a padding token
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
                self.model.config.pad_token_id = self.model.config.eos_token_id
            logger.info("Models setup completed successfully.")
        except Exception as e:
            logger.error(f"Failed to setup models: {e}")
            raise RuntimeError("Failed to initialize required models") from e

    def load_and_index_dataset(self):
        logger.info(f"Loading dataset: {Config.DATASET_NAME}, split: {Config.DATASET_SPLIT}")
        try:
            # Load the dataset
            dataset = load_dataset(Config.DATASET_NAME, split=Config.DATASET_SPLIT)
            logger.info(f"Dataset loaded successfully. Number of rows: {len(dataset)}")

            # Process dataset rows in parallel (for embedding generation)
            with ThreadPoolExecutor(max_workers=Config.MAX_WORKERS) as executor:
                futures = []
                for i, row in enumerate(dataset):
                    # Combine text from specified columns
                    text_content = " ".join([str(row[col]) for col in Config.TEXT_COLUMNS if row.get(col)])
                    text_content = text_content.strip() # Remove leading/trailing whitespace

                    # Create a source identifier
                    source_parts = [f"{col}: {row[col]}" for col in Config.SOURCE_INFO_COLUMNS if row.get(col)]
                    source_id = f"Dataset Entry {i} ({'; '.join(source_parts)})" # More informative ID

                    if not text_content:
                        logger.warning(f"Row {i} has no content in specified columns. Skipping.")
                        continue

                    # Submit the processing task
                    futures.append(executor.submit(self.process_and_store_resource, source_id, text_content, 'dataset_entry'))

                # Wait for all futures to complete and collect results
                for future in futures:
                    try:
                        result_item = future.result()
                        if result_item:
                            self.resources.append(result_item)
                    except Exception as e:
                        logger.error(f"Error processing dataset entry in thread: {e}")

            logger.info(f"Dataset processing completed. Indexed {len(self.resources)} resources.")

        except Exception as e:
            logger.error(f"Failed to load or process dataset {Config.DATASET_NAME}: {e}")
            # Decide if the app should continue without data or raise an error
            # raise RuntimeError("Failed to load data") from e # Option: halt if data fails

    def process_and_store_resource(self, source_id: str, text_data: str, resource_type: str) -> Optional[ResourceItem]:
        """Creates ResourceItem, chunks, and generates embeddings for a single data entry."""
        try:
            # Create resource item and split into chunks
            item = ResourceItem(source_id, text_data, resource_type)
            item.create_chunks()

            if not item.chunks:
                logger.warning(f"No chunks generated for {source_id}. Skipping storage.")
                return None

            # Generate embeddings for chunks (can be slow, hence the thread pool)
            chunk_embeddings_list = self.embedding_model.encode(item.chunks, show_progress_bar=False) # Batch encode
            item.chunk_embeddings = chunk_embeddings_list

            # Calculate average embedding (optional, might not be needed if only using chunk search)
            # if item.chunk_embeddings:
            #     item.embedding = np.mean(item.chunk_embeddings, axis=0)

            logger.debug(f"Processed resource: {source_id} (type={resource_type}), {len(item.chunks)} chunks.")
            return item # Return the processed item

        except Exception as e:
            logger.error(f"Error processing/storing resource {source_id}: {e}")
            return None # Return None on error

    # store_resource is now process_and_store_resource and called within the thread pool

    def find_best_matching_chunks(self, query: str, n_chunks: int = 3) -> List[Tuple[str, float, str]]:
        """Finds the most relevant text chunks based on semantic similarity."""
        if not self.resources:
            logger.warning("No resources loaded or indexed. Cannot find matches.")
            return []

        try:
            query_embedding = self.embedding_model.encode(query)
            all_chunks_with_scores = []

            for resource in self.resources:
                if not resource.chunks or not resource.chunk_embeddings:
                    continue # Skip resources with no chunks/embeddings

                # Calculate similarity between query and all chunks of the current resource
                similarities = cosine_similarity([query_embedding], resource.chunk_embeddings)[0]

                for chunk, score in zip(resource.chunks, similarities):
                    if score > Config.SIMILARITY_THRESHOLD:
                        all_chunks_with_scores.append((chunk, float(score), resource.source_id)) # Use source_id

            # Sort by similarity score (descending) and return top n
            all_chunks_with_scores.sort(key=lambda x: x[1], reverse=True)
            return all_chunks_with_scores[:n_chunks]

        except Exception as e:
            logger.error(f"Error finding matching chunks: {e}")
            return []

    def generate_response(self, user_input: str) -> str:
        """Generates a response based on user input and retrieved context."""
        try:
            # 1. Find relevant context chunks
            best_chunks = self.find_best_matching_chunks(user_input)

            if not best_chunks:
                logger.info(f"No relevant chunks found for query: '{user_input}'")
                return "I couldn't find specific information related to your question in the provided documents. Could you please rephrase or ask about a different topic?"

            # 2. Prepare context and source attribution
            context = "\n".join([chunk[0] for chunk in best_chunks])
            # Use source_id from the chunk tuple (index 2)
            source_ids = sorted(list(set(chunk[2] for chunk in best_chunks)))
            sources_text = "\n\nSources:\n" + "\n".join([f"- {sid}" for sid in source_ids])

            # 3. Prepare input for the language model
            # Ensure the input doesn't exceed model limits
            prompt_template = f"Based on the following information:\n{context}\n\nAnswer the question: {user_input}\nAnswer:"
            # prompt_template = f"Context: {context}\nUser: {user_input}\nAssistant:" # Alternative simpler prompt

            # 4. Tokenize and truncate if necessary
            input_ids = self.tokenizer.encode(prompt_template, return_tensors='pt', max_length=Config.MAX_TOKENS_INPUT, truncation=True)

            # 5. Generate response using the language model
            logger.info("Generating response with LLM...")
            output_sequences = self.model.generate(
                input_ids=input_ids,
                max_new_tokens=Config.MAX_TOKENS_RESPONSE, # Control length of *new* tokens
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
                temperature=0.7,
                top_p=0.9,
                do_sample=True,
                num_return_sequences=1 # Generate one response
            )

            # Decode the generated part of the response
            # response_text = self.tokenizer.decode(output_sequences[0], skip_special_tokens=True)
            # Decode only the newly generated tokens, excluding the prompt
            response_text = self.tokenizer.decode(output_sequences[0][input_ids.shape[-1]:], skip_special_tokens=True)


            # Basic post-processing (optional)
            response_text = response_text.strip()
            # Remove potential repetition of the question if the model includes it
            if user_input.lower() in response_text.lower()[:len(user_input)+10]:
                 response_text = response_text.split(user_input, 1)[-1].strip("? ")


            logger.info(f"Generated response (before sources): {response_text}")

            # 6. Combine response and sources
            full_response = response_text + sources_text
            return full_response

        except Exception as e:
            logger.exception(f"Error generating response: {e}") # Use logger.exception to include stack trace
            return "I apologize, but I encountered an error while processing your question. Please check the logs or try again later."

# --- Gradio Interface ---
def create_gradio_interface(chatbot: SchoolChatbot):
    """Creates and returns the Gradio web interface."""
    def respond(user_input: str) -> str:
        if not user_input:
            return "Please enter a question."
        # Add basic input sanitization if needed
        return chatbot.generate_response(user_input)

    interface = gr.Interface(
        fn=respond,
        inputs=gr.Textbox(
            label="Ask a Question",
            placeholder="Type your question about the school information...",
            lines=3, # Increased lines slightly
        ),
        outputs=gr.Textbox(
            label="Answer",
            placeholder="Response will appear here...",
            lines=10, # Increased lines for longer answers + sources
        ),
        title="School Information Chatbot (Dataset Powered)",
        description="Ask about information contained in the school dataset. The chatbot uses AI to find relevant details and generate answers.",
        examples=[ # Update examples based on dataset content
            ["What are the main subjects covered in the documents?"],
            ["Are there any mentions of specific events or dates?"],
            ["Summarize the key points about [topic from dataset]."]
        ],
        theme=gr.themes.Soft(),
        allow_flagging="never", # Changed from flagging_mode
        # Optional: Add feedback capabilities
        # feedback=["thumbs", "textbox"],
    )
    return interface

# --- Main Execution ---
if __name__ == "__main__":
    # Install necessary libraries if running for the first time
    # pip install gradio transformers sentence-transformers torch datasets scikit-learn nltk numpy beautifulsoup4 requests PyPDF2 icalendar fake-useragent joblib # Ensure all are installed
    print("Starting application...")
    try:
        # 1. Initialize the chatbot (loads models and data)
        school_chatbot = SchoolChatbot()

        # 2. Create the Gradio interface
        app_interface = create_gradio_interface(school_chatbot)

        # 3. Launch the interface
        print("Launching Gradio Interface...")
        app_interface.launch(
            server_name="0.0.0.0", # Accessible on the local network
            server_port=7860,
            share=False,           # Set to True to get a public link (use with caution)
            debug=False            # Set to True for more detailed Gradio logs (can be verbose)
        )
        print("Interface launched. Access it at http://localhost:7860 (or the relevant IP)")

    except ImportError as ie:
        logger.error(f"ImportError: {ie}. Make sure all required libraries are installed.")
        print(f"ImportError: {ie}. Please install the missing library (e.g., pip install {ie.name}).")
    except Exception as e:
        logger.critical(f"Failed to start the application: {e}", exc_info=True) # Log critical error with stack trace
        print(f"Critical error during startup: {e}")