frankai98 commited on
Commit
47b58ef
·
verified ·
1 Parent(s): 0969ab3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -9
app.py CHANGED
@@ -1,8 +1,8 @@
1
  import streamlit as st
2
  from transformers import pipeline
3
  from PIL import Image
4
- import numpy as np
5
- import io
6
 
7
  # Set page configuration
8
  st.set_page_config(
@@ -20,6 +20,12 @@ def load_model():
20
  """Load the age classification model and cache it."""
21
  return pipeline("image-classification", model="nateraw/vit-age-classifier")
22
 
 
 
 
 
 
 
23
  # Load the model
24
  pipe = load_model()
25
 
@@ -30,15 +36,18 @@ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png
30
  st.markdown("### Or try an example:")
31
  col1, col2 = st.columns(2)
32
 
 
 
 
33
  with col1:
34
  if st.button("Example 1"):
35
- example_img = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/person1.jpg"
36
- st.session_state.example_image = example_img
37
 
38
  with col2:
39
  if st.button("Example 2"):
40
- example_img = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/person2.jpg"
41
- st.session_state.example_image = example_img
42
 
43
  # Process the image and display results
44
  if uploaded_file is not None:
@@ -55,12 +64,16 @@ if uploaded_file is not None:
55
  st.progress(float(pred["score"]))
56
  st.write(f"{pred['label']}: {pred['score']:.2%}")
57
 
58
- elif 'example_image' in st.session_state:
59
  # Process example image
60
- st.image(st.session_state.example_image, caption="Example Image", use_column_width=True)
 
 
 
61
 
62
  with st.spinner("Analyzing age..."):
63
- predictions = pipe(st.session_state.example_image)
 
64
 
65
  # Display results
66
  st.markdown("### Results:")
 
1
  import streamlit as st
2
  from transformers import pipeline
3
  from PIL import Image
4
+ import requests
5
+ from io import BytesIO
6
 
7
  # Set page configuration
8
  st.set_page_config(
 
20
  """Load the age classification model and cache it."""
21
  return pipeline("image-classification", model="nateraw/vit-age-classifier")
22
 
23
+ def load_image_from_url(url):
24
+ """Load an image from a URL."""
25
+ response = requests.get(url)
26
+ img = Image.open(BytesIO(response.content))
27
+ return img
28
+
29
  # Load the model
30
  pipe = load_model()
31
 
 
36
  st.markdown("### Or try an example:")
37
  col1, col2 = st.columns(2)
38
 
39
+ EXAMPLE_1 = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/person1.jpg"
40
+ EXAMPLE_2 = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/person2.jpg"
41
+
42
  with col1:
43
  if st.button("Example 1"):
44
+ st.session_state.example_image = EXAMPLE_1
45
+ st.session_state.example_loaded = True
46
 
47
  with col2:
48
  if st.button("Example 2"):
49
+ st.session_state.example_image = EXAMPLE_2
50
+ st.session_state.example_loaded = True
51
 
52
  # Process the image and display results
53
  if uploaded_file is not None:
 
64
  st.progress(float(pred["score"]))
65
  st.write(f"{pred['label']}: {pred['score']:.2%}")
66
 
67
+ elif 'example_loaded' in st.session_state and st.session_state.example_loaded:
68
  # Process example image
69
+ with st.spinner("Loading example image..."):
70
+ # Download and load the image properly
71
+ image = load_image_from_url(st.session_state.example_image)
72
+ st.image(image, caption="Example Image", use_column_width=True)
73
 
74
  with st.spinner("Analyzing age..."):
75
+ # Pass the actual PIL Image object to the pipeline
76
+ predictions = pipe(image)
77
 
78
  # Display results
79
  st.markdown("### Results:")