Toumaima commited on
Commit
5d6fe2e
·
verified ·
1 Parent(s): 459d4dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import gradio as gr
3
  import requests
4
  import pandas as pd
5
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
 
7
  # ---------- Imports for Advanced Agent ----------
8
  import re
@@ -10,7 +10,6 @@ from langgraph.graph import StateGraph, MessagesState
10
  from langgraph.prebuilt import tools_condition, ToolNode
11
  from langchain_core.messages import SystemMessage, HumanMessage
12
  from langchain_core.tools import tool
13
- from langchain_google_genai import ChatGoogleGenerativeAI
14
  from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
15
  from langchain_community.tools.tavily_search import TavilySearchResults
16
  from groq import Groq
@@ -41,12 +40,17 @@ def arvix_search(query: str) -> str:
41
  docs = ArxivLoader(query=query, load_max_docs=3).load()
42
  return "\n\n".join([doc.page_content[:1000] for doc in docs])
43
 
 
44
  def build_tool_graph(system_prompt):
45
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
46
- llm_with_tools = llm.bind_tools([wiki_search, web_search, arvix_search])
47
 
48
  def assistant(state: MessagesState):
49
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
 
 
 
 
50
 
51
  builder = StateGraph(MessagesState)
52
  builder.add_node("assistant", assistant)
@@ -57,6 +61,7 @@ def build_tool_graph(system_prompt):
57
  builder.add_edge("tools", "assistant")
58
  return builder.compile()
59
 
 
60
  # --- Advanced BasicAgent Class ---
61
  class BasicAgent:
62
  def __init__(self):
 
2
  import gradio as gr
3
  import requests
4
  import pandas as pd
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
 
7
  # ---------- Imports for Advanced Agent ----------
8
  import re
 
10
  from langgraph.prebuilt import tools_condition, ToolNode
11
  from langchain_core.messages import SystemMessage, HumanMessage
12
  from langchain_core.tools import tool
 
13
  from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
14
  from langchain_community.tools.tavily_search import TavilySearchResults
15
  from groq import Groq
 
40
  docs = ArxivLoader(query=query, load_max_docs=3).load()
41
  return "\n\n".join([doc.page_content[:1000] for doc in docs])
42
 
43
+ # Tool-based LangGraph builder
44
  def build_tool_graph(system_prompt):
45
+ llm = AutoModelForCausalLM.from_pretrained("gpt2") # Load Hugging Face GPT-2 model
46
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
47
 
48
  def assistant(state: MessagesState):
49
+ input_text = state["messages"][-1]["content"]
50
+ inputs = tokenizer(input_text, return_tensors="pt")
51
+ outputs = llm.generate(**inputs)
52
+ result = tokenizer.decode(outputs[0], skip_special_tokens=True)
53
+ return {"messages": [{"content": result}]}
54
 
55
  builder = StateGraph(MessagesState)
56
  builder.add_node("assistant", assistant)
 
61
  builder.add_edge("tools", "assistant")
62
  return builder.compile()
63
 
64
+
65
  # --- Advanced BasicAgent Class ---
66
  class BasicAgent:
67
  def __init__(self):