LEGENDCODER1 commited on
Commit
91a2209
·
verified ·
1 Parent(s): 945e65b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -16
app.py CHANGED
@@ -1,27 +1,26 @@
1
- import torch
2
-
3
- print("PyTorch version:", torch.__version__)
4
-
5
-
6
  from transformers import pipeline
7
  import gradio as gr
8
 
9
- # Step 1: Load a pre-trained model for image classification
10
- model = pipeline("image-classification", model="google/vit-base-patch16-224")
11
 
12
- # Step 2: Define a function for classifying images
13
  def classify_image(image):
14
- predictions = model(image)
15
- return predictions
 
16
 
17
- # Step 3: Create a Gradio interface
18
  interface = gr.Interface(
19
- fn=classify_image,
20
- inputs="image", # Input is an image
21
- outputs="label", # Output is a label (e.g., "sitting", "standing")
22
- title="Pose Detection: Sitting or Standing"
 
23
  )
24
 
25
- # Step 4: Launch the app
26
  if __name__ == "__main__":
27
  interface.launch()
 
 
1
+ # Import necessary libraries
 
 
 
 
2
  from transformers import pipeline
3
  import gradio as gr
4
 
5
+ # Load a lightweight image classification model
6
+ model = pipeline("image-classification", model="facebook/deit-tiny-patch16-224", cache_dir="./model_cache")
7
 
8
+ # Function to classify an uploaded image
9
  def classify_image(image):
10
+ predictions = model(image) # Make predictions
11
+ # Format predictions as a dictionary: Label -> Confidence
12
+ return {pred["label"]: round(pred["score"], 4) for pred in predictions}
13
 
14
+ # Create a Gradio interface for the app
15
  interface = gr.Interface(
16
+ fn=classify_image, # Function to call
17
+ inputs=gr.Image(type="pil"), # Input: Image (PIL format)
18
+ outputs=gr.Label(), # Output: Label with confidence scores
19
+ title="Image Classification App",
20
+ description="Upload an image, and the app will classify it using a vision transformer model."
21
  )
22
 
23
+ # Run the app
24
  if __name__ == "__main__":
25
  interface.launch()
26
+