Spaces:
Running
Running
Commit
·
18869bb
0
Parent(s):
Initial commit for Hugging Face deployment
Browse files- .env.example +6 -0
- .gitignore +43 -0
- README-HF.md +32 -0
- README.md +66 -0
- app.py +39 -0
- app_hf.py +16 -0
- requirements-full.txt +10 -0
- requirements.txt +9 -0
- src/__init__.py +1 -0
- src/api/__init__.py +1 -0
- src/api/models.py +42 -0
- src/api/text_classification.py +52 -0
- src/config.py +23 -0
- src/main.py +66 -0
- src/models/__init__.py +1 -0
- src/models/base.py +42 -0
- src/models/registry.py +79 -0
- src/models/text_classification.py +106 -0
- test_api.py +60 -0
- test_client.py +130 -0
.env.example
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
APP_NAME="ML/AI Models API Service"
|
2 |
+
DEBUG=true
|
3 |
+
HOST="0.0.0.0"
|
4 |
+
PORT=8000
|
5 |
+
MODEL_CACHE_DIR="model_cache"
|
6 |
+
MAX_BATCH_SIZE=32
|
.gitignore
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
*.so
|
6 |
+
.Python
|
7 |
+
build/
|
8 |
+
develop-eggs/
|
9 |
+
dist/
|
10 |
+
downloads/
|
11 |
+
eggs/
|
12 |
+
.eggs/
|
13 |
+
lib/
|
14 |
+
lib64/
|
15 |
+
parts/
|
16 |
+
sdist/
|
17 |
+
var/
|
18 |
+
wheels/
|
19 |
+
*.egg-info/
|
20 |
+
.installed.cfg
|
21 |
+
*.egg
|
22 |
+
|
23 |
+
# Virtual Environment
|
24 |
+
venv/
|
25 |
+
ENV/
|
26 |
+
env/
|
27 |
+
|
28 |
+
# IDE
|
29 |
+
.idea/
|
30 |
+
.vscode/
|
31 |
+
*.swp
|
32 |
+
*.swo
|
33 |
+
|
34 |
+
# Logs
|
35 |
+
*.log
|
36 |
+
|
37 |
+
# Local development
|
38 |
+
.env
|
39 |
+
.env.local
|
40 |
+
flagged/
|
41 |
+
|
42 |
+
# Gradio
|
43 |
+
gradio_cached_examples/
|
README-HF.md
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Text Classification Model Demo
|
2 |
+
|
3 |
+
This is a Gradio interface for text classification using a BERT-based model. The model can classify text into predefined categories.
|
4 |
+
|
5 |
+
## Model Details
|
6 |
+
|
7 |
+
- Base Model: prajjwal1/bert-tiny
|
8 |
+
- Task: Text Classification
|
9 |
+
- Interface: Gradio
|
10 |
+
|
11 |
+
## Usage
|
12 |
+
|
13 |
+
1. Enter your text in the input textbox
|
14 |
+
2. Click submit
|
15 |
+
3. View the classification results
|
16 |
+
|
17 |
+
## Technical Details
|
18 |
+
|
19 |
+
- Python 3.9+
|
20 |
+
- Key Dependencies:
|
21 |
+
- gradio
|
22 |
+
- transformers
|
23 |
+
- torch
|
24 |
+
- numpy
|
25 |
+
|
26 |
+
## Deployment
|
27 |
+
|
28 |
+
This model is deployed using Hugging Face Spaces with a Gradio interface.
|
29 |
+
|
30 |
+
## License
|
31 |
+
|
32 |
+
MIT License
|
README.md
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ML/AI Models API Service
|
2 |
+
|
3 |
+
A centralized service that hosts various machine learning and AI models using Gradio interfaces, exposed via REST APIs for external frontend clients.
|
4 |
+
|
5 |
+
## Features
|
6 |
+
|
7 |
+
- Multiple ML/AI model endpoints
|
8 |
+
- Gradio interfaces for each model
|
9 |
+
- FastAPI backend for API exposure
|
10 |
+
- Easy model management and deployment
|
11 |
+
- Scalable architecture
|
12 |
+
|
13 |
+
## Setup
|
14 |
+
|
15 |
+
1. Create a virtual environment:
|
16 |
+
```bash
|
17 |
+
python -m venv venv
|
18 |
+
source venv/bin/activate # On Windows: venv\Scripts\activate
|
19 |
+
```
|
20 |
+
|
21 |
+
2. Install dependencies:
|
22 |
+
```bash
|
23 |
+
pip install -r requirements.txt
|
24 |
+
```
|
25 |
+
|
26 |
+
3. Create a `.env` file:
|
27 |
+
```bash
|
28 |
+
cp .env.example .env
|
29 |
+
```
|
30 |
+
|
31 |
+
4. Run the development server:
|
32 |
+
```bash
|
33 |
+
python src/main.py
|
34 |
+
```
|
35 |
+
|
36 |
+
## Project Structure
|
37 |
+
|
38 |
+
```
|
39 |
+
├── src/
|
40 |
+
│ ├── main.py # Main application entry point
|
41 |
+
│ ├── config.py # Configuration settings
|
42 |
+
│ ├── models/ # ML/AI model implementations
|
43 |
+
│ │ ├── __init__.py
|
44 |
+
│ │ └── base.py # Base model class
|
45 |
+
│ ├── interfaces/ # Gradio interfaces
|
46 |
+
│ │ └── __init__.py
|
47 |
+
│ └── api/ # FastAPI routes
|
48 |
+
│ └── __init__.py
|
49 |
+
├── tests/ # Test files
|
50 |
+
├── .env.example # Example environment variables
|
51 |
+
├── requirements.txt # Project dependencies
|
52 |
+
└── README.md # This file
|
53 |
+
```
|
54 |
+
|
55 |
+
## Adding New Models
|
56 |
+
|
57 |
+
1. Add your model implementation in `src/models/`
|
58 |
+
2. Create a Gradio interface in `src/interfaces/`
|
59 |
+
3. Add API endpoints in `src/api/`
|
60 |
+
4. Register the model in `src/main.py`
|
61 |
+
|
62 |
+
## API Documentation
|
63 |
+
|
64 |
+
Once the server is running, visit:
|
65 |
+
- API documentation: `http://localhost:8000/docs`
|
66 |
+
- Gradio interfaces: `http://localhost:8000/gradio`
|
app.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from src.models.text_classification import TextClassificationModel
|
3 |
+
import logging
|
4 |
+
|
5 |
+
# Configure logging
|
6 |
+
logging.basicConfig(level=logging.INFO)
|
7 |
+
logger = logging.getLogger(__name__)
|
8 |
+
|
9 |
+
def create_demo():
|
10 |
+
try:
|
11 |
+
# Initialize the model
|
12 |
+
logger.info("Initializing Text Classification model...")
|
13 |
+
model = TextClassificationModel()
|
14 |
+
|
15 |
+
# Create the interface
|
16 |
+
logger.info("Creating Gradio interface...")
|
17 |
+
demo = model.create_interface()
|
18 |
+
|
19 |
+
logger.info("Gradio interface created successfully")
|
20 |
+
return demo
|
21 |
+
|
22 |
+
except Exception as e:
|
23 |
+
logger.error(f"Error creating demo: {str(e)}")
|
24 |
+
raise
|
25 |
+
|
26 |
+
if __name__ == "__main__":
|
27 |
+
try:
|
28 |
+
logger.info("Starting the application...")
|
29 |
+
demo = create_demo()
|
30 |
+
logger.info("Launching the interface...")
|
31 |
+
demo.launch(
|
32 |
+
server_name="0.0.0.0", # Allow external connections
|
33 |
+
server_port=7860, # Specify port explicitly
|
34 |
+
share=True # Enable public URL
|
35 |
+
)
|
36 |
+
logger.info("Interface launched successfully")
|
37 |
+
except Exception as e:
|
38 |
+
logger.error(f"Application error: {str(e)}")
|
39 |
+
raise
|
app_hf.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from src.models.text_classification import TextClassificationModel
|
3 |
+
import logging
|
4 |
+
|
5 |
+
# Configure logging
|
6 |
+
logging.basicConfig(level=logging.INFO)
|
7 |
+
logger = logging.getLogger(__name__)
|
8 |
+
|
9 |
+
# Initialize the model
|
10 |
+
model = TextClassificationModel()
|
11 |
+
|
12 |
+
# Create the interface
|
13 |
+
demo = model.create_interface()
|
14 |
+
|
15 |
+
# Launch the interface (Hugging Face will handle the server configuration)
|
16 |
+
demo.launch()
|
requirements-full.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio>=4.19.2
|
2 |
+
fastapi>=0.110.0
|
3 |
+
uvicorn>=0.27.1
|
4 |
+
python-dotenv>=1.0.1
|
5 |
+
pydantic>=2.6.3
|
6 |
+
numpy>=1.26.4
|
7 |
+
torch>=2.2.1
|
8 |
+
transformers>=4.38.2
|
9 |
+
pillow>=10.2.0
|
10 |
+
python-multipart>=0.0.9
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio>=4.19.2
|
2 |
+
transformers>=4.38.2
|
3 |
+
torch>=2.2.1
|
4 |
+
numpy>=1.26.4
|
5 |
+
fastapi>=0.110.0
|
6 |
+
uvicorn>=0.27.1
|
7 |
+
pydantic>=2.6.3
|
8 |
+
pydantic-settings>=2.2.1
|
9 |
+
python-dotenv>=1.0.1
|
src/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
"""ML Models API Service."""
|
src/api/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
"""API endpoints for ML models."""
|
src/api/models.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter, HTTPException
|
2 |
+
from typing import List, Dict
|
3 |
+
from models.registry import GradioRegistry
|
4 |
+
|
5 |
+
router = APIRouter()
|
6 |
+
registry = GradioRegistry()
|
7 |
+
|
8 |
+
@router.get("/models", response_model=List[Dict])
|
9 |
+
async def list_models():
|
10 |
+
"""List all available models."""
|
11 |
+
return registry.list_models()
|
12 |
+
|
13 |
+
@router.get("/models/{model_id}")
|
14 |
+
async def get_model_info(model_id: str):
|
15 |
+
"""Get information about a specific model."""
|
16 |
+
model_info = registry.get_model_info(model_id)
|
17 |
+
if not model_info:
|
18 |
+
raise HTTPException(status_code=404, detail="Model not found")
|
19 |
+
return model_info
|
20 |
+
|
21 |
+
@router.get("/models/{model_id}/status")
|
22 |
+
async def get_model_status(model_id: str):
|
23 |
+
"""Get the current status of a model."""
|
24 |
+
model_info = registry.get_model_info(model_id)
|
25 |
+
if not model_info:
|
26 |
+
raise HTTPException(status_code=404, detail="Model not found")
|
27 |
+
return {"status": model_info.status}
|
28 |
+
|
29 |
+
@router.post("/models/{model_id}/load")
|
30 |
+
async def load_model(model_id: str):
|
31 |
+
"""Load a model into memory."""
|
32 |
+
model = registry.get_model(model_id)
|
33 |
+
if not model:
|
34 |
+
raise HTTPException(status_code=404, detail="Model not found")
|
35 |
+
|
36 |
+
try:
|
37 |
+
model.load_model()
|
38 |
+
registry.update_model_status(model_id, "loaded")
|
39 |
+
return {"status": "loaded"}
|
40 |
+
except Exception as e:
|
41 |
+
registry.update_model_status(model_id, "error")
|
42 |
+
raise HTTPException(status_code=500, detail=str(e))
|
src/api/text_classification.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter, HTTPException
|
2 |
+
from pydantic import BaseModel
|
3 |
+
from typing import Dict, Union, List
|
4 |
+
from models.text_classification import TextClassificationModel
|
5 |
+
|
6 |
+
router = APIRouter()
|
7 |
+
model = TextClassificationModel()
|
8 |
+
|
9 |
+
class TextInput(BaseModel):
|
10 |
+
text: str
|
11 |
+
|
12 |
+
class BatchTextInput(BaseModel):
|
13 |
+
texts: List[str]
|
14 |
+
|
15 |
+
class PredictionResponse(BaseModel):
|
16 |
+
label: str
|
17 |
+
confidence: float
|
18 |
+
|
19 |
+
class BatchPredictionResponse(BaseModel):
|
20 |
+
predictions: List[PredictionResponse]
|
21 |
+
|
22 |
+
@router.post("/predict", response_model=PredictionResponse)
|
23 |
+
async def predict(input_data: TextInput) -> Dict[str, Union[str, float]]:
|
24 |
+
"""Make a prediction for a single text."""
|
25 |
+
try:
|
26 |
+
result = await model.predict(input_data.text)
|
27 |
+
return result
|
28 |
+
except Exception as e:
|
29 |
+
raise HTTPException(
|
30 |
+
status_code=500,
|
31 |
+
detail=f"Prediction failed: {str(e)}"
|
32 |
+
)
|
33 |
+
|
34 |
+
@router.post("/predict_batch", response_model=BatchPredictionResponse)
|
35 |
+
async def predict_batch(input_data: BatchTextInput) -> Dict[str, List[Dict[str, Union[str, float]]]]:
|
36 |
+
"""Make predictions for multiple texts."""
|
37 |
+
try:
|
38 |
+
predictions = []
|
39 |
+
for text in input_data.texts:
|
40 |
+
result = await model.predict(text)
|
41 |
+
predictions.append(result)
|
42 |
+
return {"predictions": predictions}
|
43 |
+
except Exception as e:
|
44 |
+
raise HTTPException(
|
45 |
+
status_code=500,
|
46 |
+
detail=f"Batch prediction failed: {str(e)}"
|
47 |
+
)
|
48 |
+
|
49 |
+
@router.get("/info")
|
50 |
+
async def get_model_info():
|
51 |
+
"""Get information about the text classification model."""
|
52 |
+
return model.get_info()
|
src/config.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic_settings import BaseSettings
|
2 |
+
from pydantic import Field
|
3 |
+
|
4 |
+
|
5 |
+
class Settings(BaseSettings):
|
6 |
+
"""Application settings."""
|
7 |
+
app_name: str = "ML Models API"
|
8 |
+
debug: bool = Field(default=False, env="DEBUG")
|
9 |
+
host: str = Field(default="127.0.0.1", env="HOST")
|
10 |
+
port: int = Field(default=8000, env="PORT")
|
11 |
+
|
12 |
+
# Add model-specific configurations here
|
13 |
+
MODEL_CACHE_DIR: str = "model_cache"
|
14 |
+
MAX_BATCH_SIZE: int = 32
|
15 |
+
|
16 |
+
class Config:
|
17 |
+
env_file = ".env"
|
18 |
+
env_file_encoding = "utf-8"
|
19 |
+
|
20 |
+
|
21 |
+
def get_settings() -> Settings:
|
22 |
+
"""Get application settings."""
|
23 |
+
return Settings()
|
src/main.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import uvicorn
|
2 |
+
from fastapi import FastAPI
|
3 |
+
from fastapi.middleware.cors import CORSMiddleware
|
4 |
+
import gradio as gr
|
5 |
+
from config import get_settings
|
6 |
+
from models.text_classification import TextClassificationModel
|
7 |
+
from api.models import router as models_router, registry
|
8 |
+
|
9 |
+
app = FastAPI(
|
10 |
+
title=get_settings().app_name,
|
11 |
+
description="API for managing and running ML models",
|
12 |
+
version="1.0.0",
|
13 |
+
docs_url="/docs",
|
14 |
+
redoc_url="/redoc",
|
15 |
+
)
|
16 |
+
|
17 |
+
# Configure CORS
|
18 |
+
app.add_middleware(
|
19 |
+
CORSMiddleware,
|
20 |
+
allow_origins=["*"], # Modify this in production
|
21 |
+
allow_credentials=True,
|
22 |
+
allow_methods=["*"],
|
23 |
+
allow_headers=["*"],
|
24 |
+
)
|
25 |
+
|
26 |
+
# Register models
|
27 |
+
text_classifier = TextClassificationModel()
|
28 |
+
registry.register_model(
|
29 |
+
"text-classification",
|
30 |
+
text_classifier,
|
31 |
+
"/gradio/text-classification"
|
32 |
+
)
|
33 |
+
|
34 |
+
# Mount the models API router
|
35 |
+
app.include_router(
|
36 |
+
models_router,
|
37 |
+
prefix="/api/models",
|
38 |
+
tags=["models"]
|
39 |
+
)
|
40 |
+
|
41 |
+
# Mount Gradio interface
|
42 |
+
app = gr.mount_gradio_app(
|
43 |
+
app,
|
44 |
+
text_classifier.create_interface(),
|
45 |
+
path="/gradio/text-classification"
|
46 |
+
)
|
47 |
+
|
48 |
+
@app.get("/")
|
49 |
+
async def root():
|
50 |
+
"""Root endpoint returning basic API information."""
|
51 |
+
return {
|
52 |
+
"name": get_settings().app_name,
|
53 |
+
"version": "1.0.0",
|
54 |
+
"status": "running"
|
55 |
+
}
|
56 |
+
|
57 |
+
if __name__ == "__main__":
|
58 |
+
# Initialize settings
|
59 |
+
settings = get_settings()
|
60 |
+
|
61 |
+
uvicorn.run(
|
62 |
+
"main:app",
|
63 |
+
host=settings.host,
|
64 |
+
port=settings.port,
|
65 |
+
reload=settings.debug
|
66 |
+
)
|
src/models/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
"""ML model implementations."""
|
src/models/base.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import Any, Dict, Optional
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
|
6 |
+
class BaseModel(ABC):
|
7 |
+
"""Base class for all ML/AI models."""
|
8 |
+
|
9 |
+
def __init__(self, name: str, description: str):
|
10 |
+
self.name = name
|
11 |
+
self.description = description
|
12 |
+
self._model: Optional[Any] = None
|
13 |
+
self._interface: Optional[gr.Interface] = None
|
14 |
+
|
15 |
+
@abstractmethod
|
16 |
+
def load_model(self) -> None:
|
17 |
+
"""Load the model into memory."""
|
18 |
+
pass
|
19 |
+
|
20 |
+
@abstractmethod
|
21 |
+
def create_interface(self) -> gr.Interface:
|
22 |
+
"""Create and return a Gradio interface for the model."""
|
23 |
+
pass
|
24 |
+
|
25 |
+
@abstractmethod
|
26 |
+
async def predict(self, *args, **kwargs) -> Any:
|
27 |
+
"""Make predictions using the model."""
|
28 |
+
pass
|
29 |
+
|
30 |
+
def get_interface(self) -> gr.Interface:
|
31 |
+
"""Get or create the Gradio interface."""
|
32 |
+
if self._interface is None:
|
33 |
+
self._interface = self.create_interface()
|
34 |
+
return self._interface
|
35 |
+
|
36 |
+
def get_info(self) -> Dict[str, str]:
|
37 |
+
"""Get model information."""
|
38 |
+
return {
|
39 |
+
"name": self.name,
|
40 |
+
"description": self.description,
|
41 |
+
"status": "loaded" if self._model is not None else "unloaded"
|
42 |
+
}
|
src/models/registry.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional
|
2 |
+
import gradio as gr
|
3 |
+
from .base import BaseModel
|
4 |
+
|
5 |
+
class GradioModelInfo:
|
6 |
+
"""Information about a Gradio model."""
|
7 |
+
def __init__(self,
|
8 |
+
model_id: str,
|
9 |
+
name: str,
|
10 |
+
description: str,
|
11 |
+
input_type: str,
|
12 |
+
output_type: List[str],
|
13 |
+
examples: List[List[str]],
|
14 |
+
api_path: str):
|
15 |
+
self.model_id = model_id
|
16 |
+
self.name = name
|
17 |
+
self.description = description
|
18 |
+
self.input_type = input_type
|
19 |
+
self.output_type = output_type
|
20 |
+
self.examples = examples
|
21 |
+
self.api_path = api_path
|
22 |
+
self.status = "unloaded"
|
23 |
+
|
24 |
+
class GradioRegistry:
|
25 |
+
"""Registry for Gradio models."""
|
26 |
+
|
27 |
+
def __init__(self):
|
28 |
+
self._models: Dict[str, BaseModel] = {}
|
29 |
+
self._model_info: Dict[str, GradioModelInfo] = {}
|
30 |
+
|
31 |
+
def register_model(self,
|
32 |
+
model_id: str,
|
33 |
+
model: BaseModel,
|
34 |
+
api_path: str) -> None:
|
35 |
+
"""Register a new Gradio model."""
|
36 |
+
self._models[model_id] = model
|
37 |
+
|
38 |
+
# Create interface to extract information
|
39 |
+
interface = model.create_interface()
|
40 |
+
|
41 |
+
# Store model information
|
42 |
+
self._model_info[model_id] = GradioModelInfo(
|
43 |
+
model_id=model_id,
|
44 |
+
name=model.name,
|
45 |
+
description=model.description,
|
46 |
+
input_type=interface.input_components[0].__class__.__name__,
|
47 |
+
output_type=[comp.__class__.__name__ for comp in interface.output_components],
|
48 |
+
examples=interface.examples or [],
|
49 |
+
api_path=api_path
|
50 |
+
)
|
51 |
+
|
52 |
+
def get_model(self, model_id: str) -> Optional[BaseModel]:
|
53 |
+
"""Get a model by ID."""
|
54 |
+
return self._models.get(model_id)
|
55 |
+
|
56 |
+
def get_model_info(self, model_id: str) -> Optional[GradioModelInfo]:
|
57 |
+
"""Get model information by ID."""
|
58 |
+
return self._model_info.get(model_id)
|
59 |
+
|
60 |
+
def list_models(self) -> List[Dict]:
|
61 |
+
"""List all registered models."""
|
62 |
+
return [
|
63 |
+
{
|
64 |
+
"id": info.model_id,
|
65 |
+
"name": info.name,
|
66 |
+
"description": info.description,
|
67 |
+
"input_type": info.input_type,
|
68 |
+
"output_type": info.output_type,
|
69 |
+
"examples": info.examples,
|
70 |
+
"api_path": info.api_path,
|
71 |
+
"status": info.status
|
72 |
+
}
|
73 |
+
for info in self._model_info.values()
|
74 |
+
]
|
75 |
+
|
76 |
+
def update_model_status(self, model_id: str, status: str) -> None:
|
77 |
+
"""Update the status of a model."""
|
78 |
+
if model_id in self._model_info:
|
79 |
+
self._model_info[model_id].status = status
|
src/models/text_classification.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, Union, Tuple
|
2 |
+
import gradio as gr
|
3 |
+
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
|
4 |
+
import logging
|
5 |
+
from .base import BaseModel
|
6 |
+
|
7 |
+
logger = logging.getLogger(__name__)
|
8 |
+
|
9 |
+
class TextClassificationModel(BaseModel):
|
10 |
+
"""Lightweight text classification model using tiny BERT."""
|
11 |
+
|
12 |
+
def __init__(self):
|
13 |
+
super().__init__(
|
14 |
+
name="Lightweight Text Classifier",
|
15 |
+
description="Fast text classification using a tiny BERT model (4.4MB)"
|
16 |
+
)
|
17 |
+
self.model_name = "prajjwal1/bert-tiny"
|
18 |
+
self._model = None
|
19 |
+
|
20 |
+
def load_model(self) -> None:
|
21 |
+
"""Load the classification model."""
|
22 |
+
try:
|
23 |
+
logger.info(f"Loading model: {self.model_name}")
|
24 |
+
|
25 |
+
# Initialize model with binary classification
|
26 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
27 |
+
self.model_name,
|
28 |
+
num_labels=2
|
29 |
+
)
|
30 |
+
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
31 |
+
|
32 |
+
self._model = pipeline(
|
33 |
+
"text-classification",
|
34 |
+
model=model,
|
35 |
+
tokenizer=tokenizer,
|
36 |
+
device=-1 # CPU, use device=0 for GPU
|
37 |
+
)
|
38 |
+
|
39 |
+
# Log model size
|
40 |
+
model_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024)
|
41 |
+
logger.info(f"Model loaded successfully. Size: {model_size_mb:.2f} MB")
|
42 |
+
|
43 |
+
except Exception as e:
|
44 |
+
logger.error(f"Error loading model: {str(e)}")
|
45 |
+
raise
|
46 |
+
|
47 |
+
async def predict(self, text: str) -> Dict[str, Union[str, float]]:
|
48 |
+
"""Make a prediction using the model."""
|
49 |
+
try:
|
50 |
+
if self._model is None:
|
51 |
+
self.load_model()
|
52 |
+
|
53 |
+
logger.info(f"Processing text: {text[:50]}...")
|
54 |
+
result = self._model(text)[0]
|
55 |
+
|
56 |
+
# Map raw labels to sentiment
|
57 |
+
label_map = {
|
58 |
+
"LABEL_0": "NEGATIVE",
|
59 |
+
"LABEL_1": "POSITIVE"
|
60 |
+
}
|
61 |
+
|
62 |
+
prediction = {
|
63 |
+
"label": label_map.get(result["label"], result["label"]),
|
64 |
+
"confidence": float(result["score"])
|
65 |
+
}
|
66 |
+
logger.info(f"Prediction result: {prediction}")
|
67 |
+
return prediction
|
68 |
+
|
69 |
+
except Exception as e:
|
70 |
+
logger.error(f"Prediction error: {str(e)}")
|
71 |
+
raise
|
72 |
+
|
73 |
+
async def predict_for_interface(self, text: str) -> Tuple[str, float]:
|
74 |
+
"""Make a prediction and return it in a format suitable for the Gradio interface."""
|
75 |
+
result = await self.predict(text)
|
76 |
+
return result["label"], result["confidence"]
|
77 |
+
|
78 |
+
def create_interface(self) -> gr.Interface:
|
79 |
+
"""Create a Gradio interface for text classification."""
|
80 |
+
if self._model is None:
|
81 |
+
self.load_model()
|
82 |
+
|
83 |
+
examples = [
|
84 |
+
["This movie was fantastic! I really enjoyed it."],
|
85 |
+
["The service was terrible and the food was cold."],
|
86 |
+
["It was an okay experience, nothing special."],
|
87 |
+
["The weather is nice today!"],
|
88 |
+
["I'm feeling sick and tired."]
|
89 |
+
]
|
90 |
+
|
91 |
+
return gr.Interface(
|
92 |
+
fn=self.predict_for_interface, # Use the interface-specific prediction function
|
93 |
+
inputs=gr.Textbox(
|
94 |
+
lines=3,
|
95 |
+
placeholder="Enter text to classify...",
|
96 |
+
label="Input Text"
|
97 |
+
),
|
98 |
+
outputs=[
|
99 |
+
gr.Label(label="Sentiment"),
|
100 |
+
gr.Number(label="Confidence", precision=4)
|
101 |
+
],
|
102 |
+
title=self.name,
|
103 |
+
description=self.description + "\n\nThis model is also available via API!",
|
104 |
+
examples=examples,
|
105 |
+
api_name="predict"
|
106 |
+
)
|
test_api.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import time
|
3 |
+
import logging
|
4 |
+
|
5 |
+
logging.basicConfig(level=logging.INFO)
|
6 |
+
logger = logging.getLogger(__name__)
|
7 |
+
|
8 |
+
def test_api():
|
9 |
+
# Wait a bit for the model to load
|
10 |
+
logger.info("Waiting for the server to start...")
|
11 |
+
time.sleep(10)
|
12 |
+
|
13 |
+
base_url = "http://127.0.0.1:7860"
|
14 |
+
|
15 |
+
# Test single prediction
|
16 |
+
test_texts = [
|
17 |
+
"This is amazing! I love it!",
|
18 |
+
"This is terrible, I hate it.",
|
19 |
+
"It's okay, nothing special."
|
20 |
+
]
|
21 |
+
|
22 |
+
logger.info("\nTesting single predictions:")
|
23 |
+
for text in test_texts:
|
24 |
+
try:
|
25 |
+
logger.info(f"\nTesting with text: {text}")
|
26 |
+
response = requests.post(
|
27 |
+
f"{base_url}/api/predict",
|
28 |
+
json={"text": text}
|
29 |
+
)
|
30 |
+
|
31 |
+
if response.status_code == 200:
|
32 |
+
result = response.json()
|
33 |
+
logger.info(f"Result: {result}")
|
34 |
+
else:
|
35 |
+
logger.error(f"Error: {response.status_code} - {response.text}")
|
36 |
+
|
37 |
+
except Exception as e:
|
38 |
+
logger.error(f"Request failed: {str(e)}")
|
39 |
+
|
40 |
+
time.sleep(1) # Small delay between requests
|
41 |
+
|
42 |
+
# Test batch prediction
|
43 |
+
logger.info("\nTesting batch prediction:")
|
44 |
+
try:
|
45 |
+
response = requests.post(
|
46 |
+
f"{base_url}/api/predict_batch",
|
47 |
+
json={"texts": test_texts}
|
48 |
+
)
|
49 |
+
|
50 |
+
if response.status_code == 200:
|
51 |
+
result = response.json()
|
52 |
+
logger.info(f"Batch results: {result}")
|
53 |
+
else:
|
54 |
+
logger.error(f"Batch Error: {response.status_code} - {response.text}")
|
55 |
+
|
56 |
+
except Exception as e:
|
57 |
+
logger.error(f"Batch request failed: {str(e)}")
|
58 |
+
|
59 |
+
if __name__ == "__main__":
|
60 |
+
test_api()
|
test_client.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import json
|
3 |
+
import time
|
4 |
+
import logging
|
5 |
+
from typing import Dict, List, Any, Optional
|
6 |
+
|
7 |
+
logging.basicConfig(level=logging.INFO)
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
class MLModelsClient:
|
11 |
+
"""Client for interacting with the ML Models API."""
|
12 |
+
|
13 |
+
def __init__(self, base_url: str = "http://localhost:8000"):
|
14 |
+
self.base_url = base_url
|
15 |
+
|
16 |
+
def list_models(self) -> List[Dict]:
|
17 |
+
"""List all available models."""
|
18 |
+
try:
|
19 |
+
logger.info("Fetching available models...")
|
20 |
+
response = requests.get(f"{self.base_url}/api/models")
|
21 |
+
response.raise_for_status()
|
22 |
+
models = response.json()
|
23 |
+
logger.info(f"Found {len(models)} models")
|
24 |
+
return models
|
25 |
+
except Exception as e:
|
26 |
+
logger.error(f"Error listing models: {str(e)}")
|
27 |
+
raise
|
28 |
+
|
29 |
+
def get_model_info(self, model_id: str) -> Dict:
|
30 |
+
"""Get information about a specific model."""
|
31 |
+
try:
|
32 |
+
logger.info(f"Fetching info for model {model_id}...")
|
33 |
+
response = requests.get(f"{self.base_url}/api/models/{model_id}")
|
34 |
+
response.raise_for_status()
|
35 |
+
return response.json()
|
36 |
+
except Exception as e:
|
37 |
+
logger.error(f"Error getting model info: {str(e)}")
|
38 |
+
raise
|
39 |
+
|
40 |
+
def get_model_status(self, model_id: str) -> str:
|
41 |
+
"""Get the current status of a model."""
|
42 |
+
try:
|
43 |
+
logger.info(f"Fetching status for model {model_id}...")
|
44 |
+
response = requests.get(f"{self.base_url}/api/models/{model_id}/status")
|
45 |
+
response.raise_for_status()
|
46 |
+
return response.json()["status"]
|
47 |
+
except Exception as e:
|
48 |
+
logger.error(f"Error getting model status: {str(e)}")
|
49 |
+
raise
|
50 |
+
|
51 |
+
def load_model(self, model_id: str) -> str:
|
52 |
+
"""Load a model into memory."""
|
53 |
+
try:
|
54 |
+
logger.info(f"Loading model {model_id}...")
|
55 |
+
response = requests.post(f"{self.base_url}/api/models/{model_id}/load")
|
56 |
+
response.raise_for_status()
|
57 |
+
return response.json()["status"]
|
58 |
+
except Exception as e:
|
59 |
+
logger.error(f"Error loading model: {str(e)}")
|
60 |
+
raise
|
61 |
+
|
62 |
+
def predict(self, model_id: str, text: str) -> Dict:
|
63 |
+
"""Make a prediction using a model."""
|
64 |
+
try:
|
65 |
+
logger.info(f"Making prediction with model {model_id}...")
|
66 |
+
model_info = self.get_model_info(model_id)
|
67 |
+
response = requests.post(
|
68 |
+
f"{self.base_url}{model_info.get('api_path')}/predict",
|
69 |
+
json={"text": text}
|
70 |
+
)
|
71 |
+
response.raise_for_status()
|
72 |
+
return response.json()
|
73 |
+
except Exception as e:
|
74 |
+
logger.error(f"Error making prediction: {str(e)}")
|
75 |
+
raise
|
76 |
+
|
77 |
+
def test_model_workflow():
|
78 |
+
"""Test the complete model workflow."""
|
79 |
+
client = MLModelsClient()
|
80 |
+
|
81 |
+
try:
|
82 |
+
# 1. List available models
|
83 |
+
logger.info("\n1. Testing model listing...")
|
84 |
+
models = client.list_models()
|
85 |
+
for model in models:
|
86 |
+
logger.info(f"Found model: {json.dumps(model, indent=2)}")
|
87 |
+
|
88 |
+
if not models:
|
89 |
+
logger.error("No models found!")
|
90 |
+
return
|
91 |
+
|
92 |
+
# Use the first model for testing
|
93 |
+
model_id = models[0]["id"]
|
94 |
+
|
95 |
+
# 2. Get model information
|
96 |
+
logger.info(f"\n2. Testing model info retrieval for {model_id}...")
|
97 |
+
model_info = client.get_model_info(model_id)
|
98 |
+
logger.info(f"Model info: {json.dumps(model_info, indent=2)}")
|
99 |
+
|
100 |
+
# 3. Get model status
|
101 |
+
logger.info(f"\n3. Testing model status retrieval for {model_id}...")
|
102 |
+
status = client.get_model_status(model_id)
|
103 |
+
logger.info(f"Model status: {status}")
|
104 |
+
|
105 |
+
# 4. Load the model
|
106 |
+
logger.info(f"\n4. Testing model loading for {model_id}...")
|
107 |
+
load_status = client.load_model(model_id)
|
108 |
+
logger.info(f"Load status: {load_status}")
|
109 |
+
|
110 |
+
# 5. Test predictions
|
111 |
+
test_texts = [
|
112 |
+
"This is amazing! I really love it!",
|
113 |
+
"This is terrible, I hate it.",
|
114 |
+
"It's okay, nothing special."
|
115 |
+
]
|
116 |
+
|
117 |
+
logger.info(f"\n5. Testing predictions for {model_id}...")
|
118 |
+
for text in test_texts:
|
119 |
+
logger.info(f"\nPredicting for text: {text}")
|
120 |
+
result = client.predict(model_id, text)
|
121 |
+
logger.info(f"Prediction: {json.dumps(result, indent=2)}")
|
122 |
+
time.sleep(1) # Small delay between predictions
|
123 |
+
|
124 |
+
except Exception as e:
|
125 |
+
logger.error(f"Test workflow failed: {str(e)}")
|
126 |
+
raise
|
127 |
+
|
128 |
+
if __name__ == "__main__":
|
129 |
+
logger.info("Starting model testing workflow...")
|
130 |
+
test_model_workflow()
|