0r0b0r0s commited on
Commit
f2bd555
·
verified ·
1 Parent(s): 6ba9a09

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -42
app.py CHANGED
@@ -3,8 +3,8 @@ import gradio as gr
3
  import requests
4
  import inspect
5
  import pandas as pd
6
- from huggingface_hub import HfApi, InferenceClient, login
7
  from langgraph.graph import StateGraph, END
 
8
 
9
 
10
  # (Keep Constants as is)
@@ -14,54 +14,56 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
14
  # --- Basic Agent Definition ---
15
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
16
  # --- GAIA-Optimized Agent Implementation ---
17
- from huggingface_hub import InferenceClient, login
18
-
19
- # Configure models
20
  MODELS = [
21
  "Qwen/Qwen2-0.5B-Instruct",
22
  "google/flan-t5-xxl",
23
  "mistralai/Mistral-7B-Instruct-v0.2"
24
  ]
25
 
26
- class AgentState:
27
- def __init__(self):
28
- self.question = ""
29
- self.retries = 0
30
- self.current_model = 0
31
- self.answer = ""
32
-
33
- # Initialize clients
34
  clients = [InferenceClient(model=model, token=os.environ["HF_TOKEN"]) for model in MODELS]
35
 
36
- def model_router(state):
 
 
 
 
 
 
 
 
37
  """Rotate through available models"""
38
- state.current_model = (state.current_model + 1) % len(MODELS)
39
  return state
40
 
41
- def query_model(state):
42
  """Attempt to get answer from current model"""
43
  try:
44
- response = clients[state.current_model].text_generation(
45
- prompt=f"GAIA Question: {state.question}\nAnswer:",
 
 
 
 
 
 
46
  max_new_tokens=50,
47
- temperature=0.01
48
  )
49
- state.answer = response.split("Answer:")[-1].strip()
50
- return state
51
  except Exception as e:
52
- print(f"Error with {MODELS[state.current_model]}: {str(e)}")
53
- state.retries += 1
54
- return state
55
 
56
- def validate_answer(state):
57
- """Basic GAIA answer validation"""
58
- if len(state.answer) > 0 and 2 <= len(state.answer) <= 100:
59
- return "final_answer"
60
- return "retry"
61
 
62
  # Build workflow
63
- workflow = StateGraph(AgentState)
64
-
65
  workflow.add_node("route_model", model_router)
66
  workflow.add_node("query", query_model)
67
  workflow.add_node("validate", validate_answer)
@@ -71,29 +73,26 @@ workflow.add_edge("query", "validate")
71
 
72
  workflow.add_conditional_edges(
73
  "validate",
74
- lambda x: "final_answer" if x.answer else "retry",
75
- {
76
- "final_answer": END,
77
- "retry": "route_model"
78
- }
79
  )
80
 
81
  workflow.set_entry_point("route_model")
82
- agent = workflow.compile()
83
 
 
84
  class BasicAgent:
85
  def __call__(self, question: str) -> str:
86
- state = AgentState()
87
- state.question = question
88
 
89
  for _ in range(3): # Max 3 attempts
90
- state = agent.invoke(state)
91
- if state.answer:
92
- return state.answer
93
  time.sleep(1) # Backoff
94
 
95
  return "" # Return empty to preserve scoring
96
-
97
 
98
  def run_and_submit_all( profile: gr.OAuthProfile | None):
99
  """
 
3
  import requests
4
  import inspect
5
  import pandas as pd
 
6
  from langgraph.graph import StateGraph, END
7
+ from huggingface_hub import InferenceClient
8
 
9
 
10
  # (Keep Constants as is)
 
14
  # --- Basic Agent Definition ---
15
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
16
  # --- GAIA-Optimized Agent Implementation ---
17
+ # Configure fallback models
 
 
18
  MODELS = [
19
  "Qwen/Qwen2-0.5B-Instruct",
20
  "google/flan-t5-xxl",
21
  "mistralai/Mistral-7B-Instruct-v0.2"
22
  ]
23
 
24
+ # Initialize clients with automatic retry
 
 
 
 
 
 
 
25
  clients = [InferenceClient(model=model, token=os.environ["HF_TOKEN"]) for model in MODELS]
26
 
27
+ # Define state structure using dictionary
28
+ initial_state = {
29
+ "question": "",
30
+ "retries": 0,
31
+ "current_model": 0,
32
+ "answer": ""
33
+ }
34
+
35
+ def model_router(state: dict) -> dict:
36
  """Rotate through available models"""
37
+ state["current_model"] = (state["current_model"] + 1) % len(MODELS)
38
  return state
39
 
40
+ def query_model(state: dict) -> dict:
41
  """Attempt to get answer from current model"""
42
  try:
43
+ response = clients[state["current_model"]].text_generation(
44
+ prompt=f"""<|im_start|>system
45
+ Answer with ONLY the exact value requested.<|im_end|>
46
+ <|im_start|>user
47
+ {state['question']}<|im_end|>
48
+ <|im_start|>assistant
49
+ """,
50
+ temperature=0.01,
51
  max_new_tokens=50,
52
+ stop_sequences=["<|im_end|>"]
53
  )
54
+ state["answer"] = response.split("<|im_start|>assistant")[-1].split("<|im_end|>")[0].strip()
55
+ state["answer"] = re.sub(r'[^a-zA-Z0-9]', '', state["answer"]).lower()
56
  except Exception as e:
57
+ print(f"Model error: {str(e)}")
58
+ state["answer"] = ""
59
+ return state
60
 
61
+ def validate_answer(state: dict) -> str:
62
+ """Check if we have a valid answer"""
63
+ return "final_answer" if state["answer"] else "retry"
 
 
64
 
65
  # Build workflow
66
+ workflow = StateGraph(dict)
 
67
  workflow.add_node("route_model", model_router)
68
  workflow.add_node("query", query_model)
69
  workflow.add_node("validate", validate_answer)
 
73
 
74
  workflow.add_conditional_edges(
75
  "validate",
76
+ lambda x: "final_answer" if x["answer"] else "retry",
77
+ {"final_answer": END, "retry": "route_model"}
 
 
 
78
  )
79
 
80
  workflow.set_entry_point("route_model")
81
+ compiled_agent = workflow.compile()
82
 
83
+ # GAIA Interface
84
  class BasicAgent:
85
  def __call__(self, question: str) -> str:
86
+ state = initial_state.copy()
87
+ state["question"] = question
88
 
89
  for _ in range(3): # Max 3 attempts
90
+ state = compiled_agent.invoke(state)
91
+ if state["answer"]:
92
+ return state["answer"]
93
  time.sleep(1) # Backoff
94
 
95
  return "" # Return empty to preserve scoring
 
96
 
97
  def run_and_submit_all( profile: gr.OAuthProfile | None):
98
  """