Shashank commited on
Commit
944ccbf
·
1 Parent(s): 58a504a

initial commit

Browse files
Files changed (3) hide show
  1. app.py +29 -0
  2. requirements.txt +4 -0
  3. resnet50_diabretino.pth +3 -0
app.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision.models import resnet50
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ import gradio as gr
6
+
7
+ # Load model
8
+ model = resnet50(weights=None)
9
+ model.fc = torch.nn.Linear(model.fc.in_features, 5)
10
+ model.load_state_dict(torch.load("resnet50_dr.pth", map_location="cpu"))
11
+ model.eval()
12
+
13
+ class_names = ["No DR", "Mild", "Moderate", "Severe", "Proliferative DR"]
14
+ transform = transforms.Compose([
15
+ transforms.Resize((224, 224)),
16
+ transforms.ToTensor(),
17
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
18
+ std=[0.229, 0.224, 0.225])
19
+ ])
20
+
21
+ def predict(image):
22
+ image = image.convert("RGB")
23
+ img_tensor = transform(image).unsqueeze(0)
24
+ with torch.no_grad():
25
+ outputs = model(img_tensor)
26
+ _, predicted = torch.max(outputs, 1)
27
+ return class_names[predicted.item()]
28
+
29
+ gr.Interface(fn=predict, inputs="image", outputs="text").launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ Pillow
resnet50_diabretino.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8392eed4552b9f30fedc3eb221a62dfb9d64687341e91bfad5875839b4e03dd4
3
+ size 94392002