fahmiaziz98 commited on
Commit
6f3af4f
Β·
1 Parent(s): e069d9a

first commit

Browse files
Files changed (3) hide show
  1. app.py +87 -3
  2. requirements.txt +3 -0
  3. utils.py +43 -0
app.py CHANGED
@@ -1,8 +1,92 @@
1
- import time
 
2
  import streamlit as st
 
3
  from PIL import Image
 
 
 
 
 
 
 
 
4
 
5
  st.title("Machine Learning Model Deployment")
6
- x = st.slider('Select a value')
7
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
 
1
+ import os
2
+ import torch
3
  import streamlit as st
4
+ from transformers import pipeline, AutoImageProcessor
5
  from PIL import Image
6
+ from utils import download_model_from_s3
7
+
8
+ model_path = {
9
+ "TinyBert Sentiment Analysis": "ml-models/tinybert-sentiment-analysis/",
10
+ "TinyBert Disaster Classification": "ml-models/tinybert-disaster-tweet/",
11
+ "VIT Pose Classification": "ml-models/vit-human-pose-classification/"
12
+ }
13
+
14
 
15
  st.title("Machine Learning Model Deployment")
16
+
17
+ model_choice = st.selectbox(
18
+ "Select Model:",[
19
+ "TinyBert Sentiment Analysis",
20
+ "TinyBert Disaster Classification",
21
+ "VIT Pose Classification"
22
+ ]
23
+ )
24
+
25
+ local_path = model_choice.lower().replace(" ", "-")
26
+ s3_prefix = model_paths[model_choice]
27
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
28
+
29
+ if "downloaded_models" not in st.session_state:
30
+ st.session_state.downloaded_models = set()
31
+
32
+
33
+ if model_choice not in st.session_state.downloaded_models:
34
+ if st.button(f"Download {model_choice}"):
35
+ with st.spinner(f"Downloading {model_choice}... Please wait!"):
36
+ download_model_from_s3(local_path, s3_prefix)
37
+ st.session_state.downloaded_models.add(model_choice)
38
+ st.toast(f"βœ… {model_choice} Succesfuly Download!", icon="πŸŽ‰")
39
+
40
+ # **1. Sentiment Analysis Model**
41
+ if model_choice == "TinyBERT Sentiment Analysis":
42
+ text = st.text_area("Enter Text:", "Your review...")
43
+ predict = st.button("Predict Sentiment")
44
+
45
+ classifier = pipeline("text-classification", model=local_path, device=device)
46
+
47
+ if predict:
48
+ with st.spinner("Predicting..."):
49
+ output = classifier(text)
50
+ st.write(output)
51
+
52
+ # **2. Disaster Classification**
53
+ if model_choice == "TinyBert Disaster Classification":
54
+ text = st.text_area("Enter Text:", "Your Tweet...")
55
+ predict = st.button("Predict Sentiment")
56
+
57
+ classifier = pipeline("text-classification", model=local_path, device=device)
58
+
59
+ if predict:
60
+ with st.spinner("Predicting..."):
61
+ output = classifier(text)
62
+ st.write(output)
63
+
64
+ # **3. Image Classification**
65
+ if model_choice == "VIT Pose Classification":
66
+ uploaded_file = st.file_uploader("Upload Image", type=["jpg", "png", "jpeg"])
67
+ predict = st.button("Predict Image")
68
+
69
+ if uploaded_file is not None:
70
+ image = Image.open(uploaded_file)
71
+ st.image(image, caption="Your Image", use_column_width=True)
72
+
73
+ image_processor = AutoImageProcessor.from_pretrained(local_directory, use_fast=True)
74
+ pipe = pipeline('image-classification', model=local_path, image_processor=image_processor, device=device)
75
+
76
+ if predict:
77
+ with st.spinner("Predicting..."):
78
+ output = pipe(image)
79
+ st.write(output)
80
+
81
+
82
+
83
+
84
+ user_input = st.text_area("Enter text here...")
85
+ predict = st.button("Predict")
86
+ if predict:
87
+ with st.spinner("Downloading Model...")
88
+
89
+
90
+
91
+
92
 
requirements.txt CHANGED
@@ -1 +1,4 @@
1
  streamlit
 
 
 
 
1
  streamlit
2
+ torch
3
+ boto3
4
+ transformers
utils.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import boto3
3
+ from path import Path
4
+
5
+ # credentials aws
6
+ aws_access_key = os.getenv("AWS_ACCESS_KEY_ID")
7
+ aws_secret_key = os.getenv("AWS_SECRET_ACCESS_KEY")
8
+ bucket_name = os.getenv("BUCKET_NAME")
9
+
10
+
11
+ s3 = boto3.client(
12
+ "s3",
13
+ aws_access_key_id=aws_access_key,
14
+ aws_secret_access_key=aws_secret_key
15
+ )
16
+
17
+
18
+ def download_model_from_s3(local_path: Path, s3_prefix: str):
19
+ if os.path.exists(local_path) and os.listdir(local_path):
20
+ st.toast(f"βœ… Model {local_path} Available!", icon="πŸŽ‰")
21
+ return
22
+
23
+ os.makedirs(local_path, exist_ok=True)
24
+ paginator = s3.get_paginator("list_objects_v2")
25
+
26
+ for result in paginator.paginate(Bucket=bucket_name, Prefix=s3_prefix):
27
+ if "Contents" in result:
28
+ for key in result["Contents"]:
29
+ s3_key = key["Key"]
30
+ local_file = os.path.join(local_path, os.path.relpath(s3_key, s3_prefix))
31
+
32
+ os.makedirs(os.path.dirname(local_file), exist_ok=True)
33
+ s3.download_file(bucket_name, s3_key, local_file)
34
+
35
+
36
+
37
+
38
+
39
+
40
+
41
+
42
+
43
+