Datawithsarah commited on
Commit
58aeeb9
·
1 Parent(s): 0c02b02

switched to qwen

Browse files
Files changed (3) hide show
  1. agent.py +20 -22
  2. app.py +9 -4
  3. requirements.txt +2 -0
agent.py CHANGED
@@ -18,7 +18,6 @@ from langchain.tools.retriever import create_retriever_tool
18
  from supabase.client import Client, create_client
19
  import re
20
 
21
- # === Load environment ===
22
  load_dotenv()
23
 
24
  # === Tools ===
@@ -72,7 +71,7 @@ with open("system_prompt.txt", "r", encoding="utf-8") as f:
72
  system_prompt = f.read()
73
  sys_msg = SystemMessage(content=system_prompt)
74
 
75
- # === Embedding & Vector DB ===
76
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
77
  supabase: Client = create_client(os.getenv("SUPABASE_URL"), os.getenv("SUPABASE_SERVICE_KEY"))
78
  vector_store = SupabaseVectorStore(
@@ -85,31 +84,30 @@ vector_store = SupabaseVectorStore(
85
  # === Tools ===
86
  tools = [multiply, add, subtract, divide, modulus, wiki_search, web_search, arvix_search]
87
 
88
- # === LangGraph Agent Definition ===
89
- def build_graph(provider: str = "claude"):
90
- if provider == "claude":
91
- llm = ChatAnthropic(
92
- model="claude-3-opus-20240229",
93
- temperature=0,
94
- anthropic_api_key=os.getenv("CLAUDE_API_KEY")
 
95
  )
96
- elif provider == "groq":
97
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
98
- elif provider == "google":
99
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
100
- elif provider == "huggingface":
101
- llm = ChatHuggingFace(llm=HuggingFaceEndpoint(
102
- url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
103
- temperature=0))
104
  else:
105
- raise ValueError("Invalid provider")
106
 
107
  llm_with_tools = llm.bind_tools(tools)
108
 
109
  def retriever(state: MessagesState):
110
  query = state["messages"][-1].content
111
  similar = vector_store.similarity_search(query)
112
- return {"messages": [sys_msg, state["messages"][-1], HumanMessage(content=f"Reference: {similar[0].page_content}")]}
 
 
 
 
 
 
113
 
114
  def assistant(state: MessagesState):
115
  response = llm_with_tools.invoke(state["messages"])
@@ -117,7 +115,6 @@ def build_graph(provider: str = "claude"):
117
 
118
  def formatter(state: MessagesState):
119
  last = state["messages"][-1].content.strip()
120
-
121
  cleaned = re.sub(r"<.*?>", "", last)
122
  cleaned = re.sub(r"(Final\s*Answer:|Answer:)", "", cleaned, flags=re.IGNORECASE)
123
  cleaned = cleaned.strip().split("\n")[0].strip()
@@ -128,6 +125,7 @@ def build_graph(provider: str = "claude"):
128
  builder.add_node("assistant", assistant)
129
  builder.add_node("tools", ToolNode(tools))
130
  builder.add_node("formatter", formatter)
 
131
  builder.add_edge(START, "retriever")
132
  builder.add_edge("retriever", "assistant")
133
  builder.add_conditional_edges("assistant", tools_condition)
@@ -136,9 +134,9 @@ def build_graph(provider: str = "claude"):
136
 
137
  return builder.compile()
138
 
139
- # === Test ===
140
  if __name__ == "__main__":
141
- graph = build_graph("claude")
142
  result = graph.invoke({"messages": [HumanMessage(content="What is the capital of France?")]})
143
  for m in result["messages"]:
144
  m.pretty_print()
 
18
  from supabase.client import Client, create_client
19
  import re
20
 
 
21
  load_dotenv()
22
 
23
  # === Tools ===
 
71
  system_prompt = f.read()
72
  sys_msg = SystemMessage(content=system_prompt)
73
 
74
+ # === Embeddings & Vector Store ===
75
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
76
  supabase: Client = create_client(os.getenv("SUPABASE_URL"), os.getenv("SUPABASE_SERVICE_KEY"))
77
  vector_store = SupabaseVectorStore(
 
84
  # === Tools ===
85
  tools = [multiply, add, subtract, divide, modulus, wiki_search, web_search, arvix_search]
86
 
87
+ # === LangGraph Builder ===
88
+ def build_graph(provider: str = "huggingface"):
89
+ if provider == "huggingface":
90
+ llm = ChatHuggingFace(
91
+ llm=HuggingFaceEndpoint(
92
+ repo_id="Qwen/Qwen1.5-7B-Chat",
93
+ temperature=0,
94
+ )
95
  )
 
 
 
 
 
 
 
 
96
  else:
97
+ raise ValueError("Only 'huggingface' (Qwen3) is supported in this build.")
98
 
99
  llm_with_tools = llm.bind_tools(tools)
100
 
101
  def retriever(state: MessagesState):
102
  query = state["messages"][-1].content
103
  similar = vector_store.similarity_search(query)
104
+ return {
105
+ "messages": [
106
+ sys_msg,
107
+ state["messages"][-1],
108
+ HumanMessage(content=f"Reference: {similar[0].page_content}")
109
+ ]
110
+ }
111
 
112
  def assistant(state: MessagesState):
113
  response = llm_with_tools.invoke(state["messages"])
 
115
 
116
  def formatter(state: MessagesState):
117
  last = state["messages"][-1].content.strip()
 
118
  cleaned = re.sub(r"<.*?>", "", last)
119
  cleaned = re.sub(r"(Final\s*Answer:|Answer:)", "", cleaned, flags=re.IGNORECASE)
120
  cleaned = cleaned.strip().split("\n")[0].strip()
 
125
  builder.add_node("assistant", assistant)
126
  builder.add_node("tools", ToolNode(tools))
127
  builder.add_node("formatter", formatter)
128
+
129
  builder.add_edge(START, "retriever")
130
  builder.add_edge("retriever", "assistant")
131
  builder.add_conditional_edges("assistant", tools_condition)
 
134
 
135
  return builder.compile()
136
 
137
+ # === Run Test ===
138
  if __name__ == "__main__":
139
+ graph = build_graph()
140
  result = graph.invoke({"messages": [HumanMessage(content="What is the capital of France?")]})
141
  for m in result["messages"]:
142
  m.pretty_print()
app.py CHANGED
@@ -13,8 +13,8 @@ cached_answers = []
13
 
14
  class ChatAgent:
15
  def __init__(self):
16
- print("ChatAgent initialized with Claude LangGraph workflow.")
17
- self.graph = build_graph("claude")
18
 
19
  def __call__(self, question: str) -> str:
20
  print(f"Processing question: {question[:60]}...")
@@ -59,7 +59,11 @@ def run_agent_only(profile: gr.OAuthProfile | None):
59
  cached_answers.append({"task_id": task_id, "submitted_answer": answer})
60
  results_log.append({"Task ID": task_id, "Question": question, "Submitted Answer": answer})
61
  except Exception as e:
62
- results_log.append({"Task ID": task_id, "Question": question, "Submitted Answer": f"AGENT ERROR: {e}"})
 
 
 
 
63
 
64
  return "Agent finished. Now click 'Submit Cached Answers'", pd.DataFrame(results_log)
65
 
@@ -83,7 +87,8 @@ def submit_cached_answers(profile: gr.OAuthProfile | None):
83
  result = response.json()
84
  final_status = (
85
  f"Submission Successful!\nUser: {result.get('username')}\n"
86
- f"Score: {result.get('score', 'N/A')}% ({result.get('correct_count', '?')}/{result.get('total_attempted', '?')})"
 
87
  )
88
  return final_status, None
89
  except Exception as e:
 
13
 
14
  class ChatAgent:
15
  def __init__(self):
16
+ print("ChatAgent initialized with Qwen LangGraph workflow.")
17
+ self.graph = build_graph("huggingface") # Uses Qwen endpoint
18
 
19
  def __call__(self, question: str) -> str:
20
  print(f"Processing question: {question[:60]}...")
 
59
  cached_answers.append({"task_id": task_id, "submitted_answer": answer})
60
  results_log.append({"Task ID": task_id, "Question": question, "Submitted Answer": answer})
61
  except Exception as e:
62
+ results_log.append({
63
+ "Task ID": task_id,
64
+ "Question": question,
65
+ "Submitted Answer": f"AGENT ERROR: {e}"
66
+ })
67
 
68
  return "Agent finished. Now click 'Submit Cached Answers'", pd.DataFrame(results_log)
69
 
 
87
  result = response.json()
88
  final_status = (
89
  f"Submission Successful!\nUser: {result.get('username')}\n"
90
+ f"Score: {result.get('score', 'N/A')}% "
91
+ f"({result.get('correct_count', '?')}/{result.get('total_attempted', '?')})"
92
  )
93
  return final_status, None
94
  except Exception as e:
requirements.txt CHANGED
@@ -11,9 +11,11 @@ langchain-tavily
11
  langchain-chroma
12
  langgraph
13
  huggingface_hub
 
14
  supabase
15
  arxiv
16
  pymupdf
17
  wikipedia
18
  pgvector
19
  python-dotenv
 
 
11
  langchain-chroma
12
  langgraph
13
  huggingface_hub
14
+ sentence-transformers
15
  supabase
16
  arxiv
17
  pymupdf
18
  wikipedia
19
  pgvector
20
  python-dotenv
21
+ tqdm