wt002 commited on
Commit
564212d
·
verified ·
1 Parent(s): 244ae60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -15
app.py CHANGED
@@ -23,21 +23,14 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
23
 
24
 
25
  class BasicAgent:
26
- def __init__(self, model="mistral-7b-instruct-v0.1"): # Smaller model recommended
27
- self.tokenizer = AutoTokenizer.from_pretrained(model)
28
- self.model = AutoModelForCausalLM.from_pretrained(
29
- model,
30
- device_map="auto",
31
- torch_dtype=torch.float32, # Explicitly use float32 for CPU
32
- low_cpu_mem_usage=True # Reduces memory spikes
33
- )
34
- print(f"Initialized on device: {self.model.device}")
35
-
36
- def __call__(self, question: str, max_tokens: int = 100) -> str:
37
- inputs = self.tokenizer(question, return_tensors="pt").to(self.model.device)
38
- with torch.no_grad(): # Reduces memory usage
39
- outputs = self.model.generate(**inputs, max_new_tokens=max_tokens)
40
- return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
41
 
42
  def wikipedia_search(self, query: str) -> str:
43
  """Get Wikipedia summary"""
 
23
 
24
 
25
  class BasicAgent:
26
+ def __init__(self, model="google/gemma-7b"):
27
+ self.pipe = pipeline("text-generation", model=model)
28
+ print("BasicAgent initialized.")
29
+
30
+ def __call__(self, question: str) -> str:
31
+ print(f"Question: {question[:50]}...")
32
+ response = self.pipe(question, max_new_tokens=100)
33
+ return response[0]['generated_text']
 
 
 
 
 
 
 
34
 
35
  def wikipedia_search(self, query: str) -> str:
36
  """Get Wikipedia summary"""