File size: 2,153 Bytes
63c6811
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9abff41
63c6811
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd303e0
63c6811
 
 
 
9abff41
3d97360
 
63c6811
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import streamlit as st
import pandas as pd
import numpy as np
from unidecode import unidecode
import tensorflow as tf 
import cloudpickle
from transformers import AlbertTokenizerFast
import os

def load_model():
    interpreter = tf.lite.Interpreter(model_path=os.path.join("models/albert_sentiment_analysis.tflite"))
    with open("models/sentiment_preprocessor_labelencoder.bin", "rb") as model_file_obj:
        text_preprocessor, label_encoder = cloudpickle.load(model_file_obj)
        
    model_checkpoint = "albert-base-v2"
    tokenizer = AlbertTokenizerFast.from_pretrained(model_checkpoint)
    return interpreter, text_preprocessor, label_encoder, tokenizer

interpreter, text_preprocessor, label_encoder, tokenizer = load_model()


def inference(text):
    tflite_pred = "Can't Predict"
    text = text_preprocessor.preprocess(pd.Series(text))[0]
    if text != "this is an empty message":
        tokens = tokenizer(text, max_length=150, padding="max_length", truncation=True, return_tensors="tf")
        # tflite model inference  
        interpreter.allocate_tensors()
        input_details = interpreter.get_input_details()
        output_details = interpreter.get_output_details()[0]
        attention_mask, input_ids = tokens['attention_mask'], tokens['input_ids']
        interpreter.set_tensor(input_details[0]["index"], attention_mask)
        interpreter.set_tensor(input_details[1]["index"], input_ids)
        interpreter.invoke()
        tflite_pred = interpreter.get_tensor(output_details["index"])[0]
        tflite_pred_argmax = np.argmax(tflite_pred)
        tflite_pred = f"{label_encoder.inverse_transform([tflite_pred_argmax])[0]} ({str(np.round(tflite_pred[tflite_pred_argmax], 5))})"
    return tflite_pred


def main():
    st.title("Sentiment Analysis")
    st.write("This model is trained on Amazon reviews dataset.")
    review = st.text_area("Enter a product review:", "", height=200)
    if st.button("Submit"):
        result = inference(review)
        if result.find("positive") >=0 :
            st.success(f"{result}")
        else:
            st.error(f"{result}")

if __name__ == "__main__":
    main()