Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,55 +1,30 @@
|
|
1 |
-
import os, json
|
2 |
-
import numpy as np
|
3 |
-
import tensorflow as tf
|
4 |
-
from PIL import Image
|
5 |
import gradio as gr
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
REPO_ID = "Adriana213/vgg16-fruit-classifier"
|
10 |
-
local_dir = snapshot_download(repo_id=REPO_ID)
|
11 |
-
|
12 |
-
# 2️⃣ Load labels
|
13 |
-
with open(os.path.join(local_dir, "class_labels.json"), "r") as f:
|
14 |
-
id2label = json.load(f)
|
15 |
-
|
16 |
-
# 3️⃣ Load as a TF SavedModel and grab the default signature
|
17 |
-
loaded = tf.saved_model.load(local_dir)
|
18 |
-
infer = loaded.signatures["serving_default"]
|
19 |
-
|
20 |
-
# 4️⃣ Figure out the input key name
|
21 |
-
input_key = list(infer.structured_input_signature[1].keys())[0]
|
22 |
-
|
23 |
-
# 5️⃣ Preprocess helper (same as before)
|
24 |
-
def preprocess(img: Image.Image) -> np.ndarray:
|
25 |
-
img = img.resize((100,100))
|
26 |
-
arr = np.array(img.convert("RGB")).astype("float32")
|
27 |
-
arr[...,0] -= 123.68
|
28 |
-
arr[...,1] -= 116.779
|
29 |
-
arr[...,2] -= 103.939
|
30 |
-
return np.expand_dims(arr, 0)
|
31 |
-
|
32 |
-
# 6️⃣ Inference fn
|
33 |
-
def classify_fruit(img: Image.Image):
|
34 |
-
x = preprocess(img)
|
35 |
-
# call the signature with the correct kwarg
|
36 |
-
outputs = infer(**{input_key: tf.constant(x)})
|
37 |
-
# grab the first tensor in the outputs dict
|
38 |
-
logits = list(outputs.values())[0].numpy()
|
39 |
-
probs = tf.nn.softmax(logits, axis=1).numpy()
|
40 |
-
idx = int(np.argmax(probs, axis=1)[0])
|
41 |
-
label = id2label[str(idx)]
|
42 |
-
score = float(probs[0][idx])
|
43 |
-
return {label: round(score, 4)}
|
44 |
|
45 |
-
#
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
)
|
53 |
|
54 |
if __name__ == "__main__":
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from transformers import AutoModelForImageClassification, AutoImageProcessor
|
4 |
+
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
+
# Load model and processor
|
7 |
+
model = AutoModelForImageClassification.from_pretrained("jazzmacedo/fruits-and-vegetables-detector-36")
|
8 |
+
processor = AutoImageProcessor.from_pretrained("jazzmacedo/fruits-and-vegetables-detector-36")
|
9 |
+
labels = list(model.config.id2label.values())
|
10 |
+
|
11 |
+
def classify_image(image):
|
12 |
+
# Preprocess the image
|
13 |
+
inputs = processor(images=image, return_tensors="pt")
|
14 |
+
with torch.no_grad():
|
15 |
+
outputs = model(**inputs)
|
16 |
+
predicted_idx = torch.argmax(outputs.logits, dim=1).item()
|
17 |
+
predicted_label = labels[predicted_idx]
|
18 |
+
return predicted_label
|
19 |
+
|
20 |
+
# Create Gradio interface
|
21 |
+
interface = gr.Interface(
|
22 |
+
fn=classify_image,
|
23 |
+
inputs=gr.Image(type="pil"),
|
24 |
+
outputs=gr.Text(label="Detected Label"),
|
25 |
+
title="Fruit & Vegetable Detector",
|
26 |
+
description="Upload an image of a fruit or vegetable and get the predicted label."
|
27 |
)
|
28 |
|
29 |
if __name__ == "__main__":
|
30 |
+
interface.launch()
|