File size: 1,417 Bytes
3ec25da
db2db2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ec25da
b08d3ea
3ec25da
b08d3ea
db2db2a
 
4ca551f
db2db2a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from scripts.s3 import download_model_from_s3
from router.disaster import router as disaster_router
from router.sentiment import router as sentiment_router
from router.image_clf import router as image_router
from utils.logger import logger


app = FastAPI(
    title="ML API",
    description="ML API for sentiment analysis and image classification",
    version="0.0.1",
    openapi_url="/openapi.json"
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


MODEL_PATH = "ml-models/"  
sentiment_model_path = "tinybert-sentiment-analysis/"
disaster_model_path = "tinybert-disaster-tweet/"
image_model_path = "vit-human-pose-classification/"

logger.info("Ensuring models are downloaded...")
download_model_from_s3(MODEL_PATH + sentiment_model_path, sentiment_model_path)
download_model_from_s3(MODEL_PATH + disaster_model_path, disaster_model_path)
download_model_from_s3(MODEL_PATH + image_model_path, image_model_path)
logger.info("All models are ready.")


@app.get("/")
def read_root():
    return {"Status": "Running"}
  

app.include_router(disaster_router, prefix="/api/v1", tags=["Disaster"])
app.include_router(sentiment_router, prefix="/api/v1", tags=["Sentiment"])
app.include_router(image_router, prefix="/api/v1", tags=["Image"])