openfree commited on
Commit
8cbcecb
·
verified ·
1 Parent(s): 408eb2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -92
app.py CHANGED
@@ -19,58 +19,25 @@ tokenizer = None
19
  def fetch_arxiv_paper(arxiv_input):
20
  """Fetch paper details from arXiv URL or ID using requests."""
21
  try:
22
- # Extract arXiv ID from URL or use directly
23
  if 'arxiv.org' in arxiv_input:
24
  parsed = urlparse(arxiv_input)
25
- path = parsed.path
26
- arxiv_id = path.split('/')[-1].replace('.pdf', '')
27
  else:
28
  arxiv_id = arxiv_input.strip()
29
-
30
- # Fetch metadata using arXiv API
31
  api_url = f'http://export.arxiv.org/api/query?id_list={arxiv_id}'
32
  response = requests.get(api_url)
33
-
34
  if response.status_code != 200:
35
- return {
36
- "title": "",
37
- "abstract": "",
38
- "success": False,
39
- "message": "Error fetching paper from arXiv API"
40
- }
41
-
42
- # Parse the response XML
43
  root = ET.fromstring(response.text)
44
-
45
- # ArXiv API uses namespaces
46
  ns = {'arxiv': 'http://www.w3.org/2005/Atom'}
47
-
48
- # Extract title and abstract
49
  entry = root.find('.//arxiv:entry', ns)
50
  if entry is None:
51
- return {
52
- "title": "",
53
- "abstract": "",
54
- "success": False,
55
- "message": "Paper not found"
56
- }
57
-
58
  title = entry.find('arxiv:title', ns).text.strip()
59
  abstract = entry.find('arxiv:summary', ns).text.strip()
60
-
61
- return {
62
- "title": title,
63
- "abstract": abstract,
64
- "success": True,
65
- "message": "Paper fetched successfully!"
66
- }
67
  except Exception as e:
68
- return {
69
- "title": "",
70
- "abstract": "",
71
- "success": False,
72
- "message": f"Error fetching paper: {str(e)}"
73
- }
74
 
75
  @spaces.GPU(duration=60, enable_queue=True)
76
  def predict(title, abstract):
@@ -78,50 +45,48 @@ def predict(title, abstract):
78
  abstract = abstract.replace("\n", " ").strip().replace("''", "'")
79
  global model, tokenizer
80
  if model is None:
 
81
  try:
82
- # Always load in full float32 precision
83
  model = AutoModelForSequenceClassification.from_pretrained(
84
  model_path,
85
  num_labels=1,
86
  device_map=None,
87
- torch_dtype=torch.float32
 
 
 
88
  )
89
- # 명시적으로 device에 올리기
90
- model.to(device)
91
  except Exception as e:
92
- print(f"Standard loading failed, retrying in float32: {str(e)}")
93
- # Fallback: basic 로딩, 역시 float32
94
  model = AutoModelForSequenceClassification.from_pretrained(
95
  model_path,
96
  num_labels=1,
97
  torch_dtype=torch.float32
98
  )
 
 
99
  model.to(device)
 
 
100
  tokenizer = AutoTokenizer.from_pretrained(model_path)
101
  model.eval()
102
-
103
  text = (
104
  f"Given a certain paper, Title: {title}\n"
105
  f"Abstract: {abstract}.\n"
106
  "Predict its normalized academic impact (between 0 and 1):"
107
  )
108
-
109
  try:
110
  inputs = tokenizer(text, return_tensors="pt")
111
- # inputs를 device로 이동
112
  inputs = {k: v.to(device) for k, v in inputs.items()}
113
-
114
  with torch.no_grad():
115
  outputs = model(**inputs)
116
- probability = torch.sigmoid(outputs.logits).item()
117
-
118
- # 소폭 올림 보정
119
- score = min(1.0, probability + 0.05)
120
  return round(score, 4)
121
-
122
  except Exception as e:
123
- print(f"Prediction error: {str(e)}")
124
- return 0.0 # 오류 시 기본값
125
 
126
  def get_grade_and_emoji(score):
127
  if score >= 0.900: return "AAA 🌟"
@@ -158,46 +123,97 @@ example_papers = [
158
  def validate_input(title, abstract):
159
  title = title.replace("\n", " ").strip().replace("''", "'")
160
  abstract = abstract.replace("\n", " ").strip().replace("''", "'")
161
-
162
- non_latin_pattern = re.compile(r'[^\u0000-\u007F]')
163
- non_latin_in_title = non_latin_pattern.findall(title)
164
- non_latin_in_abstract = non_latin_pattern.findall(abstract)
165
-
166
- if len(title.strip().split(' ')) < 3:
167
  return False, "The title must be at least 3 words long."
168
- if len(abstract.strip().split(' ')) < 50:
169
  return False, "The abstract must be at least 50 words long."
170
- if len((title + abstract).split(' ')) > 1024:
171
- return True, "Warning, the input length is approaching tokenization limits (1024) and may be truncated without further warning!"
172
- if non_latin_in_title:
173
- return False, f"The title contains invalid characters: {', '.join(non_latin_in_title)}. Only English letters and special symbols are allowed."
174
- if non_latin_in_abstract:
175
- return False, f"The abstract contains invalid characters: {', '.join(non_latin_in_abstract)}. Only English letters and special symbols are allowed."
176
-
177
  return True, "Inputs are valid!"
178
 
179
  def update_button_status(title, abstract):
180
- valid, message = validate_input(title, abstract)
181
  if not valid:
182
- return gr.update(value="Error: " + message), gr.update(interactive=False)
183
- return gr.update(value=message), gr.update(interactive=True)
184
 
185
  def process_arxiv_input(arxiv_input):
186
- """Process arXiv input and update title/abstract fields."""
187
  if not arxiv_input.strip():
188
  return "", "", "Please enter an arXiv URL or ID"
189
-
190
  result = fetch_arxiv_paper(arxiv_input)
191
  if result["success"]:
192
  return result["title"], result["abstract"], result["message"]
193
- else:
194
- return "", "", result["message"]
195
 
196
  css = """
197
  .gradio-container {
198
  font-family: 'Arial', sans-serif;
199
  }
200
- /* ... 이하 CSS는 동일 ... */
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  """
202
 
203
  with gr.Blocks(theme=gr.themes.Default(), css=css) as iface:
@@ -291,21 +307,9 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as iface:
291
  """
292
  )
293
 
294
- title_input.change(
295
- update_button_status,
296
- inputs=[title_input, abstract_input],
297
- outputs=[validation_status, submit_button]
298
- )
299
- abstract_input.change(
300
- update_button_status,
301
- inputs=[title_input, abstract_input],
302
- outputs=[validation_status, submit_button]
303
- )
304
- fetch_button.click(
305
- process_arxiv_input,
306
- inputs=[arxiv_input],
307
- outputs=[title_input, abstract_input, validation_status]
308
- )
309
 
310
  def process_prediction(title, abstract):
311
  score = predict(title, abstract)
 
19
  def fetch_arxiv_paper(arxiv_input):
20
  """Fetch paper details from arXiv URL or ID using requests."""
21
  try:
 
22
  if 'arxiv.org' in arxiv_input:
23
  parsed = urlparse(arxiv_input)
24
+ arxiv_id = parsed.path.split('/')[-1].replace('.pdf', '')
 
25
  else:
26
  arxiv_id = arxiv_input.strip()
 
 
27
  api_url = f'http://export.arxiv.org/api/query?id_list={arxiv_id}'
28
  response = requests.get(api_url)
 
29
  if response.status_code != 200:
30
+ return {"title": "", "abstract": "", "success": False, "message": "Error fetching paper from arXiv API"}
 
 
 
 
 
 
 
31
  root = ET.fromstring(response.text)
 
 
32
  ns = {'arxiv': 'http://www.w3.org/2005/Atom'}
 
 
33
  entry = root.find('.//arxiv:entry', ns)
34
  if entry is None:
35
+ return {"title": "", "abstract": "", "success": False, "message": "Paper not found"}
 
 
 
 
 
 
36
  title = entry.find('arxiv:title', ns).text.strip()
37
  abstract = entry.find('arxiv:summary', ns).text.strip()
38
+ return {"title": title, "abstract": abstract, "success": True, "message": "Paper fetched successfully!"}
 
 
 
 
 
 
39
  except Exception as e:
40
+ return {"title": "", "abstract": "", "success": False, "message": f"Error fetching paper: {e}"}
 
 
 
 
 
41
 
42
  @spaces.GPU(duration=60, enable_queue=True)
43
  def predict(title, abstract):
 
45
  abstract = abstract.replace("\n", " ").strip().replace("''", "'")
46
  global model, tokenizer
47
  if model is None:
48
+ # 1) 전부 float32 로드
49
  try:
 
50
  model = AutoModelForSequenceClassification.from_pretrained(
51
  model_path,
52
  num_labels=1,
53
  device_map=None,
54
+ torch_dtype=torch.float32,
55
+ load_in_8bit=False,
56
+ load_in_4bit=False,
57
+ low_cpu_mem_usage=False
58
  )
 
 
59
  except Exception as e:
60
+ print(f" 로딩 실패, 재시도: {e}")
 
61
  model = AutoModelForSequenceClassification.from_pretrained(
62
  model_path,
63
  num_labels=1,
64
  torch_dtype=torch.float32
65
  )
66
+ # 2) device에 올려보기 (unsupported error 무시)
67
+ try:
68
  model.to(device)
69
+ except ValueError as e:
70
+ print(f"model.to() 무시: {e}")
71
  tokenizer = AutoTokenizer.from_pretrained(model_path)
72
  model.eval()
73
+
74
  text = (
75
  f"Given a certain paper, Title: {title}\n"
76
  f"Abstract: {abstract}.\n"
77
  "Predict its normalized academic impact (between 0 and 1):"
78
  )
 
79
  try:
80
  inputs = tokenizer(text, return_tensors="pt")
 
81
  inputs = {k: v.to(device) for k, v in inputs.items()}
 
82
  with torch.no_grad():
83
  outputs = model(**inputs)
84
+ prob = torch.sigmoid(outputs.logits).item()
85
+ score = min(1.0, prob + 0.05)
 
 
86
  return round(score, 4)
 
87
  except Exception as e:
88
+ print(f"Prediction error: {e}")
89
+ return 0.0
90
 
91
  def get_grade_and_emoji(score):
92
  if score >= 0.900: return "AAA 🌟"
 
123
  def validate_input(title, abstract):
124
  title = title.replace("\n", " ").strip().replace("''", "'")
125
  abstract = abstract.replace("\n", " ").strip().replace("''", "'")
126
+ non_latin = re.compile(r'[^\u0000-\u007F]')
127
+ if len(title.split()) < 3:
 
 
 
 
128
  return False, "The title must be at least 3 words long."
129
+ if len(abstract.split()) < 50:
130
  return False, "The abstract must be at least 50 words long."
131
+ if non_latin.search(title):
132
+ return False, "Title에 영어 문자가 포함되어 있습니다."
133
+ if non_latin.search(abstract):
134
+ return False, "Abstract에 영어 문자가 포함되어 있습니다."
 
 
 
135
  return True, "Inputs are valid!"
136
 
137
  def update_button_status(title, abstract):
138
+ valid, msg = validate_input(title, abstract)
139
  if not valid:
140
+ return gr.update(value="Error: " + msg), gr.update(interactive=False)
141
+ return gr.update(value=msg), gr.update(interactive=True)
142
 
143
  def process_arxiv_input(arxiv_input):
 
144
  if not arxiv_input.strip():
145
  return "", "", "Please enter an arXiv URL or ID"
 
146
  result = fetch_arxiv_paper(arxiv_input)
147
  if result["success"]:
148
  return result["title"], result["abstract"], result["message"]
149
+ return "", "", result["message"]
 
150
 
151
  css = """
152
  .gradio-container {
153
  font-family: 'Arial', sans-serif;
154
  }
155
+ .main-title {
156
+ text-align: center;
157
+ color: #2563eb;
158
+ font-size: 2.5rem !important;
159
+ margin-bottom: 1rem !important;
160
+ background: linear-gradient(45deg, #2563eb, #1d4ed8);
161
+ -webkit-background-clip: text;
162
+ -webkit-text-fill-color: transparent;
163
+ }
164
+ .sub-title {
165
+ text-align: center;
166
+ color: #4b5563;
167
+ font-size: 1.5rem !important;
168
+ margin-bottom: 2rem !important;
169
+ }
170
+ .input-section {
171
+ background: white;
172
+ padding: 2rem;
173
+ border-radius: 1rem;
174
+ box-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1);
175
+ }
176
+ .result-section {
177
+ background: #f8fafc;
178
+ padding: 2rem;
179
+ border-radius: 1rem;
180
+ margin-top: 2rem;
181
+ }
182
+ .methodology-section {
183
+ background: #ecfdf5;
184
+ padding: 2rem;
185
+ border-radius: 1rem;
186
+ margin-top: 2rem;
187
+ }
188
+ .example-section {
189
+ background: #fff7ed;
190
+ padding: 2rem;
191
+ border-radius: 1rem;
192
+ margin-top: 2rem;
193
+ }
194
+ .grade-display {
195
+ font-size: 3rem;
196
+ text-align: center;
197
+ margin: 1rem 0;
198
+ }
199
+ .arxiv-input {
200
+ margin-bottom: 1.5rem;
201
+ padding: 1rem;
202
+ background: #f3f4f6;
203
+ border-radius: 0.5rem;
204
+ }
205
+ .arxiv-link {
206
+ color: #2563eb;
207
+ text-decoration: underline;
208
+ font-size: 0.9em;
209
+ margin-top: 0.5em;
210
+ }
211
+ .arxiv-note {
212
+ color: #666;
213
+ font-size: 0.9em;
214
+ margin-top: 0.5em;
215
+ margin-bottom: 0.5em;
216
+ }
217
  """
218
 
219
  with gr.Blocks(theme=gr.themes.Default(), css=css) as iface:
 
307
  """
308
  )
309
 
310
+ title_input.change(update_button_status, [title_input, abstract_input], [validation_status, submit_button])
311
+ abstract_input.change(update_button_status, [title_input, abstract_input], [validation_status, submit_button])
312
+ fetch_button.click(process_arxiv_input, [arxiv_input], [title_input, abstract_input, validation_status])
 
 
 
 
 
 
 
 
 
 
 
 
313
 
314
  def process_prediction(title, abstract):
315
  score = predict(title, abstract)