Rivalcoder commited on
Commit
91f9d2d
·
1 Parent(s): a85ce4c
Files changed (1) hide show
  1. app.py +40 -13
app.py CHANGED
@@ -1,19 +1,23 @@
1
  import torch
2
  from fastapi import FastAPI, File, UploadFile
3
- from fastapi.responses import JSONResponse
4
  from transformers import ConvNextForImageClassification, AutoImageProcessor
5
  from PIL import Image
6
  import io
 
 
 
 
7
 
8
- # Class names (for skin diseases)
9
  class_names = [
10
  'Acne and Rosacea Photos', 'Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions', 'Atopic Dermatitis Photos',
11
- 'Bullous Disease Photos', 'Cellulitis Impetigo and other Bacterial Infections', 'Eczema Photos', 'Exanthems and Drug Eruptions',
12
- 'Hair Loss Photos Alopecia and other Hair Diseases', 'Herpes HPV and other STDs Photos', 'Light Diseases and Disorders of Pigmentation',
13
- 'Lupus and other Connective Tissue diseases', 'Melanoma Skin Cancer Nevi and Moles', 'Nail Fungus and other Nail Disease',
14
- 'Poison Ivy Photos and other Contact Dermatitis', 'Psoriasis pictures Lichen Planus and related diseases',
15
- 'Scabies Lyme Disease and other Infestations and Bites', 'Seborrheic Keratoses and other Benign Tumors', 'Systemic Disease',
16
- 'Tinea Ringworm Candidiasis and other Fungal Infections', 'Urticaria Hives', 'Vascular Tumors', 'Vasculitis Photos',
17
  'Warts Molluscum and other Viral Infections'
18
  ]
19
 
@@ -27,8 +31,15 @@ processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224")
27
 
28
  # FastAPI app
29
  app = FastAPI()
 
 
 
 
 
 
 
30
 
31
- # Prediction helper
32
  def predict(image: Image.Image):
33
  inputs = processor(images=image, return_tensors="pt")
34
  with torch.no_grad():
@@ -36,7 +47,7 @@ def predict(image: Image.Image):
36
  predicted_class = torch.argmax(outputs.logits, dim=1).item()
37
  return predicted_class, class_names[predicted_class]
38
 
39
- # Endpoint: /predict
40
  @app.post("/predict/")
41
  async def predict_endpoint(file: UploadFile = File(...)):
42
  try:
@@ -50,6 +61,22 @@ async def predict_endpoint(file: UploadFile = File(...)):
50
  except Exception as e:
51
  return JSONResponse(content={"error": str(e)}, status_code=500)
52
 
53
- # Required for Hugging Face Spaces (do NOT run uvicorn manually)
54
- # Just expose the app
55
- app = app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from fastapi import FastAPI, File, UploadFile
3
+ from fastapi.responses import JSONResponse, RedirectResponse
4
  from transformers import ConvNextForImageClassification, AutoImageProcessor
5
  from PIL import Image
6
  import io
7
+ import gradio as gr
8
+ from starlette.middleware.cors import CORSMiddleware
9
+ from fastapi.staticfiles import StaticFiles
10
+ from gradio.routes import mount_gradio_app
11
 
12
+ # Class names
13
  class_names = [
14
  'Acne and Rosacea Photos', 'Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions', 'Atopic Dermatitis Photos',
15
+ 'Bullous Disease Photos', 'Cellulitis Impetigo and other Bacterial Infections', 'Eczema Photos', 'Exanthems and Drug Eruptions',
16
+ 'Hair Loss Photos Alopecia and other Hair Diseases', 'Herpes HPV and other STDs Photos', 'Light Diseases and Disorders of Pigmentation',
17
+ 'Lupus and other Connective Tissue diseases', 'Melanoma Skin Cancer Nevi and Moles', 'Nail Fungus and other Nail Disease',
18
+ 'Poison Ivy Photos and other Contact Dermatitis', 'Psoriasis pictures Lichen Planus and related diseases',
19
+ 'Scabies Lyme Disease and other Infestations and Bites', 'Seborrheic Keratoses and other Benign Tumors', 'Systemic Disease',
20
+ 'Tinea Ringworm Candidiasis and other Fungal Infections', 'Urticaria Hives', 'Vascular Tumors', 'Vasculitis Photos',
21
  'Warts Molluscum and other Viral Infections'
22
  ]
23
 
 
31
 
32
  # FastAPI app
33
  app = FastAPI()
34
+ app.add_middleware(
35
+ CORSMiddleware,
36
+ allow_origins=["*"], # Adjust for production
37
+ allow_credentials=True,
38
+ allow_methods=["*"],
39
+ allow_headers=["*"],
40
+ )
41
 
42
+ # Predict function
43
  def predict(image: Image.Image):
44
  inputs = processor(images=image, return_tensors="pt")
45
  with torch.no_grad():
 
47
  predicted_class = torch.argmax(outputs.logits, dim=1).item()
48
  return predicted_class, class_names[predicted_class]
49
 
50
+ # FastAPI route
51
  @app.post("/predict/")
52
  async def predict_endpoint(file: UploadFile = File(...)):
53
  try:
 
61
  except Exception as e:
62
  return JSONResponse(content={"error": str(e)}, status_code=500)
63
 
64
+ @app.get("/")
65
+ def redirect_root_to_gradio():
66
+ return RedirectResponse(url="/gradio")
67
+
68
+ # Gradio interface
69
+ def gradio_interface(image):
70
+ predicted_class, predicted_name = predict(image)
71
+ return f"{predicted_name} (Class {predicted_class})"
72
+
73
+ gradio_app = gr.Interface(
74
+ fn=gradio_interface,
75
+ inputs=gr.Image(type="pil"),
76
+ outputs="text",
77
+ title="Skin Disease Classifier",
78
+ description="Upload a skin image to classify the condition using a fine-tuned ConvNeXt model."
79
+ )
80
+
81
+ # Mount Gradio in FastAPI
82
+ app = mount_gradio_app(app, gradio_app, path="/gradio")