rdsarjito
3 commit
e88e274
raw
history blame
2.88 kB
import streamlit as st
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import re
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# Load tokenizer dan model
MODEL_PATH = 'model/alergen_model.pt'
MODEL_NAME = 'indobenchmark/indobert-base-p1'
TARGET_COLUMNS = ['susu', 'kacang', 'telur', 'makanan_laut', 'gandum']
MAX_LEN = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
class MultilabelBertClassifier(nn.Module):
def __init__(self, model_name, num_labels):
super(MultilabelBertClassifier, self).__init__()
self.bert = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
self.bert.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
return outputs.logits
model = MultilabelBertClassifier(MODEL_NAME, len(TARGET_COLUMNS))
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.to(device)
model.eval()
# Fungsi preprocessing
def clean_text(text):
text = text.replace('--', ' ')
text = re.sub(r"http\S+", "", text)
text = re.sub('\n', ' ', text)
text = re.sub("[^a-zA-Z0-9\s]", " ", text)
text = re.sub(" {2,}", " ", text)
text = text.strip()
text = text.lower()
return text
# Fungsi prediksi
def predict(text):
cleaned = clean_text(text)
encoding = tokenizer.encode_plus(
cleaned,
add_special_tokens=True,
max_length=MAX_LEN,
return_tensors='pt',
padding='max_length',
truncation=True
)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
with torch.no_grad():
logits = model(input_ids=input_ids, attention_mask=attention_mask)
probs = torch.sigmoid(logits).cpu().numpy().flatten()
results = {TARGET_COLUMNS[i]: float(probs[i]) for i in range(len(TARGET_COLUMNS))}
return results
# STREAMLIT UI
st.title("πŸ” Deteksi Alergen dari Bahan Makanan")
st.markdown("Masukkan daftar bahan makanan, dan sistem akan memprediksi kemungkinan alergen.")
user_input = st.text_area("🧾 Bahan makanan (contoh: 2 butir telur, 1 gelas susu, kacang tanah...)")
if st.button("Prediksi Alergen"):
if user_input.strip() == "":
st.warning("Silakan masukkan bahan makanan terlebih dahulu.")
else:
with st.spinner("Memproses..."):
predictions = predict(user_input)
st.subheader("πŸ“Š Hasil Prediksi:")
for allergen, score in predictions.items():
st.write(f"- **{allergen}**: {'βœ… Terdeteksi' if score > 0.5 else '❌ Tidak terdeteksi'} (Probabilitas: {score:.2f})")