Spaces:
Running
Running
import os | |
import logging | |
import json | |
import numpy as np | |
import faiss | |
from typing import List, Dict, Any | |
import gradio as gr | |
from langchain.prompts import PromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate | |
from langchain.chat_models import ChatOpenAI | |
from langchain import OpenAI | |
from sentence_transformers import SentenceTransformer | |
# Configure logging | |
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') | |
# Load API key from environment | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
if not OPENAI_API_KEY: | |
raise ValueError("API key is missing. Set OPENAI_API_KEY in Hugging Face Secrets.") | |
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY | |
# logging.debug(f"Using OpenAI API Key: {OPENAI_API_KEY[:5]}... (truncated for security)") | |
# Load FAISS index and chunked data | |
logging.debug("Loading FAISS index and chunked data...") | |
faiss_index = faiss.read_index("fp16_faiss_embeddings.index") | |
with open("all_chunked_data.json", "r") as f: | |
all_chunked_data = json.load(f) | |
logging.debug("FAISS index and chunked data loaded successfully.") | |
# Log random FAISS index for verification | |
random_index = np.random.randint(0, len(all_chunked_data)) | |
logging.debug(f"Random FAISS index verification: {random_index}") | |
logging.debug(f"Corresponding chunk: {all_chunked_data[random_index]['text'][:100]}...") | |
logging.debug("Loading and configuring the embedding model...") | |
model = SentenceTransformer( | |
"dunzhang/stella_en_400M_v5", | |
trust_remote_code=True, | |
device="cpu", | |
config_kwargs={"use_memory_efficient_attention": False, "unpad_inputs": False} | |
) | |
logging.debug("Embedding model loaded successfully.") | |
# Test embedding model | |
import time | |
start_time = time.time() | |
logging.debug("Testing embedding model with a sample query...") | |
try: | |
query_embedding = model.encode(["test query"], show_progress_bar=False) | |
logging.debug(f"Embedding shape: {query_embedding.shape}") | |
logging.debug(f"Encoding took {time.time() - start_time:.2f} seconds") | |
except Exception as e: | |
logging.error(f"Error in embedding model test: {repr(e)}") | |
logging.error(f"Error details: {str(e)}") | |
import traceback | |
logging.error(f"Traceback: {traceback.format_exc()}") | |
# ======================= | |
# Test Embeddings | |
# ======================= | |
# Check the size of the FAISS index | |
# logging.debug(f"Number of embeddings in FAISS index: {faiss_index.ntotal}") | |
# logging.debug("") | |
# logging.debug("") | |
# # Retrieve embeddings from FAISS index (first 'k' embeddings) | |
# k = 2 # Number of embeddings to retrieve for verification | |
# stored_embeddings = np.zeros((k, 1024), dtype='float32') # 1024 is the embedding dimension | |
# faiss_index.reconstruct_n(0, k, stored_embeddings) | |
# # Compare with original embeddings (for example, the first 5 chunks) | |
# original_embeddings = model.encode(all_chunked_data[:k]) | |
# # Print or compare both to check if they match | |
# logging.debug(f"Original Embeddings: {original_embeddings}") | |
# logging.debug(f"Stored Embeddings from FAISS index: {stored_embeddings}") | |
# logging.debug("") | |
# logging.debug("") | |
# # Query one of the chunks and check if FAISS returns the correct nearest neighbor | |
# query_embedding = model.encode([all_chunked_data[0]]) # Encode the first chunk | |
# D, I = faiss_index.search(np.array(query_embedding, dtype='float32'), k=1) # Search for top-1 match | |
# logging.debug(f"Distance: {D}, Index: {I}") | |
# # Check if the index corresponds to the same chunk | |
# logging.debug(f"Queried Chunk: {all_chunked_data[0]}") | |
# logging.debug(f"Matched Chunk: {all_chunked_data[I[0][0]]}") | |
# logging.debug("") | |
# logging.debug("") | |
# # Check the dimensionality of the FAISS index | |
# logging.debug(f"Dimension of embeddings in FAISS index: {faiss_index.d}") | |
CHUNK_SIZE = 400 # Roughly 400 words | |
CHUNK_OVERLAP = 50 # 50 words overlap | |
LLM_MODEL_NAME = "gpt-4o-mini" # Use latest model "o1-mini" much better but paid | |
LLM_TEMPERATURE = 0 | |
TOP_K_RETRIEVAL = 3 | |
# ======================= | |
# Prompt Configuration | |
# ======================= | |
def create_chat_prompt(): | |
"""Create a chat prompt template for the AI model.""" | |
chat_prompt_template = """ | |
You are AQUABOTICA, the most advanced AI assistant specializing in aquaculture information. | |
Given a specific query, analyze the provided context extracted from academic documents, and also use your knowledge to generate a precise and concise answer. Also, If the the context contains some quantitative figures, do mention them. | |
Avoid LaTeX or complex math formatting, use plain text for maths. | |
**Query:** {question} | |
**Context:** {context} | |
**Response:** | |
""" | |
prompt = PromptTemplate( | |
template=chat_prompt_template, | |
input_variables=['context', 'question'] | |
) | |
chat_prompt = ChatPromptTemplate( | |
input_variables=['context', 'question'], | |
metadata={ | |
'lc_hub_owner': 'aquabotica', | |
'lc_hub_repo': 'aquaculture-research', | |
'lc_hub_commit_hash': 'a7b9c123abc12345f6789e123456def123456789' # Adjust commit hash if required | |
}, | |
messages=[ | |
HumanMessagePromptTemplate(prompt=prompt) | |
] | |
) | |
return chat_prompt | |
# ======================= | |
# Metadata Formatting | |
# ======================= | |
def format_metadata(chunk_id: int, all_chunked_data: List[Dict[str, Any]]) -> str: | |
"""Format metadata directly from the chunked data for a given chunk ID.""" | |
chunk = all_chunked_data[chunk_id] | |
logging.debug(f"Chunk Retrieved: {chunk['text'][:100]}...") # Print first 100 characters | |
logging.debug(f"Metadata: {chunk['metadata']}") | |
metadata = chunk.get('metadata', {}) | |
return f"Chunk {chunk_id}: {metadata}" | |
# ======================= | |
# Language Model and Retrieval Setup | |
# ======================= | |
def initialize_llm(model_name=LLM_MODEL_NAME, temperature=LLM_TEMPERATURE): | |
"""Initialize the language model.""" | |
logging.debug("Initializing LLM model...") | |
return ChatOpenAI(model_name=model_name, temperature=temperature,openai_api_key=OPENAI_API_KEY) | |
def main(QUESTION=""): | |
logging.debug(f"Received user query: {QUESTION}") | |
chat_prompt = create_chat_prompt() | |
llm = initialize_llm() | |
# Query FAISS Index | |
try: | |
logging.debug("Encoding query for FAISS retrieval...") | |
query_embedding = model.encode([QUESTION]) | |
logging.debug(f"Query embedding: {query_embedding[:5]}... (truncated)") | |
D, I = faiss_index.search(np.array(query_embedding, dtype='float32'), k=3) | |
relevant_chunk_ids = I[0] | |
logging.debug(f"Retrieved chunk IDs: {relevant_chunk_ids}, Distances: {D}") | |
relevant_chunks = [all_chunked_data[i]['text'] for i in relevant_chunk_ids] | |
#### | |
#### | |
context_display = "\n\n".join([ | |
f"Chunk {idx+1}: {chunk[:]}...\nMetadata: {all_chunked_data[i]['metadata']}" | |
for idx, (i, chunk) in enumerate(zip(relevant_chunk_ids, relevant_chunks)) | |
]) | |
#### | |
#### | |
# context = "\n\n".join([f"Retrieved Chunk: {chunk}\nMetadata: {all_chunked_data[i]['metadata']}" for i, chunk in zip(relevant_chunk_ids, relevant_chunks)]) | |
context = " ".join(relevant_chunks) | |
except Exception as e: | |
logging.error(f"Error during FAISS search: {e}") | |
return f"Error during FAISS search: {e}" | |
# Generate Response | |
try: | |
logging.debug("Formatting input for LLM...") | |
prompt_input = chat_prompt.format(context=context, question=QUESTION) | |
logging.debug(f"Formatted prompt: {prompt_input}") | |
result = llm.invoke(prompt_input) | |
answer = result.content if hasattr(result, 'content') else "No answer found." | |
logging.debug("LLM successfully generated response.") | |
except Exception as e: | |
logging.error(f"Error during LLM execution: {e}") | |
return f"Error during LLM execution: {e}" | |
return answer, context_display | |
# relevant_chunks_metadata = [format_metadata(chunk_id, all_chunked_data) for chunk_id in relevant_chunk_ids] | |
# return f"\n{answer}\n\n" + context | |
# return f"\n{answer}\n\n" + "\n"+ "\n".join(relevant_chunks_metadata) | |
# iface = gr.Interface( | |
# fn=main, | |
# inputs="text", | |
# outputs="text", | |
# title="Aquabotica: Aquaculture Chatbot", | |
# description="Ask questions about aquaculture and get answers based on scientific manuals." | |
# ) | |
# if __name__ == "__main__": | |
# logging.debug("Launching Gradio UI...") | |
# iface.launch() | |
# # Updated CSS | |
# custom_css = """ | |
# /* Style for labels across all components */ | |
# .question-input label span, | |
# .solution-output label span, | |
# .metadata-output label span { | |
# font-size: 20px !important; | |
# font-weight: bold !important; | |
# } | |
# /* Style for the submit button */ | |
# .submit-btn button { | |
# background-color: orange !important; | |
# color: black !important; | |
# font-weight: bold !important; | |
# } | |
# /* Preserve newlines and enable horizontal scrolling */ | |
# .metadata-output textarea { | |
# white-space: pre !important; | |
# overflow-x: auto !important; | |
# padding: 8px !important; | |
# } | |
# """ | |
# with gr.Blocks(css=custom_css) as demo: | |
# with gr.Column(): | |
# question_input = gr.Textbox( | |
# label="Ask a Question relevant to provided Aquaculture documents", | |
# lines=2, | |
# placeholder="Enter your question here", | |
# elem_classes="question-input" | |
# ) | |
# submit_btn = gr.Button("Submit", elem_classes="submit-btn") | |
# solution_output = gr.Textbox( | |
# label="Response", | |
# interactive=False, | |
# lines=5, | |
# elem_classes="solution-output" # Added missing class | |
# ) | |
# retrieved_chunks = gr.Textbox( | |
# label="Retrieved Data", | |
# interactive=False, | |
# lines=5, | |
# elem_classes="metadata-output" | |
# ) | |
# submit_btn.click(main, inputs=question_input, outputs=[solution_output, retrieved_chunks]) | |
# demo.launch() | |
custom_css = """ | |
/* Style for labels across all components */ | |
.question-input label span, | |
.solution-output label span, | |
.metadata-output label span { | |
font-size: 20px !important; | |
font-weight: bold !important; | |
color: orange !important; | |
} | |
/* Correct style for the submit button */ | |
.submit-btn button { | |
background-color: orange !important; | |
color: black !important; | |
font-weight: bold !important; | |
border: none !important; | |
border-radius: 8px !important; | |
padding: 10px 20px !important; | |
cursor: pointer !important; | |
} | |
/* Hover effect for submit button */ | |
.submit-btn button:hover { | |
background-color: darkorange !important; | |
} | |
/* Preserve newlines and enable horizontal scrolling in retrieved documents */ | |
.metadata-output textarea { | |
white-space: pre !important; | |
overflow-x: auto !important; | |
padding: 8px !important; | |
} | |
""" | |
with gr.Blocks(css=custom_css) as demo: | |
with gr.Column(): | |
question_input = gr.Textbox( | |
label="Ask a Question", | |
lines=2, | |
placeholder="Enter your question here", | |
elem_classes="question-input" | |
) | |
submit_btn = gr.Button( | |
"Submit", | |
elem_classes="submit-btn" | |
) | |
solution_output = gr.Textbox( | |
label="Response", | |
interactive=False, | |
lines=5, | |
elem_classes="solution-output" | |
) | |
retrieved_chunks = gr.Textbox( | |
label="Retrieved Data/Documents", | |
interactive=False, | |
lines=5, | |
elem_classes="metadata-output" | |
) | |
submit_btn.click( | |
main, | |
inputs=question_input, | |
outputs=[solution_output, retrieved_chunks] | |
) | |
demo.launch() | |