rdsarjito
6 commit
b1b9a76
raw
history blame
9.64 kB
import streamlit as st
import torch
import torch.nn as nn
import re
from transformers import AutoTokenizer
import os
import numpy as np
# Set page config
st.set_page_config(
page_title="Allergen Detection App",
page_icon="🍲",
layout="wide"
)
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define target columns (allergens)
target_columns = ['susu', 'kacang', 'telur', 'makanan_laut', 'gandum']
# Clean text function
def clean_text(text):
# Convert dashes to spaces for better tokenization
text = text.replace('--', ' ')
# Basic cleaning
text = re.sub(r"http\S+", "", text)
text = re.sub('\n', ' ', text)
text = re.sub("[^a-zA-Z0-9\s]", " ", text)
text = re.sub(" {2,}", " ", text)
text = text.strip()
text = text.lower()
return text
# Define model for multilabel classification
class MultilabelBertClassifier(nn.Module):
def __init__(self, model_name, num_labels):
super(MultilabelBertClassifier, self).__init__()
# Replace with a simpler initialization for inference only
from transformers import AutoConfig, AutoModel
self.config = AutoConfig.from_pretrained(model_name)
self.bert = AutoModel.from_pretrained(model_name)
self.classifier = nn.Linear(self.config.hidden_size, num_labels)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.last_hidden_state[:, 0, :] # Use [CLS] token
return self.classifier(pooled_output)
# Function to remove 'module.' prefix from state dict keys
def remove_module_prefix(state_dict):
new_state_dict = {}
for key, value in state_dict.items():
if key.startswith('module.'):
new_key = key[7:] # Remove 'module.' prefix
else:
new_key = key
new_state_dict[new_key] = value
return new_state_dict
# Load model function
@st.cache_resource
def load_model():
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained('indobenchmark/indobert-base-p2')
# Initialize model
model = MultilabelBertClassifier('indobenchmark/indobert-base-p1', len(target_columns))
# Check if model exists
model_path = "model/alergen_model.pt"
if os.path.exists(model_path):
try:
# Load model weights
checkpoint = torch.load(model_path, map_location=device)
# Check if state_dict is directly in checkpoint or under 'model_state_dict' key
if 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
else:
state_dict = checkpoint
# Remove 'module.' prefix if it exists
state_dict = remove_module_prefix(state_dict)
# Load the processed state dict
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model, tokenizer
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return None, tokenizer
else:
st.error("Model file not found. Please upload the model file.")
return None, tokenizer
# Function to predict allergens
def predict_allergens(model, tokenizer, ingredients_text, max_length=128):
if not model:
return {}
# Clean the text
cleaned_text = clean_text(ingredients_text)
# Tokenize
encoding = tokenizer.encode_plus(
cleaned_text,
add_special_tokens=True,
max_length=max_length,
truncation=True,
return_tensors='pt',
padding='max_length'
)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
predictions = torch.sigmoid(outputs)
predictions = (predictions > 0.5).float().cpu().numpy()[0]
result = {}
for i, target in enumerate(target_columns):
result[target] = bool(predictions[i])
return result
# UI components
def main():
st.title("🍲 Allergen Detection in Indonesian Recipes")
st.write("This app predicts common allergens in your recipe based on ingredients.")
# Create directory for model if it doesn't exist
os.makedirs("model", exist_ok=True)
# Sidebar for model upload
with st.sidebar:
st.header("Model Settings")
uploaded_model = st.file_uploader("Upload model file (alergen_model.pt)", type=["pt"])
if uploaded_model:
# Save uploaded model
with open("model/alergen_model.pt", "wb") as f:
f.write(uploaded_model.getbuffer())
st.success("Model uploaded successfully!")
st.markdown("---")
st.write("Allergen Categories:")
for allergen in target_columns:
if allergen == 'susu':
st.write("- Susu (Milk)")
elif allergen == 'kacang':
st.write("- Kacang (Nuts)")
elif allergen == 'telur':
st.write("- Telur (Eggs)")
elif allergen == 'makanan_laut':
st.write("- Makanan Laut (Seafood)")
elif allergen == 'gandum':
st.write("- Gandum (Wheat/Gluten)")
# Load model
model, tokenizer = load_model()
# Input area
st.header("Recipe Ingredients")
# Example button
if st.button("Load Example"):
example_text = "1 bungkus Lontong homemade 2 butir Telur ayam 2 kotak kecil Tahu coklat 4 butir kecil Kentang 2 buah Tomat merah 1 buah Ketimun lalap 4 lembar Selada keriting 2 lembar Kol putih 2 porsi Saus kacang homemade 4 buah Kerupuk udang goreng Secukupnya emping goreng 2 sdt Bawang goreng Secukupnya Kecap manis (bila suka)"
st.session_state.ingredients = example_text
# Text input
ingredients_text = st.text_area(
"Enter recipe ingredients (in Indonesian):",
height=150,
key="ingredients"
)
# Predict button
if st.button("Detect Allergens"):
if ingredients_text.strip() == "":
st.warning("Please enter ingredients first.")
elif model is None:
st.error("Please upload the model file first.")
else:
with st.spinner("Analyzing ingredients..."):
# Make prediction
allergens = predict_allergens(model, tokenizer, ingredients_text)
# Display results
st.header("Results")
# Create columns for results
col1, col2 = st.columns(2)
with col1:
st.subheader("Detected Allergens:")
has_allergens = False
for allergen, present in allergens.items():
if present:
has_allergens = True
if allergen == 'susu':
st.warning("πŸ₯› Susu (Milk)")
elif allergen == 'kacang':
st.warning("πŸ₯œ Kacang (Nuts)")
elif allergen == 'telur':
st.warning("πŸ₯š Telur (Eggs)")
elif allergen == 'makanan_laut':
st.warning("🦐 Makanan Laut (Seafood)")
elif allergen == 'gandum':
st.warning("🌾 Gandum (Wheat/Gluten)")
if not has_allergens:
st.success("βœ… No allergens detected!")
with col2:
st.subheader("All Categories:")
for allergen, present in allergens.items():
if allergen == 'susu':
st.write("πŸ₯› Susu (Milk): " + ("Detected ⚠️" if present else "Not detected βœ“"))
elif allergen == 'kacang':
st.write("πŸ₯œ Kacang (Nuts): " + ("Detected ⚠️" if present else "Not detected βœ“"))
elif allergen == 'telur':
st.write("πŸ₯š Telur (Eggs): " + ("Detected ⚠️" if present else "Not detected βœ“"))
elif allergen == 'makanan_laut':
st.write("🦐 Makanan Laut (Seafood): " + ("Detected ⚠️" if present else "Not detected βœ“"))
elif allergen == 'gandum':
st.write("🌾 Gandum (Wheat/Gluten): " + ("Detected ⚠️" if present else "Not detected βœ“"))
# Show cleaned text
with st.expander("Processed Text"):
st.code(clean_text(ingredients_text))
# Instructions and information
with st.expander("How to Use"):
st.write("""
1. First, upload the trained model file (`alergen_model.pt`) using the sidebar uploader
2. Enter your recipe ingredients in the text box (in Indonesian)
3. Click the "Detect Allergens" button to analyze the recipe
4. View the results showing which allergens are present in your recipe
The model detects five common allergen categories: milk, nuts, eggs, seafood, and wheat/gluten.
""")
if __name__ == "__main__":
main()