0r0b0r0s commited on
Commit
443c6c3
·
verified ·
1 Parent(s): 25b21d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -58
app.py CHANGED
@@ -5,83 +5,111 @@ import re # Added missing import
5
  import pandas as pd
6
  from langgraph.graph import StateGraph, END
7
  from huggingface_hub import InferenceClient
8
- import time # Added missing import
 
 
 
 
 
9
 
10
  # --- Constants ---
11
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
12
 
13
- # --- Optimized Agent Implementation ---
14
- MODELS = [
15
- "Qwen/Qwen2-0.5B-Instruct",
16
- "google/flan-t5-xxl",
17
- "mistralai/Mistral-7B-Instruct-v0.2"
18
- ]
19
 
20
- clients = [InferenceClient(model=model, token=os.environ["HF_TOKEN"]) for model in MODELS]
 
 
 
 
21
 
22
- def model_router(state: dict) -> dict:
23
- """Rotate through available models"""
24
- state["current_model"] = (state["current_model"] + 1) % len(MODELS)
25
- return state
26
-
27
- def query_model(state: dict) -> dict:
28
- """Generate answer with error handling"""
29
- try:
30
- response = clients[state["current_model"]].text_generation(
31
- prompt=f"""<|im_start|>system
32
- Answer with ONLY the exact value requested.<|im_end|>
33
- <|im_start|>user
34
- {state['question']}<|im_end|>
35
- <|im_start|>assistant
36
- """,
37
- temperature=0.01,
38
- max_new_tokens=50,
39
- stop_sequences=["<|im_end|>"]
40
  )
41
- # Fixed answer extraction
42
- answer_part = response.split("<|im_start|>assistant")[-1]
43
- answer = answer_part.split("<|im_end|>")[0].strip()
44
- state["answer"] = re.sub(r'[^a-zA-Z0-9]', '', answer).lower()
45
- except Exception as e:
46
- print(f"Model error: {str(e)}")
47
- state["answer"] = ""
48
- return state
 
 
49
 
50
- def should_continue(state: dict) -> str:
51
- """Conditional edge function (not a node)"""
52
- return END if state["answer"] else "route_model"
 
 
 
53
 
54
- # Build workflow
55
- workflow = StateGraph(dict)
56
- workflow.add_node("route_model", model_router)
57
- workflow.add_node("query", query_model)
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- workflow.set_entry_point("route_model")
60
- workflow.add_edge("route_model", "query")
61
- workflow.add_conditional_edges(
62
- "query",
63
- should_continue,
64
- {END: END, "route_model": "route_model"}
65
- )
66
 
67
- compiled_agent = workflow.compile()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- class BasicAgent:
70
  def __call__(self, question: str) -> str:
 
71
  state = {
72
  "question": question,
73
- "retries": 0,
74
- "current_model": 0,
75
- "answer": ""
76
  }
77
 
78
  for _ in range(3): # Max 3 attempts
79
- state = compiled_agent.invoke(state)
80
  if state["answer"]:
81
- return state["answer"]
82
- time.sleep(1)
83
-
84
- return ""
85
 
86
  def run_and_submit_all( profile: gr.OAuthProfile | None):
87
  """
 
5
  import pandas as pd
6
  from langgraph.graph import StateGraph, END
7
  from huggingface_hub import InferenceClient
8
+ from langchain_community.vectorstores import Chroma
9
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
10
+ from typing import TypedDict, Annotated
11
+ import os
12
+ import re
13
+
14
 
15
  # --- Constants ---
16
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
17
 
18
+ # Configuration
19
+ MODEL_ID = "TheBloke/Mistral-7B-Instruct-v0.1-GPTQ" # 4.2GB quantized
20
+ EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
21
+ FALLBACK_MODELS = ["google/flan-t5-base", "mistralai/Mistral-7B-Instruct-v0.2"]
 
 
22
 
23
+ class AgentState(TypedDict):
24
+ question: str
25
+ context: str
26
+ answer: str
27
+ attempts: Annotated[int, lambda x, y: x + 1]
28
 
29
+ class BasicAgent:
30
+ def __init__(self):
31
+ # Initialize components
32
+ self.client = InferenceClient(
33
+ model=MODEL_ID,
34
+ token=os.environ["HF_TOKEN"],
35
+ timeout=120
 
 
 
 
 
 
 
 
 
 
 
36
  )
37
+
38
+ # Initialize vector store (add your documents here)
39
+ self.vectorstore = Chroma.from_texts(
40
+ texts=["GAIA knowledge content..."], # Replace with your documents
41
+ embedding=EMBEDDING_MODEL,
42
+ persist_directory="./chroma_db"
43
+ )
44
+
45
+ # Build LangGraph workflow
46
+ self.workflow = self._build_graph()
47
 
48
+ def _build_graph(self):
49
+ # Define nodes
50
+ def retrieve(state: AgentState):
51
+ docs = self.vectorstore.similarity_search(state["question"], k=3)
52
+ state["context"] = "\n".join([d.page_content for d in docs])
53
+ return state
54
 
55
+ def generate(state: AgentState):
56
+ try:
57
+ response = self.client.text_generation(
58
+ f"""<s>[INST]Answer using ONLY this context:
59
+ {state['context']}
60
+ Question: {state['question']}
61
+ Answer: [/INST]""",
62
+ temperature=0.1,
63
+ max_new_tokens=100,
64
+ stop_sequences=["</s>"]
65
+ )
66
+ state["answer"] = response.split("[/INST]")[-1].strip()
67
+ except Exception:
68
+ state["answer"] = ""
69
+ return state
70
 
71
+ def validate(state: AgentState):
72
+ if len(state["answer"]) > 5 and state["attempts"] < 3:
73
+ return "final"
74
+ return "retry"
 
 
 
75
 
76
+ # Build workflow
77
+ workflow = StateGraph(AgentState)
78
+ workflow.add_node("retrieve", retrieve)
79
+ workflow.add_node("generate", generate)
80
+ workflow.add_node("validate", validate)
81
+
82
+ workflow.set_entry_point("retrieve")
83
+ workflow.add_edge("retrieve", "generate")
84
+ workflow.add_edge("generate", "validate")
85
+
86
+ workflow.add_conditional_edges(
87
+ "validate",
88
+ lambda x: "retry" if x["answer"] == "" else "final",
89
+ {
90
+ "retry": "retrieve",
91
+ "final": END
92
+ }
93
+ )
94
+
95
+ return workflow.compile()
96
 
 
97
  def __call__(self, question: str) -> str:
98
+ # GAIA-compliant formatting
99
  state = {
100
  "question": question,
101
+ "context": "",
102
+ "answer": "",
103
+ "attempts": 0
104
  }
105
 
106
  for _ in range(3): # Max 3 attempts
107
+ state = self.workflow.invoke(state)
108
  if state["answer"]:
109
+ answer = re.sub(r'[^a-zA-Z0-9]', '', state["answer"]).lower()
110
+ return answer[:100] # GAIA length constraint
111
+
112
+ return "" # Preserve scoring eligibility
113
 
114
  def run_and_submit_all( profile: gr.OAuthProfile | None):
115
  """