Commit
·
5c3b44c
1
Parent(s):
74cee4f
Updated main.py to the app.py
Browse files- app/{main.py → app.py} +45 -51
- runtime.txt +1 -1
app/{main.py → app.py}
RENAMED
@@ -48,12 +48,17 @@ RELOAD_INTERVAL = 300 # 5 minutes
|
|
48 |
def load_models_impl():
|
49 |
"""Implementation of model loading logic with proper error handling"""
|
50 |
global embedder, ai_tokenizer, ai_model, model_status
|
51 |
-
|
52 |
# Track attempt time
|
53 |
model_status["last_reload_attempt"] = time.time()
|
54 |
model_status["last_error"] = None
|
55 |
-
|
56 |
try:
|
|
|
|
|
|
|
|
|
|
|
57 |
# Check Hugging Face Hub connectivity
|
58 |
response = requests.head("https://huggingface.co", timeout=5)
|
59 |
if response.status_code == 200:
|
@@ -63,34 +68,34 @@ def load_models_impl():
|
|
63 |
logger.error(f"Failed to connect to Hugging Face Hub: {response.status_code}")
|
64 |
except Exception as e:
|
65 |
logger.error(f"Error checking Hugging Face Hub connectivity: {e}")
|
66 |
-
|
67 |
try:
|
68 |
# Load SentenceTransformer model for embeddings
|
69 |
logger.info("Loading SentenceTransformer model...")
|
70 |
embedder = SentenceTransformer('sentence-transformers/all-roberta-large-v1')
|
71 |
-
|
72 |
# Load AI detection model
|
73 |
ai_model_name = "ChrispamWrites/roberta-ai-detector-20250401_232702"
|
74 |
logger.info(f"Loading AI detection model: {ai_model_name}")
|
75 |
-
|
76 |
# Use local cache if available or download from HF
|
77 |
ai_tokenizer = AutoTokenizer.from_pretrained(
|
78 |
ai_model_name,
|
79 |
local_files_only=not model_status["hub_accessible"],
|
80 |
cache_dir="./model_cache"
|
81 |
)
|
82 |
-
|
83 |
# Load the config first
|
84 |
ai_config = AutoConfig.from_pretrained(
|
85 |
ai_model_name,
|
86 |
local_files_only=not model_status["hub_accessible"],
|
87 |
cache_dir="./model_cache"
|
88 |
)
|
89 |
-
|
90 |
# Modify the config to match the checkpoint's expected dimensions
|
91 |
-
ai_config.max_position_embeddings = 514
|
92 |
-
ai_config.type_vocab_size = 1
|
93 |
-
|
94 |
# Load the model with this config
|
95 |
ai_model = AutoModelForSequenceClassification.from_pretrained(
|
96 |
ai_model_name,
|
@@ -98,7 +103,7 @@ def load_models_impl():
|
|
98 |
local_files_only=not model_status["hub_accessible"],
|
99 |
cache_dir="./model_cache"
|
100 |
)
|
101 |
-
|
102 |
# If the above doesn't work, try with ignore_mismatched_sizes
|
103 |
if ai_model is None:
|
104 |
logger.info("Attempting to load model with ignore_mismatched_sizes=True")
|
@@ -108,18 +113,18 @@ def load_models_impl():
|
|
108 |
cache_dir="./model_cache",
|
109 |
ignore_mismatched_sizes=True
|
110 |
)
|
111 |
-
|
112 |
# Verify models are loaded by testing them
|
113 |
test_sentence = "This is a test sentence to verify model loading."
|
114 |
-
|
115 |
# Test sentence transformer
|
116 |
_ = embedder.encode(test_sentence)
|
117 |
-
|
118 |
# Test AI detection model
|
119 |
inputs = ai_tokenizer(test_sentence, return_tensors="pt", max_length=512, truncation=True)
|
120 |
with torch.no_grad():
|
121 |
_ = ai_model(**inputs)
|
122 |
-
|
123 |
model_status["model_loaded"] = True
|
124 |
logger.info("Models loaded successfully!")
|
125 |
return True
|
@@ -141,7 +146,7 @@ async def load_models():
|
|
141 |
retries += 1
|
142 |
logger.info(f"Retrying model loading ({retries}/{MAX_RETRIES})...")
|
143 |
time.sleep(5) # Wait 5 seconds before retrying
|
144 |
-
|
145 |
if not model_status["model_loaded"]:
|
146 |
logger.warning(f"Failed to load models after {MAX_RETRIES} attempts. API will start, but analyze endpoint won't work.")
|
147 |
|
@@ -176,7 +181,7 @@ def detect_ai_generated(text):
|
|
176 |
probs = torch.softmax(logits, dim=1).squeeze()
|
177 |
predicted_class = torch.argmax(probs).item()
|
178 |
confidence = probs[predicted_class].item()
|
179 |
-
|
180 |
return {
|
181 |
"label": "AI-generated" if predicted_class == 1 else "Human-written",
|
182 |
"confidence": round(confidence * 100, 2)
|
@@ -192,12 +197,12 @@ async def health_check() -> Dict[str, Any]:
|
|
192 |
(model_status["last_reload_attempt"] is None or
|
193 |
current_time - model_status["last_reload_attempt"] > RELOAD_INTERVAL)
|
194 |
)
|
195 |
-
|
196 |
return {
|
197 |
**model_status,
|
198 |
"reload_needed": reload_needed,
|
199 |
-
"last_reload_attempt_time": time.strftime('%Y-%m-%d %H:%M:%S',
|
200 |
-
time.localtime(model_status["last_reload_attempt"]))
|
201 |
if model_status["last_reload_attempt"] else None
|
202 |
}
|
203 |
|
@@ -206,20 +211,20 @@ async def reload_models(background_tasks: BackgroundTasks):
|
|
206 |
"""Endpoint to manually trigger model reloading"""
|
207 |
# Check if enough time has passed since last reload attempt
|
208 |
current_time = time.time()
|
209 |
-
if (model_status["last_reload_attempt"] is not None and
|
210 |
current_time - model_status["last_reload_attempt"] < 60): # Prevent reloading more than once per minute
|
211 |
return JSONResponse(content={
|
212 |
"message": "Too many reload attempts. Please wait before trying again.",
|
213 |
"seconds_until_next_attempt": 60 - int(current_time - model_status["last_reload_attempt"])
|
214 |
}, status_code=429)
|
215 |
-
|
216 |
background_tasks.add_task(background_model_reload, background_tasks)
|
217 |
return {"message": "Model reload initiated in background"}
|
218 |
|
219 |
@app.post("/analyze")
|
220 |
async def analyze_essay(file: UploadFile = File(...), background_tasks: BackgroundTasks = None):
|
221 |
global model_status, embedder, ai_tokenizer, ai_model
|
222 |
-
|
223 |
# Check if models are loaded
|
224 |
if not model_status["model_loaded"]:
|
225 |
# Check if we should attempt to reload models
|
@@ -228,64 +233,53 @@ async def analyze_essay(file: UploadFile = File(...), background_tasks: Backgrou
|
|
228 |
model_status["last_reload_attempt"] is None or
|
229 |
current_time - model_status["last_reload_attempt"] > RELOAD_INTERVAL
|
230 |
)
|
231 |
-
|
232 |
if reload_needed and background_tasks:
|
233 |
# Start a background reload
|
234 |
background_tasks.add_task(background_model_reload, background_tasks)
|
235 |
message = "Models are being reloaded in the background. Please try again in a few minutes."
|
236 |
else:
|
237 |
message = "Model not loaded. Check /health endpoint for details or try /reload-models endpoint."
|
238 |
-
|
239 |
raise HTTPException(status_code=503, detail=message)
|
240 |
-
|
241 |
# Check if models are actually initialized
|
242 |
if embedder is None or ai_tokenizer is None or ai_model is None:
|
243 |
logger.error("Models appear loaded but variables are None")
|
244 |
raise HTTPException(status_code=503, detail="Model initialization incomplete. Please try again later.")
|
245 |
-
|
246 |
if not file.filename.endswith(".pdf"):
|
247 |
raise HTTPException(status_code=400, detail="Only PDF files are supported")
|
248 |
-
|
249 |
with tempfile.TemporaryDirectory() as tmpdir:
|
250 |
file_path = os.path.join(tmpdir, f"{uuid.uuid4()}.pdf")
|
251 |
with open(file_path, "wb") as buffer:
|
252 |
shutil.copyfileobj(file.file, buffer)
|
253 |
-
|
254 |
try:
|
255 |
essay_text = extract_text_from_pdf(file_path)
|
256 |
except RuntimeError as e:
|
257 |
raise HTTPException(status_code=500, detail=str(e))
|
258 |
-
|
259 |
if not essay_text.strip():
|
260 |
raise HTTPException(status_code=400, detail="The PDF seems to contain no extractable text.")
|
261 |
-
|
262 |
try:
|
263 |
# Run AI content detection
|
264 |
ai_result = detect_ai_generated(essay_text)
|
265 |
-
|
266 |
# Run internal plagiarism detection
|
267 |
chunks = chunk_text(essay_text)
|
268 |
if len(chunks) < 2:
|
269 |
raise HTTPException(status_code=400, detail="Not enough text chunks to assess internal plagiarism.")
|
270 |
-
|
271 |
embeddings = embedder.encode(chunks)
|
272 |
similarities = []
|
273 |
-
for i in range(len(embeddings)):
|
274 |
-
for j in range(i + 1, len(embeddings)):
|
275 |
-
sim = cosine_similarity([embeddings[i]], [embeddings[j]])[0][0]
|
276 |
-
similarities.append(sim)
|
277 |
-
|
278 |
-
max_similarity = max(similarities) if similarities else 0
|
279 |
-
avg_similarity = sum(similarities) / len(similarities) if similarities else 0
|
280 |
-
internal_score = round(avg_similarity * 100, 2)
|
281 |
except Exception as e:
|
282 |
-
logger.error(f"
|
283 |
-
raise HTTPException(status_code=500, detail=
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
})
|
290 |
-
|
291 |
-
|
|
|
48 |
def load_models_impl():
|
49 |
"""Implementation of model loading logic with proper error handling"""
|
50 |
global embedder, ai_tokenizer, ai_model, model_status
|
51 |
+
|
52 |
# Track attempt time
|
53 |
model_status["last_reload_attempt"] = time.time()
|
54 |
model_status["last_error"] = None
|
55 |
+
|
56 |
try:
|
57 |
+
# Placeholder for the code that should be inside the try block
|
58 |
+
pass
|
59 |
+
except Exception as e:
|
60 |
+
logger.error(f"An error occurred: {e}")
|
61 |
+
raise HTTPException(status_code=500, detail="An internal error occurred.")
|
62 |
# Check Hugging Face Hub connectivity
|
63 |
response = requests.head("https://huggingface.co", timeout=5)
|
64 |
if response.status_code == 200:
|
|
|
68 |
logger.error(f"Failed to connect to Hugging Face Hub: {response.status_code}")
|
69 |
except Exception as e:
|
70 |
logger.error(f"Error checking Hugging Face Hub connectivity: {e}")
|
71 |
+
|
72 |
try:
|
73 |
# Load SentenceTransformer model for embeddings
|
74 |
logger.info("Loading SentenceTransformer model...")
|
75 |
embedder = SentenceTransformer('sentence-transformers/all-roberta-large-v1')
|
76 |
+
|
77 |
# Load AI detection model
|
78 |
ai_model_name = "ChrispamWrites/roberta-ai-detector-20250401_232702"
|
79 |
logger.info(f"Loading AI detection model: {ai_model_name}")
|
80 |
+
|
81 |
# Use local cache if available or download from HF
|
82 |
ai_tokenizer = AutoTokenizer.from_pretrained(
|
83 |
ai_model_name,
|
84 |
local_files_only=not model_status["hub_accessible"],
|
85 |
cache_dir="./model_cache"
|
86 |
)
|
87 |
+
|
88 |
# Load the config first
|
89 |
ai_config = AutoConfig.from_pretrained(
|
90 |
ai_model_name,
|
91 |
local_files_only=not model_status["hub_accessible"],
|
92 |
cache_dir="./model_cache"
|
93 |
)
|
94 |
+
|
95 |
# Modify the config to match the checkpoint's expected dimensions
|
96 |
+
ai_config.max_position_embeddings = 514
|
97 |
+
ai_config.type_vocab_size = 1
|
98 |
+
|
99 |
# Load the model with this config
|
100 |
ai_model = AutoModelForSequenceClassification.from_pretrained(
|
101 |
ai_model_name,
|
|
|
103 |
local_files_only=not model_status["hub_accessible"],
|
104 |
cache_dir="./model_cache"
|
105 |
)
|
106 |
+
|
107 |
# If the above doesn't work, try with ignore_mismatched_sizes
|
108 |
if ai_model is None:
|
109 |
logger.info("Attempting to load model with ignore_mismatched_sizes=True")
|
|
|
113 |
cache_dir="./model_cache",
|
114 |
ignore_mismatched_sizes=True
|
115 |
)
|
116 |
+
|
117 |
# Verify models are loaded by testing them
|
118 |
test_sentence = "This is a test sentence to verify model loading."
|
119 |
+
|
120 |
# Test sentence transformer
|
121 |
_ = embedder.encode(test_sentence)
|
122 |
+
|
123 |
# Test AI detection model
|
124 |
inputs = ai_tokenizer(test_sentence, return_tensors="pt", max_length=512, truncation=True)
|
125 |
with torch.no_grad():
|
126 |
_ = ai_model(**inputs)
|
127 |
+
|
128 |
model_status["model_loaded"] = True
|
129 |
logger.info("Models loaded successfully!")
|
130 |
return True
|
|
|
146 |
retries += 1
|
147 |
logger.info(f"Retrying model loading ({retries}/{MAX_RETRIES})...")
|
148 |
time.sleep(5) # Wait 5 seconds before retrying
|
149 |
+
|
150 |
if not model_status["model_loaded"]:
|
151 |
logger.warning(f"Failed to load models after {MAX_RETRIES} attempts. API will start, but analyze endpoint won't work.")
|
152 |
|
|
|
181 |
probs = torch.softmax(logits, dim=1).squeeze()
|
182 |
predicted_class = torch.argmax(probs).item()
|
183 |
confidence = probs[predicted_class].item()
|
184 |
+
|
185 |
return {
|
186 |
"label": "AI-generated" if predicted_class == 1 else "Human-written",
|
187 |
"confidence": round(confidence * 100, 2)
|
|
|
197 |
(model_status["last_reload_attempt"] is None or
|
198 |
current_time - model_status["last_reload_attempt"] > RELOAD_INTERVAL)
|
199 |
)
|
200 |
+
|
201 |
return {
|
202 |
**model_status,
|
203 |
"reload_needed": reload_needed,
|
204 |
+
"last_reload_attempt_time": time.strftime('%Y-%m-%d %H:%M:%S',
|
205 |
+
time.localtime(model_status["last_reload_attempt"]))
|
206 |
if model_status["last_reload_attempt"] else None
|
207 |
}
|
208 |
|
|
|
211 |
"""Endpoint to manually trigger model reloading"""
|
212 |
# Check if enough time has passed since last reload attempt
|
213 |
current_time = time.time()
|
214 |
+
if (model_status["last_reload_attempt"] is not None and
|
215 |
current_time - model_status["last_reload_attempt"] < 60): # Prevent reloading more than once per minute
|
216 |
return JSONResponse(content={
|
217 |
"message": "Too many reload attempts. Please wait before trying again.",
|
218 |
"seconds_until_next_attempt": 60 - int(current_time - model_status["last_reload_attempt"])
|
219 |
}, status_code=429)
|
220 |
+
|
221 |
background_tasks.add_task(background_model_reload, background_tasks)
|
222 |
return {"message": "Model reload initiated in background"}
|
223 |
|
224 |
@app.post("/analyze")
|
225 |
async def analyze_essay(file: UploadFile = File(...), background_tasks: BackgroundTasks = None):
|
226 |
global model_status, embedder, ai_tokenizer, ai_model
|
227 |
+
|
228 |
# Check if models are loaded
|
229 |
if not model_status["model_loaded"]:
|
230 |
# Check if we should attempt to reload models
|
|
|
233 |
model_status["last_reload_attempt"] is None or
|
234 |
current_time - model_status["last_reload_attempt"] > RELOAD_INTERVAL
|
235 |
)
|
236 |
+
|
237 |
if reload_needed and background_tasks:
|
238 |
# Start a background reload
|
239 |
background_tasks.add_task(background_model_reload, background_tasks)
|
240 |
message = "Models are being reloaded in the background. Please try again in a few minutes."
|
241 |
else:
|
242 |
message = "Model not loaded. Check /health endpoint for details or try /reload-models endpoint."
|
243 |
+
|
244 |
raise HTTPException(status_code=503, detail=message)
|
245 |
+
|
246 |
# Check if models are actually initialized
|
247 |
if embedder is None or ai_tokenizer is None or ai_model is None:
|
248 |
logger.error("Models appear loaded but variables are None")
|
249 |
raise HTTPException(status_code=503, detail="Model initialization incomplete. Please try again later.")
|
250 |
+
|
251 |
if not file.filename.endswith(".pdf"):
|
252 |
raise HTTPException(status_code=400, detail="Only PDF files are supported")
|
253 |
+
|
254 |
with tempfile.TemporaryDirectory() as tmpdir:
|
255 |
file_path = os.path.join(tmpdir, f"{uuid.uuid4()}.pdf")
|
256 |
with open(file_path, "wb") as buffer:
|
257 |
shutil.copyfileobj(file.file, buffer)
|
258 |
+
|
259 |
try:
|
260 |
essay_text = extract_text_from_pdf(file_path)
|
261 |
except RuntimeError as e:
|
262 |
raise HTTPException(status_code=500, detail=str(e))
|
263 |
+
|
264 |
if not essay_text.strip():
|
265 |
raise HTTPException(status_code=400, detail="The PDF seems to contain no extractable text.")
|
266 |
+
|
267 |
try:
|
268 |
# Run AI content detection
|
269 |
ai_result = detect_ai_generated(essay_text)
|
270 |
+
|
271 |
# Run internal plagiarism detection
|
272 |
chunks = chunk_text(essay_text)
|
273 |
if len(chunks) < 2:
|
274 |
raise HTTPException(status_code=400, detail="Not enough text chunks to assess internal plagiarism.")
|
275 |
+
|
276 |
embeddings = embedder.encode(chunks)
|
277 |
similarities = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
278 |
except Exception as e:
|
279 |
+
logger.error(f"An error occurred during analysis: {e}")
|
280 |
+
raise HTTPException(status_code=500, detail="An error occurred during analysis.")
|
281 |
+
except Exception as e:
|
282 |
+
logger.error(f"An error occurred during analysis: {e}")
|
283 |
+
raise HTTPException(status_code=500, detail="An error occurred during analysis.")
|
284 |
+
|
285 |
+
|
|
|
|
|
|
runtime.txt
CHANGED
@@ -1 +1 @@
|
|
1 |
-
python-3.
|
|
|
1 |
+
python-3.10
|