0r0b0r0s commited on
Commit
8beb49e
·
verified ·
1 Parent(s): 713b236

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -86
app.py CHANGED
@@ -20,96 +20,18 @@ MODEL_ID = "TheBloke/Mistral-7B-Instruct-v0.1-GPTQ" # 4.2GB quantized
20
  EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
21
  FALLBACK_MODELS = ["google/flan-t5-base", "mistralai/Mistral-7B-Instruct-v0.2"]
22
 
23
- class AgentState(TypedDict):
24
- question: str
25
- context: str
26
- answer: str
27
- attempts: Annotated[int, lambda x, y: x + 1]
28
-
29
  class BasicAgent:
 
30
  def __init__(self):
31
- # Initialize components
32
- self.client = InferenceClient(
33
- model=MODEL_ID,
34
- token=os.environ["HF_TOKEN"],
35
- timeout=120
36
- )
37
-
38
- # Initialize vector store (add your documents here)
39
- self.vectorstore = Chroma.from_texts(
40
- texts=["GAIA knowledge content..."], # Replace with your documents
41
- embedding=EMBEDDING_MODEL,
42
- persist_directory="./chroma_db"
43
- )
44
 
45
- # Build LangGraph workflow
46
- self.workflow = self._build_graph()
47
-
48
- def _build_graph(self):
49
- # Define nodes
50
- def retrieve(state: AgentState):
51
- docs = self.vectorstore.similarity_search(state["question"], k=3)
52
- state["context"] = "\n".join([d.page_content for d in docs])
53
- return state
54
-
55
- def generate(state: AgentState):
56
- try:
57
- response = self.client.text_generation(
58
- f"""<s>[INST]Answer using ONLY this context:
59
- {state['context']}
60
- Question: {state['question']}
61
- Answer: [/INST]""",
62
- temperature=0.1,
63
- max_new_tokens=100,
64
- stop_sequences=["</s>"]
65
- )
66
- state["answer"] = response.split("[/INST]")[-1].strip()
67
- except Exception:
68
- state["answer"] = ""
69
- return state
70
-
71
- def validate(state: AgentState):
72
- if len(state["answer"]) > 5 and state["attempts"] < 3:
73
- return "final"
74
- return "retry"
75
-
76
- # Build workflow
77
- workflow = StateGraph(AgentState)
78
- workflow.add_node("retrieve", retrieve)
79
- workflow.add_node("generate", generate)
80
- workflow.add_node("validate", validate)
81
-
82
- workflow.set_entry_point("retrieve")
83
- workflow.add_edge("retrieve", "generate")
84
- workflow.add_edge("generate", "validate")
85
-
86
- workflow.add_conditional_edges(
87
- "validate",
88
- lambda x: "retry" if x["answer"] == "" else "final",
89
- {
90
- "retry": "retrieve",
91
- "final": END
92
- }
93
- )
94
-
95
- return workflow.compile()
96
-
97
  def __call__(self, question: str) -> str:
98
- # GAIA-compliant formatting
99
- state = {
100
- "question": question,
101
- "context": "",
102
- "answer": "",
103
- "attempts": 0
104
- }
105
-
106
- for _ in range(3): # Max 3 attempts
107
- state = self.workflow.invoke(state)
108
- if state["answer"]:
109
- answer = re.sub(r'[^a-zA-Z0-9]', '', state["answer"]).lower()
110
- return answer[:100] # GAIA length constraint
111
-
112
- return "" # Preserve scoring eligibility
113
 
114
  def run_and_submit_all( profile: gr.OAuthProfile | None):
115
  """
 
20
  EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
21
  FALLBACK_MODELS = ["google/flan-t5-base", "mistralai/Mistral-7B-Instruct-v0.2"]
22
 
23
+ # --- Basic Agent Definition ---
24
+ # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
 
 
 
 
25
  class BasicAgent:
26
+
27
  def __init__(self):
28
+ print("BasicAgent initialized.")
 
 
 
 
 
 
 
 
 
 
 
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def __call__(self, question: str) -> str:
31
+ print(f"Agent received question (first 50 chars): {question[:50]}...")
32
+ fixed_answer = "This is a default answer."
33
+ print(f"Agent returning fixed answer: {fixed_answer}")
34
+ return fixed_answer
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  def run_and_submit_all( profile: gr.OAuthProfile | None):
37
  """