Update app.py
Browse files
app.py
CHANGED
@@ -1,37 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from flask import Flask, request, send_file, jsonify
|
2 |
from flask_cors import CORS
|
3 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
4 |
-
|
5 |
-
|
6 |
-
import torch
|
7 |
-
import tempfile
|
8 |
-
import os
|
9 |
-
os.environ['TRANSFORMERS_CACHE'] = '/app/.cache'
|
10 |
-
os.environ['HF_DATASETS_CACHE'] = '/app/.cache'
|
11 |
-
os.environ['XDG_CACHE_HOME'] = '/app/.cache'
|
12 |
os.environ['HF_HOME'] = '/app/.cache'
|
|
|
13 |
|
14 |
app = Flask(__name__)
|
15 |
CORS(app)
|
16 |
|
17 |
-
#
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
|
|
|
|
|
|
35 |
IEEE_TEMPLATE = """
|
36 |
<!DOCTYPE html>
|
37 |
<html>
|
@@ -45,14 +71,8 @@ IEEE_TEMPLATE = """
|
|
45 |
font-size: 12pt;
|
46 |
line-height: 1.5;
|
47 |
}
|
48 |
-
.header {
|
49 |
-
|
50 |
-
margin-bottom: 24pt;
|
51 |
-
}
|
52 |
-
.two-column {
|
53 |
-
column-count: 2;
|
54 |
-
column-gap: 0.5in;
|
55 |
-
}
|
56 |
h1 { font-size: 14pt; margin: 12pt 0; }
|
57 |
h2 { font-size: 12pt; margin: 12pt 0 6pt 0; }
|
58 |
.abstract { margin-bottom: 24pt; }
|
@@ -97,31 +117,46 @@ IEEE_TEMPLATE = """
|
|
97 |
</html>
|
98 |
"""
|
99 |
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
return
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
|
|
115 |
|
116 |
@app.route('/generate', methods=['POST'])
|
117 |
def generate_pdf():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
try:
|
|
|
119 |
data = request.json
|
120 |
-
if not data
|
121 |
-
return jsonify({"error": "
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
-
# Format content
|
124 |
-
formatted = format_content(data
|
125 |
|
126 |
# Generate HTML
|
127 |
html = jinja2.Template(IEEE_TEMPLATE).render(
|
@@ -144,13 +179,34 @@ def generate_pdf():
|
|
144 |
'quiet': ''
|
145 |
}
|
146 |
|
147 |
-
# Create PDF
|
148 |
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as f:
|
149 |
pdfkit.from_string(html, f.name, options=options)
|
150 |
return send_file(f.name, mimetype='application/pdf')
|
151 |
|
152 |
except Exception as e:
|
153 |
return jsonify({"error": str(e)}), 500
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
|
155 |
if __name__ == '__main__':
|
156 |
app.run(host='0.0.0.0', port=5000)
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import tempfile
|
4 |
+
import jinja2
|
5 |
+
import pdfkit
|
6 |
+
import torch
|
7 |
+
from threading import Thread
|
8 |
from flask import Flask, request, send_file, jsonify
|
9 |
from flask_cors import CORS
|
10 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
11 |
+
|
12 |
+
# Configure cache directories
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
os.environ['HF_HOME'] = '/app/.cache'
|
14 |
+
os.environ['XDG_CACHE_HOME'] = '/app/.cache'
|
15 |
|
16 |
app = Flask(__name__)
|
17 |
CORS(app)
|
18 |
|
19 |
+
# Global state tracking
|
20 |
+
model_loaded = False
|
21 |
+
load_error = None
|
22 |
+
generator = None
|
23 |
+
|
24 |
+
# --------------------------------------------------
|
25 |
+
# Asynchronous Model Loading
|
26 |
+
# --------------------------------------------------
|
27 |
+
def load_model():
|
28 |
+
global model_loaded, load_error, generator
|
29 |
+
try:
|
30 |
+
# Initialize model with low-memory settings
|
31 |
+
model = AutoModelForCausalLM.from_pretrained(
|
32 |
+
"gpt2-medium",
|
33 |
+
use_safetensors=True,
|
34 |
+
device_map="auto",
|
35 |
+
low_cpu_mem_usage=True,
|
36 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
|
37 |
+
)
|
38 |
+
|
39 |
+
tokenizer = AutoTokenizer.from_pretrained("gpt2-medium")
|
40 |
+
|
41 |
+
generator = pipeline(
|
42 |
+
'text-generation',
|
43 |
+
model=model,
|
44 |
+
tokenizer=tokenizer,
|
45 |
+
device=0 if torch.cuda.is_available() else -1
|
46 |
+
)
|
47 |
+
|
48 |
+
model_loaded = True
|
49 |
+
print("Model loaded successfully")
|
50 |
+
|
51 |
+
except Exception as e:
|
52 |
+
load_error = str(e)
|
53 |
+
print(f"Model loading failed: {load_error}")
|
54 |
+
|
55 |
+
# Start model loading in background thread
|
56 |
+
Thread(target=load_model).start()
|
57 |
|
58 |
+
# --------------------------------------------------
|
59 |
+
# IEEE Format Template
|
60 |
+
# --------------------------------------------------
|
61 |
IEEE_TEMPLATE = """
|
62 |
<!DOCTYPE html>
|
63 |
<html>
|
|
|
71 |
font-size: 12pt;
|
72 |
line-height: 1.5;
|
73 |
}
|
74 |
+
.header { text-align: center; margin-bottom: 24pt; }
|
75 |
+
.two-column { column-count: 2; column-gap: 0.5in; }
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
h1 { font-size: 14pt; margin: 12pt 0; }
|
77 |
h2 { font-size: 12pt; margin: 12pt 0 6pt 0; }
|
78 |
.abstract { margin-bottom: 24pt; }
|
|
|
117 |
</html>
|
118 |
"""
|
119 |
|
120 |
+
# --------------------------------------------------
|
121 |
+
# API Endpoints
|
122 |
+
# --------------------------------------------------
|
123 |
+
@app.route('/health', methods=['GET'])
|
124 |
+
def health_check():
|
125 |
+
if load_error:
|
126 |
+
return jsonify({
|
127 |
+
"status": "error",
|
128 |
+
"message": f"Model failed to load: {load_error}"
|
129 |
+
}), 500
|
130 |
+
|
131 |
+
return jsonify({
|
132 |
+
"status": "ready" if model_loaded else "loading",
|
133 |
+
"model_loaded": model_loaded,
|
134 |
+
"device": "cuda" if torch.cuda.is_available() else "cpu"
|
135 |
+
}), 200 if model_loaded else 503
|
136 |
|
137 |
@app.route('/generate', methods=['POST'])
|
138 |
def generate_pdf():
|
139 |
+
# Check model status
|
140 |
+
if not model_loaded:
|
141 |
+
return jsonify({
|
142 |
+
"error": "Model not loaded yet",
|
143 |
+
"status": "loading"
|
144 |
+
}), 503
|
145 |
+
|
146 |
try:
|
147 |
+
# Validate input
|
148 |
data = request.json
|
149 |
+
if not data:
|
150 |
+
return jsonify({"error": "No data provided"}), 400
|
151 |
+
|
152 |
+
required = ['title', 'authors', 'content']
|
153 |
+
if missing := [field for field in required if field not in data]:
|
154 |
+
return jsonify({
|
155 |
+
"error": f"Missing fields: {', '.join(missing)}"
|
156 |
+
}), 400
|
157 |
|
158 |
+
# Format content
|
159 |
+
formatted = format_content(data['content'])
|
160 |
|
161 |
# Generate HTML
|
162 |
html = jinja2.Template(IEEE_TEMPLATE).render(
|
|
|
179 |
'quiet': ''
|
180 |
}
|
181 |
|
182 |
+
# Create temporary PDF
|
183 |
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as f:
|
184 |
pdfkit.from_string(html, f.name, options=options)
|
185 |
return send_file(f.name, mimetype='application/pdf')
|
186 |
|
187 |
except Exception as e:
|
188 |
return jsonify({"error": str(e)}), 500
|
189 |
+
finally:
|
190 |
+
if 'f' in locals():
|
191 |
+
try: os.remove(f.name)
|
192 |
+
except: pass
|
193 |
+
|
194 |
+
# --------------------------------------------------
|
195 |
+
# Content Formatting
|
196 |
+
# --------------------------------------------------
|
197 |
+
def format_content(content):
|
198 |
+
try:
|
199 |
+
prompt = f"Format this research content to IEEE standards:\n{str(content)}"
|
200 |
+
return generator(
|
201 |
+
prompt,
|
202 |
+
max_new_tokens=512,
|
203 |
+
temperature=0.7,
|
204 |
+
do_sample=True,
|
205 |
+
truncation=True
|
206 |
+
)[0]['generated_text']
|
207 |
+
except Exception as e:
|
208 |
+
print(f"Formatting error: {str(e)}")
|
209 |
+
return content
|
210 |
|
211 |
if __name__ == '__main__':
|
212 |
app.run(host='0.0.0.0', port=5000)
|