Spaces:
Running
Running
import os | |
from dotenv import load_dotenv | |
import logging | |
import sys | |
from transformers import pipeline | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
stream=sys.stdout | |
) | |
logger = logging.getLogger(__name__) | |
# Define the RAG prompt template | |
RAG_PROMPT_TEMPLATE = """ | |
You are an AI assistant analyzing YouTube video transcripts. Your task is to answer questions based on the provided transcript context. | |
Context from transcript: | |
{context} | |
User Question: {question} | |
Please provide a clear, concise answer based only on the information given in the context. If the context doesn't contain enough information to fully answer the question, acknowledge this in your response. | |
""".strip() | |
class RAGSystem: | |
def __init__(self, data_processor): | |
self.data_processor = data_processor | |
self.model = pipeline( | |
"text-generation", | |
model="google/flan-t5-base", # Using a smaller model suitable for Spaces | |
device=-1 # Use CPU | |
) | |
logger.info("Initialized RAG system with flan-t5-base model") | |
def generate(self, prompt): | |
try: | |
response = self.model( | |
prompt, | |
max_length=512, | |
min_length=64, | |
num_return_sequences=1 | |
)[0]['generated_text'] | |
return response | |
except Exception as e: | |
logger.error(f"Error generating response: {e}") | |
return None | |
def get_prompt(self, user_query, relevant_docs): | |
context = "\n".join([doc['content'] for doc in relevant_docs]) | |
return RAG_PROMPT_TEMPLATE.format( | |
context=context, | |
question=user_query | |
) | |
def query(self, user_query, search_method='hybrid', index_name=None): | |
try: | |
if not index_name: | |
raise ValueError("No index name provided. Please select a video and ensure it has been processed.") | |
relevant_docs = self.data_processor.search( | |
user_query, | |
num_results=3, | |
method=search_method, | |
index_name=index_name | |
) | |
if not relevant_docs: | |
logger.warning("No relevant documents found for the query.") | |
return "I couldn't find any relevant information to answer your query.", "" | |
prompt = self.get_prompt(user_query, relevant_docs) | |
answer = self.generate(prompt) | |
if not answer: | |
return "I encountered an error generating the response.", prompt | |
return answer, prompt | |
except Exception as e: | |
logger.error(f"An error occurred in the RAG system: {e}") | |
return f"An error occurred: {str(e)}", "" | |
def rewrite_cot(self, query): | |
prompt = f""" | |
Think through this step by step: | |
1. Original query: {query} | |
2. What are the key components of this query? | |
3. How can we break this down into a clearer question? | |
Rewritten query: | |
""" | |
response = self.generate(prompt) | |
if response: | |
return response, prompt | |
return query, prompt | |
def rewrite_react(self, query): | |
prompt = f""" | |
Let's approach this step-by-step: | |
1. Question: {query} | |
2. What information do we need? | |
3. What's the best way to structure this query? | |
Rewritten query: | |
""" | |
response = self.generate(prompt) | |
if response: | |
return response, prompt | |
return query, prompt |