Spaces:
Running
Running
File size: 9,641 Bytes
552cd20 c0cfde6 9de5935 314c91a 9de5935 314c91a 9de5935 314c91a 9de5935 314c91a 552cd20 9de5935 314c91a 552cd20 314c91a e88e274 314c91a e88e274 314c91a e88e274 552cd20 9de5935 b1b9a76 9de5935 314c91a 9de5935 b1b9a76 9de5935 314c91a 9de5935 314c91a 552cd20 314c91a 552cd20 314c91a 552cd20 314c91a 552cd20 314c91a 552cd20 314c91a e88e274 314c91a 9de5935 314c91a 9de5935 314c91a 9de5935 314c91a c0cfde6 9de5935 314c91a 9de5935 b1b9a76 9de5935 b1b9a76 9de5935 314c91a 9de5935 314c91a 9de5935 314c91a 9de5935 314c91a 9de5935 314c91a 9de5935 314c91a 9de5935 314c91a 9de5935 314c91a 9de5935 314c91a 9de5935 314c91a b1b9a76 9de5935 314c91a 9de5935 314c91a c0cfde6 314c91a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 |
import streamlit as st
import torch
import torch.nn as nn
import re
from transformers import AutoTokenizer
import os
import numpy as np
# Set page config
st.set_page_config(
page_title="Allergen Detection App",
page_icon="π²",
layout="wide"
)
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define target columns (allergens)
target_columns = ['susu', 'kacang', 'telur', 'makanan_laut', 'gandum']
# Clean text function
def clean_text(text):
# Convert dashes to spaces for better tokenization
text = text.replace('--', ' ')
# Basic cleaning
text = re.sub(r"http\S+", "", text)
text = re.sub('\n', ' ', text)
text = re.sub("[^a-zA-Z0-9\s]", " ", text)
text = re.sub(" {2,}", " ", text)
text = text.strip()
text = text.lower()
return text
# Define model for multilabel classification
class MultilabelBertClassifier(nn.Module):
def __init__(self, model_name, num_labels):
super(MultilabelBertClassifier, self).__init__()
# Replace with a simpler initialization for inference only
from transformers import AutoConfig, AutoModel
self.config = AutoConfig.from_pretrained(model_name)
self.bert = AutoModel.from_pretrained(model_name)
self.classifier = nn.Linear(self.config.hidden_size, num_labels)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.last_hidden_state[:, 0, :] # Use [CLS] token
return self.classifier(pooled_output)
# Function to remove 'module.' prefix from state dict keys
def remove_module_prefix(state_dict):
new_state_dict = {}
for key, value in state_dict.items():
if key.startswith('module.'):
new_key = key[7:] # Remove 'module.' prefix
else:
new_key = key
new_state_dict[new_key] = value
return new_state_dict
# Load model function
@st.cache_resource
def load_model():
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained('indobenchmark/indobert-base-p2')
# Initialize model
model = MultilabelBertClassifier('indobenchmark/indobert-base-p1', len(target_columns))
# Check if model exists
model_path = "model/alergen_model.pt"
if os.path.exists(model_path):
try:
# Load model weights
checkpoint = torch.load(model_path, map_location=device)
# Check if state_dict is directly in checkpoint or under 'model_state_dict' key
if 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
else:
state_dict = checkpoint
# Remove 'module.' prefix if it exists
state_dict = remove_module_prefix(state_dict)
# Load the processed state dict
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model, tokenizer
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return None, tokenizer
else:
st.error("Model file not found. Please upload the model file.")
return None, tokenizer
# Function to predict allergens
def predict_allergens(model, tokenizer, ingredients_text, max_length=128):
if not model:
return {}
# Clean the text
cleaned_text = clean_text(ingredients_text)
# Tokenize
encoding = tokenizer.encode_plus(
cleaned_text,
add_special_tokens=True,
max_length=max_length,
truncation=True,
return_tensors='pt',
padding='max_length'
)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
predictions = torch.sigmoid(outputs)
predictions = (predictions > 0.5).float().cpu().numpy()[0]
result = {}
for i, target in enumerate(target_columns):
result[target] = bool(predictions[i])
return result
# UI components
def main():
st.title("π² Allergen Detection in Indonesian Recipes")
st.write("This app predicts common allergens in your recipe based on ingredients.")
# Create directory for model if it doesn't exist
os.makedirs("model", exist_ok=True)
# Sidebar for model upload
with st.sidebar:
st.header("Model Settings")
uploaded_model = st.file_uploader("Upload model file (alergen_model.pt)", type=["pt"])
if uploaded_model:
# Save uploaded model
with open("model/alergen_model.pt", "wb") as f:
f.write(uploaded_model.getbuffer())
st.success("Model uploaded successfully!")
st.markdown("---")
st.write("Allergen Categories:")
for allergen in target_columns:
if allergen == 'susu':
st.write("- Susu (Milk)")
elif allergen == 'kacang':
st.write("- Kacang (Nuts)")
elif allergen == 'telur':
st.write("- Telur (Eggs)")
elif allergen == 'makanan_laut':
st.write("- Makanan Laut (Seafood)")
elif allergen == 'gandum':
st.write("- Gandum (Wheat/Gluten)")
# Load model
model, tokenizer = load_model()
# Input area
st.header("Recipe Ingredients")
# Example button
if st.button("Load Example"):
example_text = "1 bungkus Lontong homemade 2 butir Telur ayam 2 kotak kecil Tahu coklat 4 butir kecil Kentang 2 buah Tomat merah 1 buah Ketimun lalap 4 lembar Selada keriting 2 lembar Kol putih 2 porsi Saus kacang homemade 4 buah Kerupuk udang goreng Secukupnya emping goreng 2 sdt Bawang goreng Secukupnya Kecap manis (bila suka)"
st.session_state.ingredients = example_text
# Text input
ingredients_text = st.text_area(
"Enter recipe ingredients (in Indonesian):",
height=150,
key="ingredients"
)
# Predict button
if st.button("Detect Allergens"):
if ingredients_text.strip() == "":
st.warning("Please enter ingredients first.")
elif model is None:
st.error("Please upload the model file first.")
else:
with st.spinner("Analyzing ingredients..."):
# Make prediction
allergens = predict_allergens(model, tokenizer, ingredients_text)
# Display results
st.header("Results")
# Create columns for results
col1, col2 = st.columns(2)
with col1:
st.subheader("Detected Allergens:")
has_allergens = False
for allergen, present in allergens.items():
if present:
has_allergens = True
if allergen == 'susu':
st.warning("π₯ Susu (Milk)")
elif allergen == 'kacang':
st.warning("π₯ Kacang (Nuts)")
elif allergen == 'telur':
st.warning("π₯ Telur (Eggs)")
elif allergen == 'makanan_laut':
st.warning("π¦ Makanan Laut (Seafood)")
elif allergen == 'gandum':
st.warning("πΎ Gandum (Wheat/Gluten)")
if not has_allergens:
st.success("β
No allergens detected!")
with col2:
st.subheader("All Categories:")
for allergen, present in allergens.items():
if allergen == 'susu':
st.write("π₯ Susu (Milk): " + ("Detected β οΈ" if present else "Not detected β"))
elif allergen == 'kacang':
st.write("π₯ Kacang (Nuts): " + ("Detected β οΈ" if present else "Not detected β"))
elif allergen == 'telur':
st.write("π₯ Telur (Eggs): " + ("Detected β οΈ" if present else "Not detected β"))
elif allergen == 'makanan_laut':
st.write("π¦ Makanan Laut (Seafood): " + ("Detected β οΈ" if present else "Not detected β"))
elif allergen == 'gandum':
st.write("πΎ Gandum (Wheat/Gluten): " + ("Detected β οΈ" if present else "Not detected β"))
# Show cleaned text
with st.expander("Processed Text"):
st.code(clean_text(ingredients_text))
# Instructions and information
with st.expander("How to Use"):
st.write("""
1. First, upload the trained model file (`alergen_model.pt`) using the sidebar uploader
2. Enter your recipe ingredients in the text box (in Indonesian)
3. Click the "Detect Allergens" button to analyze the recipe
4. View the results showing which allergens are present in your recipe
The model detects five common allergen categories: milk, nuts, eggs, seafood, and wheat/gluten.
""")
if __name__ == "__main__":
main() |