|
from fastapi import FastAPI, BackgroundTasks
|
|
from fastapi.responses import JSONResponse, HTMLResponse
|
|
from pydantic import BaseModel
|
|
import uvicorn
|
|
from uvicorn.config import logger
|
|
import os
|
|
import argparse
|
|
from tasks.training import TrainingTask
|
|
from config import enable_test_mode
|
|
|
|
app = FastAPI()
|
|
|
|
@app.post("/train/start", response_class=JSONResponse)
|
|
async def start_model_training(background_tasks: BackgroundTasks):
|
|
"""
|
|
Endpoint on which a request can be sent to start model re-training,
|
|
if there's no training task currently running.
|
|
The task will be carried out in background and its status can be
|
|
polled via /train/get_state.
|
|
|
|
Args:
|
|
background_tasks (BackgroundTasks): BG Tasks scheduler provided by FastAPI
|
|
|
|
Returns:
|
|
dict: A dictionary containing a message of the outcome for the request.
|
|
"""
|
|
if not TrainingTask.has_instance():
|
|
background_tasks.add_task(TrainingTask.get_instance())
|
|
|
|
return {
|
|
"message": "Model training was scheduled and will begin shortly.",
|
|
}
|
|
|
|
return {
|
|
"message": "A training instance is already running.",
|
|
}
|
|
|
|
@app.post("/train/get_state", response_class=JSONResponse)
|
|
async def poll_model_training_state():
|
|
"""
|
|
Checks if there is currently a training task ongoing.
|
|
If so, returns whether it's done and/or if an error occurred.
|
|
Otherwise if no instance is running, returns only a message.
|
|
|
|
Returns:
|
|
dict: Dictionary containing either done/error or message.
|
|
"""
|
|
if TrainingTask.has_instance():
|
|
train_instance : TrainingTask = TrainingTask.get_instance()
|
|
is_done = train_instance.is_done()
|
|
has_error = train_instance.has_error()
|
|
|
|
if is_done:
|
|
TrainingTask.clear_instance()
|
|
|
|
return {
|
|
"done": is_done,
|
|
"error": has_error,
|
|
}
|
|
|
|
return {
|
|
"message": "No training instance running!",
|
|
}
|
|
|
|
class InferenceRequest(BaseModel):
|
|
"""
|
|
Provides a model/schema for the accepted request body for incoming
|
|
inference requests.
|
|
"""
|
|
messages: list[str]
|
|
|
|
@app.post("/inference", response_class=JSONResponse)
|
|
async def inference(data: InferenceRequest):
|
|
"""
|
|
Endpoint on which you can send a list of messages that need to be
|
|
classified. A list of predictions will be returned in response,
|
|
containing for each message all of the probabilities for each label.
|
|
|
|
Args:
|
|
data (InferenceRequest): Structure containing a list of
|
|
messages that shall be evaluated
|
|
|
|
Returns:
|
|
json: A json list containing the sentiment analysis for each message.
|
|
Each element consists of a dictionary with the following keys:
|
|
positive, neutral, negative
|
|
"""
|
|
|
|
from tasks.inference import infer_task
|
|
return infer_task.predict(data.messages)
|
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
async def root():
|
|
"""
|
|
The root endpoint for our hosted application. Only shows a message
|
|
showing that it's up and running.
|
|
|
|
Returns:
|
|
str: A html response containing a hello world-like string
|
|
"""
|
|
return "Hi there! It's a nice blank page, isn't it?"
|
|
|
|
if __name__ == "__main__":
|
|
"""
|
|
Entrypoint for the application executed via command-line.
|
|
It accepts an optional argument "--test" to enable the test mode.
|
|
"""
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("test", nargs="?", default="no")
|
|
args = parser.parse_args()
|
|
|
|
if args.test == "yes":
|
|
enable_test_mode()
|
|
|
|
config = uvicorn.Config("main:app", host="0.0.0.0", port=int(os.environ["APP_LISTEN_PORT"]), log_level="debug")
|
|
server = uvicorn.Server(config)
|
|
server.run()
|
|
|