Spaces:
Running
Running
import streamlit as st | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
import pandas as pd | |
import re | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
# Load tokenizer dan model | |
MODEL_PATH = 'model/alergen_model.pt' | |
MODEL_NAME = 'indobenchmark/indobert-base-p1' | |
TARGET_COLUMNS = ['susu', 'kacang', 'telur', 'makanan_laut', 'gandum'] | |
MAX_LEN = 128 | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
class MultilabelBertClassifier(nn.Module): | |
def __init__(self, model_name, num_labels): | |
super(MultilabelBertClassifier, self).__init__() | |
self.bert = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels) | |
self.bert.classifier = nn.Linear(self.bert.config.hidden_size, num_labels) | |
def forward(self, input_ids, attention_mask): | |
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
return outputs.logits | |
model = MultilabelBertClassifier(MODEL_NAME, len(TARGET_COLUMNS)) | |
model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) | |
model.to(device) | |
model.eval() | |
# Fungsi preprocessing | |
def clean_text(text): | |
text = text.replace('--', ' ') | |
text = re.sub(r"http\S+", "", text) | |
text = re.sub('\n', ' ', text) | |
text = re.sub("[^a-zA-Z0-9\s]", " ", text) | |
text = re.sub(" {2,}", " ", text) | |
text = text.strip() | |
text = text.lower() | |
return text | |
# Fungsi prediksi | |
def predict(text): | |
cleaned = clean_text(text) | |
encoding = tokenizer.encode_plus( | |
cleaned, | |
add_special_tokens=True, | |
max_length=MAX_LEN, | |
return_tensors='pt', | |
padding='max_length', | |
truncation=True | |
) | |
input_ids = encoding['input_ids'].to(device) | |
attention_mask = encoding['attention_mask'].to(device) | |
with torch.no_grad(): | |
logits = model(input_ids=input_ids, attention_mask=attention_mask) | |
probs = torch.sigmoid(logits).cpu().numpy().flatten() | |
results = {TARGET_COLUMNS[i]: float(probs[i]) for i in range(len(TARGET_COLUMNS))} | |
return results | |
# STREAMLIT UI | |
st.title("π Deteksi Alergen dari Bahan Makanan") | |
st.markdown("Masukkan daftar bahan makanan, dan sistem akan memprediksi kemungkinan alergen.") | |
user_input = st.text_area("π§Ύ Bahan makanan (contoh: 2 butir telur, 1 gelas susu, kacang tanah...)") | |
if st.button("Prediksi Alergen"): | |
if user_input.strip() == "": | |
st.warning("Silakan masukkan bahan makanan terlebih dahulu.") | |
else: | |
with st.spinner("Memproses..."): | |
predictions = predict(user_input) | |
st.subheader("π Hasil Prediksi:") | |
for allergen, score in predictions.items(): | |
st.write(f"- **{allergen}**: {'β Terdeteksi' if score > 0.5 else 'β Tidak terdeteksi'} (Probabilitas: {score:.2f})") | |