Spaces:
Running
Running
rdsarjito
commited on
Commit
Β·
e88e274
1
Parent(s):
c0cfde6
3 commit
Browse files- app.py +47 -218
- requirements.txt +9 -5
app.py
CHANGED
@@ -1,249 +1,78 @@
|
|
1 |
import streamlit as st
|
2 |
import torch
|
3 |
import torch.nn as nn
|
|
|
|
|
4 |
import re
|
5 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
6 |
-
import os
|
7 |
-
import numpy as np
|
8 |
-
|
9 |
-
# Set page config
|
10 |
-
st.set_page_config(
|
11 |
-
page_title="Deteksi Alergen Resep",
|
12 |
-
page_icon="π½οΈ",
|
13 |
-
layout="wide"
|
14 |
-
)
|
15 |
|
16 |
-
#
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
""")
|
22 |
-
|
23 |
-
# Set device
|
24 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
25 |
|
26 |
-
|
27 |
-
target_columns = ['susu', 'kacang', 'telur', 'makanan_laut', 'gandum']
|
28 |
-
allergen_descriptions = {
|
29 |
-
'susu': 'Produk susu (milk products)',
|
30 |
-
'kacang': 'Kacang-kacangan (nuts)',
|
31 |
-
'telur': 'Telur (eggs)',
|
32 |
-
'makanan_laut': 'Makanan laut (seafood)',
|
33 |
-
'gandum': 'Gandum/gluten (wheat/gluten)'
|
34 |
-
}
|
35 |
-
|
36 |
-
# Clean text function
|
37 |
-
@st.cache_data
|
38 |
-
def clean_text(text):
|
39 |
-
# Convert dashes to spaces for better tokenization
|
40 |
-
text = text.replace('--', ' ')
|
41 |
-
# Basic cleaning
|
42 |
-
text = re.sub(r"http\S+", "", text)
|
43 |
-
text = re.sub('\n', ' ', text)
|
44 |
-
text = re.sub("[^a-zA-Z0-9\s]", " ", text)
|
45 |
-
text = re.sub(" {2,}", " ", text)
|
46 |
-
text = text.strip()
|
47 |
-
text = text.lower()
|
48 |
-
return text
|
49 |
|
50 |
-
# Define model for multilabel classification
|
51 |
class MultilabelBertClassifier(nn.Module):
|
52 |
def __init__(self, model_name, num_labels):
|
53 |
super(MultilabelBertClassifier, self).__init__()
|
54 |
self.bert = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
|
55 |
-
# Replace the classification head with our own for multilabel
|
56 |
self.bert.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
|
57 |
|
58 |
def forward(self, input_ids, attention_mask):
|
59 |
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
60 |
return outputs.logits
|
61 |
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
model.load_state_dict(checkpoint['model_state_dict'])
|
78 |
-
else:
|
79 |
-
st.warning("Model file not found. Please upload your model file.")
|
80 |
-
|
81 |
-
model.to(device)
|
82 |
-
model.eval()
|
83 |
-
|
84 |
-
return model, tokenizer
|
85 |
-
except Exception as e:
|
86 |
-
st.error(f"Error loading model: {e}")
|
87 |
-
return None, None
|
88 |
|
89 |
-
#
|
90 |
-
def
|
91 |
-
|
92 |
-
return None
|
93 |
-
|
94 |
-
# Clean the text
|
95 |
-
cleaned_text = clean_text(ingredients_text)
|
96 |
-
|
97 |
-
# Tokenize
|
98 |
encoding = tokenizer.encode_plus(
|
99 |
-
|
100 |
add_special_tokens=True,
|
101 |
-
max_length=
|
102 |
-
truncation=True,
|
103 |
return_tensors='pt',
|
104 |
-
padding='max_length'
|
|
|
105 |
)
|
106 |
-
|
107 |
input_ids = encoding['input_ids'].to(device)
|
108 |
attention_mask = encoding['attention_mask'].to(device)
|
109 |
-
|
110 |
-
with torch.no_grad():
|
111 |
-
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
112 |
-
predictions = torch.sigmoid(outputs)
|
113 |
-
predictions_np = predictions.cpu().numpy()[0]
|
114 |
-
binary_predictions = (predictions > 0.5).float().cpu().numpy()[0]
|
115 |
-
|
116 |
-
result = {}
|
117 |
-
confidence = {}
|
118 |
-
for i, target in enumerate(target_columns):
|
119 |
-
result[target] = bool(binary_predictions[i])
|
120 |
-
confidence[target] = float(predictions_np[i])
|
121 |
-
|
122 |
-
return result, confidence
|
123 |
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
if uploaded_model is not None:
|
130 |
-
with open("alergen_model.pt", "wb") as f:
|
131 |
-
f.write(uploaded_model.getbuffer())
|
132 |
-
st.success("Model uploaded successfully!")
|
133 |
-
st.cache_resource.clear()
|
134 |
-
|
135 |
-
st.markdown("---")
|
136 |
-
st.markdown("### Tentang Aplikasi")
|
137 |
-
st.markdown("""
|
138 |
-
Aplikasi ini menggunakan model deep learning berbasis IndoBERT untuk mendeteksi
|
139 |
-
potensi alergen dalam resep makanan. Model dilatih untuk mendeteksi lima jenis alergen
|
140 |
-
umum dalam makanan.
|
141 |
-
""")
|
142 |
-
|
143 |
-
# Load model and tokenizer
|
144 |
-
model, tokenizer = load_model_and_tokenizer()
|
145 |
|
146 |
-
#
|
147 |
-
st.
|
|
|
148 |
|
149 |
-
|
150 |
-
ingredients = st.text_area(
|
151 |
-
"Daftar Bahan (satu per baris atau dengan format yang umum digunakan)",
|
152 |
-
height=150,
|
153 |
-
placeholder="Contoh:\n1 bungkus Lontong homemade\n2 butir Telur ayam\n2 kotak kecil Tahu coklat\n4 butir kecil Kentang\n..."
|
154 |
-
)
|
155 |
|
156 |
-
|
157 |
-
if
|
158 |
-
|
159 |
-
st.warning("Silakan masukkan daftar bahan terlebih dahulu.")
|
160 |
-
elif not model:
|
161 |
-
st.error("Model belum tersedia. Silakan upload model terlebih dahulu.")
|
162 |
else:
|
163 |
-
with st.spinner("
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
st.
|
168 |
-
|
169 |
-
# Display detected allergens
|
170 |
-
detected_allergens = [allergen for allergen, present in results.items() if present]
|
171 |
-
|
172 |
-
if detected_allergens:
|
173 |
-
st.markdown("### β οΈ Alergen Terdeteksi:")
|
174 |
-
|
175 |
-
# Create columns for the allergen cards
|
176 |
-
cols = st.columns(len(detected_allergens) if len(detected_allergens) < 3 else 3)
|
177 |
-
|
178 |
-
for i, allergen in enumerate(detected_allergens):
|
179 |
-
col_idx = i % 3
|
180 |
-
with cols[col_idx]:
|
181 |
-
st.markdown(f"""
|
182 |
-
<div style="padding: 10px; border-radius: 5px; background-color: #ffeeee; margin-bottom: 10px;">
|
183 |
-
<h4 style="color: #cc0000;">{allergen_descriptions[allergen]}</h4>
|
184 |
-
<p>Tingkat kepercayaan: {confidence[allergen]*100:.1f}%</p>
|
185 |
-
</div>
|
186 |
-
""", unsafe_allow_html=True)
|
187 |
-
else:
|
188 |
-
st.success("β
Tidak ada alergen yang terdeteksi dalam resep ini.")
|
189 |
-
|
190 |
-
# Display detailed analysis
|
191 |
-
with st.expander("Lihat Analisis Detail"):
|
192 |
-
st.markdown("### Tingkat Kepercayaan Per Alergen")
|
193 |
-
for allergen in target_columns:
|
194 |
-
conf_value = confidence[allergen]
|
195 |
-
st.markdown(f"**{allergen_descriptions[allergen]}:** {conf_value*100:.1f}%")
|
196 |
-
st.progress(conf_value)
|
197 |
-
else:
|
198 |
-
st.error("Terjadi kesalahan dalam prediksi. Silakan coba lagi.")
|
199 |
-
|
200 |
-
# Example recipe section
|
201 |
-
with st.expander("Lihat Contoh Resep"):
|
202 |
-
st.markdown("""
|
203 |
-
**Gado-gado:**
|
204 |
-
1 bungkus Lontong homemade
|
205 |
-
2 butir Telur ayam
|
206 |
-
2 kotak kecil Tahu coklat
|
207 |
-
4 butir kecil Kentang
|
208 |
-
2 buah Tomat merah
|
209 |
-
1 buah Ketimun lalap
|
210 |
-
4 lembar Selada keriting
|
211 |
-
2 lembar Kol putih
|
212 |
-
2 porsi Saus kacang homemade
|
213 |
-
4 buah Kerupuk udang goreng
|
214 |
-
Secukupnya emping goreng
|
215 |
-
2 sdt Bawang goreng
|
216 |
-
Secukupnya Kecap manis
|
217 |
-
""")
|
218 |
-
|
219 |
-
if st.button("Gunakan Contoh Ini"):
|
220 |
-
st.session_state.example_used = True
|
221 |
-
# Will be processed in next rerun
|
222 |
-
|
223 |
-
# Handle example
|
224 |
-
if 'example_used' in st.session_state and st.session_state.example_used:
|
225 |
-
example_recipe = """1 bungkus Lontong homemade
|
226 |
-
2 butir Telur ayam
|
227 |
-
2 kotak kecil Tahu coklat
|
228 |
-
4 butir kecil Kentang
|
229 |
-
2 buah Tomat merah
|
230 |
-
1 buah Ketimun lalap
|
231 |
-
4 lembar Selada keriting
|
232 |
-
2 lembar Kol putih
|
233 |
-
2 porsi Saus kacang homemade
|
234 |
-
4 buah Kerupuk udang goreng
|
235 |
-
Secukupnya emping goreng
|
236 |
-
2 sdt Bawang goreng
|
237 |
-
Secukupnya Kecap manis"""
|
238 |
-
|
239 |
-
st.session_state.example_used = False
|
240 |
-
st.text_area(
|
241 |
-
"Daftar Bahan (satu per baris atau dengan format yang umum digunakan)",
|
242 |
-
value=example_recipe,
|
243 |
-
height=150,
|
244 |
-
key="ingredients_example"
|
245 |
-
)
|
246 |
-
|
247 |
-
# Footer
|
248 |
-
st.markdown("---")
|
249 |
-
st.markdown("*Aplikasi ini hanya untuk tujuan informasi. Silakan konsultasikan dengan ahli gizi untuk konfirmasi alergen dalam makanan.*")
|
|
|
1 |
import streamlit as st
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
import re
|
7 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
+
# Load tokenizer dan model
|
10 |
+
MODEL_PATH = 'model/alergen_model.pt'
|
11 |
+
MODEL_NAME = 'indobenchmark/indobert-base-p1'
|
12 |
+
TARGET_COLUMNS = ['susu', 'kacang', 'telur', 'makanan_laut', 'gandum']
|
13 |
+
MAX_LEN = 128
|
|
|
|
|
|
|
14 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
|
16 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
|
|
18 |
class MultilabelBertClassifier(nn.Module):
|
19 |
def __init__(self, model_name, num_labels):
|
20 |
super(MultilabelBertClassifier, self).__init__()
|
21 |
self.bert = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
|
|
|
22 |
self.bert.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
|
23 |
|
24 |
def forward(self, input_ids, attention_mask):
|
25 |
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
26 |
return outputs.logits
|
27 |
|
28 |
+
model = MultilabelBertClassifier(MODEL_NAME, len(TARGET_COLUMNS))
|
29 |
+
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
|
30 |
+
model.to(device)
|
31 |
+
model.eval()
|
32 |
+
|
33 |
+
# Fungsi preprocessing
|
34 |
+
def clean_text(text):
|
35 |
+
text = text.replace('--', ' ')
|
36 |
+
text = re.sub(r"http\S+", "", text)
|
37 |
+
text = re.sub('\n', ' ', text)
|
38 |
+
text = re.sub("[^a-zA-Z0-9\s]", " ", text)
|
39 |
+
text = re.sub(" {2,}", " ", text)
|
40 |
+
text = text.strip()
|
41 |
+
text = text.lower()
|
42 |
+
return text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
+
# Fungsi prediksi
|
45 |
+
def predict(text):
|
46 |
+
cleaned = clean_text(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
encoding = tokenizer.encode_plus(
|
48 |
+
cleaned,
|
49 |
add_special_tokens=True,
|
50 |
+
max_length=MAX_LEN,
|
|
|
51 |
return_tensors='pt',
|
52 |
+
padding='max_length',
|
53 |
+
truncation=True
|
54 |
)
|
|
|
55 |
input_ids = encoding['input_ids'].to(device)
|
56 |
attention_mask = encoding['attention_mask'].to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
+
with torch.no_grad():
|
59 |
+
logits = model(input_ids=input_ids, attention_mask=attention_mask)
|
60 |
+
probs = torch.sigmoid(logits).cpu().numpy().flatten()
|
61 |
+
results = {TARGET_COLUMNS[i]: float(probs[i]) for i in range(len(TARGET_COLUMNS))}
|
62 |
+
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
+
# STREAMLIT UI
|
65 |
+
st.title("π Deteksi Alergen dari Bahan Makanan")
|
66 |
+
st.markdown("Masukkan daftar bahan makanan, dan sistem akan memprediksi kemungkinan alergen.")
|
67 |
|
68 |
+
user_input = st.text_area("π§Ύ Bahan makanan (contoh: 2 butir telur, 1 gelas susu, kacang tanah...)")
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
+
if st.button("Prediksi Alergen"):
|
71 |
+
if user_input.strip() == "":
|
72 |
+
st.warning("Silakan masukkan bahan makanan terlebih dahulu.")
|
|
|
|
|
|
|
73 |
else:
|
74 |
+
with st.spinner("Memproses..."):
|
75 |
+
predictions = predict(user_input)
|
76 |
+
st.subheader("π Hasil Prediksi:")
|
77 |
+
for allergen, score in predictions.items():
|
78 |
+
st.write(f"- **{allergen}**: {'β
Terdeteksi' if score > 0.5 else 'β Tidak terdeteksi'} (Probabilitas: {score:.2f})")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -1,5 +1,9 @@
|
|
1 |
-
streamlit
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
pandas
|
3 |
+
numpy
|
4 |
+
torch
|
5 |
+
transformers
|
6 |
+
scikit-learn
|
7 |
+
tqdm
|
8 |
+
matplotlib
|
9 |
+
sentencepiece
|