Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
@@ -11,26 +11,11 @@ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
|
11 |
import tempfile
|
12 |
import uuid
|
13 |
import shutil
|
|
|
14 |
|
15 |
# Disable numba JIT to avoid caching issues
|
16 |
os.environ["NUMBA_DISABLE_JIT"] = "1"
|
17 |
|
18 |
-
# Initialize FastAPI app
|
19 |
-
app = FastAPI(
|
20 |
-
title="Quran Recitation Comparison API",
|
21 |
-
description="API for comparing similarity between Quran recitations using Wav2Vec2 embeddings",
|
22 |
-
version="1.0.0"
|
23 |
-
)
|
24 |
-
|
25 |
-
# Add CORS middleware
|
26 |
-
app.add_middleware(
|
27 |
-
CORSMiddleware,
|
28 |
-
allow_origins=["*"], # Allows all origins
|
29 |
-
allow_credentials=True,
|
30 |
-
allow_methods=["*"], # Allows all methods
|
31 |
-
allow_headers=["*"], # Allows all headers
|
32 |
-
)
|
33 |
-
|
34 |
# Global variables
|
35 |
MODEL = None
|
36 |
PROCESSOR = None
|
@@ -59,10 +44,10 @@ def initialize_model():
|
|
59 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
60 |
print(f"Loading model on device: {device}")
|
61 |
|
62 |
-
# Load model and processor
|
63 |
if hf_token:
|
64 |
-
PROCESSOR = Wav2Vec2Processor.from_pretrained(model_name,
|
65 |
-
MODEL = Wav2Vec2ForCTC.from_pretrained(model_name,
|
66 |
else:
|
67 |
PROCESSOR = Wav2Vec2Processor.from_pretrained(model_name)
|
68 |
MODEL = Wav2Vec2ForCTC.from_pretrained(model_name)
|
@@ -280,13 +265,24 @@ async def health_check():
|
|
280 |
)
|
281 |
return {"status": "ok", "model_loaded": True}
|
282 |
|
283 |
-
#
|
284 |
-
@
|
285 |
-
async def
|
286 |
initialize_model()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
287 |
|
288 |
# Run the FastAPI app
|
289 |
if __name__ == "__main__":
|
290 |
import uvicorn
|
291 |
port = int(os.environ.get("PORT", 7860)) # Default to port 7860 for Hugging Face Spaces
|
292 |
-
uvicorn.run("main:app", host="0.0.0.0", port=port, reload=False)
|
|
|
11 |
import tempfile
|
12 |
import uuid
|
13 |
import shutil
|
14 |
+
from contextlib import asynccontextmanager
|
15 |
|
16 |
# Disable numba JIT to avoid caching issues
|
17 |
os.environ["NUMBA_DISABLE_JIT"] = "1"
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
# Global variables
|
20 |
MODEL = None
|
21 |
PROCESSOR = None
|
|
|
44 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
45 |
print(f"Loading model on device: {device}")
|
46 |
|
47 |
+
# Load model and processor using the updated parameter `token`
|
48 |
if hf_token:
|
49 |
+
PROCESSOR = Wav2Vec2Processor.from_pretrained(model_name, token=hf_token)
|
50 |
+
MODEL = Wav2Vec2ForCTC.from_pretrained(model_name, token=hf_token)
|
51 |
else:
|
52 |
PROCESSOR = Wav2Vec2Processor.from_pretrained(model_name)
|
53 |
MODEL = Wav2Vec2ForCTC.from_pretrained(model_name)
|
|
|
265 |
)
|
266 |
return {"status": "ok", "model_loaded": True}
|
267 |
|
268 |
+
# Use lifespan context manager instead of on_event decorators
|
269 |
+
@asynccontextmanager
|
270 |
+
async def lifespan(app: FastAPI):
|
271 |
initialize_model()
|
272 |
+
yield
|
273 |
+
|
274 |
+
# Initialize FastAPI app with lifespan handler
|
275 |
+
app = FastAPI(
|
276 |
+
title="Quran Recitation Comparison API",
|
277 |
+
description="API for comparing similarity between Quran recitations using Wav2Vec2 embeddings",
|
278 |
+
version="1.0.0",
|
279 |
+
lifespan=lifespan
|
280 |
+
)
|
281 |
+
|
282 |
+
# Note: Ensure that all route definitions are declared AFTER the app initialization above.
|
283 |
|
284 |
# Run the FastAPI app
|
285 |
if __name__ == "__main__":
|
286 |
import uvicorn
|
287 |
port = int(os.environ.get("PORT", 7860)) # Default to port 7860 for Hugging Face Spaces
|
288 |
+
uvicorn.run("main:app", host="0.0.0.0", port=port, reload=False)
|