import streamlit as st import torch import torch.nn as nn import numpy as np import pandas as pd import re from transformers import AutoTokenizer, AutoModelForSequenceClassification # Load tokenizer dan model MODEL_PATH = 'model/alergen_model.pt' MODEL_NAME = 'indobenchmark/indobert-base-p1' TARGET_COLUMNS = ['susu', 'kacang', 'telur', 'makanan_laut', 'gandum'] MAX_LEN = 128 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) class MultilabelBertClassifier(nn.Module): def __init__(self, model_name, num_labels): super(MultilabelBertClassifier, self).__init__() self.bert = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels) self.bert.classifier = nn.Linear(self.bert.config.hidden_size, num_labels) def forward(self, input_ids, attention_mask): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) return outputs.logits model = MultilabelBertClassifier(MODEL_NAME, len(TARGET_COLUMNS)) model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) model.to(device) model.eval() # Fungsi preprocessing def clean_text(text): text = text.replace('--', ' ') text = re.sub(r"http\S+", "", text) text = re.sub('\n', ' ', text) text = re.sub("[^a-zA-Z0-9\s]", " ", text) text = re.sub(" {2,}", " ", text) text = text.strip() text = text.lower() return text # Fungsi prediksi def predict(text): cleaned = clean_text(text) encoding = tokenizer.encode_plus( cleaned, add_special_tokens=True, max_length=MAX_LEN, return_tensors='pt', padding='max_length', truncation=True ) input_ids = encoding['input_ids'].to(device) attention_mask = encoding['attention_mask'].to(device) with torch.no_grad(): logits = model(input_ids=input_ids, attention_mask=attention_mask) probs = torch.sigmoid(logits).cpu().numpy().flatten() results = {TARGET_COLUMNS[i]: float(probs[i]) for i in range(len(TARGET_COLUMNS))} return results # STREAMLIT UI st.title("๐Ÿ” Deteksi Alergen dari Bahan Makanan") st.markdown("Masukkan daftar bahan makanan, dan sistem akan memprediksi kemungkinan alergen.") user_input = st.text_area("๐Ÿงพ Bahan makanan (contoh: 2 butir telur, 1 gelas susu, kacang tanah...)") if st.button("Prediksi Alergen"): if user_input.strip() == "": st.warning("Silakan masukkan bahan makanan terlebih dahulu.") else: with st.spinner("Memproses..."): predictions = predict(user_input) st.subheader("๐Ÿ“Š Hasil Prediksi:") for allergen, score in predictions.items(): st.write(f"- **{allergen}**: {'โœ… Terdeteksi' if score > 0.5 else 'โŒ Tidak terdeteksi'} (Probabilitas: {score:.2f})")