File size: 9,641 Bytes
552cd20
 
 
c0cfde6
9de5935
314c91a
9de5935
314c91a
9de5935
314c91a
9de5935
 
314c91a
 
 
 
552cd20
 
9de5935
314c91a
552cd20
314c91a
e88e274
314c91a
e88e274
314c91a
e88e274
 
 
 
 
 
 
552cd20
9de5935
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1b9a76
 
 
 
 
 
 
 
 
 
 
9de5935
314c91a
 
9de5935
 
 
 
 
 
 
 
 
 
b1b9a76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9de5935
 
 
314c91a
 
 
9de5935
314c91a
 
 
 
 
 
552cd20
314c91a
552cd20
314c91a
 
552cd20
314c91a
552cd20
314c91a
552cd20
 
314c91a
e88e274
314c91a
 
9de5935
314c91a
9de5935
314c91a
9de5935
314c91a
 
c0cfde6
9de5935
314c91a
9de5935
 
 
b1b9a76
 
 
9de5935
 
 
b1b9a76
9de5935
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314c91a
 
 
9de5935
314c91a
9de5935
 
 
 
314c91a
9de5935
 
 
 
 
 
314c91a
9de5935
 
 
 
 
 
 
 
 
 
 
 
 
314c91a
9de5935
314c91a
 
 
9de5935
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314c91a
9de5935
 
314c91a
 
9de5935
 
 
 
 
 
 
 
 
 
 
 
314c91a
9de5935
 
 
 
 
 
314c91a
b1b9a76
9de5935
 
 
314c91a
9de5935
314c91a
c0cfde6
314c91a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import streamlit as st
import torch
import torch.nn as nn
import re
from transformers import AutoTokenizer
import os
import numpy as np

# Set page config
st.set_page_config(
    page_title="Allergen Detection App",
    page_icon="🍲",
    layout="wide"
)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define target columns (allergens)
target_columns = ['susu', 'kacang', 'telur', 'makanan_laut', 'gandum']

# Clean text function
def clean_text(text):
    # Convert dashes to spaces for better tokenization
    text = text.replace('--', ' ')
    # Basic cleaning
    text = re.sub(r"http\S+", "", text)
    text = re.sub('\n', ' ', text)
    text = re.sub("[^a-zA-Z0-9\s]", " ", text)
    text = re.sub(" {2,}", " ", text)
    text = text.strip()
    text = text.lower()
    return text

# Define model for multilabel classification
class MultilabelBertClassifier(nn.Module):
    def __init__(self, model_name, num_labels):
        super(MultilabelBertClassifier, self).__init__()
        # Replace with a simpler initialization for inference only
        from transformers import AutoConfig, AutoModel
        self.config = AutoConfig.from_pretrained(model_name)
        self.bert = AutoModel.from_pretrained(model_name)
        self.classifier = nn.Linear(self.config.hidden_size, num_labels)
        
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0, :]  # Use [CLS] token
        return self.classifier(pooled_output)

# Function to remove 'module.' prefix from state dict keys
def remove_module_prefix(state_dict):
    new_state_dict = {}
    for key, value in state_dict.items():
        if key.startswith('module.'):
            new_key = key[7:]  # Remove 'module.' prefix
        else:
            new_key = key
        new_state_dict[new_key] = value
    return new_state_dict

# Load model function
@st.cache_resource
def load_model():
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained('indobenchmark/indobert-base-p2')
    
    # Initialize model
    model = MultilabelBertClassifier('indobenchmark/indobert-base-p1', len(target_columns))
    
    # Check if model exists
    model_path = "model/alergen_model.pt"
    
    if os.path.exists(model_path):
        try:
            # Load model weights
            checkpoint = torch.load(model_path, map_location=device)
            
            # Check if state_dict is directly in checkpoint or under 'model_state_dict' key
            if 'model_state_dict' in checkpoint:
                state_dict = checkpoint['model_state_dict']
            else:
                state_dict = checkpoint
                
            # Remove 'module.' prefix if it exists
            state_dict = remove_module_prefix(state_dict)
            
            # Load the processed state dict
            model.load_state_dict(state_dict)
            
            model.to(device)
            model.eval()
            return model, tokenizer
        except Exception as e:
            st.error(f"Error loading model: {str(e)}")
            return None, tokenizer
    else:
        st.error("Model file not found. Please upload the model file.")
        return None, tokenizer

# Function to predict allergens
def predict_allergens(model, tokenizer, ingredients_text, max_length=128):
    if not model:
        return {}
    
    # Clean the text
    cleaned_text = clean_text(ingredients_text)
    
    # Tokenize
    encoding = tokenizer.encode_plus(
        cleaned_text,
        add_special_tokens=True,
        max_length=max_length,
        truncation=True,
        return_tensors='pt',
        padding='max_length'
    )
    
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        predictions = torch.sigmoid(outputs)
        predictions = (predictions > 0.5).float().cpu().numpy()[0]
    
    result = {}
    for i, target in enumerate(target_columns):
        result[target] = bool(predictions[i])
    
    return result

# UI components
def main():
    st.title("🍲 Allergen Detection in Indonesian Recipes")
    st.write("This app predicts common allergens in your recipe based on ingredients.")
    
    # Create directory for model if it doesn't exist
    os.makedirs("model", exist_ok=True)
    
    # Sidebar for model upload
    with st.sidebar:
        st.header("Model Settings")
        uploaded_model = st.file_uploader("Upload model file (alergen_model.pt)", type=["pt"])
        
        if uploaded_model:
            # Save uploaded model
            with open("model/alergen_model.pt", "wb") as f:
                f.write(uploaded_model.getbuffer())
            st.success("Model uploaded successfully!")
        
        st.markdown("---")
        st.write("Allergen Categories:")
        for allergen in target_columns:
            if allergen == 'susu':
                st.write("- Susu (Milk)")
            elif allergen == 'kacang':
                st.write("- Kacang (Nuts)")
            elif allergen == 'telur':
                st.write("- Telur (Eggs)")
            elif allergen == 'makanan_laut':
                st.write("- Makanan Laut (Seafood)")
            elif allergen == 'gandum':
                st.write("- Gandum (Wheat/Gluten)")
    
    # Load model
    model, tokenizer = load_model()
    
    # Input area
    st.header("Recipe Ingredients")
    
    # Example button
    if st.button("Load Example"):
        example_text = "1 bungkus Lontong homemade 2 butir Telur ayam 2 kotak kecil Tahu coklat 4 butir kecil Kentang 2 buah Tomat merah 1 buah Ketimun lalap 4 lembar Selada keriting 2 lembar Kol putih 2 porsi Saus kacang homemade 4 buah Kerupuk udang goreng Secukupnya emping goreng 2 sdt Bawang goreng Secukupnya Kecap manis (bila suka)"
        st.session_state.ingredients = example_text
    
    # Text input
    ingredients_text = st.text_area(
        "Enter recipe ingredients (in Indonesian):",
        height=150,
        key="ingredients"
    )
    
    # Predict button
    if st.button("Detect Allergens"):
        if ingredients_text.strip() == "":
            st.warning("Please enter ingredients first.")
        elif model is None:
            st.error("Please upload the model file first.")
        else:
            with st.spinner("Analyzing ingredients..."):
                # Make prediction
                allergens = predict_allergens(model, tokenizer, ingredients_text)
                
                # Display results
                st.header("Results")
                
                # Create columns for results
                col1, col2 = st.columns(2)
                
                with col1:
                    st.subheader("Detected Allergens:")
                    has_allergens = False
                    for allergen, present in allergens.items():
                        if present:
                            has_allergens = True
                            if allergen == 'susu':
                                st.warning("πŸ₯› Susu (Milk)")
                            elif allergen == 'kacang':
                                st.warning("πŸ₯œ Kacang (Nuts)")
                            elif allergen == 'telur':
                                st.warning("πŸ₯š Telur (Eggs)")
                            elif allergen == 'makanan_laut':
                                st.warning("🦐 Makanan Laut (Seafood)")
                            elif allergen == 'gandum':
                                st.warning("🌾 Gandum (Wheat/Gluten)")
                    
                    if not has_allergens:
                        st.success("βœ… No allergens detected!")
                
                with col2:
                    st.subheader("All Categories:")
                    for allergen, present in allergens.items():
                        if allergen == 'susu':
                            st.write("πŸ₯› Susu (Milk): " + ("Detected ⚠️" if present else "Not detected βœ“"))
                        elif allergen == 'kacang':
                            st.write("πŸ₯œ Kacang (Nuts): " + ("Detected ⚠️" if present else "Not detected βœ“"))
                        elif allergen == 'telur':
                            st.write("πŸ₯š Telur (Eggs): " + ("Detected ⚠️" if present else "Not detected βœ“"))
                        elif allergen == 'makanan_laut':
                            st.write("🦐 Makanan Laut (Seafood): " + ("Detected ⚠️" if present else "Not detected βœ“"))
                        elif allergen == 'gandum':
                            st.write("🌾 Gandum (Wheat/Gluten): " + ("Detected ⚠️" if present else "Not detected βœ“"))
                
                # Show cleaned text
                with st.expander("Processed Text"):
                    st.code(clean_text(ingredients_text))

    # Instructions and information
    with st.expander("How to Use"):
        st.write("""
        1. First, upload the trained model file (`alergen_model.pt`) using the sidebar uploader
        2. Enter your recipe ingredients in the text box (in Indonesian)
        3. Click the "Detect Allergens" button to analyze the recipe
        4. View the results showing which allergens are present in your recipe
        
        The model detects five common allergen categories: milk, nuts, eggs, seafood, and wheat/gluten.
        """)

if __name__ == "__main__":
    main()