Spaces:
Running
Running
File size: 2,884 Bytes
552cd20 e88e274 c0cfde6 552cd20 e88e274 552cd20 e88e274 552cd20 e88e274 552cd20 e88e274 552cd20 e88e274 552cd20 e88e274 552cd20 e88e274 552cd20 e88e274 c0cfde6 e88e274 c0cfde6 e88e274 c0cfde6 e88e274 c0cfde6 e88e274 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import streamlit as st
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import re
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# Load tokenizer dan model
MODEL_PATH = 'model/alergen_model.pt'
MODEL_NAME = 'indobenchmark/indobert-base-p1'
TARGET_COLUMNS = ['susu', 'kacang', 'telur', 'makanan_laut', 'gandum']
MAX_LEN = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
class MultilabelBertClassifier(nn.Module):
def __init__(self, model_name, num_labels):
super(MultilabelBertClassifier, self).__init__()
self.bert = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
self.bert.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
return outputs.logits
model = MultilabelBertClassifier(MODEL_NAME, len(TARGET_COLUMNS))
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.to(device)
model.eval()
# Fungsi preprocessing
def clean_text(text):
text = text.replace('--', ' ')
text = re.sub(r"http\S+", "", text)
text = re.sub('\n', ' ', text)
text = re.sub("[^a-zA-Z0-9\s]", " ", text)
text = re.sub(" {2,}", " ", text)
text = text.strip()
text = text.lower()
return text
# Fungsi prediksi
def predict(text):
cleaned = clean_text(text)
encoding = tokenizer.encode_plus(
cleaned,
add_special_tokens=True,
max_length=MAX_LEN,
return_tensors='pt',
padding='max_length',
truncation=True
)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
with torch.no_grad():
logits = model(input_ids=input_ids, attention_mask=attention_mask)
probs = torch.sigmoid(logits).cpu().numpy().flatten()
results = {TARGET_COLUMNS[i]: float(probs[i]) for i in range(len(TARGET_COLUMNS))}
return results
# STREAMLIT UI
st.title("π Deteksi Alergen dari Bahan Makanan")
st.markdown("Masukkan daftar bahan makanan, dan sistem akan memprediksi kemungkinan alergen.")
user_input = st.text_area("π§Ύ Bahan makanan (contoh: 2 butir telur, 1 gelas susu, kacang tanah...)")
if st.button("Prediksi Alergen"):
if user_input.strip() == "":
st.warning("Silakan masukkan bahan makanan terlebih dahulu.")
else:
with st.spinner("Memproses..."):
predictions = predict(user_input)
st.subheader("π Hasil Prediksi:")
for allergen, score in predictions.items():
st.write(f"- **{allergen}**: {'β
Terdeteksi' if score > 0.5 else 'β Tidak terdeteksi'} (Probabilitas: {score:.2f})")
|