File size: 2,803 Bytes
507c938
 
ae8d0fc
f5736db
507c938
f5736db
5de591d
 
 
 
 
507c938
dbd33b2
 
 
f5736db
 
 
 
 
 
507c938
 
 
f5736db
 
 
 
 
 
 
507c938
 
f5736db
dbd33b2
 
 
f5736db
 
 
 
 
 
 
 
dbd33b2
 
 
f5736db
507c938
f5736db
 
507c938
f5736db
 
 
 
 
 
 
 
dbd33b2
 
 
f5736db
dbd33b2
f5736db
dbd33b2
f5736db
 
 
dbd33b2
f5736db
dbd33b2
f5736db
507c938
f5736db
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import logging
import sys
from transformers import pipeline

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    stream=sys.stdout
)
logger = logging.getLogger(__name__)

class QueryRewriter:
    def __init__(self):
        self.model = pipeline(
            "text-generation",
            model="google/flan-t5-base",  # Using a smaller model suitable for Spaces
            device=-1  # Use CPU
        )
        logger.info("Initialized QueryRewriter with flan-t5-base model")

    def generate(self, prompt):
        try:
            response = self.model(
                prompt,
                max_length=256,
                min_length=32,
                num_return_sequences=1
            )[0]['generated_text']
            return response
        except Exception as e:
            logger.error(f"Error generating response: {e}")
            return None

    def rewrite_cot(self, query):
        prompt = f"""
        Rewrite the following query using step-by-step reasoning:
        
        Original query: {query}
        
        Steps:
        1. What is the main question being asked?
        2. What are the key components?
        3. How can we make it clearer?
        
        Rewritten query:
        """
        
        rewritten_query = self.generate(prompt)
        if rewritten_query is None:
            logger.error(f"Error in CoT rewriting for query: {query}")
            return query, prompt  # Return original query if rewriting fails
            
        # Extract the rewritten query (everything after "Rewritten query:")
        try:
            final_query = rewritten_query.split("Rewritten query:")[-1].strip()
            return final_query, prompt
        except Exception as e:
            logger.error(f"Error extracting rewritten query: {e}")
            return query, prompt

    def rewrite_react(self, query):
        prompt = f"""
        Rewrite the following query using a systematic approach:
        
        Original query: {query}
        
        Thought: What information are we looking for?
        Action: Break down the query into key components
        Observation: Identify the main focus
        
        Rewritten query:
        """
        
        rewritten_query = self.generate(prompt)
        if rewritten_query is None:
            logger.error(f"Error in ReAct rewriting for query: {query}")
            return query, prompt
            
        # Extract the rewritten query
        try:
            final_query = rewritten_query.split("Rewritten query:")[-1].strip()
            return final_query, prompt
        except Exception as e:
            logger.error(f"Error extracting rewritten query: {e}")
            return query, prompt