wt002 commited on
Commit
c170ebd
·
verified ·
1 Parent(s): 01bade2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -18
app.py CHANGED
@@ -21,28 +21,21 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
21
  # --- Basic Agent Definition ---
22
 
23
  class BasicAgent:
24
- def __init__(self, model: str = "google/gemma-7b", searx_url: str = "https://search.example.org"):
25
- self.client = InferenceClient(model=model)
26
- self.searx_url = searx_url
27
- print(f"Initialized with model: {model}")
28
-
 
 
29
  def __call__(self, question: str) -> str:
30
- print(f"Processing: {question[:50]}...")
31
- return self.client.text_generation(
32
- prompt=question,
33
  max_new_tokens=100,
 
34
  temperature=0.7
35
  )
36
-
37
- def web_search(self, query: str) -> List[Dict]:
38
- """Search using SearxNG"""
39
- params = {"q": query, "format": "json", "engines": "google,bing"}
40
- try:
41
- response = requests.get(self.searx_url, params=params, timeout=10)
42
- return response.json().get("results", [])
43
- except Exception as e:
44
- print(f"Search failed: {e}")
45
- return []
46
 
47
  def wikipedia_search(self, query: str) -> str:
48
  """Get Wikipedia summary"""
 
21
  # --- Basic Agent Definition ---
22
 
23
  class BasicAgent:
24
+ def __init__(self, model: str = "google/gemma-7b"):
25
+ self.generator = pipeline(
26
+ "text-generation",
27
+ model=model,
28
+ device_map="auto" # Uses GPU if available
29
+ )
30
+
31
  def __call__(self, question: str) -> str:
32
+ outputs = self.generator(
33
+ question,
 
34
  max_new_tokens=100,
35
+ do_sample=True,
36
  temperature=0.7
37
  )
38
+ return outputs[0]['generated_text']
 
 
 
 
 
 
 
 
 
39
 
40
  def wikipedia_search(self, query: str) -> str:
41
  """Get Wikipedia summary"""