File size: 3,521 Bytes
058f1d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92b2ebb
 
 
058f1d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from fastapi import FastAPI, BackgroundTasks
from fastapi.responses import JSONResponse, HTMLResponse
from pydantic import BaseModel
import uvicorn
from uvicorn.config import logger
import os
import argparse
from tasks.training import TrainingTask
from config import enable_test_mode

app = FastAPI()

@app.post("/train/start", response_class=JSONResponse)
async def start_model_training(background_tasks: BackgroundTasks):
	"""

	Endpoint on which a request can be sent to start model re-training, 

	if there's no training task currently running.

	The task will be carried out in background and its status can be

	polled via /train/get_state.



	Args:

		background_tasks (BackgroundTasks): BG Tasks scheduler provided by FastAPI



	Returns:

		dict: A dictionary containing a message of the outcome for the request.

	"""
	if not TrainingTask.has_instance():
		background_tasks.add_task(TrainingTask.get_instance())

		return {
			"message": "Model training was scheduled and will begin shortly.",
		}
	
	return {
		"message": "A training instance is already running.",
	}

@app.post("/train/get_state", response_class=JSONResponse)
async def poll_model_training_state():
	"""

	Checks if there is currently a training task ongoing.

	If so, returns whether it's done and/or if an error occurred.

	Otherwise if no instance is running, returns only a message.



	Returns:

		dict: Dictionary containing either done/error or message.

	"""
	if TrainingTask.has_instance():
		train_instance : TrainingTask = TrainingTask.get_instance()
		is_done = train_instance.is_done()
		has_error = train_instance.has_error()

		if is_done:
			TrainingTask.clear_instance()

		return {
			"done": is_done,
			"error": has_error,
		}

	return {
		"message": "No training instance running!",
	}

class InferenceRequest(BaseModel):
	"""

	Provides a model/schema for the accepted request body for incoming

	inference requests.

	"""
	messages: list[str]

@app.post("/inference", response_class=JSONResponse)
async def inference(data: InferenceRequest):
	"""

	Endpoint on which you can send a list of messages that need to be

	classified. A list of predictions will be returned in response,

	containing for each message all of the probabilities for each label.



	Args:

		data (InferenceRequest): Structure containing a list of 

  								 messages that shall be evaluated



	Returns:

		json: A json list containing the sentiment analysis for each message.

			  Each element consists of a dictionary with the following keys:

			  positive, neutral, negative

	"""
	
	from tasks.inference import infer_task
	return infer_task.predict(data.messages)

@app.get("/", response_class=HTMLResponse)
async def root():
	"""

	The root endpoint for our hosted application. Only shows a message

	showing that it's up and running.



	Returns:

		str: A html response containing a hello world-like string

	"""
	return "Hi there! It's a nice blank page, isn't it?"

if __name__ == "__main__":
	"""

	Entrypoint for the application executed via command-line.

	It accepts an optional argument "--test" to enable the test mode.

	"""
	parser = argparse.ArgumentParser()
	parser.add_argument("test", nargs="?", default="no")
	args = parser.parse_args()

	if args.test == "yes":
		enable_test_mode()

	config = uvicorn.Config("main:app", host="0.0.0.0", port=int(os.environ["APP_LISTEN_PORT"]), log_level="debug")
	server = uvicorn.Server(config)
	server.run()