rdsarjito commited on
Commit
b1b9a76
·
1 Parent(s): 9de5935
Files changed (1) hide show
  1. app.py +38 -12
app.py CHANGED
@@ -47,6 +47,17 @@ class MultilabelBertClassifier(nn.Module):
47
  pooled_output = outputs.last_hidden_state[:, 0, :] # Use [CLS] token
48
  return self.classifier(pooled_output)
49
 
 
 
 
 
 
 
 
 
 
 
 
50
  # Load model function
51
  @st.cache_resource
52
  def load_model():
@@ -60,16 +71,28 @@ def load_model():
60
  model_path = "model/alergen_model.pt"
61
 
62
  if os.path.exists(model_path):
63
- # Load model weights
64
- checkpoint = torch.load(model_path, map_location=device)
65
- if 'model_state_dict' in checkpoint:
66
- model.load_state_dict(checkpoint['model_state_dict'])
67
- else:
68
- model.load_state_dict(checkpoint)
69
-
70
- model.to(device)
71
- model.eval()
72
- return model, tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
73
  else:
74
  st.error("Model file not found. Please upload the model file.")
75
  return None, tokenizer
@@ -111,10 +134,13 @@ def main():
111
  st.title("🍲 Allergen Detection in Indonesian Recipes")
112
  st.write("This app predicts common allergens in your recipe based on ingredients.")
113
 
 
 
 
114
  # Sidebar for model upload
115
  with st.sidebar:
116
  st.header("Model Settings")
117
- uploaded_model = st.file_uploader("Upload model file (model/alergen_model.pt)", type=["pt"])
118
 
119
  if uploaded_model:
120
  # Save uploaded model
@@ -212,7 +238,7 @@ def main():
212
  # Instructions and information
213
  with st.expander("How to Use"):
214
  st.write("""
215
- 1. First, upload the trained model file (`model/alergen_model.pt`) using the sidebar uploader
216
  2. Enter your recipe ingredients in the text box (in Indonesian)
217
  3. Click the "Detect Allergens" button to analyze the recipe
218
  4. View the results showing which allergens are present in your recipe
 
47
  pooled_output = outputs.last_hidden_state[:, 0, :] # Use [CLS] token
48
  return self.classifier(pooled_output)
49
 
50
+ # Function to remove 'module.' prefix from state dict keys
51
+ def remove_module_prefix(state_dict):
52
+ new_state_dict = {}
53
+ for key, value in state_dict.items():
54
+ if key.startswith('module.'):
55
+ new_key = key[7:] # Remove 'module.' prefix
56
+ else:
57
+ new_key = key
58
+ new_state_dict[new_key] = value
59
+ return new_state_dict
60
+
61
  # Load model function
62
  @st.cache_resource
63
  def load_model():
 
71
  model_path = "model/alergen_model.pt"
72
 
73
  if os.path.exists(model_path):
74
+ try:
75
+ # Load model weights
76
+ checkpoint = torch.load(model_path, map_location=device)
77
+
78
+ # Check if state_dict is directly in checkpoint or under 'model_state_dict' key
79
+ if 'model_state_dict' in checkpoint:
80
+ state_dict = checkpoint['model_state_dict']
81
+ else:
82
+ state_dict = checkpoint
83
+
84
+ # Remove 'module.' prefix if it exists
85
+ state_dict = remove_module_prefix(state_dict)
86
+
87
+ # Load the processed state dict
88
+ model.load_state_dict(state_dict)
89
+
90
+ model.to(device)
91
+ model.eval()
92
+ return model, tokenizer
93
+ except Exception as e:
94
+ st.error(f"Error loading model: {str(e)}")
95
+ return None, tokenizer
96
  else:
97
  st.error("Model file not found. Please upload the model file.")
98
  return None, tokenizer
 
134
  st.title("🍲 Allergen Detection in Indonesian Recipes")
135
  st.write("This app predicts common allergens in your recipe based on ingredients.")
136
 
137
+ # Create directory for model if it doesn't exist
138
+ os.makedirs("model", exist_ok=True)
139
+
140
  # Sidebar for model upload
141
  with st.sidebar:
142
  st.header("Model Settings")
143
+ uploaded_model = st.file_uploader("Upload model file (alergen_model.pt)", type=["pt"])
144
 
145
  if uploaded_model:
146
  # Save uploaded model
 
238
  # Instructions and information
239
  with st.expander("How to Use"):
240
  st.write("""
241
+ 1. First, upload the trained model file (`alergen_model.pt`) using the sidebar uploader
242
  2. Enter your recipe ingredients in the text box (in Indonesian)
243
  3. Click the "Detect Allergens" button to analyze the recipe
244
  4. View the results showing which allergens are present in your recipe