import os import logging import sys from transformers import pipeline # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', stream=sys.stdout ) logger = logging.getLogger(__name__) class QueryRewriter: def __init__(self): self.model = pipeline( "text-generation", model="google/flan-t5-base", # Using a smaller model suitable for Spaces device=-1 # Use CPU ) logger.info("Initialized QueryRewriter with flan-t5-base model") def generate(self, prompt): try: response = self.model( prompt, max_length=256, min_length=32, num_return_sequences=1 )[0]['generated_text'] return response except Exception as e: logger.error(f"Error generating response: {e}") return None def rewrite_cot(self, query): prompt = f""" Rewrite the following query using step-by-step reasoning: Original query: {query} Steps: 1. What is the main question being asked? 2. What are the key components? 3. How can we make it clearer? Rewritten query: """ rewritten_query = self.generate(prompt) if rewritten_query is None: logger.error(f"Error in CoT rewriting for query: {query}") return query, prompt # Return original query if rewriting fails # Extract the rewritten query (everything after "Rewritten query:") try: final_query = rewritten_query.split("Rewritten query:")[-1].strip() return final_query, prompt except Exception as e: logger.error(f"Error extracting rewritten query: {e}") return query, prompt def rewrite_react(self, query): prompt = f""" Rewrite the following query using a systematic approach: Original query: {query} Thought: What information are we looking for? Action: Break down the query into key components Observation: Identify the main focus Rewritten query: """ rewritten_query = self.generate(prompt) if rewritten_query is None: logger.error(f"Error in ReAct rewriting for query: {query}") return query, prompt # Extract the rewritten query try: final_query = rewritten_query.split("Rewritten query:")[-1].strip() return final_query, prompt except Exception as e: logger.error(f"Error extracting rewritten query: {e}") return query, prompt