mike23415 commited on
Commit
f67e43c
·
verified ·
1 Parent(s): f9dfaf0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -23
app.py CHANGED
@@ -4,6 +4,7 @@ import tempfile
4
  import jinja2
5
  import pdfkit
6
  import torch
 
7
  from threading import Thread
8
  from flask import Flask, request, send_file, jsonify
9
  from flask_cors import CORS
@@ -13,7 +14,13 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
13
  os.environ['HF_HOME'] = '/app/.cache'
14
  os.environ['XDG_CACHE_HOME'] = '/app/.cache'
15
 
 
 
 
 
 
16
 
 
17
  app = Flask(__name__)
18
  CORS(app)
19
 
@@ -22,11 +29,20 @@ model_loaded = False
22
  load_error = None
23
  generator = None
24
 
 
 
 
 
 
25
  def load_model():
26
  global model_loaded, load_error, generator
27
  try:
 
 
28
  # Detect device and dtype automatically
29
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
 
 
30
 
31
  model = AutoModelForCausalLM.from_pretrained(
32
  "gpt2-medium",
@@ -47,17 +63,15 @@ def load_model():
47
  )
48
 
49
  model_loaded = True
50
- print(f"Model loaded on {model.device}")
51
 
52
  except Exception as e:
53
  load_error = str(e)
54
- print(f"Model loading failed: {load_error}")
55
 
56
  # Start model loading in background thread
57
  Thread(target=load_model).start()
58
 
59
-
60
-
61
  # --------------------------------------------------
62
  # IEEE Format Template
63
  # --------------------------------------------------
@@ -102,14 +116,12 @@ IEEE_TEMPLATE = """
102
  {{ abstract }}
103
  <div class="keywords">Keywords— {{ keywords }}</div>
104
  </div>
105
-
106
  <div class="two-column">
107
  {% for section in sections %}
108
  <h2>{{ section.title }}</h2>
109
  {{ section.content }}
110
  {% endfor %}
111
  </div>
112
-
113
  <div class="references">
114
  <h2>References</h2>
115
  {% for ref in references %}
@@ -125,42 +137,58 @@ IEEE_TEMPLATE = """
125
  # --------------------------------------------------
126
  @app.route('/health', methods=['GET'])
127
  def health_check():
 
 
128
  if load_error:
 
129
  return jsonify({
130
  "status": "error",
131
  "message": f"Model failed to load: {load_error}"
132
  }), 500
133
 
 
 
 
 
134
  return jsonify({
135
  "status": "ready" if model_loaded else "loading",
136
  "model_loaded": model_loaded,
137
- "device": "cuda" if torch.cuda.is_available() else "cpu"
138
- }), 200 if model_loaded else 503
139
 
140
  @app.route('/generate', methods=['POST'])
141
  def generate_pdf():
142
  # Check model status
143
  if not model_loaded:
 
144
  return jsonify({
145
  "error": "Model not loaded yet",
146
  "status": "loading"
147
  }), 503
148
 
149
  try:
 
 
150
  # Validate input
151
  data = request.json
152
  if not data:
 
153
  return jsonify({"error": "No data provided"}), 400
154
 
155
  required = ['title', 'authors', 'content']
156
  if missing := [field for field in required if field not in data]:
 
157
  return jsonify({
158
  "error": f"Missing fields: {', '.join(missing)}"
159
  }), 400
160
 
161
- # Format content
 
 
 
162
  formatted = format_content(data['content'])
163
 
 
164
  # Generate HTML
165
  html = jinja2.Template(IEEE_TEMPLATE).render(
166
  title=data['title'],
@@ -183,33 +211,178 @@ def generate_pdf():
183
  }
184
 
185
  # Create temporary PDF
186
- with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as f:
187
- pdfkit.from_string(html, f.name, options=options)
188
- return send_file(f.name, mimetype='application/pdf')
 
 
 
 
 
 
 
 
 
 
189
 
 
 
 
 
190
  except Exception as e:
 
191
  return jsonify({"error": str(e)}), 500
 
192
  finally:
193
- if 'f' in locals():
194
- try: os.remove(f.name)
195
- except: pass
 
 
 
 
196
 
197
  # --------------------------------------------------
198
  # Content Formatting
199
  # --------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  def format_content(content):
 
201
  try:
202
- prompt = f"Format this research content to IEEE standards:\n{str(content)}"
203
- return generator(
 
 
204
  prompt,
205
- max_new_tokens=512,
206
- temperature=0.7,
207
  do_sample=True,
208
- truncation=True
209
- )[0]['generated_text']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  except Exception as e:
211
- print(f"Formatting error: {str(e)}")
212
- return content
 
 
 
 
 
 
213
 
214
  if __name__ == '__main__':
215
  app.run(host='0.0.0.0', port=5000)
 
4
  import jinja2
5
  import pdfkit
6
  import torch
7
+ import logging
8
  from threading import Thread
9
  from flask import Flask, request, send_file, jsonify
10
  from flask_cors import CORS
 
14
  os.environ['HF_HOME'] = '/app/.cache'
15
  os.environ['XDG_CACHE_HOME'] = '/app/.cache'
16
 
17
+ # Configure logging
18
+ logging.basicConfig(
19
+ level=logging.INFO,
20
+ format='%(asctime)s [%(levelname)s] %(message)s'
21
+ )
22
 
23
+ # Initialize Flask app
24
  app = Flask(__name__)
25
  CORS(app)
26
 
 
29
  load_error = None
30
  generator = None
31
 
32
+ # Configure wkhtmltopdf
33
+ # Use xvfb-run for headless PDF generation
34
+ WKHTMLTOPDF_CMD = 'xvfb-run -a wkhtmltopdf'
35
+ pdf_config = pdfkit.configuration(wkhtmltopdf=WKHTMLTOPDF_CMD)
36
+
37
  def load_model():
38
  global model_loaded, load_error, generator
39
  try:
40
+ app.logger.info("Starting model loading process")
41
+
42
  # Detect device and dtype automatically
43
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
44
+ device = "cuda" if torch.cuda.is_available() else "cpu"
45
+ app.logger.info(f"Device set to use {device}")
46
 
47
  model = AutoModelForCausalLM.from_pretrained(
48
  "gpt2-medium",
 
63
  )
64
 
65
  model_loaded = True
66
+ app.logger.info(f"Model loaded successfully on {model.device}")
67
 
68
  except Exception as e:
69
  load_error = str(e)
70
+ app.logger.error(f"Model loading failed: {load_error}", exc_info=True)
71
 
72
  # Start model loading in background thread
73
  Thread(target=load_model).start()
74
 
 
 
75
  # --------------------------------------------------
76
  # IEEE Format Template
77
  # --------------------------------------------------
 
116
  {{ abstract }}
117
  <div class="keywords">Keywords— {{ keywords }}</div>
118
  </div>
 
119
  <div class="two-column">
120
  {% for section in sections %}
121
  <h2>{{ section.title }}</h2>
122
  {{ section.content }}
123
  {% endfor %}
124
  </div>
 
125
  <div class="references">
126
  <h2>References</h2>
127
  {% for ref in references %}
 
137
  # --------------------------------------------------
138
  @app.route('/health', methods=['GET'])
139
  def health_check():
140
+ app.logger.info("Health check requested")
141
+
142
  if load_error:
143
+ app.logger.error(f"Health check failed: {load_error}")
144
  return jsonify({
145
  "status": "error",
146
  "message": f"Model failed to load: {load_error}"
147
  }), 500
148
 
149
+ status_code = 200 if model_loaded else 503
150
+ device_info = "cuda" if torch.cuda.is_available() else "cpu"
151
+
152
+ app.logger.info(f"Health check returning status: {'ready' if model_loaded else 'loading'}, device: {device_info}")
153
  return jsonify({
154
  "status": "ready" if model_loaded else "loading",
155
  "model_loaded": model_loaded,
156
+ "device": device_info
157
+ }), status_code
158
 
159
  @app.route('/generate', methods=['POST'])
160
  def generate_pdf():
161
  # Check model status
162
  if not model_loaded:
163
+ app.logger.error("PDF generation requested but model not loaded")
164
  return jsonify({
165
  "error": "Model not loaded yet",
166
  "status": "loading"
167
  }), 503
168
 
169
  try:
170
+ app.logger.info("Processing PDF generation request")
171
+
172
  # Validate input
173
  data = request.json
174
  if not data:
175
+ app.logger.error("No data provided in request")
176
  return jsonify({"error": "No data provided"}), 400
177
 
178
  required = ['title', 'authors', 'content']
179
  if missing := [field for field in required if field not in data]:
180
+ app.logger.error(f"Missing required fields: {missing}")
181
  return jsonify({
182
  "error": f"Missing fields: {', '.join(missing)}"
183
  }), 400
184
 
185
+ app.logger.info(f"Received request with title: {data['title']}")
186
+
187
+ # Format content with model
188
+ app.logger.info("Formatting content using the model")
189
  formatted = format_content(data['content'])
190
 
191
+ app.logger.info("Creating HTML from template")
192
  # Generate HTML
193
  html = jinja2.Template(IEEE_TEMPLATE).render(
194
  title=data['title'],
 
211
  }
212
 
213
  # Create temporary PDF
214
+ app.logger.info("Generating PDF file")
215
+ pdf_path = None
216
+
217
+ try:
218
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as f:
219
+ pdf_path = f.name
220
+
221
+ # Generate PDF using wkhtmltopdf with xvfb
222
+ pdfkit.from_string(html, pdf_path, options=options, configuration=pdf_config)
223
+
224
+ app.logger.info(f"PDF generated successfully at {pdf_path}")
225
+ return send_file(pdf_path, mimetype='application/pdf', as_attachment=True,
226
+ download_name=f"{data['title'].replace(' ', '_')}.pdf")
227
 
228
+ except Exception as e:
229
+ app.logger.error(f"PDF generation failed: {str(e)}", exc_info=True)
230
+ raise
231
+
232
  except Exception as e:
233
+ app.logger.error(f"Request processing failed: {str(e)}", exc_info=True)
234
  return jsonify({"error": str(e)}), 500
235
+
236
  finally:
237
+ # Clean up temporary file
238
+ if 'pdf_path' in locals() and pdf_path:
239
+ try:
240
+ app.logger.info(f"Cleaning up temporary file {pdf_path}")
241
+ os.remove(pdf_path)
242
+ except Exception as e:
243
+ app.logger.warning(f"Failed to remove temporary file: {str(e)}")
244
 
245
  # --------------------------------------------------
246
  # Content Formatting
247
  # --------------------------------------------------
248
+ def parse_formatted_content(text):
249
+ """Parse the generated text into structured sections"""
250
+ app.logger.info("Parsing formatted content")
251
+
252
+ try:
253
+ lines = text.split('\n')
254
+
255
+ # Default structure
256
+ result = {
257
+ 'abstract': '',
258
+ 'keywords': ['IEEE', 'format', 'research', 'paper'],
259
+ 'sections': [],
260
+ 'references': []
261
+ }
262
+
263
+ # Extract abstract (simple approach - first paragraph after "Abstract")
264
+ abstract_start = None
265
+ for i, line in enumerate(lines):
266
+ if line.strip().lower() == 'abstract':
267
+ abstract_start = i + 1
268
+ break
269
+
270
+ if abstract_start:
271
+ abstract_text = []
272
+ i = abstract_start
273
+ while i < len(lines) and not lines[i].strip().lower().startswith('keyword'):
274
+ if lines[i].strip():
275
+ abstract_text.append(lines[i].strip())
276
+ i += 1
277
+ result['abstract'] = ' '.join(abstract_text)
278
+
279
+ # Extract keywords
280
+ for line in lines:
281
+ if line.strip().lower().startswith('keyword'):
282
+ # Extract keywords from the line
283
+ keyword_parts = line.split('—')
284
+ if len(keyword_parts) > 1:
285
+ keywords = keyword_parts[1].strip().split(',')
286
+ result['keywords'] = [k.strip() for k in keywords if k.strip()]
287
+ break
288
+
289
+ # Extract sections
290
+ current_section = None
291
+ section_content = []
292
+
293
+ # Skip lines until we find a section heading
294
+ started = False
295
+ for line in lines:
296
+ # Very basic heuristic for Roman numerals section headings
297
+ if line.strip() and (line.strip()[0].isupper() or line.strip()[0].isdigit()):
298
+ started = True
299
+ if not started:
300
+ continue
301
+
302
+ if line.strip() and (line.strip()[0].isupper() or line.strip()[0].isdigit()) and len(line.strip().split()) <= 6:
303
+ # This is likely a section heading
304
+ if current_section:
305
+ # Save the previous section
306
+ result['sections'].append({
307
+ 'title': current_section,
308
+ 'content': '\n'.join(section_content)
309
+ })
310
+ section_content = []
311
+
312
+ current_section = line.strip()
313
+ elif current_section and line.strip().lower() == 'references':
314
+ # We've reached the references section
315
+ if current_section:
316
+ # Save the last section
317
+ result['sections'].append({
318
+ 'title': current_section,
319
+ 'content': '\n'.join(section_content)
320
+ })
321
+ break
322
+ elif current_section:
323
+ # Add to current section content
324
+ section_content.append(line)
325
+
326
+ # Extract references
327
+ in_references = False
328
+ for line in lines:
329
+ if line.strip().lower() == 'references':
330
+ in_references = True
331
+ continue
332
+
333
+ if in_references and line.strip():
334
+ result['references'].append(line.strip())
335
+
336
+ app.logger.info(f"Content parsed into {len(result['sections'])} sections and {len(result['references'])} references")
337
+ return result
338
+
339
+ except Exception as e:
340
+ app.logger.error(f"Error parsing formatted content: {str(e)}", exc_info=True)
341
+ # Return a basic structure if parsing fails
342
+ return {
343
+ 'abstract': 'Error parsing content.',
344
+ 'keywords': ['IEEE', 'format'],
345
+ 'sections': [{'title': 'Content', 'content': text}],
346
+ 'references': []
347
+ }
348
+
349
  def format_content(content):
350
+ """Format the content using the ML model"""
351
  try:
352
+ app.logger.info("Formatting content with ML model")
353
+ prompt = f"Format this research content to IEEE standards with sections, abstract, and references:\n\n{str(content)}"
354
+
355
+ response = generator(
356
  prompt,
357
+ max_new_tokens=1024, # Increased for more complete generation
358
+ temperature=0.5, # More deterministic output
359
  do_sample=True,
360
+ truncation=True,
361
+ num_return_sequences=1
362
+ )
363
+
364
+ generated_text = response[0]['generated_text']
365
+
366
+ # Remove the prompt from the generated text
367
+ if prompt in generated_text:
368
+ formatted_text = generated_text[len(prompt):].strip()
369
+ else:
370
+ formatted_text = generated_text
371
+
372
+ app.logger.info("Content formatted successfully")
373
+
374
+ # Parse the formatted text into structured sections
375
+ return parse_formatted_content(formatted_text)
376
+
377
  except Exception as e:
378
+ app.logger.error(f"Error formatting content: {str(e)}", exc_info=True)
379
+ # Return the original content if formatting fails
380
+ return {
381
+ 'abstract': 'Content processing error.',
382
+ 'keywords': ['IEEE', 'format'],
383
+ 'sections': [{'title': 'Content', 'content': str(content)}],
384
+ 'references': []
385
+ }
386
 
387
  if __name__ == '__main__':
388
  app.run(host='0.0.0.0', port=5000)