Hammad712 commited on
Commit
b05966a
·
verified ·
1 Parent(s): 9d9096a

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +19 -23
main.py CHANGED
@@ -11,26 +11,11 @@ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
11
  import tempfile
12
  import uuid
13
  import shutil
 
14
 
15
  # Disable numba JIT to avoid caching issues
16
  os.environ["NUMBA_DISABLE_JIT"] = "1"
17
 
18
- # Initialize FastAPI app
19
- app = FastAPI(
20
- title="Quran Recitation Comparison API",
21
- description="API for comparing similarity between Quran recitations using Wav2Vec2 embeddings",
22
- version="1.0.0"
23
- )
24
-
25
- # Add CORS middleware
26
- app.add_middleware(
27
- CORSMiddleware,
28
- allow_origins=["*"], # Allows all origins
29
- allow_credentials=True,
30
- allow_methods=["*"], # Allows all methods
31
- allow_headers=["*"], # Allows all headers
32
- )
33
-
34
  # Global variables
35
  MODEL = None
36
  PROCESSOR = None
@@ -59,10 +44,10 @@ def initialize_model():
59
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
  print(f"Loading model on device: {device}")
61
 
62
- # Load model and processor
63
  if hf_token:
64
- PROCESSOR = Wav2Vec2Processor.from_pretrained(model_name, use_auth_token=hf_token)
65
- MODEL = Wav2Vec2ForCTC.from_pretrained(model_name, use_auth_token=hf_token)
66
  else:
67
  PROCESSOR = Wav2Vec2Processor.from_pretrained(model_name)
68
  MODEL = Wav2Vec2ForCTC.from_pretrained(model_name)
@@ -280,13 +265,24 @@ async def health_check():
280
  )
281
  return {"status": "ok", "model_loaded": True}
282
 
283
- # Initialize model on startup
284
- @app.on_event("startup")
285
- async def startup_event():
286
  initialize_model()
 
 
 
 
 
 
 
 
 
 
 
287
 
288
  # Run the FastAPI app
289
  if __name__ == "__main__":
290
  import uvicorn
291
  port = int(os.environ.get("PORT", 7860)) # Default to port 7860 for Hugging Face Spaces
292
- uvicorn.run("main:app", host="0.0.0.0", port=port, reload=False)
 
11
  import tempfile
12
  import uuid
13
  import shutil
14
+ from contextlib import asynccontextmanager
15
 
16
  # Disable numba JIT to avoid caching issues
17
  os.environ["NUMBA_DISABLE_JIT"] = "1"
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # Global variables
20
  MODEL = None
21
  PROCESSOR = None
 
44
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
  print(f"Loading model on device: {device}")
46
 
47
+ # Load model and processor using the updated parameter `token`
48
  if hf_token:
49
+ PROCESSOR = Wav2Vec2Processor.from_pretrained(model_name, token=hf_token)
50
+ MODEL = Wav2Vec2ForCTC.from_pretrained(model_name, token=hf_token)
51
  else:
52
  PROCESSOR = Wav2Vec2Processor.from_pretrained(model_name)
53
  MODEL = Wav2Vec2ForCTC.from_pretrained(model_name)
 
265
  )
266
  return {"status": "ok", "model_loaded": True}
267
 
268
+ # Use lifespan context manager instead of on_event decorators
269
+ @asynccontextmanager
270
+ async def lifespan(app: FastAPI):
271
  initialize_model()
272
+ yield
273
+
274
+ # Initialize FastAPI app with lifespan handler
275
+ app = FastAPI(
276
+ title="Quran Recitation Comparison API",
277
+ description="API for comparing similarity between Quran recitations using Wav2Vec2 embeddings",
278
+ version="1.0.0",
279
+ lifespan=lifespan
280
+ )
281
+
282
+ # Note: Ensure that all route definitions are declared AFTER the app initialization above.
283
 
284
  # Run the FastAPI app
285
  if __name__ == "__main__":
286
  import uvicorn
287
  port = int(os.environ.get("PORT", 7860)) # Default to port 7860 for Hugging Face Spaces
288
+ uvicorn.run("main:app", host="0.0.0.0", port=port, reload=False)