mike23415 commited on
Commit
9d13a5a
·
verified ·
1 Parent(s): 1b26681

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -53
app.py CHANGED
@@ -1,37 +1,63 @@
 
 
 
 
 
 
 
1
  from flask import Flask, request, send_file, jsonify
2
  from flask_cors import CORS
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
- import pdfkit
5
- import jinja2
6
- import torch
7
- import tempfile
8
- import os
9
- os.environ['TRANSFORMERS_CACHE'] = '/app/.cache'
10
- os.environ['HF_DATASETS_CACHE'] = '/app/.cache'
11
- os.environ['XDG_CACHE_HOME'] = '/app/.cache'
12
  os.environ['HF_HOME'] = '/app/.cache'
 
13
 
14
  app = Flask(__name__)
15
  CORS(app)
16
 
17
- # Initialize model and tokenizer
18
- try:
19
- model = AutoModelForCausalLM.from_pretrained(
20
- "gpt2-medium",
21
- from_tf=False,
22
- use_safetensors=True
23
- )
24
- tokenizer = AutoTokenizer.from_pretrained("gpt2-medium")
25
- generator = pipeline(
26
- 'text-generation',
27
- model=model,
28
- tokenizer=tokenizer,
29
- device=0 if torch.cuda.is_available() else -1
30
- )
31
- except Exception as e:
32
- print(f"Model loading failed: {str(e)}")
33
- generator = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
 
 
 
35
  IEEE_TEMPLATE = """
36
  <!DOCTYPE html>
37
  <html>
@@ -45,14 +71,8 @@ IEEE_TEMPLATE = """
45
  font-size: 12pt;
46
  line-height: 1.5;
47
  }
48
- .header {
49
- text-align: center;
50
- margin-bottom: 24pt;
51
- }
52
- .two-column {
53
- column-count: 2;
54
- column-gap: 0.5in;
55
- }
56
  h1 { font-size: 14pt; margin: 12pt 0; }
57
  h2 { font-size: 12pt; margin: 12pt 0 6pt 0; }
58
  .abstract { margin-bottom: 24pt; }
@@ -97,31 +117,46 @@ IEEE_TEMPLATE = """
97
  </html>
98
  """
99
 
100
- def format_content(content):
101
- if not generator:
102
- return content # Fallback if model failed to load
103
-
104
- try:
105
- prompt = f"Format this research content to IEEE standards:\n{str(content)}"
106
- return generator(
107
- prompt,
108
- max_length=1024,
109
- num_return_sequences=1,
110
- clean_up_tokenization_spaces=True
111
- )[0]['generated_text']
112
- except Exception as e:
113
- print(f"Formatting failed: {str(e)}")
114
- return content
 
115
 
116
  @app.route('/generate', methods=['POST'])
117
  def generate_pdf():
 
 
 
 
 
 
 
118
  try:
 
119
  data = request.json
120
- if not data or 'title' not in data or 'authors' not in data:
121
- return jsonify({"error": "Missing required fields"}), 400
 
 
 
 
 
 
122
 
123
- # Format content using AI
124
- formatted = format_content(data.get('content', {}))
125
 
126
  # Generate HTML
127
  html = jinja2.Template(IEEE_TEMPLATE).render(
@@ -144,13 +179,34 @@ def generate_pdf():
144
  'quiet': ''
145
  }
146
 
147
- # Create PDF
148
  with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as f:
149
  pdfkit.from_string(html, f.name, options=options)
150
  return send_file(f.name, mimetype='application/pdf')
151
 
152
  except Exception as e:
153
  return jsonify({"error": str(e)}), 500
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
  if __name__ == '__main__':
156
  app.run(host='0.0.0.0', port=5000)
 
1
+ import os
2
+ import time
3
+ import tempfile
4
+ import jinja2
5
+ import pdfkit
6
+ import torch
7
+ from threading import Thread
8
  from flask import Flask, request, send_file, jsonify
9
  from flask_cors import CORS
10
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
11
+
12
+ # Configure cache directories
 
 
 
 
 
 
13
  os.environ['HF_HOME'] = '/app/.cache'
14
+ os.environ['XDG_CACHE_HOME'] = '/app/.cache'
15
 
16
  app = Flask(__name__)
17
  CORS(app)
18
 
19
+ # Global state tracking
20
+ model_loaded = False
21
+ load_error = None
22
+ generator = None
23
+
24
+ # --------------------------------------------------
25
+ # Asynchronous Model Loading
26
+ # --------------------------------------------------
27
+ def load_model():
28
+ global model_loaded, load_error, generator
29
+ try:
30
+ # Initialize model with low-memory settings
31
+ model = AutoModelForCausalLM.from_pretrained(
32
+ "gpt2-medium",
33
+ use_safetensors=True,
34
+ device_map="auto",
35
+ low_cpu_mem_usage=True,
36
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
37
+ )
38
+
39
+ tokenizer = AutoTokenizer.from_pretrained("gpt2-medium")
40
+
41
+ generator = pipeline(
42
+ 'text-generation',
43
+ model=model,
44
+ tokenizer=tokenizer,
45
+ device=0 if torch.cuda.is_available() else -1
46
+ )
47
+
48
+ model_loaded = True
49
+ print("Model loaded successfully")
50
+
51
+ except Exception as e:
52
+ load_error = str(e)
53
+ print(f"Model loading failed: {load_error}")
54
+
55
+ # Start model loading in background thread
56
+ Thread(target=load_model).start()
57
 
58
+ # --------------------------------------------------
59
+ # IEEE Format Template
60
+ # --------------------------------------------------
61
  IEEE_TEMPLATE = """
62
  <!DOCTYPE html>
63
  <html>
 
71
  font-size: 12pt;
72
  line-height: 1.5;
73
  }
74
+ .header { text-align: center; margin-bottom: 24pt; }
75
+ .two-column { column-count: 2; column-gap: 0.5in; }
 
 
 
 
 
 
76
  h1 { font-size: 14pt; margin: 12pt 0; }
77
  h2 { font-size: 12pt; margin: 12pt 0 6pt 0; }
78
  .abstract { margin-bottom: 24pt; }
 
117
  </html>
118
  """
119
 
120
+ # --------------------------------------------------
121
+ # API Endpoints
122
+ # --------------------------------------------------
123
+ @app.route('/health', methods=['GET'])
124
+ def health_check():
125
+ if load_error:
126
+ return jsonify({
127
+ "status": "error",
128
+ "message": f"Model failed to load: {load_error}"
129
+ }), 500
130
+
131
+ return jsonify({
132
+ "status": "ready" if model_loaded else "loading",
133
+ "model_loaded": model_loaded,
134
+ "device": "cuda" if torch.cuda.is_available() else "cpu"
135
+ }), 200 if model_loaded else 503
136
 
137
  @app.route('/generate', methods=['POST'])
138
  def generate_pdf():
139
+ # Check model status
140
+ if not model_loaded:
141
+ return jsonify({
142
+ "error": "Model not loaded yet",
143
+ "status": "loading"
144
+ }), 503
145
+
146
  try:
147
+ # Validate input
148
  data = request.json
149
+ if not data:
150
+ return jsonify({"error": "No data provided"}), 400
151
+
152
+ required = ['title', 'authors', 'content']
153
+ if missing := [field for field in required if field not in data]:
154
+ return jsonify({
155
+ "error": f"Missing fields: {', '.join(missing)}"
156
+ }), 400
157
 
158
+ # Format content
159
+ formatted = format_content(data['content'])
160
 
161
  # Generate HTML
162
  html = jinja2.Template(IEEE_TEMPLATE).render(
 
179
  'quiet': ''
180
  }
181
 
182
+ # Create temporary PDF
183
  with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as f:
184
  pdfkit.from_string(html, f.name, options=options)
185
  return send_file(f.name, mimetype='application/pdf')
186
 
187
  except Exception as e:
188
  return jsonify({"error": str(e)}), 500
189
+ finally:
190
+ if 'f' in locals():
191
+ try: os.remove(f.name)
192
+ except: pass
193
+
194
+ # --------------------------------------------------
195
+ # Content Formatting
196
+ # --------------------------------------------------
197
+ def format_content(content):
198
+ try:
199
+ prompt = f"Format this research content to IEEE standards:\n{str(content)}"
200
+ return generator(
201
+ prompt,
202
+ max_new_tokens=512,
203
+ temperature=0.7,
204
+ do_sample=True,
205
+ truncation=True
206
+ )[0]['generated_text']
207
+ except Exception as e:
208
+ print(f"Formatting error: {str(e)}")
209
+ return content
210
 
211
  if __name__ == '__main__':
212
  app.run(host='0.0.0.0', port=5000)