shubham5524 commited on
Commit
bcf4ce7
·
verified ·
1 Parent(s): aeaa86c

Upload 4 files

Browse files
Files changed (4) hide show
  1. .huggingface.yaml +1 -0
  2. Dockerfile +10 -0
  3. app.py +120 -0
  4. requirements.txt +9 -0
.huggingface.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ sdk: docker
Dockerfile ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt .
6
+ RUN pip install --no-cache-dir -r requirements.txt
7
+
8
+ COPY . .
9
+
10
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Query
2
+ from pydantic import BaseModel
3
+ from typing import List, Tuple
4
+ from fastapi import Body
5
+
6
+ import torch
7
+ import torchxrayvision as xrv
8
+ import torchvision
9
+ import skimage.io
10
+ import numpy as np
11
+ import requests
12
+ import cv2
13
+
14
+ from io import BytesIO
15
+ import matplotlib.pyplot as plt
16
+ from pytorch_grad_cam import GradCAM
17
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
18
+ from pytorch_grad_cam.utils.image import show_cam_on_image
19
+
20
+ app = FastAPI()
21
+
22
+ model = xrv.models.DenseNet(weights="densenet121-res224-all")
23
+ model.eval()
24
+
25
+
26
+ def preprocess_image_from_url(image_url: str) -> torch.Tensor:
27
+ response = requests.get(image_url)
28
+ img = skimage.io.imread(BytesIO(response.content))
29
+ img = xrv.datasets.normalize(img, 255)
30
+
31
+ if img.ndim == 3:
32
+ img = img.mean(2)[None, ...]
33
+ elif img.ndim == 2:
34
+ img = img[None, ...]
35
+
36
+ transform = torchvision.transforms.Compose([
37
+ xrv.datasets.XRayCenterCrop(),
38
+ xrv.datasets.XRayResizer(224)
39
+ ])
40
+ img = transform(img)
41
+ img_tensor = torch.from_numpy(img)
42
+ return img_tensor
43
+
44
+
45
+ def get_predictions_and_bounding_box(img_tensor: torch.Tensor):
46
+ with torch.no_grad():
47
+ output = model(img_tensor[None, ...])[0]
48
+
49
+ predictions = dict(zip(model.pathologies, output.numpy()))
50
+ sorted_preds = sorted(predictions.items(), key=lambda x: -x[1])
51
+
52
+ top_pred_label, top_conf = sorted_preds[0]
53
+ top_pred_index = list(model.pathologies).index(top_pred_label)
54
+
55
+ target_layer = model.features[-1]
56
+ cam = GradCAM(model=model, target_layers=[target_layer])
57
+ grayscale_cam = cam(input_tensor=img_tensor[None, ...],
58
+ targets=[ClassifierOutputTarget(top_pred_index)])[0, :]
59
+
60
+ input_img = img_tensor.numpy()[0]
61
+ input_img_norm = (input_img - input_img.min()) / (input_img.max() - input_img.min())
62
+ input_img_rgb = cv2.cvtColor((input_img_norm * 255).astype(np.uint8), cv2.COLOR_GRAY2RGB)
63
+
64
+ cam_resized = cv2.resize(grayscale_cam, (224, 224))
65
+ cam_uint8 = (cam_resized * 255).astype(np.uint8)
66
+ _, thresh = cv2.threshold(cam_uint8, 100, 255, cv2.THRESH_BINARY)
67
+ contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
68
+
69
+ bounding_boxes = []
70
+ for cnt in contours:
71
+ x, y, w, h = cv2.boundingRect(cnt)
72
+ bounding_boxes.append((int(x), int(y), int(x + w), int(y + h)))
73
+
74
+ return sorted_preds, bounding_boxes
75
+
76
+
77
+ class Prediction(BaseModel):
78
+ label: str
79
+ confidence: float
80
+
81
+
82
+ class PredictionResponse(BaseModel):
83
+ predictions: List[Prediction]
84
+ top_prediction_bounding_boxes: List[Tuple[int, int, int, int]]
85
+
86
+
87
+ @app.get("/predict", response_model=PredictionResponse)
88
+ def predict(image_url: str = Query(..., description="URL of chest X-ray image")):
89
+ try:
90
+ img_tensor = preprocess_image_from_url(image_url)
91
+ preds, bboxes = get_predictions_and_bounding_box(img_tensor)
92
+ prediction_list = [Prediction(label=label, confidence=float(conf)) for label, conf in preds]
93
+
94
+ return PredictionResponse(
95
+ predictions=prediction_list,
96
+ top_prediction_bounding_boxes=bboxes
97
+ )
98
+ except Exception as e:
99
+ return {"error": str(e)}
100
+
101
+ class URLRequest(BaseModel):
102
+ url: str
103
+
104
+ @app.post("/predict", response_model=PredictionResponse)
105
+ def predict_from_url(body: URLRequest):
106
+ try:
107
+ img_tensor = preprocess_image_from_url(body.url)
108
+ preds, bboxes = get_predictions_and_bounding_box(img_tensor)
109
+ prediction_list = [Prediction(label=label, confidence=float(conf)) for label, conf in preds]
110
+
111
+ return PredictionResponse(
112
+ predictions=prediction_list,
113
+ top_prediction_bounding_boxes=bboxes
114
+ )
115
+ except Exception as e:
116
+ return {"error": str(e)}
117
+
118
+
119
+
120
+ # uvicorn app:app --reload
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ torch
4
+ torchvision
5
+ scikit-image
6
+ opencv-python
7
+ requests
8
+ torchxrayvision
9
+ grad-cam