Spaces:
Running
Running
import streamlit as st | |
import torch | |
import torch.nn as nn | |
import re | |
from transformers import AutoTokenizer | |
import os | |
import numpy as np | |
# Set page config | |
st.set_page_config( | |
page_title="Allergen Detection App", | |
page_icon="π²", | |
layout="wide" | |
) | |
# Set device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Define target columns (allergens) | |
target_columns = ['susu', 'kacang', 'telur', 'makanan_laut', 'gandum'] | |
# Clean text function | |
def clean_text(text): | |
# Convert dashes to spaces for better tokenization | |
text = text.replace('--', ' ') | |
# Basic cleaning | |
text = re.sub(r"http\S+", "", text) | |
text = re.sub('\n', ' ', text) | |
text = re.sub("[^a-zA-Z0-9\s]", " ", text) | |
text = re.sub(" {2,}", " ", text) | |
text = text.strip() | |
text = text.lower() | |
return text | |
# Define model for multilabel classification | |
class MultilabelBertClassifier(nn.Module): | |
def __init__(self, model_name, num_labels): | |
super(MultilabelBertClassifier, self).__init__() | |
# Replace with a simpler initialization for inference only | |
from transformers import AutoConfig, AutoModel | |
self.config = AutoConfig.from_pretrained(model_name) | |
self.bert = AutoModel.from_pretrained(model_name) | |
self.classifier = nn.Linear(self.config.hidden_size, num_labels) | |
def forward(self, input_ids, attention_mask): | |
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
pooled_output = outputs.last_hidden_state[:, 0, :] # Use [CLS] token | |
return self.classifier(pooled_output) | |
# Function to remove 'module.' prefix from state dict keys | |
def remove_module_prefix(state_dict): | |
new_state_dict = {} | |
for key, value in state_dict.items(): | |
if key.startswith('module.'): | |
new_key = key[7:] # Remove 'module.' prefix | |
else: | |
new_key = key | |
new_state_dict[new_key] = value | |
return new_state_dict | |
# Load model function | |
def load_model(): | |
# Load tokenizer | |
tokenizer = AutoTokenizer.from_pretrained('indobenchmark/indobert-base-p2') | |
# Initialize model | |
model = MultilabelBertClassifier('indobenchmark/indobert-base-p1', len(target_columns)) | |
# Check if model exists | |
model_path = "model/alergen_model.pt" | |
if os.path.exists(model_path): | |
try: | |
# Load model weights | |
checkpoint = torch.load(model_path, map_location=device) | |
# Check if state_dict is directly in checkpoint or under 'model_state_dict' key | |
if 'model_state_dict' in checkpoint: | |
state_dict = checkpoint['model_state_dict'] | |
else: | |
state_dict = checkpoint | |
# Remove 'module.' prefix if it exists | |
state_dict = remove_module_prefix(state_dict) | |
# Load the processed state dict | |
model.load_state_dict(state_dict) | |
model.to(device) | |
model.eval() | |
return model, tokenizer | |
except Exception as e: | |
st.error(f"Error loading model: {str(e)}") | |
return None, tokenizer | |
else: | |
st.error("Model file not found. Please upload the model file.") | |
return None, tokenizer | |
# Function to predict allergens | |
def predict_allergens(model, tokenizer, ingredients_text, max_length=128): | |
if not model: | |
return {} | |
# Clean the text | |
cleaned_text = clean_text(ingredients_text) | |
# Tokenize | |
encoding = tokenizer.encode_plus( | |
cleaned_text, | |
add_special_tokens=True, | |
max_length=max_length, | |
truncation=True, | |
return_tensors='pt', | |
padding='max_length' | |
) | |
input_ids = encoding['input_ids'].to(device) | |
attention_mask = encoding['attention_mask'].to(device) | |
with torch.no_grad(): | |
outputs = model(input_ids=input_ids, attention_mask=attention_mask) | |
predictions = torch.sigmoid(outputs) | |
predictions = (predictions > 0.5).float().cpu().numpy()[0] | |
result = {} | |
for i, target in enumerate(target_columns): | |
result[target] = bool(predictions[i]) | |
return result | |
# UI components | |
def main(): | |
st.title("π² Allergen Detection in Indonesian Recipes") | |
st.write("This app predicts common allergens in your recipe based on ingredients.") | |
# Create directory for model if it doesn't exist | |
os.makedirs("model", exist_ok=True) | |
# Sidebar for model upload | |
with st.sidebar: | |
st.header("Model Settings") | |
uploaded_model = st.file_uploader("Upload model file (alergen_model.pt)", type=["pt"]) | |
if uploaded_model: | |
# Save uploaded model | |
with open("model/alergen_model.pt", "wb") as f: | |
f.write(uploaded_model.getbuffer()) | |
st.success("Model uploaded successfully!") | |
st.markdown("---") | |
st.write("Allergen Categories:") | |
for allergen in target_columns: | |
if allergen == 'susu': | |
st.write("- Susu (Milk)") | |
elif allergen == 'kacang': | |
st.write("- Kacang (Nuts)") | |
elif allergen == 'telur': | |
st.write("- Telur (Eggs)") | |
elif allergen == 'makanan_laut': | |
st.write("- Makanan Laut (Seafood)") | |
elif allergen == 'gandum': | |
st.write("- Gandum (Wheat/Gluten)") | |
# Load model | |
model, tokenizer = load_model() | |
# Input area | |
st.header("Recipe Ingredients") | |
# Example button | |
if st.button("Load Example"): | |
example_text = "1 bungkus Lontong homemade 2 butir Telur ayam 2 kotak kecil Tahu coklat 4 butir kecil Kentang 2 buah Tomat merah 1 buah Ketimun lalap 4 lembar Selada keriting 2 lembar Kol putih 2 porsi Saus kacang homemade 4 buah Kerupuk udang goreng Secukupnya emping goreng 2 sdt Bawang goreng Secukupnya Kecap manis (bila suka)" | |
st.session_state.ingredients = example_text | |
# Text input | |
ingredients_text = st.text_area( | |
"Enter recipe ingredients (in Indonesian):", | |
height=150, | |
key="ingredients" | |
) | |
# Predict button | |
if st.button("Detect Allergens"): | |
if ingredients_text.strip() == "": | |
st.warning("Please enter ingredients first.") | |
elif model is None: | |
st.error("Please upload the model file first.") | |
else: | |
with st.spinner("Analyzing ingredients..."): | |
# Make prediction | |
allergens = predict_allergens(model, tokenizer, ingredients_text) | |
# Display results | |
st.header("Results") | |
# Create columns for results | |
col1, col2 = st.columns(2) | |
with col1: | |
st.subheader("Detected Allergens:") | |
has_allergens = False | |
for allergen, present in allergens.items(): | |
if present: | |
has_allergens = True | |
if allergen == 'susu': | |
st.warning("π₯ Susu (Milk)") | |
elif allergen == 'kacang': | |
st.warning("π₯ Kacang (Nuts)") | |
elif allergen == 'telur': | |
st.warning("π₯ Telur (Eggs)") | |
elif allergen == 'makanan_laut': | |
st.warning("π¦ Makanan Laut (Seafood)") | |
elif allergen == 'gandum': | |
st.warning("πΎ Gandum (Wheat/Gluten)") | |
if not has_allergens: | |
st.success("β No allergens detected!") | |
with col2: | |
st.subheader("All Categories:") | |
for allergen, present in allergens.items(): | |
if allergen == 'susu': | |
st.write("π₯ Susu (Milk): " + ("Detected β οΈ" if present else "Not detected β")) | |
elif allergen == 'kacang': | |
st.write("π₯ Kacang (Nuts): " + ("Detected β οΈ" if present else "Not detected β")) | |
elif allergen == 'telur': | |
st.write("π₯ Telur (Eggs): " + ("Detected β οΈ" if present else "Not detected β")) | |
elif allergen == 'makanan_laut': | |
st.write("π¦ Makanan Laut (Seafood): " + ("Detected β οΈ" if present else "Not detected β")) | |
elif allergen == 'gandum': | |
st.write("πΎ Gandum (Wheat/Gluten): " + ("Detected β οΈ" if present else "Not detected β")) | |
# Show cleaned text | |
with st.expander("Processed Text"): | |
st.code(clean_text(ingredients_text)) | |
# Instructions and information | |
with st.expander("How to Use"): | |
st.write(""" | |
1. First, upload the trained model file (`alergen_model.pt`) using the sidebar uploader | |
2. Enter your recipe ingredients in the text box (in Indonesian) | |
3. Click the "Detect Allergens" button to analyze the recipe | |
4. View the results showing which allergens are present in your recipe | |
The model detects five common allergen categories: milk, nuts, eggs, seafood, and wheat/gluten. | |
""") | |
if __name__ == "__main__": | |
main() |