Spaces:
Sleeping
Sleeping
Rivalcoder
commited on
Commit
·
2118cd6
1
Parent(s):
91f9d2d
[Edit]
Browse files- app.py +24 -10
- requirements.txt +1 -0
app.py
CHANGED
@@ -8,8 +8,11 @@ import gradio as gr
|
|
8 |
from starlette.middleware.cors import CORSMiddleware
|
9 |
from fastapi.staticfiles import StaticFiles
|
10 |
from gradio.routes import mount_gradio_app
|
|
|
|
|
|
|
11 |
|
12 |
-
# Class names
|
13 |
class_names = [
|
14 |
'Acne and Rosacea Photos', 'Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions', 'Atopic Dermatitis Photos',
|
15 |
'Bullous Disease Photos', 'Cellulitis Impetigo and other Bacterial Infections', 'Eczema Photos', 'Exanthems and Drug Eruptions',
|
@@ -21,7 +24,7 @@ class_names = [
|
|
21 |
'Warts Molluscum and other Viral Infections'
|
22 |
]
|
23 |
|
24 |
-
# Load model and processor
|
25 |
model = ConvNextForImageClassification.from_pretrained("facebook/convnext-base-224")
|
26 |
model.classifier = torch.nn.Linear(in_features=1024, out_features=23)
|
27 |
model.load_state_dict(torch.load("./models/convnext_base_finetuned.pth", map_location="cpu"))
|
@@ -29,17 +32,19 @@ model.eval()
|
|
29 |
|
30 |
processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224")
|
31 |
|
32 |
-
# FastAPI app
|
33 |
app = FastAPI()
|
|
|
|
|
34 |
app.add_middleware(
|
35 |
CORSMiddleware,
|
36 |
-
allow_origins=["*"], #
|
37 |
allow_credentials=True,
|
38 |
allow_methods=["*"],
|
39 |
allow_headers=["*"],
|
40 |
)
|
41 |
|
42 |
-
#
|
43 |
def predict(image: Image.Image):
|
44 |
inputs = processor(images=image, return_tensors="pt")
|
45 |
with torch.no_grad():
|
@@ -47,9 +52,10 @@ def predict(image: Image.Image):
|
|
47 |
predicted_class = torch.argmax(outputs.logits, dim=1).item()
|
48 |
return predicted_class, class_names[predicted_class]
|
49 |
|
50 |
-
# FastAPI route
|
51 |
-
@app.post("/predict
|
52 |
-
async def
|
|
|
53 |
try:
|
54 |
img_bytes = await file.read()
|
55 |
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
|
@@ -61,15 +67,18 @@ async def predict_endpoint(file: UploadFile = File(...)):
|
|
61 |
except Exception as e:
|
62 |
return JSONResponse(content={"error": str(e)}, status_code=500)
|
63 |
|
|
|
64 |
@app.get("/")
|
65 |
def redirect_root_to_gradio():
|
66 |
return RedirectResponse(url="/gradio")
|
67 |
|
68 |
-
# Gradio interface
|
69 |
def gradio_interface(image):
|
|
|
70 |
predicted_class, predicted_name = predict(image)
|
71 |
return f"{predicted_name} (Class {predicted_class})"
|
72 |
|
|
|
73 |
gradio_app = gr.Interface(
|
74 |
fn=gradio_interface,
|
75 |
inputs=gr.Image(type="pil"),
|
@@ -78,5 +87,10 @@ gradio_app = gr.Interface(
|
|
78 |
description="Upload a skin image to classify the condition using a fine-tuned ConvNeXt model."
|
79 |
)
|
80 |
|
81 |
-
# Mount Gradio
|
82 |
app = mount_gradio_app(app, gradio_app, path="/gradio")
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from starlette.middleware.cors import CORSMiddleware
|
9 |
from fastapi.staticfiles import StaticFiles
|
10 |
from gradio.routes import mount_gradio_app
|
11 |
+
import tempfile
|
12 |
+
import os
|
13 |
+
from typing import Optional
|
14 |
|
15 |
+
# Class names for skin disease classification
|
16 |
class_names = [
|
17 |
'Acne and Rosacea Photos', 'Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions', 'Atopic Dermatitis Photos',
|
18 |
'Bullous Disease Photos', 'Cellulitis Impetigo and other Bacterial Infections', 'Eczema Photos', 'Exanthems and Drug Eruptions',
|
|
|
24 |
'Warts Molluscum and other Viral Infections'
|
25 |
]
|
26 |
|
27 |
+
# Load the ConvNeXt model and processor
|
28 |
model = ConvNextForImageClassification.from_pretrained("facebook/convnext-base-224")
|
29 |
model.classifier = torch.nn.Linear(in_features=1024, out_features=23)
|
30 |
model.load_state_dict(torch.load("./models/convnext_base_finetuned.pth", map_location="cpu"))
|
|
|
32 |
|
33 |
processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224")
|
34 |
|
35 |
+
# FastAPI app setup
|
36 |
app = FastAPI()
|
37 |
+
|
38 |
+
# CORS Middleware to allow cross-origin requests
|
39 |
app.add_middleware(
|
40 |
CORSMiddleware,
|
41 |
+
allow_origins=["*"], # Allow all origins for demo purposes
|
42 |
allow_credentials=True,
|
43 |
allow_methods=["*"],
|
44 |
allow_headers=["*"],
|
45 |
)
|
46 |
|
47 |
+
# Function to predict the skin disease from an image
|
48 |
def predict(image: Image.Image):
|
49 |
inputs = processor(images=image, return_tensors="pt")
|
50 |
with torch.no_grad():
|
|
|
52 |
predicted_class = torch.argmax(outputs.logits, dim=1).item()
|
53 |
return predicted_class, class_names[predicted_class]
|
54 |
|
55 |
+
# FastAPI route for prediction via API
|
56 |
+
@app.post("/api/predict")
|
57 |
+
async def predict_from_upload(file: UploadFile = File(...)):
|
58 |
+
"""API endpoint for image uploads"""
|
59 |
try:
|
60 |
img_bytes = await file.read()
|
61 |
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
|
|
|
67 |
except Exception as e:
|
68 |
return JSONResponse(content={"error": str(e)}, status_code=500)
|
69 |
|
70 |
+
# Redirect root to Gradio interface
|
71 |
@app.get("/")
|
72 |
def redirect_root_to_gradio():
|
73 |
return RedirectResponse(url="/gradio")
|
74 |
|
75 |
+
# Gradio interface for testing
|
76 |
def gradio_interface(image):
|
77 |
+
"""Gradio function to handle the prediction from image"""
|
78 |
predicted_class, predicted_name = predict(image)
|
79 |
return f"{predicted_name} (Class {predicted_class})"
|
80 |
|
81 |
+
# Gradio app setup
|
82 |
gradio_app = gr.Interface(
|
83 |
fn=gradio_interface,
|
84 |
inputs=gr.Image(type="pil"),
|
|
|
87 |
description="Upload a skin image to classify the condition using a fine-tuned ConvNeXt model."
|
88 |
)
|
89 |
|
90 |
+
# Mount Gradio app into FastAPI
|
91 |
app = mount_gradio_app(app, gradio_app, path="/gradio")
|
92 |
+
|
93 |
+
# For running the app locally with uvicorn
|
94 |
+
if __name__ == "__main__":
|
95 |
+
import uvicorn
|
96 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
requirements.txt
CHANGED
@@ -8,4 +8,5 @@ Pillow
|
|
8 |
python-multipart
|
9 |
gradio
|
10 |
transformers
|
|
|
11 |
|
|
|
8 |
python-multipart
|
9 |
gradio
|
10 |
transformers
|
11 |
+
starlette
|
12 |
|