Recompense commited on
Commit
32142a3
·
verified ·
1 Parent(s): 4954de9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -101
app.py CHANGED
@@ -3,14 +3,46 @@ import streamlit as st
3
  from streamlit.runtime.uploaded_file_manager import UploadedFile
4
  import tensorflow as tf
5
  import pandas as pd
6
- from PIL import Image # Needed for image display consistency potentially
7
 
8
  # 🔹 Expand the Page Layout
9
- st.set_page_config(layout="wide") # Use Streamlit's built-in wide layout
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  # --- Constants and Data ---
12
  current_model = "Model Mini"
13
- new_model = "Food Vision" # Define the second model name
 
14
  class_names = ['apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare',
15
  'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito',
16
  'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake',
@@ -33,7 +65,7 @@ class_names = ['apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef
33
  'sushi', 'tacos', 'takoyaki', 'tiramisu', 'tuna_tartare', 'waffles']
34
 
35
  top_ten_dict = {
36
- "class_name": ["edamame", "macarons", "oysters", "pho", "mussels", # Corrected 'mussles' -> 'mussels'
37
  "sashimi", "seaweed_salad", "dumplings", "guacamole", "onion_rings"],
38
  "f1-score": [0.964427, 0.900433, 0.853119, 0.852652, 0.850622,
39
  0.844794, 0.834356, 0.833006, 0.83209, 0.831967]
@@ -45,72 +77,100 @@ last_ten_dict = {
45
  0.340426, 0.340045, 0.339785, 0.324826, 0.282407]
46
  }
47
 
48
- # 🔹 Custom CSS for Centered Content within elements and layout stability
49
  st.markdown(
50
  """
51
  <style>
52
- /* Center content vertically and horizontally using flexbox */
53
  .centered {
54
  display: flex;
55
  flex-direction: column;
56
- align-items: center;
57
- justify-content: center; /* Can adjust to flex-start if needed */
58
  text-align: center;
59
- width: 100%; /* Take full width of its container (e.g., column) */
60
- min-height: 300px; /* Give containers minimum height to reduce collapse */
61
- padding-top: 20px; /* Add some padding */
62
- padding-bottom: 20px;
 
63
  }
64
 
65
- /* Style file uploader for better centering if needed */
66
- /* Streamlit structure might change, this targets common patterns */
 
 
 
 
 
 
67
  div[data-testid="stFileUploader"] > section {
68
- padding: 0; /* Reduce default padding if it pushes content */
 
 
69
  }
70
- div[data-testid="stFileUploader"] > section > input {
71
- /* Hide default input if necessary */
 
 
 
72
  }
73
- div[data-testid="stFileUploader"] label {
74
- /* Style the label if needed */
75
- }
76
 
77
 
78
  /* Center images and standardize size */
79
- .centered img { /* Target images specifically within centered divs */
80
  display: block;
81
  margin-left: auto;
82
  margin-right: auto;
83
- max-width: 200px; /* Use max-width for responsiveness */
84
- max-height: 200px; /* Use max-height */
85
- width: auto; /* Allow auto width */
86
- height: auto; /* Allow auto height */
87
- object-fit: contain; /* Contain ensures the whole image fits */
88
  border-radius: 20px;
89
- margin-bottom: 15px; /* Add space below image */
90
  }
91
 
92
- /* Ensure columns try to vertically align content */
93
  div[data-testid="stVerticalBlock"] div[data-testid="stHorizontalBlock"] {
94
- align-items: center;
95
  }
96
 
97
  /* Style the radio buttons */
98
  div[data-testid="stRadio"] > label {
99
- font-weight: bold; /* Make label bold */
100
- margin-bottom: 10px;
 
101
  }
102
  div[data-testid="stRadio"] > div {
103
- display: flex;
104
- justify-content: center; /* Center radio options */
105
- gap: 15px; /* Add space between radio buttons */
 
106
  }
107
 
108
  /* Style the button */
109
  div[data-testid="stButton"] > button {
110
- width: 80%; /* Make button wider */
111
- margin-top: 20px; /* Add space above button */
 
 
 
 
 
 
112
  }
113
 
 
 
 
 
 
 
 
 
 
 
 
114
  </style>
115
  """,
116
  unsafe_allow_html=True
@@ -122,6 +182,7 @@ st.header("A food vision app using a CNN model fine-tuned on EfficientNet.")
122
  st.divider()
123
 
124
  # --- Explanations (Collapsible) ---
 
125
  with st.expander("Learn More: What is a CNN?"):
126
  st.write("""
127
  A Neural Network is a system inspired by the human brain, composed of interconnected nodes (neurons) organized in layers: an input layer, one or more hidden layers, and an output layer.
@@ -310,36 +371,27 @@ with st.expander("What is the F1-Score?"):
310
  st.subheader("Top and Least Performing Classes (by F1-Score)")
311
  with st.container():
312
  top_ten_df = pd.DataFrame(top_ten_dict).sort_values("f1-score", ascending=False)
313
- last_ten_df = pd.DataFrame(last_ten_dict).sort_values("f1-score", ascending=True) # Already sorted ascendingly in dict creation usually
314
 
315
- # Format class names for display
316
  top_ten_df['class_name_display'] = top_ten_df['class_name'].str.replace('_', ' ').str.title()
317
  last_ten_df['class_name_display'] = last_ten_df['class_name'].str.replace('_', ' ').str.title()
318
 
319
-
320
  col1, col2 = st.columns(2)
321
  with col1:
322
  st.write("**Top 10 Classes**")
323
- st.bar_chart(top_ten_df.set_index('class_name_display')['f1-score'],
324
- # horizontal=True, # Bar chart auto-detects horizontal best here
325
- use_container_width=True)
326
  with col2:
327
  st.write("**Bottom 10 Classes**")
328
- st.bar_chart(last_ten_df.set_index('class_name_display')['f1-score'],
329
- # horizontal=True,
330
- use_container_width=True, color="#ff748c") # Red color for low scores
331
- st.divider()
332
-
333
 
334
  # --- Helper Functions ---
335
- @st.cache_resource # Cache the loaded model
336
  def load_model(filepath):
337
  """Loads a Tensorflow Keras Model."""
338
- st.write(f"Cache miss: Loading model from {filepath}") # Debug message
339
  try:
340
  model = tf.keras.models.load_model(filepath)
341
- # You might need a warm-up prediction for GPU memory allocation
342
- # For example: model.predict(tf.zeros([1, 224, 224, 3]))
343
  return model
344
  except Exception as e:
345
  st.error(f"Error loading model from {filepath}: {e}")
@@ -348,21 +400,10 @@ def load_model(filepath):
348
  def load_prep_image(image_input: UploadedFile, img_shape=224):
349
  """Reads and preprocesses an image for EfficientNet prediction."""
350
  try:
351
- # Read image file buffer
352
  bytes_data = image_input.getvalue()
353
- # Decode image
354
  image_tensor = tf.io.decode_image(bytes_data, channels=3)
355
- # Resize image
356
- # Use tf.image.resize with method='nearest' or 'bilinear' (default)
357
  image_tensor_resized = tf.image.resize(image_tensor, [img_shape, img_shape])
358
- # Expand dimensions to create batch_size 1 -> (1, H, W, C)
359
  image_tensor_expanded = tf.expand_dims(image_tensor_resized, axis=0)
360
- # EfficientNet models usually have their own preprocessing layer/function
361
- # or expect inputs scaled 0-255. Check the specific model's requirement.
362
- # If it expects 0-1 scaling and doesn't do it internally:
363
- # image_tensor_scaled = image_tensor_expanded / 255.0
364
- # return image_tensor_scaled
365
- # Assuming EfficientNet B0 handles scaling or expects 0-255:
366
  return image_tensor_expanded
367
  except Exception as e:
368
  st.error(f"Error processing image: {e}")
@@ -385,19 +426,20 @@ def predict_using_model(image_input: UploadedFile, model_path: str) -> tuple[str
385
  try:
386
  with st.spinner("🤖 Model is predicting..."):
387
  pred_prob = model.predict(processed_image)
388
- predicted_index = tf.argmax(pred_prob, axis=1).numpy()[0] # Get index of highest probability
389
  predicted_class_name = class_names[predicted_index]
390
- predicted_probability = float(tf.reduce_max(pred_prob).numpy()) # Get the highest probability
391
  return predicted_class_name, predicted_probability
392
  except Exception as e:
393
  st.error(f"Prediction failed: {e}")
394
  return None, None
395
 
396
  # --- Interactive Demo Section ---
397
- st.divider()
398
  st.header(f"Try the Models: :blue[{current_model}] & :blue[{new_model}]")
399
  st.caption("_Model performance may vary. Models are periodically updated._")
400
 
 
401
  # Initialize session state keys if they don't exist
402
  if "prediction_result" not in st.session_state:
403
  st.session_state.prediction_result = None
@@ -408,18 +450,19 @@ if "predicted_prob" not in st.session_state:
408
 
409
 
410
  # Use columns for layout
411
- cols = st.columns([3, 0.5, 2, 0.5, 3], gap="medium") # Adjusted column ratios and gaps
 
412
 
413
  # --- Column 1: Image Input ---
414
  with cols[0]:
415
  st.markdown('<div class="centered">', unsafe_allow_html=True) # Apply centering
416
- st.subheader("1. Provide an Image")
417
  image_source = st.radio(
418
  "Choose image source:",
419
  ("Upload Image", "Use Camera"),
420
  key="image_source",
421
  horizontal=True,
422
- label_visibility="collapsed" # Hide the radio label itself
423
  )
424
 
425
  uploaded_image = None
@@ -442,22 +485,25 @@ with cols[0]:
442
 
443
  # Display uploaded image preview
444
  if uploaded_image:
445
- image_bytes_for_state = uploaded_image.getvalue() # Store bytes for state
446
- st.image(image_bytes_for_state, caption="Your image", use_column_width='auto') # Auto width fits container
447
- st.success("Image ready!")
448
  else:
449
  st.info("Upload or take a picture.")
450
 
451
- st.markdown('</div>', unsafe_allow_html=True) # Close centered div
 
452
 
453
  # --- Column 2: Arrow 1 ---
454
  with cols[1]:
455
- st.markdown('<div class="centered" style="justify-content: center; min-height: 300px;">➡️</div>', unsafe_allow_html=True)
 
 
456
 
457
  # --- Column 3: Model Selection & Prediction ---
458
  with cols[2]:
459
- st.markdown('<div class="centered">', unsafe_allow_html=True) # Apply centering
460
- st.subheader("2. Select Model")
461
 
462
  chosen_model = st.radio(
463
  "Pick a Model:",
@@ -470,17 +516,16 @@ with cols[2]:
470
  model_path_to_use = ""
471
  model_image_path = ""
472
 
473
- if chosen_model == current_model: # Model Mini
474
- model_image_path = "content/brain.png" # Make sure this file exists
475
- model_path_to_use = "model_mini_Food101.keras" # Make sure this path is correct
476
- elif chosen_model == new_model: # Food Vision
477
- model_image_path = "content/creativity_15557951.png" # Make sure this file exists
478
- model_path_to_use = "FoodVision.keras" # Make sure this path is correct
479
 
480
- # Display model icon/image if path is valid
481
  try:
482
  if model_image_path:
483
- st.image(model_image_path, width=150) # Control model image size
484
  except Exception as e:
485
  st.warning(f"Could not load model image: {model_image_path}")
486
 
@@ -489,41 +534,39 @@ with cols[2]:
489
  label="Predict Food!",
490
  icon="⚛️",
491
  type="primary",
492
- use_container_width=True, # Make button fill column width
493
- disabled=not uploaded_image or not model_path_to_use # Disable if no image or path
494
  )
495
 
496
  if predict_button:
497
  if uploaded_image and model_path_to_use:
498
- # Perform prediction
499
  result_class, result_prob = predict_using_model(uploaded_image, model_path=model_path_to_use)
500
- # Store results in session state
501
  st.session_state.prediction_result = result_class
502
  st.session_state.predicted_prob = result_prob
503
- st.session_state.predicted_image_bytes = image_bytes_for_state # Store the bytes of the image used
504
  else:
505
  st.warning("Please provide an image and select a valid model.")
506
 
507
- st.markdown('</div>', unsafe_allow_html=True) # Close centered div
 
508
 
509
  # --- Column 4: Arrow 2 ---
510
  with cols[3]:
511
- st.markdown('<div class="centered" style="justify-content: center; min-height: 300px;">➡️</div>', unsafe_allow_html=True)
 
 
512
 
513
  # --- Column 5: Output ---
514
  with cols[4]:
515
- st.markdown('<div class="centered">', unsafe_allow_html=True) # Apply centering
516
- st.subheader("3. Prediction Result")
517
 
518
- # Display result from session state
519
  if st.session_state.prediction_result and st.session_state.predicted_image_bytes:
520
- # Display the image associated with the prediction
521
  st.image(st.session_state.predicted_image_bytes, caption="Image Analyzed", use_column_width='auto')
522
 
523
  result_class = st.session_state.prediction_result
524
  probability = st.session_state.predicted_prob
525
 
526
- # Format class name nicely
527
  if "_" in result_class:
528
  modified_class = result_class.replace("_", " ").title()
529
  else:
@@ -531,15 +574,15 @@ with cols[4]:
531
 
532
  st.success(f"Prediction: **:blue[{modified_class}]**")
533
  if probability:
534
- st.write(f"Confidence: {probability:.2%}") # Display confidence
535
 
536
  elif predict_button:
537
- # If button was clicked but prediction failed or had no result
538
- st.error("Prediction could not be completed. Check logs or try again.")
539
  else:
540
- st.info("Result will appear here after prediction.")
 
 
541
 
542
- st.markdown('</div>', unsafe_allow_html=True) # Close centered div
543
 
544
  # --- Footer or Final Divider ---
545
- st.divider()
 
3
  from streamlit.runtime.uploaded_file_manager import UploadedFile
4
  import tensorflow as tf
5
  import pandas as pd
6
+ from PIL import Image
7
 
8
  # 🔹 Expand the Page Layout
9
+ st.set_page_config(layout="wide")
10
+
11
+ # 2. Inject CSS to override the default max-width and set it to 90%
12
+ st.markdown(
13
+ """
14
+ <style>
15
+ /* Target the main block container */
16
+ .block-container {
17
+ /* Set the max-width to 90% of the viewport */
18
+ max-width: 90% !important;
19
+
20
+ /* Optional: Adjust padding if needed */
21
+ /* padding-left: 2rem !important; */
22
+ /* padding-right: 2rem !important; */
23
+ /* padding-top: 1rem !important; */ /* Adjust top padding */
24
+ /* padding-bottom: 1rem !important; */
25
+
26
+ /* Ensure it remains centered */
27
+ margin: auto !important;
28
+ }
29
+
30
+ /* You might need to target a more specific element depending on the Streamlit version */
31
+ /* If the above doesn't work, try inspecting the element in your browser */
32
+ /* and use a selector like: */
33
+ /* div[data-testid="stAppViewContainer"] > section > div[data-testid="stBlock"] { */
34
+ /* max-width: 90% !important; */
35
+ /* } */
36
+
37
+ </style>
38
+ """,
39
+ unsafe_allow_html=True
40
+ )
41
 
42
  # --- Constants and Data ---
43
  current_model = "Model Mini"
44
+ new_model = "Food Vision"
45
+ # ... (class_names, top_ten_dict, last_ten_dict remain the same) ...
46
  class_names = ['apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare',
47
  'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito',
48
  'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake',
 
65
  'sushi', 'tacos', 'takoyaki', 'tiramisu', 'tuna_tartare', 'waffles']
66
 
67
  top_ten_dict = {
68
+ "class_name": ["edamame", "macarons", "oysters", "pho", "mussels",
69
  "sashimi", "seaweed_salad", "dumplings", "guacamole", "onion_rings"],
70
  "f1-score": [0.964427, 0.900433, 0.853119, 0.852652, 0.850622,
71
  0.844794, 0.834356, 0.833006, 0.83209, 0.831967]
 
77
  0.340426, 0.340045, 0.339785, 0.324826, 0.282407]
78
  }
79
 
80
+ # 🔹 Custom CSS for Alignment and Spacing
81
  st.markdown(
82
  """
83
  <style>
84
+ /* Center content H, align V top, reduce padding/min-height */
85
  .centered {
86
  display: flex;
87
  flex-direction: column;
88
+ align-items: center; /* Keep horizontal centering */
89
+ justify-content: flex-start; /* Align content to the top */
90
  text-align: center;
91
+ width: 100%;
92
+ min-height: 100px; /* Reduced minimum height */
93
+ padding-top: 5px; /* Reduced padding */
94
+ padding-bottom: 10px;
95
+ height: 100%; /* Allow div to take full column height if needed */
96
  }
97
 
98
+ /* Ensure subheaders within columns have less top margin */
99
+ .centered h3 { /* Targets the subheaders like "1. Provide Image" */
100
+ margin-top: 0px !important;
101
+ padding-top: 0px !important;
102
+ margin-bottom: 15px !important; /* Add space below subheader */
103
+ }
104
+
105
+ /* Style file uploader (ensure consistent padding) */
106
  div[data-testid="stFileUploader"] > section {
107
+ padding: 0;
108
+ display: flex; /* Use flex for centering button inside */
109
+ justify-content: center;
110
  }
111
+ div[data-testid="stFileUploader"] label { /* Style the button-like label */
112
+ margin-bottom: 10px; /* Space below uploader */
113
+ }
114
+ div[data-testid="stCameraInput"] label { /* Style the camera input label */
115
+ margin-bottom: 10px; /* Space below camera input */
116
  }
 
 
 
117
 
118
 
119
  /* Center images and standardize size */
120
+ .centered img {
121
  display: block;
122
  margin-left: auto;
123
  margin-right: auto;
124
+ max-width: 200px;
125
+ max-height: 200px;
126
+ width: auto;
127
+ height: auto;
128
+ object-fit: contain;
129
  border-radius: 20px;
130
+ margin-bottom: 10px; /* Reduced space below image */
131
  }
132
 
133
+ /* Align the TOPS of the columns */
134
  div[data-testid="stVerticalBlock"] div[data-testid="stHorizontalBlock"] {
135
+ align-items: flex-start !important; /* Align column tops */
136
  }
137
 
138
  /* Style the radio buttons */
139
  div[data-testid="stRadio"] > label {
140
+ font-weight: bold;
141
+ margin-bottom: 5px !important; /* Reduced margin */
142
+ padding-top: 0 !important;
143
  }
144
  div[data-testid="stRadio"] > div {
145
+ display: flex;
146
+ justify-content: center;
147
+ gap: 10px; /* Reduced gap */
148
+ margin-bottom: 10px; /* Add space below radio group */
149
  }
150
 
151
  /* Style the button */
152
  div[data-testid="stButton"] > button {
153
+ width: 80%;
154
+ margin-top: 15px; /* Reduced margin */
155
+ }
156
+
157
+ /* Reduce space BELOW the caption ABOVE the columns */
158
+ div[data-testid="stCaptionContainer"] {
159
+ padding-bottom: 0px !important;
160
+ margin-bottom: -10px !important; /* Negative margin pulls following elements up */
161
  }
162
 
163
+ /* Reduce space ABOVE the main H2 header for the demo section */
164
+ h2[data-testid="stHeading"]:has(+ div[data-testid="stCaptionContainer"]) {
165
+ margin-bottom: 5px !important; /* Space between header and caption */
166
+ }
167
+
168
+ /* Reduce space below the final divider before the demo H2 header */
169
+ hr[data-testid="stDivider"] + h2[data-testid="stHeading"] {
170
+ margin-top: -15px !important; /* Pull header closer to divider */
171
+ }
172
+
173
+
174
  </style>
175
  """,
176
  unsafe_allow_html=True
 
182
  st.divider()
183
 
184
  # --- Explanations (Collapsible) ---
185
+ # ... (Keep the expanders as they were) ...
186
  with st.expander("Learn More: What is a CNN?"):
187
  st.write("""
188
  A Neural Network is a system inspired by the human brain, composed of interconnected nodes (neurons) organized in layers: an input layer, one or more hidden layers, and an output layer.
 
371
  st.subheader("Top and Least Performing Classes (by F1-Score)")
372
  with st.container():
373
  top_ten_df = pd.DataFrame(top_ten_dict).sort_values("f1-score", ascending=False)
374
+ last_ten_df = pd.DataFrame(last_ten_dict).sort_values("f1-score", ascending=True)
375
 
 
376
  top_ten_df['class_name_display'] = top_ten_df['class_name'].str.replace('_', ' ').str.title()
377
  last_ten_df['class_name_display'] = last_ten_df['class_name'].str.replace('_', ' ').str.title()
378
 
 
379
  col1, col2 = st.columns(2)
380
  with col1:
381
  st.write("**Top 10 Classes**")
382
+ st.bar_chart(top_ten_df.set_index('class_name_display')['f1-score'], use_container_width=True)
 
 
383
  with col2:
384
  st.write("**Bottom 10 Classes**")
385
+ st.bar_chart(last_ten_df.set_index('class_name_display')['f1-score'], use_container_width=True, color="#ff748c")
386
+ st.divider() # Divider before the interactive section
 
 
 
387
 
388
  # --- Helper Functions ---
389
+ @st.cache_resource
390
  def load_model(filepath):
391
  """Loads a Tensorflow Keras Model."""
392
+ st.write(f"Cache miss: Loading model from {filepath}")
393
  try:
394
  model = tf.keras.models.load_model(filepath)
 
 
395
  return model
396
  except Exception as e:
397
  st.error(f"Error loading model from {filepath}: {e}")
 
400
  def load_prep_image(image_input: UploadedFile, img_shape=224):
401
  """Reads and preprocesses an image for EfficientNet prediction."""
402
  try:
 
403
  bytes_data = image_input.getvalue()
 
404
  image_tensor = tf.io.decode_image(bytes_data, channels=3)
 
 
405
  image_tensor_resized = tf.image.resize(image_tensor, [img_shape, img_shape])
 
406
  image_tensor_expanded = tf.expand_dims(image_tensor_resized, axis=0)
 
 
 
 
 
 
407
  return image_tensor_expanded
408
  except Exception as e:
409
  st.error(f"Error processing image: {e}")
 
426
  try:
427
  with st.spinner("🤖 Model is predicting..."):
428
  pred_prob = model.predict(processed_image)
429
+ predicted_index = tf.argmax(pred_prob, axis=1).numpy()[0]
430
  predicted_class_name = class_names[predicted_index]
431
+ predicted_probability = float(tf.reduce_max(pred_prob).numpy())
432
  return predicted_class_name, predicted_probability
433
  except Exception as e:
434
  st.error(f"Prediction failed: {e}")
435
  return None, None
436
 
437
  # --- Interactive Demo Section ---
438
+ # Header and Caption for the interactive section
439
  st.header(f"Try the Models: :blue[{current_model}] & :blue[{new_model}]")
440
  st.caption("_Model performance may vary. Models are periodically updated._")
441
 
442
+
443
  # Initialize session state keys if they don't exist
444
  if "prediction_result" not in st.session_state:
445
  st.session_state.prediction_result = None
 
450
 
451
 
452
  # Use columns for layout
453
+ cols = st.columns([3, 0.5, 2, 0.5, 3], gap="medium") # Keep original column ratios
454
+
455
 
456
  # --- Column 1: Image Input ---
457
  with cols[0]:
458
  st.markdown('<div class="centered">', unsafe_allow_html=True) # Apply centering
459
+ st.subheader("1. Provide an Image") # H3 targeted by CSS
460
  image_source = st.radio(
461
  "Choose image source:",
462
  ("Upload Image", "Use Camera"),
463
  key="image_source",
464
  horizontal=True,
465
+ label_visibility="collapsed"
466
  )
467
 
468
  uploaded_image = None
 
485
 
486
  # Display uploaded image preview
487
  if uploaded_image:
488
+ image_bytes_for_state = uploaded_image.getvalue()
489
+ st.image(image_bytes_for_state, caption="Your image", use_column_width='auto')
490
+ # Removed success message to save space
491
  else:
492
  st.info("Upload or take a picture.")
493
 
494
+ st.markdown('</div>', unsafe_allow_html=True)
495
+
496
 
497
  # --- Column 2: Arrow 1 ---
498
  with cols[1]:
499
+ # Adjusted padding-top to roughly align arrow with top content
500
+ st.markdown('<div class="centered" style="justify-content: flex-start; padding-top: 50px;">➡️</div>', unsafe_allow_html=True)
501
+
502
 
503
  # --- Column 3: Model Selection & Prediction ---
504
  with cols[2]:
505
+ st.markdown('<div class="centered">', unsafe_allow_html=True)
506
+ st.subheader("2. Select Model") # H3 targeted by CSS
507
 
508
  chosen_model = st.radio(
509
  "Pick a Model:",
 
516
  model_path_to_use = ""
517
  model_image_path = ""
518
 
519
+ if chosen_model == current_model:
520
+ model_image_path = "brain.png"
521
+ model_path_to_use = "model_mini_Food101.keras"
522
+ elif chosen_model == new_model:
523
+ model_image_path = "content/creativity_15557951.png"
524
+ model_path_to_use = "FoodVision.keras"
525
 
 
526
  try:
527
  if model_image_path:
528
+ st.image(model_image_path, width=150) # Keep model image
529
  except Exception as e:
530
  st.warning(f"Could not load model image: {model_image_path}")
531
 
 
534
  label="Predict Food!",
535
  icon="⚛️",
536
  type="primary",
537
+ use_container_width=True,
538
+ disabled=not uploaded_image or not model_path_to_use
539
  )
540
 
541
  if predict_button:
542
  if uploaded_image and model_path_to_use:
 
543
  result_class, result_prob = predict_using_model(uploaded_image, model_path=model_path_to_use)
 
544
  st.session_state.prediction_result = result_class
545
  st.session_state.predicted_prob = result_prob
546
+ st.session_state.predicted_image_bytes = image_bytes_for_state
547
  else:
548
  st.warning("Please provide an image and select a valid model.")
549
 
550
+ st.markdown('</div>', unsafe_allow_html=True)
551
+
552
 
553
  # --- Column 4: Arrow 2 ---
554
  with cols[3]:
555
+ # Adjusted padding-top
556
+ st.markdown('<div class="centered" style="justify-content: flex-start; padding-top: 50px;">➡️</div>', unsafe_allow_html=True)
557
+
558
 
559
  # --- Column 5: Output ---
560
  with cols[4]:
561
+ st.markdown('<div class="centered">', unsafe_allow_html=True)
562
+ st.subheader("3. Prediction Result") # H3 targeted by CSS
563
 
 
564
  if st.session_state.prediction_result and st.session_state.predicted_image_bytes:
 
565
  st.image(st.session_state.predicted_image_bytes, caption="Image Analyzed", use_column_width='auto')
566
 
567
  result_class = st.session_state.prediction_result
568
  probability = st.session_state.predicted_prob
569
 
 
570
  if "_" in result_class:
571
  modified_class = result_class.replace("_", " ").title()
572
  else:
 
574
 
575
  st.success(f"Prediction: **:blue[{modified_class}]**")
576
  if probability:
577
+ st.write(f"Confidence: {probability:.1%}") # Slightly less verbose confidence
578
 
579
  elif predict_button:
580
+ st.error("Prediction failed or image invalid.")
 
581
  else:
582
+ st.info("Result will appear here.")
583
+
584
+ st.markdown('</div>', unsafe_allow_html=True)
585
 
 
586
 
587
  # --- Footer or Final Divider ---
588
+ # st.divider() # Optional: remove if you want less space at the bottom