wt002 commited on
Commit
6c7b3e9
·
verified ·
1 Parent(s): 0c9facb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -76
app.py CHANGED
@@ -22,92 +22,29 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
22
  # --- Basic Agent Definition ---
23
 
24
 
25
-
26
  class BasicAgent:
27
- def __init__(self, model="google/gemma-7b"):
28
  self.tokenizer = AutoTokenizer.from_pretrained(model)
29
- self.model = AutoModelForCausalLM.from_pretrained(model)
30
- print("BasicAgent initialized with AutoModel")
 
 
 
 
 
31
 
32
- def __call__(self, question: str) -> str:
33
- inputs = self.tokenizer(question, return_tensors="pt")
34
- outputs = self.model.generate(**inputs, max_new_tokens=100)
35
- return self.tokenizer.decode(outputs[0])
 
36
 
37
  def wikipedia_search(self, query: str) -> str:
38
  """Get Wikipedia summary"""
39
  page = self.wiki.page(query)
40
  return page.summary if page.exists() else "No Wikipedia page found"
41
 
42
- def process_document(self, file_path: str) -> str:
43
- """Handle PDF, Word, CSV, Excel files"""
44
- if not os.path.exists(file_path):
45
- return "File not found"
46
-
47
- ext = os.path.splitext(file_path)[1].lower()
48
-
49
- try:
50
- if ext == '.pdf':
51
- return self._process_pdf(file_path)
52
- elif ext in ('.doc', '.docx'):
53
- return self._process_word(file_path)
54
- elif ext == '.csv':
55
- return pd.read_csv(file_path).to_string()
56
- elif ext in ('.xls', '.xlsx'):
57
- return pd.read_excel(file_path).to_string()
58
- else:
59
- return "Unsupported file format"
60
- except Exception as e:
61
- return f"Error processing document: {str(e)}"
62
-
63
- def _process_pdf(self, file_path: str) -> str:
64
- """Process PDF using Gemini's vision capability"""
65
- try:
66
- # For Gemini 1.5 or later which supports file uploads
67
- with open(file_path, "rb") as f:
68
- file = genai.upload_file(f)
69
- response = self.model.generate_content(
70
- ["Extract and summarize the key points from this document:", file]
71
- )
72
- return response.text
73
- except:
74
- # Fallback for older Gemini versions
75
- try:
76
- import PyPDF2
77
- with open(file_path, 'rb') as f:
78
- reader = PyPDF2.PdfReader(f)
79
- return "\n".join([page.extract_text() for page in reader.pages])
80
- except ImportError:
81
- return "PDF processing requires PyPDF2 (pip install PyPDF2)"
82
-
83
- def _process_word(self, file_path: str) -> str:
84
- """Process Word documents"""
85
- try:
86
- from docx import Document
87
- doc = Document(file_path)
88
- return "\n".join([para.text for para in doc.paragraphs])
89
- except ImportError:
90
- return "Word processing requires python-docx (pip install python-docx)"
91
 
92
- def process_request(self, request: Union[str, Dict]) -> str:
93
- """
94
- Handle different request types:
95
- - Direct text queries
96
- - File processing requests
97
- - Complex multi-step requests
98
- """
99
- if isinstance(request, dict):
100
- if 'steps' in request:
101
- results = []
102
- for step in request['steps']:
103
- if step['type'] == 'search':
104
- results.append(self.web_search(step['query']))
105
- elif step['type'] == 'process':
106
- results.append(self.process_document(step['file']))
107
- return self.generate_response(f"Process these results: {results}")
108
- return "Unsupported request format"
109
-
110
- return self.generate_response(request)
111
 
112
 
113
  def run_and_submit_all( profile: gr.OAuthProfile | None):
 
22
  # --- Basic Agent Definition ---
23
 
24
 
 
25
  class BasicAgent:
26
+ def __init__(self, model="google/gemma-2b"): # Smaller 2B version recommended
27
  self.tokenizer = AutoTokenizer.from_pretrained(model)
28
+ self.model = AutoModelForCausalLM.from_pretrained(
29
+ model,
30
+ device_map="auto",
31
+ torch_dtype=torch.float32, # Explicitly use float32 for CPU
32
+ low_cpu_mem_usage=True # Reduces memory spikes
33
+ )
34
+ print(f"Initialized on device: {self.model.device}")
35
 
36
+ def __call__(self, question: str, max_tokens: int = 100) -> str:
37
+ inputs = self.tokenizer(question, return_tensors="pt").to(self.model.device)
38
+ with torch.no_grad(): # Reduces memory usage
39
+ outputs = self.model.generate(**inputs, max_new_tokens=max_tokens)
40
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
41
 
42
  def wikipedia_search(self, query: str) -> str:
43
  """Get Wikipedia summary"""
44
  page = self.wiki.page(query)
45
  return page.summary if page.exists() else "No Wikipedia page found"
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
 
50
  def run_and_submit_all( profile: gr.OAuthProfile | None):