ganesh3 commited on
Commit
f5736db
·
verified ·
1 Parent(s): 9da39b7

Update app/query_rewriter.py

Browse files
Files changed (1) hide show
  1. app/query_rewriter.py +53 -28
app/query_rewriter.py CHANGED
@@ -1,9 +1,9 @@
1
  import os
2
- import ollama
3
  import logging
4
  import sys
 
5
 
6
- # Configure logging for stdout only
7
  logging.basicConfig(
8
  level=logging.INFO,
9
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
@@ -13,50 +13,75 @@ logger = logging.getLogger(__name__)
13
 
14
  class QueryRewriter:
15
  def __init__(self):
16
- self.model = os.getenv('OLLAMA_MODEL', "phi3")
17
- self.ollama_host = os.getenv('OLLAMA_HOST', 'http://ollama:11434')
 
 
 
 
18
 
19
  def generate(self, prompt):
20
  try:
21
- response = ollama.chat(
22
- model=self.model,
23
- messages=[{"role": "user", "content": prompt}]
24
- )
25
- return response['message']['content']
 
 
26
  except Exception as e:
27
  logger.error(f"Error generating response: {e}")
28
- return f"Error: {str(e)}"
29
 
30
  def rewrite_cot(self, query):
31
  prompt = f"""
32
- Rewrite the following query using Chain-of-Thought reasoning:
33
- Query: {query}
 
 
 
 
 
 
34
 
35
  Rewritten query:
36
  """
 
37
  rewritten_query = self.generate(prompt)
38
- if rewritten_query.startswith("Error:"):
39
- logger.error(f"Error in CoT rewriting: {rewritten_query}")
40
  return query, prompt # Return original query if rewriting fails
41
- return rewritten_query, prompt
 
 
 
 
 
 
 
42
 
43
  def rewrite_react(self, query):
44
  prompt = f"""
45
- Rewrite the following query using the ReAct framework (Reasoning and Acting):
46
- Query: {query}
47
 
48
- Thought 1:
49
- Action 1:
50
- Observation 1:
51
 
52
- Thought 2:
53
- Action 2:
54
- Observation 2:
55
 
56
- Final rewritten query:
57
  """
 
58
  rewritten_query = self.generate(prompt)
59
- if rewritten_query.startswith("Error:"):
60
- logger.error(f"Error in ReAct rewriting: {rewritten_query}")
61
- return query, prompt # Return original query if rewriting fails
62
- return rewritten_query, prompt
 
 
 
 
 
 
 
 
1
  import os
 
2
  import logging
3
  import sys
4
+ from transformers import pipeline
5
 
6
+ # Configure logging
7
  logging.basicConfig(
8
  level=logging.INFO,
9
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
 
13
 
14
  class QueryRewriter:
15
  def __init__(self):
16
+ self.model = pipeline(
17
+ "text-generation",
18
+ model="google/flan-t5-base", # Using a smaller model suitable for Spaces
19
+ device=-1 # Use CPU
20
+ )
21
+ logger.info("Initialized QueryRewriter with flan-t5-base model")
22
 
23
  def generate(self, prompt):
24
  try:
25
+ response = self.model(
26
+ prompt,
27
+ max_length=256,
28
+ min_length=32,
29
+ num_return_sequences=1
30
+ )[0]['generated_text']
31
+ return response
32
  except Exception as e:
33
  logger.error(f"Error generating response: {e}")
34
+ return None
35
 
36
  def rewrite_cot(self, query):
37
  prompt = f"""
38
+ Rewrite the following query using step-by-step reasoning:
39
+
40
+ Original query: {query}
41
+
42
+ Steps:
43
+ 1. What is the main question being asked?
44
+ 2. What are the key components?
45
+ 3. How can we make it clearer?
46
 
47
  Rewritten query:
48
  """
49
+
50
  rewritten_query = self.generate(prompt)
51
+ if rewritten_query is None:
52
+ logger.error(f"Error in CoT rewriting for query: {query}")
53
  return query, prompt # Return original query if rewriting fails
54
+
55
+ # Extract the rewritten query (everything after "Rewritten query:")
56
+ try:
57
+ final_query = rewritten_query.split("Rewritten query:")[-1].strip()
58
+ return final_query, prompt
59
+ except Exception as e:
60
+ logger.error(f"Error extracting rewritten query: {e}")
61
+ return query, prompt
62
 
63
  def rewrite_react(self, query):
64
  prompt = f"""
65
+ Rewrite the following query using a systematic approach:
 
66
 
67
+ Original query: {query}
 
 
68
 
69
+ Thought: What information are we looking for?
70
+ Action: Break down the query into key components
71
+ Observation: Identify the main focus
72
 
73
+ Rewritten query:
74
  """
75
+
76
  rewritten_query = self.generate(prompt)
77
+ if rewritten_query is None:
78
+ logger.error(f"Error in ReAct rewriting for query: {query}")
79
+ return query, prompt
80
+
81
+ # Extract the rewritten query
82
+ try:
83
+ final_query = rewritten_query.split("Rewritten query:")[-1].strip()
84
+ return final_query, prompt
85
+ except Exception as e:
86
+ logger.error(f"Error extracting rewritten query: {e}")
87
+ return query, prompt