Your Name commited on
Commit
beb105b
·
1 Parent(s): ffc27bc

first commit

Browse files
Files changed (1) hide show
  1. app.py +49 -0
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from PIL import Image
3
+ import numpy as np
4
+ import torch
5
+ import requests
6
+ from io import BytesIO
7
+ from torchvision.models import resnet18, ResNet18_Weights
8
+
9
+ def predict(img_path = None) -> str:
10
+ # Initialize the model and transform
11
+ resnet_model = resnet18(weights=ResNet18_Weights.DEFAULT)
12
+ resnet_transform = ResNet18_Weights.DEFAULT.transforms()
13
+
14
+
15
+ # Load the image
16
+ if img_path is None:
17
+ image = Image.open("examples/steak.jpeg").convert("RGB")
18
+
19
+ if isinstance(img_path, np.ndarray):
20
+ img = Image.fromarray(img_path.astype("uint8"), "RGB")
21
+
22
+ # img = effnet_b2_transform(img).unsqueeze(0)
23
+
24
+
25
+
26
+
27
+ # Convert to tensor
28
+ # img = torch.from_numpy(np.array(image)).permute(2, 0, 1)
29
+ img = resnet_transform(img)
30
+
31
+ # Inference
32
+ resnet_model.eval()
33
+ with torch.inference_mode():
34
+ logits = resnet_model(img.unsqueeze(0))
35
+ pred_class = torch.softmax(logits, dim=1).argmax(dim=1).item()
36
+ predicted_label = ResNet18_Weights.DEFAULT.meta["categories"][pred_class]
37
+ print(f"Predicted class: {predicted_label}")
38
+ return predicted_label
39
+
40
+
41
+ import numpy as np
42
+ import gradio as gr
43
+
44
+
45
+ demo = gr.Interface(predict, gr.Image(), "label")
46
+ if __name__ == "__main__":
47
+ demo.launch()
48
+
49
+