Spaces:
Sleeping
Sleeping
import torch | |
from transformers import AutoFeatureExtractor, AutoModelForImageClassification | |
from PIL import Image | |
import io | |
import base64 | |
import numpy as np | |
from flask import Flask, request, jsonify | |
class HuggingFaceClassifier: | |
def __init__(self, model_name="microsoft/resnet-50"): | |
""" | |
Initialize Hugging Face model and feature extractor | |
Args: | |
model_name (str): Hugging Face model identifier | |
""" | |
try: | |
# Load pre-trained model and feature extractor | |
self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) | |
self.model = AutoModelForImageClassification.from_pretrained(model_name) | |
# Move to GPU if available | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.model.to(self.device) | |
self.model.eval() | |
except Exception as e: | |
raise ValueError(f"Model loading error: {e}") | |
def preprocess_image(self, image): | |
""" | |
Preprocess image for model input | |
Args: | |
image (PIL.Image): Input image | |
Returns: | |
torch.Tensor: Preprocessed image tensor | |
""" | |
# Preprocess image using feature extractor | |
inputs = self.feature_extractor(images=image, return_tensors="pt") | |
return inputs.pixel_values.to(self.device) | |
def predict(self, image): | |
""" | |
Predict image classification | |
Args: | |
image (PIL.Image): Input image | |
Returns: | |
list: Top prediction results | |
""" | |
try: | |
# Preprocess image | |
inputs = self.preprocess_image(image) | |
# Perform prediction | |
with torch.no_grad(): | |
outputs = self.model(inputs) | |
logits = outputs.logits | |
probabilities = torch.softmax(logits, dim=-1) | |
top_k = torch.topk(probabilities, k=5) | |
# Process results | |
predicted_classes = [ | |
{ | |
"label": self.model.config.id2label[idx.item()], | |
"score": prob.item() | |
} | |
for idx, prob in zip(top_k.indices[0], top_k.values[0]) | |
] | |
return predicted_classes | |
except Exception as e: | |
raise RuntimeError(f"Prediction error: {e}") | |
# Flask API Setup | |
app = Flask(__name__) | |
# Initialize classifier (can be changed to any model) | |
classifier = HuggingFaceClassifier( | |
model_name="microsoft/resnet-50" | |
) | |
def predict_image(): | |
""" | |
Image classification endpoint | |
Supports base64 and file upload | |
""" | |
try: | |
# Handle base64 encoded image | |
if 'image' in request.json: | |
image_data = base64.b64decode(request.json['image']) | |
image = Image.open(io.BytesIO(image_data)) | |
# Handle file upload | |
elif 'file' in request.files: | |
image = Image.open(request.files['file']) | |
else: | |
return jsonify({ | |
'error': 'No image provided', | |
'status': 'failed' | |
}), 400 | |
# Perform prediction | |
predictions = classifier.predict(image) | |
return jsonify({ | |
'predictions': predictions, | |
'status': 'success' | |
}) | |
except Exception as e: | |
return jsonify({ | |
'error': str(e), | |
'status': 'failed' | |
}), 500 | |
def available_models(): | |
""" | |
List available pre-trained models | |
""" | |
models = [ | |
"microsoft/resnet-50", | |
"google/vit-base-patch16-224", | |
"facebook/vit-mae-base", | |
"microsoft/beit-base-patch16-224" | |
] | |
return jsonify({ | |
'models': models, | |
'total_models': len(models) | |
}) | |
def health_check(): | |
""" | |
API health check endpoint | |
""" | |
return jsonify({ | |
'status': 'healthy', | |
'model': classifier.model.config.model_type, | |
'device': str(classifier.device) | |
}) | |
if __name__ == '__main__': | |
app.run(host='0.0.0.0', port=5000, debug=True) |