rdsarjito commited on
Commit
554b605
Β·
1 Parent(s): b1b9a76
app.py CHANGED
@@ -1,250 +1,85 @@
 
1
  import streamlit as st
2
  import torch
3
  import torch.nn as nn
 
4
  import re
5
- from transformers import AutoTokenizer
6
- import os
7
  import numpy as np
8
 
9
- # Set page config
10
- st.set_page_config(
11
- page_title="Allergen Detection App",
12
- page_icon="🍲",
13
- layout="wide"
14
- )
15
-
16
- # Set device
17
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
-
19
- # Define target columns (allergens)
20
  target_columns = ['susu', 'kacang', 'telur', 'makanan_laut', 'gandum']
21
 
22
- # Clean text function
23
  def clean_text(text):
24
- # Convert dashes to spaces for better tokenization
25
  text = text.replace('--', ' ')
26
- # Basic cleaning
27
  text = re.sub(r"http\S+", "", text)
28
  text = re.sub('\n', ' ', text)
29
  text = re.sub("[^a-zA-Z0-9\s]", " ", text)
30
  text = re.sub(" {2,}", " ", text)
31
- text = text.strip()
32
- text = text.lower()
33
  return text
34
 
35
- # Define model for multilabel classification
 
 
 
 
36
  class MultilabelBertClassifier(nn.Module):
37
  def __init__(self, model_name, num_labels):
38
  super(MultilabelBertClassifier, self).__init__()
39
- # Replace with a simpler initialization for inference only
40
- from transformers import AutoConfig, AutoModel
41
- self.config = AutoConfig.from_pretrained(model_name)
42
- self.bert = AutoModel.from_pretrained(model_name)
43
- self.classifier = nn.Linear(self.config.hidden_size, num_labels)
44
-
45
  def forward(self, input_ids, attention_mask):
46
  outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
47
- pooled_output = outputs.last_hidden_state[:, 0, :] # Use [CLS] token
48
- return self.classifier(pooled_output)
49
 
50
- # Function to remove 'module.' prefix from state dict keys
51
- def remove_module_prefix(state_dict):
52
- new_state_dict = {}
53
- for key, value in state_dict.items():
54
- if key.startswith('module.'):
55
- new_key = key[7:] # Remove 'module.' prefix
56
- else:
57
- new_key = key
58
- new_state_dict[new_key] = value
59
- return new_state_dict
60
 
61
- # Load model function
62
- @st.cache_resource
63
- def load_model():
64
- # Load tokenizer
65
- tokenizer = AutoTokenizer.from_pretrained('indobenchmark/indobert-base-p2')
66
-
67
- # Initialize model
68
- model = MultilabelBertClassifier('indobenchmark/indobert-base-p1', len(target_columns))
69
-
70
- # Check if model exists
71
- model_path = "model/alergen_model.pt"
72
-
73
- if os.path.exists(model_path):
74
- try:
75
- # Load model weights
76
- checkpoint = torch.load(model_path, map_location=device)
77
-
78
- # Check if state_dict is directly in checkpoint or under 'model_state_dict' key
79
- if 'model_state_dict' in checkpoint:
80
- state_dict = checkpoint['model_state_dict']
81
- else:
82
- state_dict = checkpoint
83
-
84
- # Remove 'module.' prefix if it exists
85
- state_dict = remove_module_prefix(state_dict)
86
-
87
- # Load the processed state dict
88
- model.load_state_dict(state_dict)
89
-
90
- model.to(device)
91
- model.eval()
92
- return model, tokenizer
93
- except Exception as e:
94
- st.error(f"Error loading model: {str(e)}")
95
- return None, tokenizer
96
- else:
97
- st.error("Model file not found. Please upload the model file.")
98
- return None, tokenizer
99
 
100
- # Function to predict allergens
101
- def predict_allergens(model, tokenizer, ingredients_text, max_length=128):
102
- if not model:
103
- return {}
104
-
105
- # Clean the text
106
- cleaned_text = clean_text(ingredients_text)
107
-
108
- # Tokenize
109
- encoding = tokenizer.encode_plus(
110
- cleaned_text,
111
  add_special_tokens=True,
112
  max_length=max_length,
113
  truncation=True,
114
  return_tensors='pt',
115
  padding='max_length'
116
  )
117
-
118
- input_ids = encoding['input_ids'].to(device)
119
- attention_mask = encoding['attention_mask'].to(device)
120
-
121
  with torch.no_grad():
122
- outputs = model(input_ids=input_ids, attention_mask=attention_mask)
123
- predictions = torch.sigmoid(outputs)
124
- predictions = (predictions > 0.5).float().cpu().numpy()[0]
125
-
126
- result = {}
127
- for i, target in enumerate(target_columns):
128
- result[target] = bool(predictions[i])
129
-
130
- return result
131
 
132
- # UI components
133
- def main():
134
- st.title("🍲 Allergen Detection in Indonesian Recipes")
135
- st.write("This app predicts common allergens in your recipe based on ingredients.")
136
-
137
- # Create directory for model if it doesn't exist
138
- os.makedirs("model", exist_ok=True)
139
-
140
- # Sidebar for model upload
141
- with st.sidebar:
142
- st.header("Model Settings")
143
- uploaded_model = st.file_uploader("Upload model file (alergen_model.pt)", type=["pt"])
144
-
145
- if uploaded_model:
146
- # Save uploaded model
147
- with open("model/alergen_model.pt", "wb") as f:
148
- f.write(uploaded_model.getbuffer())
149
- st.success("Model uploaded successfully!")
150
-
151
- st.markdown("---")
152
- st.write("Allergen Categories:")
153
- for allergen in target_columns:
154
- if allergen == 'susu':
155
- st.write("- Susu (Milk)")
156
- elif allergen == 'kacang':
157
- st.write("- Kacang (Nuts)")
158
- elif allergen == 'telur':
159
- st.write("- Telur (Eggs)")
160
- elif allergen == 'makanan_laut':
161
- st.write("- Makanan Laut (Seafood)")
162
- elif allergen == 'gandum':
163
- st.write("- Gandum (Wheat/Gluten)")
164
-
165
- # Load model
166
- model, tokenizer = load_model()
167
-
168
- # Input area
169
- st.header("Recipe Ingredients")
170
-
171
- # Example button
172
- if st.button("Load Example"):
173
- example_text = "1 bungkus Lontong homemade 2 butir Telur ayam 2 kotak kecil Tahu coklat 4 butir kecil Kentang 2 buah Tomat merah 1 buah Ketimun lalap 4 lembar Selada keriting 2 lembar Kol putih 2 porsi Saus kacang homemade 4 buah Kerupuk udang goreng Secukupnya emping goreng 2 sdt Bawang goreng Secukupnya Kecap manis (bila suka)"
174
- st.session_state.ingredients = example_text
175
-
176
- # Text input
177
- ingredients_text = st.text_area(
178
- "Enter recipe ingredients (in Indonesian):",
179
- height=150,
180
- key="ingredients"
181
- )
182
-
183
- # Predict button
184
- if st.button("Detect Allergens"):
185
- if ingredients_text.strip() == "":
186
- st.warning("Please enter ingredients first.")
187
- elif model is None:
188
- st.error("Please upload the model file first.")
189
- else:
190
- with st.spinner("Analyzing ingredients..."):
191
- # Make prediction
192
- allergens = predict_allergens(model, tokenizer, ingredients_text)
193
-
194
- # Display results
195
- st.header("Results")
196
-
197
- # Create columns for results
198
- col1, col2 = st.columns(2)
199
-
200
- with col1:
201
- st.subheader("Detected Allergens:")
202
- has_allergens = False
203
- for allergen, present in allergens.items():
204
- if present:
205
- has_allergens = True
206
- if allergen == 'susu':
207
- st.warning("πŸ₯› Susu (Milk)")
208
- elif allergen == 'kacang':
209
- st.warning("πŸ₯œ Kacang (Nuts)")
210
- elif allergen == 'telur':
211
- st.warning("πŸ₯š Telur (Eggs)")
212
- elif allergen == 'makanan_laut':
213
- st.warning("🦐 Makanan Laut (Seafood)")
214
- elif allergen == 'gandum':
215
- st.warning("🌾 Gandum (Wheat/Gluten)")
216
-
217
- if not has_allergens:
218
- st.success("βœ… No allergens detected!")
219
-
220
- with col2:
221
- st.subheader("All Categories:")
222
- for allergen, present in allergens.items():
223
- if allergen == 'susu':
224
- st.write("πŸ₯› Susu (Milk): " + ("Detected ⚠️" if present else "Not detected βœ“"))
225
- elif allergen == 'kacang':
226
- st.write("πŸ₯œ Kacang (Nuts): " + ("Detected ⚠️" if present else "Not detected βœ“"))
227
- elif allergen == 'telur':
228
- st.write("πŸ₯š Telur (Eggs): " + ("Detected ⚠️" if present else "Not detected βœ“"))
229
- elif allergen == 'makanan_laut':
230
- st.write("🦐 Makanan Laut (Seafood): " + ("Detected ⚠️" if present else "Not detected βœ“"))
231
- elif allergen == 'gandum':
232
- st.write("🌾 Gandum (Wheat/Gluten): " + ("Detected ⚠️" if present else "Not detected βœ“"))
233
-
234
- # Show cleaned text
235
- with st.expander("Processed Text"):
236
- st.code(clean_text(ingredients_text))
237
 
238
- # Instructions and information
239
- with st.expander("How to Use"):
240
- st.write("""
241
- 1. First, upload the trained model file (`alergen_model.pt`) using the sidebar uploader
242
- 2. Enter your recipe ingredients in the text box (in Indonesian)
243
- 3. Click the "Detect Allergens" button to analyze the recipe
244
- 4. View the results showing which allergens are present in your recipe
245
-
246
- The model detects five common allergen categories: milk, nuts, eggs, seafood, and wheat/gluten.
247
- """)
248
 
249
- if __name__ == "__main__":
250
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
  import streamlit as st
3
  import torch
4
  import torch.nn as nn
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
  import re
 
 
7
  import numpy as np
8
 
9
+ # Target labels
 
 
 
 
 
 
 
 
 
 
10
  target_columns = ['susu', 'kacang', 'telur', 'makanan_laut', 'gandum']
11
 
12
+ # Clean text
13
  def clean_text(text):
 
14
  text = text.replace('--', ' ')
 
15
  text = re.sub(r"http\S+", "", text)
16
  text = re.sub('\n', ' ', text)
17
  text = re.sub("[^a-zA-Z0-9\s]", " ", text)
18
  text = re.sub(" {2,}", " ", text)
19
+ text = text.strip().lower()
 
20
  return text
21
 
22
+ # Load tokenizer
23
+ tokenizer = AutoTokenizer.from_pretrained("tokenizer_dir")
24
+ max_length = 128
25
+
26
+ # Define model architecture
27
  class MultilabelBertClassifier(nn.Module):
28
  def __init__(self, model_name, num_labels):
29
  super(MultilabelBertClassifier, self).__init__()
30
+ self.bert = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
31
+ self.bert.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
32
+
 
 
 
33
  def forward(self, input_ids, attention_mask):
34
  outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
35
+ return outputs.logits
 
36
 
37
+ # Load model
38
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
+ model = torch.load("model/alergen_model_full.pt", map_location=device)
 
 
 
 
 
 
 
40
 
41
+ # Jika model dibungkus DataParallel, kita ambil model asli
42
+ if hasattr(model, "module"):
43
+ model = model.module
44
+
45
+ model.to(device)
46
+ model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ # Prediction function
49
+ def predict_alergens(text):
50
+ cleaned = clean_text(text)
51
+ inputs = tokenizer.encode_plus(
52
+ cleaned,
 
 
 
 
 
 
53
  add_special_tokens=True,
54
  max_length=max_length,
55
  truncation=True,
56
  return_tensors='pt',
57
  padding='max_length'
58
  )
59
+ input_ids = inputs['input_ids'].to(device)
60
+ attention_mask = inputs['attention_mask'].to(device)
61
+
 
62
  with torch.no_grad():
63
+ logits = model(input_ids=input_ids, attention_mask=attention_mask)
64
+ probs = torch.sigmoid(logits)
65
+ preds = (probs > 0.5).float().cpu().numpy()[0]
 
 
 
 
 
 
66
 
67
+ return {target: bool(preds[i]) for i, target in enumerate(target_columns)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ # Streamlit UI
70
+ st.title("Deteksi Alergen dari Resep Masakan πŸ§ͺ🍲")
 
 
 
 
 
 
 
 
71
 
72
+ recipe_input = st.text_area("Masukkan bahan-bahan resep di sini:", height=200)
73
+
74
+ if st.button("Deteksi Alergen"):
75
+ if recipe_input.strip() == "":
76
+ st.warning("Silakan masukkan teks resep terlebih dahulu.")
77
+ else:
78
+ with st.spinner("Menganalisis..."):
79
+ result = predict_alergens(recipe_input)
80
+ st.subheader("Hasil Prediksi Alergen:")
81
+ for allergen, is_present in result.items():
82
+ if is_present:
83
+ st.error(f"⚠️ {allergen}")
84
+ else:
85
+ st.success(f"βœ… Bebas dari {allergen}")
model/{alergen_model.pt β†’ alergen_model_full.pt} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:28df831b272894c11265ef5f4cf1ac2a2ca89e765b26bff928f34c388ff015d5
3
- size 497868974
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7b5bbb0945b811482c8bb868a13bd655572de100833a50fd516efc0e52b7c17
3
+ size 497911105
requirements.txt CHANGED
@@ -1,5 +1,4 @@
1
- streamlit>=1.25.0
2
- torch>=2.0.0
3
- transformers>=4.30.0
4
- numpy>=1.22.0
5
- protobuf>=3.20.0
 
1
+ streamlit==1.30.0
2
+ torch==2.0.1
3
+ transformers==4.36.2
4
+ numpy==1.25.2
 
tokenizer_dir/special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
tokenizer_dir/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_dir/tokenizer_config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "4": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": true,
45
+ "cls_token": "[CLS]",
46
+ "do_basic_tokenize": true,
47
+ "do_lower_case": true,
48
+ "extra_special_tokens": {},
49
+ "mask_token": "[MASK]",
50
+ "model_max_length": 1000000000000000019884624838656,
51
+ "never_split": null,
52
+ "pad_token": "[PAD]",
53
+ "sep_token": "[SEP]",
54
+ "strip_accents": null,
55
+ "tokenize_chinese_chars": true,
56
+ "tokenizer_class": "BertTokenizer",
57
+ "unk_token": "[UNK]"
58
+ }
tokenizer_dir/vocab.txt ADDED
The diff for this file is too large to render. See raw diff