File size: 3,117 Bytes
fdc1efd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from pathlib import Path

import numpy as np
import torch
from torchvision import transforms

from src.modeling import ASTPretrained, FeatureExtractor, PreprocessPipeline, StudentAST

MODELS_FOLDER = Path(__file__).parent / "models"

CLASSES = ["tru", "sax", "vio", "gac", "org", "cla", "flu", "voi", "gel", "cel", "pia"]


def load_model(model_type: str):
    """
    Loads a pre-trained AST model of the specified type.

    :param model_type: The type of model to load
    :type model_type: str
    :return: The loaded pre-trained AST model.
    :rtype: ASTPretrained
    """

    if model_type == "accuracy":
        model = ASTPretrained(n_classes=11, download_weights=False)
        model.load_state_dict(torch.load(f"{MODELS_FOLDER}/acc_model_ast.pth", map_location=torch.device("cpu")))
    else:
        model = StudentAST(n_classes=11, hidden_size=192, num_heads=3)
        model.load_state_dict(torch.load(f"{MODELS_FOLDER}/speed_model_ast.pth", map_location=torch.device("cpu")))
    model.eval()
    return model


def load_labels():
    """
    Loads a dictionary of class labels for the AST model.

    :return: A dictionary where the keys are the class indices and the values are the class labels.
    :rtype: Dict[int, str]
    """

    labels = {i: CLASSES[i] for i in range(len(CLASSES))}
    return labels


def load_thresholds(model_type: str):
    """
    Loads the prediction thresholds for the AST model.

    :return: The prediction thresholds for each class.
    :rtype: np.ndarray
    """
    if model_type == "accuracy":
        thresholds = np.load(f"{MODELS_FOLDER}/acc_model_thresh.npy", allow_pickle=True)
    else:
        thresholds = np.load(f"{MODELS_FOLDER}/speed_model_thresh.npy", allow_pickle=True)
    return thresholds


class ModelServiceAST:
    def __init__(self, model_type: str):
        """
        Initializes a ModelServiceAST instance with the specified model type.

        :param model_type: The type of model to load
        :type model_type: str
        """

        self.model = load_model(model_type)
        self.labels = load_labels()
        self.thresholds = load_thresholds(model_type)
        self.transform = transforms.Compose([PreprocessPipeline(target_sr=16000), FeatureExtractor(sr=16000)])

    def get_prediction(self, audio):
        """
        Gets the binary predictions for the given audio file.

        :param audio_file: The file object for the input audio to make predictions for.
        :type audio_file: file object
        :return: A dictionary where the keys are the class labels and the values are binary predictions (0 or 1).
        :rtype: Dict[str, int]
        """
        processed = self.transform(audio)
        with torch.no_grad():
            # Don't forget to transpose the output to seq_len x num_features!!!
            output = torch.sigmoid(self.model(processed.mT))
            output = output.squeeze().numpy().astype(float)

        binary_predictions = {}
        for i, label in enumerate(CLASSES):
            binary_predictions[label] = int(output[i] >= self.thresholds[i])

        return binary_predictions