Spaces:
Running
Running
from fastapi import FastAPI, Query | |
from fastapi.responses import JSONResponse | |
import torch | |
import torchvision | |
import numpy as np | |
import requests | |
import skimage.io | |
import cv2 | |
import tempfile | |
import os | |
from PIL import Image | |
from transformers import AutoImageProcessor, AutoModel | |
import joblib | |
from pytorch_grad_cam import GradCAM | |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
import torchxrayvision as xrv | |
import requests | |
from io import BytesIO | |
import logging | |
logging.getLogger("uvicorn").setLevel(logging.WARNING) | |
import tempfile | |
temp_dir = tempfile.gettempdir() | |
matplotlib_cache = os.path.join(temp_dir, "matplotlib") | |
torchxrayvision_cache = os.path.join(temp_dir, "torchxrayvision") | |
os.environ["MPLCONFIGDIR"] = matplotlib_cache | |
os.environ["TORCHXrayVISION_CACHE"] = torchxrayvision_cache | |
os.makedirs(matplotlib_cache, exist_ok=True) | |
os.makedirs(torchxrayvision_cache, exist_ok=True) | |
app = FastAPI() | |
cxr_model = xrv.models.DenseNet(weights="densenet121-res224-all") | |
cxr_model.eval() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
tb_processor = AutoImageProcessor.from_pretrained("StanfordAIMI/dinov2-base-xray-224") | |
tb_model = AutoModel.from_pretrained("StanfordAIMI/dinov2-base-xray-224").to(device) | |
logreg = joblib.load("logreg_model.joblib") | |
def preprocess_image(image_path): | |
img = skimage.io.imread(image_path) | |
img = xrv.datasets.normalize(img, 255) | |
if img.ndim == 3: | |
img = img.mean(2)[None, ...] | |
elif img.ndim == 2: | |
img = img[None, ...] | |
transform = torchvision.transforms.Compose([ | |
xrv.datasets.XRayCenterCrop(), | |
xrv.datasets.XRayResizer(224) | |
]) | |
img = transform(img) | |
return torch.from_numpy(img) | |
def get_predictions(img_tensor, model): | |
with torch.no_grad(): | |
outputs = model(img_tensor[None, ...]) | |
preds = dict(zip(model.pathologies, outputs[0].detach().numpy())) | |
return preds, outputs | |
def get_top_preds(preds, tolerance=0.01, topk=5): | |
sorted_preds = sorted(preds.items(), key=lambda x: -x[1]) | |
top_conf = sorted_preds[0][1] | |
similar_preds = [(i, p, conf) for i, (p, conf) in enumerate(sorted_preds) | |
if abs(conf - top_conf) <= tolerance][:topk] | |
return sorted_preds, similar_preds | |
def get_bounding_boxes(img_tensor, model, similar_preds): | |
boxes = {} | |
target_layer = model.features[-1] | |
for idx, pathology, conf in similar_preds: | |
cam = GradCAM(model=model, target_layers=[target_layer]) | |
pred_index = list(model.pathologies).index(pathology) | |
grayscale_cam = cam(input_tensor=img_tensor[None, ...], | |
targets=[ClassifierOutputTarget(pred_index)])[0] | |
cam_resized = cv2.resize(grayscale_cam, (224, 224)) | |
cam_uint8 = (cam_resized * 255).astype(np.uint8) | |
_, thresh = cv2.threshold(cam_uint8, 100, 255, cv2.THRESH_BINARY) | |
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
if contours: | |
x, y, w, h = cv2.boundingRect(contours[0]) | |
boxes[pathology] = [[x, y], [x + w, y + h]] | |
return boxes | |
def predict_tb(image_path): | |
image = Image.open(image_path) | |
inputs = tb_processor(images=image, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = tb_model(**inputs) | |
embeddings = outputs.pooler_output.cpu().numpy() | |
prediction = logreg.predict(embeddings) | |
return int(prediction[0] == "tb") | |
async def predict_cxr(image_url: str = Query(..., description="URL to a chest X-ray image")): | |
try: | |
response = requests.get(image_url) | |
if response.status_code != 200: | |
return JSONResponse(content={"error": "Failed to download image"}, status_code=400) | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp: | |
tmp.write(response.content) | |
tmp_path = tmp.name | |
img_tensor = preprocess_image(tmp_path) | |
preds, _ = get_predictions(img_tensor, cxr_model) | |
sorted_preds, similar_preds = get_top_preds(preds) | |
prediction_result = {k: float(f"{v:.2f}") for k, v in preds.items()} | |
bounding_boxes = get_bounding_boxes(img_tensor, cxr_model, similar_preds) | |
tb_result = predict_tb(tmp_path) | |
os.remove(tmp_path) | |
return JSONResponse(content={ | |
"prediction_result": prediction_result, | |
"bounding_box": bounding_boxes, # top-left , bottom-right coordinates | |
"tb_finding": tb_result | |
}) | |
except Exception as e: | |
return JSONResponse(content={"error": str(e)}, status_code=500) |