rag-youtube-assistant / app /query_rewriter.py
ganesh3's picture
Update app/query_rewriter.py
f5736db verified
import os
import logging
import sys
from transformers import pipeline
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
stream=sys.stdout
)
logger = logging.getLogger(__name__)
class QueryRewriter:
def __init__(self):
self.model = pipeline(
"text-generation",
model="google/flan-t5-base", # Using a smaller model suitable for Spaces
device=-1 # Use CPU
)
logger.info("Initialized QueryRewriter with flan-t5-base model")
def generate(self, prompt):
try:
response = self.model(
prompt,
max_length=256,
min_length=32,
num_return_sequences=1
)[0]['generated_text']
return response
except Exception as e:
logger.error(f"Error generating response: {e}")
return None
def rewrite_cot(self, query):
prompt = f"""
Rewrite the following query using step-by-step reasoning:
Original query: {query}
Steps:
1. What is the main question being asked?
2. What are the key components?
3. How can we make it clearer?
Rewritten query:
"""
rewritten_query = self.generate(prompt)
if rewritten_query is None:
logger.error(f"Error in CoT rewriting for query: {query}")
return query, prompt # Return original query if rewriting fails
# Extract the rewritten query (everything after "Rewritten query:")
try:
final_query = rewritten_query.split("Rewritten query:")[-1].strip()
return final_query, prompt
except Exception as e:
logger.error(f"Error extracting rewritten query: {e}")
return query, prompt
def rewrite_react(self, query):
prompt = f"""
Rewrite the following query using a systematic approach:
Original query: {query}
Thought: What information are we looking for?
Action: Break down the query into key components
Observation: Identify the main focus
Rewritten query:
"""
rewritten_query = self.generate(prompt)
if rewritten_query is None:
logger.error(f"Error in ReAct rewriting for query: {query}")
return query, prompt
# Extract the rewritten query
try:
final_query = rewritten_query.split("Rewritten query:")[-1].strip()
return final_query, prompt
except Exception as e:
logger.error(f"Error extracting rewritten query: {e}")
return query, prompt