Spaces:
Sleeping
Sleeping
File size: 17,238 Bytes
47fc4b4 1989065 686de7f 1f3f2ad 686de7f 1b17a7b 1989065 686de7f 1989065 686de7f 1989065 9f952bc 2805605 686de7f 2805605 686de7f 2805605 686de7f 1f3f2ad 686de7f 1989065 686de7f 5d51eb7 686de7f 1989065 686de7f 1989065 686de7f 1b70f99 1989065 686de7f 1b70f99 686de7f 1b70f99 686de7f 1989065 1b70f99 88430cf 686de7f 9f3cddf 1b70f99 686de7f 1b70f99 686de7f 1b70f99 686de7f 1b17a7b 686de7f 1b70f99 686de7f 1b70f99 686de7f 1b70f99 686de7f 1b70f99 686de7f 1b70f99 1b17a7b 686de7f 1b70f99 686de7f 1b70f99 686de7f 1b70f99 686de7f 1b70f99 686de7f 1b70f99 686de7f 1b70f99 686de7f 9f3cddf 686de7f 88430cf 686de7f 88430cf 1b70f99 686de7f 1b70f99 686de7f 1b70f99 686de7f 75300a2 1b70f99 686de7f 88430cf 686de7f 75300a2 686de7f 1b70f99 686de7f 1b70f99 686de7f 1b70f99 686de7f 1b70f99 686de7f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 |
import gradio as gr
import logging
import time
from datetime import datetime
from typing import List, Optional, Tuple
import random
import nltk
# nltk.download('punkt') # Ensure punkt is downloaded if needed
from nltk.tokenize import sent_tokenize
import io
# from joblib import dump, load # Not used currently, commented out
# Import Hugging Face libraries
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
from datasets import load_dataset # Added for dataset loading
# Import ML/Data libraries
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
# Standard libraries
from concurrent.futures import ThreadPoolExecutor # Still useful for embedding generation
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__) # Use __name__ for logger
# Download NLTK data (optional, might not be strictly needed depending on chunking)
# try:
# nltk.download('punkt', quiet=True)
# except Exception as e:
# logger.warning(f"Failed to download NLTK data: {e}")
# --- Configuration ---
class Config:
MODEL_NAME = "microsoft/DialoGPT-medium"
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
MAX_TOKENS_RESPONSE = 150 # Max tokens for the generated response part
MAX_TOKENS_INPUT = 800 # Max tokens allowed for context + query (adjust based on model limits)
SIMILARITY_THRESHOLD = 0.3 # Adjusted threshold, tune as needed
CHUNK_SIZE = 300 # Smaller chunk size might be better for dataset entries
MAX_WORKERS = 5 # For parallel embedding generation
DATASET_NAME = "acecalisto3/sspnc" # Hugging Face Dataset ID
DATASET_SPLIT = "train" # Which split of the dataset to use
TEXT_COLUMNS = ["Subject", "Body"] # Columns containing text to index
SOURCE_INFO_COLUMNS = ["Subject", "Date"] # Columns to use for source attribution
# --- Data Structures ---
class ResourceItem:
def __init__(self, source_id: str, content: str, resource_type: str):
self.source_id = source_id # Changed 'url' to 'source_id' for clarity
self.content = content
self.type = resource_type
self.embedding = None # Overall embedding (optional now, as we use chunk embeddings)
self.chunks = []
self.chunk_embeddings = []
def __str__(self):
return f"ResourceItem(type={self.type}, source_id={self.source_id}, content_length={len(self.content)})"
def create_chunks(self, chunk_size=Config.CHUNK_SIZE):
"""Split content into overlapping chunks using sentence tokenization for better boundaries"""
if not self.content:
logger.warning(f"Content is empty for source_id: {self.source_id}. Skipping chunk creation.")
return
try:
sentences = sent_tokenize(self.content)
except LookupError:
logger.warning("NLTK 'punkt' tokenizer not found. Falling back to simple whitespace splitting. Consider running nltk.download('punkt')")
# Fallback to word splitting if sentence tokenization fails
words = self.content.split()
overlap = chunk_size // 4
for i in range(0, len(words), chunk_size - overlap):
chunk = ' '.join(words[i : i + chunk_size])
if chunk:
self.chunks.append(chunk)
return
except Exception as e:
logger.error(f"Error during sentence tokenization for {self.source_id}: {e}. Skipping chunk creation.")
return
current_chunk = ""
overlap_sentences = max(1, chunk_size // 100) # Overlap a few sentences
last_sentences = []
for sentence in sentences:
# If adding the next sentence exceeds chunk size (considering words approx)
if len((current_chunk + " " + sentence).split()) > chunk_size:
if current_chunk: # Add the completed chunk
self.chunks.append(current_chunk.strip())
# Start new chunk with overlap
current_chunk = " ".join(last_sentences) + " " + sentence
else:
current_chunk += " " + sentence
# Keep track of last sentences for overlap
last_sentences.append(sentence)
if len(last_sentences) > overlap_sentences:
last_sentences.pop(0)
# Add the last remaining chunk
if current_chunk.strip():
self.chunks.append(current_chunk.strip())
if not self.chunks:
logger.warning(f"No chunks created for source_id: {self.source_id}. Content might be too short or tokenization failed.")
# --- Chatbot Core Logic ---
class SchoolChatbot:
def __init__(self):
logger.info("Initializing SchoolChatbot...")
self.setup_models()
self.resources: List[ResourceItem] = []
self.load_and_index_dataset() # Changed from crawl_and_index_resources
def setup_models(self):
try:
logger.info("Setting up models...")
# Consider adding device mapping if GPU is available: device_map="auto"
self.tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_NAME)
self.model = AutoModelForCausalLM.from_pretrained(Config.MODEL_NAME)
self.embedding_model = SentenceTransformer(Config.EMBEDDING_MODEL)
# Ensure tokenizer has a padding token
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model.config.pad_token_id = self.model.config.eos_token_id
logger.info("Models setup completed successfully.")
except Exception as e:
logger.error(f"Failed to setup models: {e}")
raise RuntimeError("Failed to initialize required models") from e
def load_and_index_dataset(self):
logger.info(f"Loading dataset: {Config.DATASET_NAME}, split: {Config.DATASET_SPLIT}")
try:
# Load the dataset
dataset = load_dataset(Config.DATASET_NAME, split=Config.DATASET_SPLIT)
logger.info(f"Dataset loaded successfully. Number of rows: {len(dataset)}")
# Process dataset rows in parallel (for embedding generation)
with ThreadPoolExecutor(max_workers=Config.MAX_WORKERS) as executor:
futures = []
for i, row in enumerate(dataset):
# Combine text from specified columns
text_content = " ".join([str(row[col]) for col in Config.TEXT_COLUMNS if row.get(col)])
text_content = text_content.strip() # Remove leading/trailing whitespace
# Create a source identifier
source_parts = [f"{col}: {row[col]}" for col in Config.SOURCE_INFO_COLUMNS if row.get(col)]
source_id = f"Dataset Entry {i} ({'; '.join(source_parts)})" # More informative ID
if not text_content:
logger.warning(f"Row {i} has no content in specified columns. Skipping.")
continue
# Submit the processing task
futures.append(executor.submit(self.process_and_store_resource, source_id, text_content, 'dataset_entry'))
# Wait for all futures to complete and collect results
for future in futures:
try:
result_item = future.result()
if result_item:
self.resources.append(result_item)
except Exception as e:
logger.error(f"Error processing dataset entry in thread: {e}")
logger.info(f"Dataset processing completed. Indexed {len(self.resources)} resources.")
except Exception as e:
logger.error(f"Failed to load or process dataset {Config.DATASET_NAME}: {e}")
# Decide if the app should continue without data or raise an error
# raise RuntimeError("Failed to load data") from e # Option: halt if data fails
def process_and_store_resource(self, source_id: str, text_data: str, resource_type: str) -> Optional[ResourceItem]:
"""Creates ResourceItem, chunks, and generates embeddings for a single data entry."""
try:
# Create resource item and split into chunks
item = ResourceItem(source_id, text_data, resource_type)
item.create_chunks()
if not item.chunks:
logger.warning(f"No chunks generated for {source_id}. Skipping storage.")
return None
# Generate embeddings for chunks (can be slow, hence the thread pool)
chunk_embeddings_list = self.embedding_model.encode(item.chunks, show_progress_bar=False) # Batch encode
item.chunk_embeddings = chunk_embeddings_list
# Calculate average embedding (optional, might not be needed if only using chunk search)
# if item.chunk_embeddings:
# item.embedding = np.mean(item.chunk_embeddings, axis=0)
logger.debug(f"Processed resource: {source_id} (type={resource_type}), {len(item.chunks)} chunks.")
return item # Return the processed item
except Exception as e:
logger.error(f"Error processing/storing resource {source_id}: {e}")
return None # Return None on error
# store_resource is now process_and_store_resource and called within the thread pool
def find_best_matching_chunks(self, query: str, n_chunks: int = 3) -> List[Tuple[str, float, str]]:
"""Finds the most relevant text chunks based on semantic similarity."""
if not self.resources:
logger.warning("No resources loaded or indexed. Cannot find matches.")
return []
try:
query_embedding = self.embedding_model.encode(query)
all_chunks_with_scores = []
for resource in self.resources:
if not resource.chunks or not resource.chunk_embeddings:
continue # Skip resources with no chunks/embeddings
# Calculate similarity between query and all chunks of the current resource
similarities = cosine_similarity([query_embedding], resource.chunk_embeddings)[0]
for chunk, score in zip(resource.chunks, similarities):
if score > Config.SIMILARITY_THRESHOLD:
all_chunks_with_scores.append((chunk, float(score), resource.source_id)) # Use source_id
# Sort by similarity score (descending) and return top n
all_chunks_with_scores.sort(key=lambda x: x[1], reverse=True)
return all_chunks_with_scores[:n_chunks]
except Exception as e:
logger.error(f"Error finding matching chunks: {e}")
return []
def generate_response(self, user_input: str) -> str:
"""Generates a response based on user input and retrieved context."""
try:
# 1. Find relevant context chunks
best_chunks = self.find_best_matching_chunks(user_input)
if not best_chunks:
logger.info(f"No relevant chunks found for query: '{user_input}'")
return "I couldn't find specific information related to your question in the provided documents. Could you please rephrase or ask about a different topic?"
# 2. Prepare context and source attribution
context = "\n".join([chunk[0] for chunk in best_chunks])
# Use source_id from the chunk tuple (index 2)
source_ids = sorted(list(set(chunk[2] for chunk in best_chunks)))
sources_text = "\n\nSources:\n" + "\n".join([f"- {sid}" for sid in source_ids])
# 3. Prepare input for the language model
# Ensure the input doesn't exceed model limits
prompt_template = f"Based on the following information:\n{context}\n\nAnswer the question: {user_input}\nAnswer:"
# prompt_template = f"Context: {context}\nUser: {user_input}\nAssistant:" # Alternative simpler prompt
# 4. Tokenize and truncate if necessary
input_ids = self.tokenizer.encode(prompt_template, return_tensors='pt', max_length=Config.MAX_TOKENS_INPUT, truncation=True)
# 5. Generate response using the language model
logger.info("Generating response with LLM...")
output_sequences = self.model.generate(
input_ids=input_ids,
max_new_tokens=Config.MAX_TOKENS_RESPONSE, # Control length of *new* tokens
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
temperature=0.7,
top_p=0.9,
do_sample=True,
num_return_sequences=1 # Generate one response
)
# Decode the generated part of the response
# response_text = self.tokenizer.decode(output_sequences[0], skip_special_tokens=True)
# Decode only the newly generated tokens, excluding the prompt
response_text = self.tokenizer.decode(output_sequences[0][input_ids.shape[-1]:], skip_special_tokens=True)
# Basic post-processing (optional)
response_text = response_text.strip()
# Remove potential repetition of the question if the model includes it
if user_input.lower() in response_text.lower()[:len(user_input)+10]:
response_text = response_text.split(user_input, 1)[-1].strip("? ")
logger.info(f"Generated response (before sources): {response_text}")
# 6. Combine response and sources
full_response = response_text + sources_text
return full_response
except Exception as e:
logger.exception(f"Error generating response: {e}") # Use logger.exception to include stack trace
return "I apologize, but I encountered an error while processing your question. Please check the logs or try again later."
# --- Gradio Interface ---
def create_gradio_interface(chatbot: SchoolChatbot):
"""Creates and returns the Gradio web interface."""
def respond(user_input: str) -> str:
if not user_input:
return "Please enter a question."
# Add basic input sanitization if needed
return chatbot.generate_response(user_input)
interface = gr.Interface(
fn=respond,
inputs=gr.Textbox(
label="Ask a Question",
placeholder="Type your question about the school information...",
lines=3, # Increased lines slightly
),
outputs=gr.Textbox(
label="Answer",
placeholder="Response will appear here...",
lines=10, # Increased lines for longer answers + sources
),
title="School Information Chatbot (Dataset Powered)",
description="Ask about information contained in the school dataset. The chatbot uses AI to find relevant details and generate answers.",
examples=[ # Update examples based on dataset content
["What are the main subjects covered in the documents?"],
["Are there any mentions of specific events or dates?"],
["Summarize the key points about [topic from dataset]."]
],
theme=gr.themes.Soft(),
allow_flagging="never", # Changed from flagging_mode
# Optional: Add feedback capabilities
# feedback=["thumbs", "textbox"],
)
return interface
# --- Main Execution ---
if __name__ == "__main__":
# Install necessary libraries if running for the first time
# pip install gradio transformers sentence-transformers torch datasets scikit-learn nltk numpy beautifulsoup4 requests PyPDF2 icalendar fake-useragent joblib # Ensure all are installed
print("Starting application...")
try:
# 1. Initialize the chatbot (loads models and data)
school_chatbot = SchoolChatbot()
# 2. Create the Gradio interface
app_interface = create_gradio_interface(school_chatbot)
# 3. Launch the interface
print("Launching Gradio Interface...")
app_interface.launch(
server_name="0.0.0.0", # Accessible on the local network
server_port=7860,
share=False, # Set to True to get a public link (use with caution)
debug=False # Set to True for more detailed Gradio logs (can be verbose)
)
print("Interface launched. Access it at http://localhost:7860 (or the relevant IP)")
except ImportError as ie:
logger.error(f"ImportError: {ie}. Make sure all required libraries are installed.")
print(f"ImportError: {ie}. Please install the missing library (e.g., pip install {ie.name}).")
except Exception as e:
logger.critical(f"Failed to start the application: {e}", exc_info=True) # Log critical error with stack trace
print(f"Critical error during startup: {e}") |