SouthSpencerQA / app.py
acecalisto3's picture
Update app.py
686de7f verified
raw
history blame contribute delete
17.2 kB
import gradio as gr
import logging
import time
from datetime import datetime
from typing import List, Optional, Tuple
import random
import nltk
# nltk.download('punkt') # Ensure punkt is downloaded if needed
from nltk.tokenize import sent_tokenize
import io
# from joblib import dump, load # Not used currently, commented out
# Import Hugging Face libraries
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
from datasets import load_dataset # Added for dataset loading
# Import ML/Data libraries
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
# Standard libraries
from concurrent.futures import ThreadPoolExecutor # Still useful for embedding generation
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__) # Use __name__ for logger
# Download NLTK data (optional, might not be strictly needed depending on chunking)
# try:
# nltk.download('punkt', quiet=True)
# except Exception as e:
# logger.warning(f"Failed to download NLTK data: {e}")
# --- Configuration ---
class Config:
MODEL_NAME = "microsoft/DialoGPT-medium"
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
MAX_TOKENS_RESPONSE = 150 # Max tokens for the generated response part
MAX_TOKENS_INPUT = 800 # Max tokens allowed for context + query (adjust based on model limits)
SIMILARITY_THRESHOLD = 0.3 # Adjusted threshold, tune as needed
CHUNK_SIZE = 300 # Smaller chunk size might be better for dataset entries
MAX_WORKERS = 5 # For parallel embedding generation
DATASET_NAME = "acecalisto3/sspnc" # Hugging Face Dataset ID
DATASET_SPLIT = "train" # Which split of the dataset to use
TEXT_COLUMNS = ["Subject", "Body"] # Columns containing text to index
SOURCE_INFO_COLUMNS = ["Subject", "Date"] # Columns to use for source attribution
# --- Data Structures ---
class ResourceItem:
def __init__(self, source_id: str, content: str, resource_type: str):
self.source_id = source_id # Changed 'url' to 'source_id' for clarity
self.content = content
self.type = resource_type
self.embedding = None # Overall embedding (optional now, as we use chunk embeddings)
self.chunks = []
self.chunk_embeddings = []
def __str__(self):
return f"ResourceItem(type={self.type}, source_id={self.source_id}, content_length={len(self.content)})"
def create_chunks(self, chunk_size=Config.CHUNK_SIZE):
"""Split content into overlapping chunks using sentence tokenization for better boundaries"""
if not self.content:
logger.warning(f"Content is empty for source_id: {self.source_id}. Skipping chunk creation.")
return
try:
sentences = sent_tokenize(self.content)
except LookupError:
logger.warning("NLTK 'punkt' tokenizer not found. Falling back to simple whitespace splitting. Consider running nltk.download('punkt')")
# Fallback to word splitting if sentence tokenization fails
words = self.content.split()
overlap = chunk_size // 4
for i in range(0, len(words), chunk_size - overlap):
chunk = ' '.join(words[i : i + chunk_size])
if chunk:
self.chunks.append(chunk)
return
except Exception as e:
logger.error(f"Error during sentence tokenization for {self.source_id}: {e}. Skipping chunk creation.")
return
current_chunk = ""
overlap_sentences = max(1, chunk_size // 100) # Overlap a few sentences
last_sentences = []
for sentence in sentences:
# If adding the next sentence exceeds chunk size (considering words approx)
if len((current_chunk + " " + sentence).split()) > chunk_size:
if current_chunk: # Add the completed chunk
self.chunks.append(current_chunk.strip())
# Start new chunk with overlap
current_chunk = " ".join(last_sentences) + " " + sentence
else:
current_chunk += " " + sentence
# Keep track of last sentences for overlap
last_sentences.append(sentence)
if len(last_sentences) > overlap_sentences:
last_sentences.pop(0)
# Add the last remaining chunk
if current_chunk.strip():
self.chunks.append(current_chunk.strip())
if not self.chunks:
logger.warning(f"No chunks created for source_id: {self.source_id}. Content might be too short or tokenization failed.")
# --- Chatbot Core Logic ---
class SchoolChatbot:
def __init__(self):
logger.info("Initializing SchoolChatbot...")
self.setup_models()
self.resources: List[ResourceItem] = []
self.load_and_index_dataset() # Changed from crawl_and_index_resources
def setup_models(self):
try:
logger.info("Setting up models...")
# Consider adding device mapping if GPU is available: device_map="auto"
self.tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_NAME)
self.model = AutoModelForCausalLM.from_pretrained(Config.MODEL_NAME)
self.embedding_model = SentenceTransformer(Config.EMBEDDING_MODEL)
# Ensure tokenizer has a padding token
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model.config.pad_token_id = self.model.config.eos_token_id
logger.info("Models setup completed successfully.")
except Exception as e:
logger.error(f"Failed to setup models: {e}")
raise RuntimeError("Failed to initialize required models") from e
def load_and_index_dataset(self):
logger.info(f"Loading dataset: {Config.DATASET_NAME}, split: {Config.DATASET_SPLIT}")
try:
# Load the dataset
dataset = load_dataset(Config.DATASET_NAME, split=Config.DATASET_SPLIT)
logger.info(f"Dataset loaded successfully. Number of rows: {len(dataset)}")
# Process dataset rows in parallel (for embedding generation)
with ThreadPoolExecutor(max_workers=Config.MAX_WORKERS) as executor:
futures = []
for i, row in enumerate(dataset):
# Combine text from specified columns
text_content = " ".join([str(row[col]) for col in Config.TEXT_COLUMNS if row.get(col)])
text_content = text_content.strip() # Remove leading/trailing whitespace
# Create a source identifier
source_parts = [f"{col}: {row[col]}" for col in Config.SOURCE_INFO_COLUMNS if row.get(col)]
source_id = f"Dataset Entry {i} ({'; '.join(source_parts)})" # More informative ID
if not text_content:
logger.warning(f"Row {i} has no content in specified columns. Skipping.")
continue
# Submit the processing task
futures.append(executor.submit(self.process_and_store_resource, source_id, text_content, 'dataset_entry'))
# Wait for all futures to complete and collect results
for future in futures:
try:
result_item = future.result()
if result_item:
self.resources.append(result_item)
except Exception as e:
logger.error(f"Error processing dataset entry in thread: {e}")
logger.info(f"Dataset processing completed. Indexed {len(self.resources)} resources.")
except Exception as e:
logger.error(f"Failed to load or process dataset {Config.DATASET_NAME}: {e}")
# Decide if the app should continue without data or raise an error
# raise RuntimeError("Failed to load data") from e # Option: halt if data fails
def process_and_store_resource(self, source_id: str, text_data: str, resource_type: str) -> Optional[ResourceItem]:
"""Creates ResourceItem, chunks, and generates embeddings for a single data entry."""
try:
# Create resource item and split into chunks
item = ResourceItem(source_id, text_data, resource_type)
item.create_chunks()
if not item.chunks:
logger.warning(f"No chunks generated for {source_id}. Skipping storage.")
return None
# Generate embeddings for chunks (can be slow, hence the thread pool)
chunk_embeddings_list = self.embedding_model.encode(item.chunks, show_progress_bar=False) # Batch encode
item.chunk_embeddings = chunk_embeddings_list
# Calculate average embedding (optional, might not be needed if only using chunk search)
# if item.chunk_embeddings:
# item.embedding = np.mean(item.chunk_embeddings, axis=0)
logger.debug(f"Processed resource: {source_id} (type={resource_type}), {len(item.chunks)} chunks.")
return item # Return the processed item
except Exception as e:
logger.error(f"Error processing/storing resource {source_id}: {e}")
return None # Return None on error
# store_resource is now process_and_store_resource and called within the thread pool
def find_best_matching_chunks(self, query: str, n_chunks: int = 3) -> List[Tuple[str, float, str]]:
"""Finds the most relevant text chunks based on semantic similarity."""
if not self.resources:
logger.warning("No resources loaded or indexed. Cannot find matches.")
return []
try:
query_embedding = self.embedding_model.encode(query)
all_chunks_with_scores = []
for resource in self.resources:
if not resource.chunks or not resource.chunk_embeddings:
continue # Skip resources with no chunks/embeddings
# Calculate similarity between query and all chunks of the current resource
similarities = cosine_similarity([query_embedding], resource.chunk_embeddings)[0]
for chunk, score in zip(resource.chunks, similarities):
if score > Config.SIMILARITY_THRESHOLD:
all_chunks_with_scores.append((chunk, float(score), resource.source_id)) # Use source_id
# Sort by similarity score (descending) and return top n
all_chunks_with_scores.sort(key=lambda x: x[1], reverse=True)
return all_chunks_with_scores[:n_chunks]
except Exception as e:
logger.error(f"Error finding matching chunks: {e}")
return []
def generate_response(self, user_input: str) -> str:
"""Generates a response based on user input and retrieved context."""
try:
# 1. Find relevant context chunks
best_chunks = self.find_best_matching_chunks(user_input)
if not best_chunks:
logger.info(f"No relevant chunks found for query: '{user_input}'")
return "I couldn't find specific information related to your question in the provided documents. Could you please rephrase or ask about a different topic?"
# 2. Prepare context and source attribution
context = "\n".join([chunk[0] for chunk in best_chunks])
# Use source_id from the chunk tuple (index 2)
source_ids = sorted(list(set(chunk[2] for chunk in best_chunks)))
sources_text = "\n\nSources:\n" + "\n".join([f"- {sid}" for sid in source_ids])
# 3. Prepare input for the language model
# Ensure the input doesn't exceed model limits
prompt_template = f"Based on the following information:\n{context}\n\nAnswer the question: {user_input}\nAnswer:"
# prompt_template = f"Context: {context}\nUser: {user_input}\nAssistant:" # Alternative simpler prompt
# 4. Tokenize and truncate if necessary
input_ids = self.tokenizer.encode(prompt_template, return_tensors='pt', max_length=Config.MAX_TOKENS_INPUT, truncation=True)
# 5. Generate response using the language model
logger.info("Generating response with LLM...")
output_sequences = self.model.generate(
input_ids=input_ids,
max_new_tokens=Config.MAX_TOKENS_RESPONSE, # Control length of *new* tokens
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
temperature=0.7,
top_p=0.9,
do_sample=True,
num_return_sequences=1 # Generate one response
)
# Decode the generated part of the response
# response_text = self.tokenizer.decode(output_sequences[0], skip_special_tokens=True)
# Decode only the newly generated tokens, excluding the prompt
response_text = self.tokenizer.decode(output_sequences[0][input_ids.shape[-1]:], skip_special_tokens=True)
# Basic post-processing (optional)
response_text = response_text.strip()
# Remove potential repetition of the question if the model includes it
if user_input.lower() in response_text.lower()[:len(user_input)+10]:
response_text = response_text.split(user_input, 1)[-1].strip("? ")
logger.info(f"Generated response (before sources): {response_text}")
# 6. Combine response and sources
full_response = response_text + sources_text
return full_response
except Exception as e:
logger.exception(f"Error generating response: {e}") # Use logger.exception to include stack trace
return "I apologize, but I encountered an error while processing your question. Please check the logs or try again later."
# --- Gradio Interface ---
def create_gradio_interface(chatbot: SchoolChatbot):
"""Creates and returns the Gradio web interface."""
def respond(user_input: str) -> str:
if not user_input:
return "Please enter a question."
# Add basic input sanitization if needed
return chatbot.generate_response(user_input)
interface = gr.Interface(
fn=respond,
inputs=gr.Textbox(
label="Ask a Question",
placeholder="Type your question about the school information...",
lines=3, # Increased lines slightly
),
outputs=gr.Textbox(
label="Answer",
placeholder="Response will appear here...",
lines=10, # Increased lines for longer answers + sources
),
title="School Information Chatbot (Dataset Powered)",
description="Ask about information contained in the school dataset. The chatbot uses AI to find relevant details and generate answers.",
examples=[ # Update examples based on dataset content
["What are the main subjects covered in the documents?"],
["Are there any mentions of specific events or dates?"],
["Summarize the key points about [topic from dataset]."]
],
theme=gr.themes.Soft(),
allow_flagging="never", # Changed from flagging_mode
# Optional: Add feedback capabilities
# feedback=["thumbs", "textbox"],
)
return interface
# --- Main Execution ---
if __name__ == "__main__":
# Install necessary libraries if running for the first time
# pip install gradio transformers sentence-transformers torch datasets scikit-learn nltk numpy beautifulsoup4 requests PyPDF2 icalendar fake-useragent joblib # Ensure all are installed
print("Starting application...")
try:
# 1. Initialize the chatbot (loads models and data)
school_chatbot = SchoolChatbot()
# 2. Create the Gradio interface
app_interface = create_gradio_interface(school_chatbot)
# 3. Launch the interface
print("Launching Gradio Interface...")
app_interface.launch(
server_name="0.0.0.0", # Accessible on the local network
server_port=7860,
share=False, # Set to True to get a public link (use with caution)
debug=False # Set to True for more detailed Gradio logs (can be verbose)
)
print("Interface launched. Access it at http://localhost:7860 (or the relevant IP)")
except ImportError as ie:
logger.error(f"ImportError: {ie}. Make sure all required libraries are installed.")
print(f"ImportError: {ie}. Please install the missing library (e.g., pip install {ie.name}).")
except Exception as e:
logger.critical(f"Failed to start the application: {e}", exc_info=True) # Log critical error with stack trace
print(f"Critical error during startup: {e}")