0r0b0r0s commited on
Commit
1f55538
·
verified ·
1 Parent(s): 4505e70

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -38
app.py CHANGED
@@ -4,6 +4,7 @@ import requests
4
  import inspect
5
  import pandas as pd
6
  from huggingface_hub import HfApi, InferenceClient, login
 
7
 
8
 
9
  # (Keep Constants as is)
@@ -15,52 +16,85 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
15
  # --- GAIA-Optimized Agent Implementation ---
16
  from huggingface_hub import InferenceClient, login
17
 
 
 
 
 
 
 
18
 
19
- class BasicAgent:
20
  def __init__(self):
21
- login(token=os.environ["HF_TOKEN"]) # Required authentication
22
-
23
- # Primary model (7B quantized)
24
- self.client = InferenceClient(
25
- model="Qwen/Qwen2-0.5B-Instruct", # 1.2GB, free-tier compatible
26
- token=os.environ["HF_TOKEN"],
27
- timeout=60
28
- )
29
-
30
- # Verify model access
31
- test_response = self._call_model("2+2=")
32
- if "4" not in test_response:
33
- raise RuntimeError("Model initialization failed")
34
-
35
- def _call_model(self, question: str) -> str:
36
- """Optimized prompt engineering for GAIA"""
37
- prompt = f"""<|im_start|>system
38
- Answer with ONLY the exact value requested. No explanations.<|im_end|>
39
- <|im_start|>user
40
- {question}<|im_end|>
41
- <|im_start|>assistant
42
- """
43
- return self.client.text_generation(
44
- prompt=prompt,
45
- temperature=0.01,
46
  max_new_tokens=50,
47
- stop_sequences=["<|im_end|>"]
48
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def __call__(self, question: str) -> str:
51
- try:
52
- raw_response = self._call_model(question)
53
-
54
- # Robust answer extraction
55
- answer = raw_response.split("<|im_start|>assistant")[-1]
56
- answer = answer.split("<|im_end|>")[0].strip()
 
 
57
 
58
- # GAIA-compliant normalization
59
- return re.sub(r'[^a-zA-Z0-9]', '', answer).lower()
60
- except Exception as e:
61
- print(f"Error: {str(e)}")
62
- return ""
63
 
 
64
  def run_and_submit_all( profile: gr.OAuthProfile | None):
65
  """
66
  Fetches all questions, runs the BasicAgent on them, submits all answers,
 
4
  import inspect
5
  import pandas as pd
6
  from huggingface_hub import HfApi, InferenceClient, login
7
+ from langgraph.graph import StateGraph, END
8
 
9
 
10
  # (Keep Constants as is)
 
16
  # --- GAIA-Optimized Agent Implementation ---
17
  from huggingface_hub import InferenceClient, login
18
 
19
+ # Configure models
20
+ MODELS = [
21
+ "Qwen/Qwen2-0.5B-Instruct",
22
+ "google/flan-t5-xxl",
23
+ "mistralai/Mistral-7B-Instruct-v0.2"
24
+ ]
25
 
26
+ class AgentState:
27
  def __init__(self):
28
+ self.question = ""
29
+ self.retries = 0
30
+ self.current_model = 0
31
+ self.answer = ""
32
+
33
+ # Initialize clients
34
+ clients = [InferenceClient(model=model, token=os.environ["HF_TOKEN"]) for model in MODELS]
35
+
36
+ def model_router(state):
37
+ """Rotate through available models"""
38
+ state.current_model = (state.current_model + 1) % len(MODELS)
39
+ return state
40
+
41
+ def query_model(state):
42
+ """Attempt to get answer from current model"""
43
+ try:
44
+ response = clients[state.current_model].text_generation(
45
+ prompt=f"GAIA Question: {state.question}\nAnswer:",
 
 
 
 
 
 
 
46
  max_new_tokens=50,
47
+ temperature=0.01
48
  )
49
+ state.answer = response.split("Answer:")[-1].strip()
50
+ return state
51
+ except Exception as e:
52
+ print(f"Error with {MODELS[state.current_model]}: {str(e)}")
53
+ state.retries += 1
54
+ return state
55
+
56
+ def validate_answer(state):
57
+ """Basic GAIA answer validation"""
58
+ if len(state.answer) > 0 and 2 <= len(state.answer) <= 100:
59
+ return "final_answer"
60
+ return "retry"
61
+
62
+ # Build workflow
63
+ workflow = StateGraph(AgentState)
64
 
65
+ workflow.add_node("route_model", model_router)
66
+ workflow.add_node("query", query_model)
67
+ workflow.add_node("validate", validate_answer)
68
+
69
+ workflow.add_edge("route_model", "query")
70
+ workflow.add_edge("query", "validate")
71
+
72
+ workflow.add_conditional_edges(
73
+ "validate",
74
+ lambda x: "final_answer" if x.answer else "retry",
75
+ {
76
+ "final_answer": END,
77
+ "retry": "route_model"
78
+ }
79
+ )
80
+
81
+ workflow.set_entry_point("route_model")
82
+ agent = workflow.compile()
83
+
84
+ class BasicAgent:
85
  def __call__(self, question: str) -> str:
86
+ state = AgentState()
87
+ state.question = question
88
+
89
+ for _ in range(3): # Max 3 attempts
90
+ state = agent.invoke(state)
91
+ if state.answer:
92
+ return state.answer
93
+ time.sleep(1) # Backoff
94
 
95
+ return "" # Return empty to preserve scoring
 
 
 
 
96
 
97
+
98
  def run_and_submit_all( profile: gr.OAuthProfile | None):
99
  """
100
  Fetches all questions, runs the BasicAgent on them, submits all answers,