jatin1233232 commited on
Commit
3dd8f3d
·
verified ·
1 Parent(s): d945f50

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -50
app.py CHANGED
@@ -1,55 +1,30 @@
1
- import os, json
2
- import numpy as np
3
- import tensorflow as tf
4
- from PIL import Image
5
  import gradio as gr
6
- from huggingface_hub import snapshot_download
7
-
8
- # 1️⃣ Download the entire SavedModel repo
9
- REPO_ID = "Adriana213/vgg16-fruit-classifier"
10
- local_dir = snapshot_download(repo_id=REPO_ID)
11
-
12
- # 2️⃣ Load labels
13
- with open(os.path.join(local_dir, "class_labels.json"), "r") as f:
14
- id2label = json.load(f)
15
-
16
- # 3️⃣ Load as a TF SavedModel and grab the default signature
17
- loaded = tf.saved_model.load(local_dir)
18
- infer = loaded.signatures["serving_default"]
19
-
20
- # 4️⃣ Figure out the input key name
21
- input_key = list(infer.structured_input_signature[1].keys())[0]
22
-
23
- # 5️⃣ Preprocess helper (same as before)
24
- def preprocess(img: Image.Image) -> np.ndarray:
25
- img = img.resize((100,100))
26
- arr = np.array(img.convert("RGB")).astype("float32")
27
- arr[...,0] -= 123.68
28
- arr[...,1] -= 116.779
29
- arr[...,2] -= 103.939
30
- return np.expand_dims(arr, 0)
31
-
32
- # 6️⃣ Inference fn
33
- def classify_fruit(img: Image.Image):
34
- x = preprocess(img)
35
- # call the signature with the correct kwarg
36
- outputs = infer(**{input_key: tf.constant(x)})
37
- # grab the first tensor in the outputs dict
38
- logits = list(outputs.values())[0].numpy()
39
- probs = tf.nn.softmax(logits, axis=1).numpy()
40
- idx = int(np.argmax(probs, axis=1)[0])
41
- label = id2label[str(idx)]
42
- score = float(probs[0][idx])
43
- return {label: round(score, 4)}
44
 
45
- # 7️⃣ Gradio interface
46
- demo = gr.Interface(
47
- fn=classify_fruit,
48
- inputs=gr.Image(type="pil", label="Upload Fruit Image"),
49
- outputs=gr.Label(num_top_classes=5, label="Top Predictions"),
50
- title="🍉 Fruit Classifier (131 types)",
51
- description="Classify a fruit image into one of 131 categories."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  )
53
 
54
  if __name__ == "__main__":
55
- demo.launch()
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForImageClassification, AutoImageProcessor
4
+ from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ # Load model and processor
7
+ model = AutoModelForImageClassification.from_pretrained("jazzmacedo/fruits-and-vegetables-detector-36")
8
+ processor = AutoImageProcessor.from_pretrained("jazzmacedo/fruits-and-vegetables-detector-36")
9
+ labels = list(model.config.id2label.values())
10
+
11
+ def classify_image(image):
12
+ # Preprocess the image
13
+ inputs = processor(images=image, return_tensors="pt")
14
+ with torch.no_grad():
15
+ outputs = model(**inputs)
16
+ predicted_idx = torch.argmax(outputs.logits, dim=1).item()
17
+ predicted_label = labels[predicted_idx]
18
+ return predicted_label
19
+
20
+ # Create Gradio interface
21
+ interface = gr.Interface(
22
+ fn=classify_image,
23
+ inputs=gr.Image(type="pil"),
24
+ outputs=gr.Text(label="Detected Label"),
25
+ title="Fruit & Vegetable Detector",
26
+ description="Upload an image of a fruit or vegetable and get the predicted label."
27
  )
28
 
29
  if __name__ == "__main__":
30
+ interface.launch()