Spaces:
Sleeping
Sleeping
File size: 3,458 Bytes
6f3af4f eb1008c 6f3af4f e069d9a 6f3af4f f417379 6f3af4f eb1008c e069d9a 6f3af4f 301cd46 6f3af4f 43148dc 6f3af4f 43148dc 6f3af4f 301cd46 6f3af4f 43148dc 6f3af4f 43148dc 6f3af4f fc81503 6f3af4f fc81503 628165b fc81503 628165b c982ca5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import os
import torch
import streamlit as st
from transformers import pipeline, AutoImageProcessor
from PIL import Image
from utils import download_model_from_s3
model_paths = {
"TinyBert Sentiment Analysis": "ml-models/tinybert-sentiment-analysis/",
"TinyBert Disaster Classification": "ml-models/tinybert-disaster-tweet/",
"VIT Pose Classification": "ml-models/vit-human-pose-classification/"
}
st.title("Machine Learning Model Deployment")
model_choice = st.selectbox(
"Select Model:",[
"TinyBert Sentiment Analysis",
"TinyBert Disaster Classification",
"VIT Pose Classification"
]
)
local_path = model_choice.lower().replace(" ", "-")
s3_prefix = model_paths[model_choice]
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
if "downloaded_models" not in st.session_state:
st.session_state.downloaded_models = set()
if model_choice not in st.session_state.downloaded_models:
if st.button(f"Download {model_choice}"):
with st.spinner(f"Downloading {model_choice}... Please wait!"):
download_model_from_s3(local_path, s3_prefix)
st.session_state.downloaded_models.add(model_choice)
st.toast(f"β
{model_choice} Succesfuly Download!", icon="π")
# **1. Sentiment Analysis Model**
if model_choice == "TinyBert Sentiment Analysis":
text = st.text_area("Enter Text:", "This movie was horrible, the plot was really boring. acting was okay")
predict = st.button("Predict Sentiment")
try:
classifier = pipeline("text-classification", model=local_path, device=device)
except OSError:
st.error("β Model not found. Please download the model first.")
classifier = None
if predict and classifier:
with st.spinner("Predicting..."):
output = classifier(text)
st.write(output)
# **2. Disaster Classification**
if model_choice == "TinyBert Disaster Classification":
text = st.text_area("Enter Text:", "There is a fire in the building")
predict = st.button("Predict Sentiment")
try:
classifier = pipeline("text-classification", model=local_path, device=device)
except OSError:
st.error("β Model not found. Please download the model first.")
classifier = None
if predict and classifier:
with st.spinner("Predicting..."):
output = classifier(text)
st.write(output)
# **3. Image Classification**
if model_choice == "VIT Pose Classification":
uploaded_file = st.file_uploader("Upload Image", type=["jpg", "png", "jpeg"])
predict = st.button("Predict Image")
if uploaded_file is not None:
image = Image.open(uploaded_file)
col1, col2 = st.columns([1, 1])
with col1:
st.image(image, caption="Your Image", width=300)
with col2:
try:
image_processor = AutoImageProcessor.from_pretrained(local_path, use_fast=True)
pipe = pipeline('image-classification', model=local_path, image_processor=image_processor, device=device)
except OSError:
st.error("β Model not found. Please download the model first.")
pipe = None
if predict and pipe:
with st.spinner("Predicting..."):
output = pipe(image)
st.write("### Prediction Results:")
st.json(output)
|