Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
from fastapi import FastAPI, File, UploadFile | |
from fastapi.responses import JSONResponse | |
from transformers import ConvNextForImageClassification, AutoImageProcessor | |
from PIL import Image | |
import io | |
# Class names (for skin diseases) | |
class_names = [ | |
'Acne and Rosacea Photos', 'Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions', 'Atopic Dermatitis Photos', | |
'Bullous Disease Photos', 'Cellulitis Impetigo and other Bacterial Infections', 'Eczema Photos', 'Exanthems and Drug Eruptions', | |
'Hair Loss Photos Alopecia and other Hair Diseases', 'Herpes HPV and other STDs Photos', 'Light Diseases and Disorders of Pigmentation', | |
'Lupus and other Connective Tissue diseases', 'Melanoma Skin Cancer Nevi and Moles', 'Nail Fungus and other Nail Disease', | |
'Poison Ivy Photos and other Contact Dermatitis', 'Psoriasis pictures Lichen Planus and related diseases', | |
'Scabies Lyme Disease and other Infestations and Bites', 'Seborrheic Keratoses and other Benign Tumors', 'Systemic Disease', | |
'Tinea Ringworm Candidiasis and other Fungal Infections', 'Urticaria Hives', 'Vascular Tumors', 'Vasculitis Photos', | |
'Warts Molluscum and other Viral Infections' | |
] | |
# Load model and processor | |
model = ConvNextForImageClassification.from_pretrained("facebook/convnext-base-224") | |
model.classifier = torch.nn.Linear(in_features=1024, out_features=23) | |
model.load_state_dict(torch.load("./models/convnext_base_finetuned.pth", map_location="cpu")) | |
model.eval() | |
processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224") | |
# FastAPI app | |
app = FastAPI() | |
# Helper function for processing the image | |
def predict(image: Image.Image): | |
# Preprocess the image | |
inputs = processor(images=image, return_tensors="pt") | |
# Perform inference | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
predicted_class = torch.argmax(outputs.logits, dim=1).item() | |
return predicted_class, class_names[predicted_class] | |
# FastAPI endpoint to handle image upload and prediction | |
async def predict_endpoint(file: UploadFile = File(...)): | |
try: | |
# Read and process the image | |
img_bytes = await file.read() | |
img = Image.open(io.BytesIO(img_bytes)) | |
# Get the prediction | |
predicted_class, predicted_name = predict(img) | |
# Return the result as JSON | |
return JSONResponse(content={"predicted_class": predicted_class, "predicted_name": predicted_name}) | |
except Exception as e: | |
return JSONResponse(content={"error": str(e)}, status_code=500) | |
# Gradio function to integrate with the FastAPI prediction | |
def gradio_predict(image: Image.Image): | |
predicted_class, predicted_name = predict(image) | |
return f"Predicted Class: {predicted_name}" | |
# Gradio Interface | |
iface = gr.Interface(fn=gradio_predict, inputs=gr.Image(type="pil"), outputs=gr.Textbox()) | |
# Serve Gradio interface on FastAPI | |
async def gradio_interface(): | |
return iface.launch(share=True, inline=True) | |
# Run the FastAPI app using Uvicorn | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |