Spaces:
Running
Running
import os | |
import logging | |
import sys | |
from transformers import pipeline | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
stream=sys.stdout | |
) | |
logger = logging.getLogger(__name__) | |
class QueryRewriter: | |
def __init__(self): | |
self.model = pipeline( | |
"text-generation", | |
model="google/flan-t5-base", # Using a smaller model suitable for Spaces | |
device=-1 # Use CPU | |
) | |
logger.info("Initialized QueryRewriter with flan-t5-base model") | |
def generate(self, prompt): | |
try: | |
response = self.model( | |
prompt, | |
max_length=256, | |
min_length=32, | |
num_return_sequences=1 | |
)[0]['generated_text'] | |
return response | |
except Exception as e: | |
logger.error(f"Error generating response: {e}") | |
return None | |
def rewrite_cot(self, query): | |
prompt = f""" | |
Rewrite the following query using step-by-step reasoning: | |
Original query: {query} | |
Steps: | |
1. What is the main question being asked? | |
2. What are the key components? | |
3. How can we make it clearer? | |
Rewritten query: | |
""" | |
rewritten_query = self.generate(prompt) | |
if rewritten_query is None: | |
logger.error(f"Error in CoT rewriting for query: {query}") | |
return query, prompt # Return original query if rewriting fails | |
# Extract the rewritten query (everything after "Rewritten query:") | |
try: | |
final_query = rewritten_query.split("Rewritten query:")[-1].strip() | |
return final_query, prompt | |
except Exception as e: | |
logger.error(f"Error extracting rewritten query: {e}") | |
return query, prompt | |
def rewrite_react(self, query): | |
prompt = f""" | |
Rewrite the following query using a systematic approach: | |
Original query: {query} | |
Thought: What information are we looking for? | |
Action: Break down the query into key components | |
Observation: Identify the main focus | |
Rewritten query: | |
""" | |
rewritten_query = self.generate(prompt) | |
if rewritten_query is None: | |
logger.error(f"Error in ReAct rewriting for query: {query}") | |
return query, prompt | |
# Extract the rewritten query | |
try: | |
final_query = rewritten_query.split("Rewritten query:")[-1].strip() | |
return final_query, prompt | |
except Exception as e: | |
logger.error(f"Error extracting rewritten query: {e}") | |
return query, prompt |