wt002 commited on
Commit
0c9facb
·
verified ·
1 Parent(s): c170ebd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -15
app.py CHANGED
@@ -9,7 +9,8 @@ import wikipediaapi
9
  import pandas as pd
10
  from transformers import pipeline # or HfAgent if you want the higher-level agent
11
  from huggingface_hub import InferenceClient # Updated import
12
-
 
13
 
14
  load_dotenv()
15
 
@@ -20,22 +21,18 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
20
 
21
  # --- Basic Agent Definition ---
22
 
 
 
23
  class BasicAgent:
24
- def __init__(self, model: str = "google/gemma-7b"):
25
- self.generator = pipeline(
26
- "text-generation",
27
- model=model,
28
- device_map="auto" # Uses GPU if available
29
- )
30
-
31
  def __call__(self, question: str) -> str:
32
- outputs = self.generator(
33
- question,
34
- max_new_tokens=100,
35
- do_sample=True,
36
- temperature=0.7
37
- )
38
- return outputs[0]['generated_text']
39
 
40
  def wikipedia_search(self, query: str) -> str:
41
  """Get Wikipedia summary"""
 
9
  import pandas as pd
10
  from transformers import pipeline # or HfAgent if you want the higher-level agent
11
  from huggingface_hub import InferenceClient # Updated import
12
+ from transformers import AutoTokenizer, AutoModelForCausalLM
13
+ import torch
14
 
15
  load_dotenv()
16
 
 
21
 
22
  # --- Basic Agent Definition ---
23
 
24
+
25
+
26
  class BasicAgent:
27
+ def __init__(self, model="google/gemma-7b"):
28
+ self.tokenizer = AutoTokenizer.from_pretrained(model)
29
+ self.model = AutoModelForCausalLM.from_pretrained(model)
30
+ print("BasicAgent initialized with AutoModel")
31
+
 
 
32
  def __call__(self, question: str) -> str:
33
+ inputs = self.tokenizer(question, return_tensors="pt")
34
+ outputs = self.model.generate(**inputs, max_new_tokens=100)
35
+ return self.tokenizer.decode(outputs[0])
 
 
 
 
36
 
37
  def wikipedia_search(self, query: str) -> str:
38
  """Get Wikipedia summary"""