ogirald0 commited on
Commit
18869bb
·
0 Parent(s):

Initial commit for Hugging Face deployment

Browse files
.env.example ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ APP_NAME="ML/AI Models API Service"
2
+ DEBUG=true
3
+ HOST="0.0.0.0"
4
+ PORT=8000
5
+ MODEL_CACHE_DIR="model_cache"
6
+ MAX_BATCH_SIZE=32
.gitignore ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual Environment
24
+ venv/
25
+ ENV/
26
+ env/
27
+
28
+ # IDE
29
+ .idea/
30
+ .vscode/
31
+ *.swp
32
+ *.swo
33
+
34
+ # Logs
35
+ *.log
36
+
37
+ # Local development
38
+ .env
39
+ .env.local
40
+ flagged/
41
+
42
+ # Gradio
43
+ gradio_cached_examples/
README-HF.md ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Text Classification Model Demo
2
+
3
+ This is a Gradio interface for text classification using a BERT-based model. The model can classify text into predefined categories.
4
+
5
+ ## Model Details
6
+
7
+ - Base Model: prajjwal1/bert-tiny
8
+ - Task: Text Classification
9
+ - Interface: Gradio
10
+
11
+ ## Usage
12
+
13
+ 1. Enter your text in the input textbox
14
+ 2. Click submit
15
+ 3. View the classification results
16
+
17
+ ## Technical Details
18
+
19
+ - Python 3.9+
20
+ - Key Dependencies:
21
+ - gradio
22
+ - transformers
23
+ - torch
24
+ - numpy
25
+
26
+ ## Deployment
27
+
28
+ This model is deployed using Hugging Face Spaces with a Gradio interface.
29
+
30
+ ## License
31
+
32
+ MIT License
README.md ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ML/AI Models API Service
2
+
3
+ A centralized service that hosts various machine learning and AI models using Gradio interfaces, exposed via REST APIs for external frontend clients.
4
+
5
+ ## Features
6
+
7
+ - Multiple ML/AI model endpoints
8
+ - Gradio interfaces for each model
9
+ - FastAPI backend for API exposure
10
+ - Easy model management and deployment
11
+ - Scalable architecture
12
+
13
+ ## Setup
14
+
15
+ 1. Create a virtual environment:
16
+ ```bash
17
+ python -m venv venv
18
+ source venv/bin/activate # On Windows: venv\Scripts\activate
19
+ ```
20
+
21
+ 2. Install dependencies:
22
+ ```bash
23
+ pip install -r requirements.txt
24
+ ```
25
+
26
+ 3. Create a `.env` file:
27
+ ```bash
28
+ cp .env.example .env
29
+ ```
30
+
31
+ 4. Run the development server:
32
+ ```bash
33
+ python src/main.py
34
+ ```
35
+
36
+ ## Project Structure
37
+
38
+ ```
39
+ ├── src/
40
+ │ ├── main.py # Main application entry point
41
+ │ ├── config.py # Configuration settings
42
+ │ ├── models/ # ML/AI model implementations
43
+ │ │ ├── __init__.py
44
+ │ │ └── base.py # Base model class
45
+ │ ├── interfaces/ # Gradio interfaces
46
+ │ │ └── __init__.py
47
+ │ └── api/ # FastAPI routes
48
+ │ └── __init__.py
49
+ ├── tests/ # Test files
50
+ ├── .env.example # Example environment variables
51
+ ├── requirements.txt # Project dependencies
52
+ └── README.md # This file
53
+ ```
54
+
55
+ ## Adding New Models
56
+
57
+ 1. Add your model implementation in `src/models/`
58
+ 2. Create a Gradio interface in `src/interfaces/`
59
+ 3. Add API endpoints in `src/api/`
60
+ 4. Register the model in `src/main.py`
61
+
62
+ ## API Documentation
63
+
64
+ Once the server is running, visit:
65
+ - API documentation: `http://localhost:8000/docs`
66
+ - Gradio interfaces: `http://localhost:8000/gradio`
app.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from src.models.text_classification import TextClassificationModel
3
+ import logging
4
+
5
+ # Configure logging
6
+ logging.basicConfig(level=logging.INFO)
7
+ logger = logging.getLogger(__name__)
8
+
9
+ def create_demo():
10
+ try:
11
+ # Initialize the model
12
+ logger.info("Initializing Text Classification model...")
13
+ model = TextClassificationModel()
14
+
15
+ # Create the interface
16
+ logger.info("Creating Gradio interface...")
17
+ demo = model.create_interface()
18
+
19
+ logger.info("Gradio interface created successfully")
20
+ return demo
21
+
22
+ except Exception as e:
23
+ logger.error(f"Error creating demo: {str(e)}")
24
+ raise
25
+
26
+ if __name__ == "__main__":
27
+ try:
28
+ logger.info("Starting the application...")
29
+ demo = create_demo()
30
+ logger.info("Launching the interface...")
31
+ demo.launch(
32
+ server_name="0.0.0.0", # Allow external connections
33
+ server_port=7860, # Specify port explicitly
34
+ share=True # Enable public URL
35
+ )
36
+ logger.info("Interface launched successfully")
37
+ except Exception as e:
38
+ logger.error(f"Application error: {str(e)}")
39
+ raise
app_hf.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from src.models.text_classification import TextClassificationModel
3
+ import logging
4
+
5
+ # Configure logging
6
+ logging.basicConfig(level=logging.INFO)
7
+ logger = logging.getLogger(__name__)
8
+
9
+ # Initialize the model
10
+ model = TextClassificationModel()
11
+
12
+ # Create the interface
13
+ demo = model.create_interface()
14
+
15
+ # Launch the interface (Hugging Face will handle the server configuration)
16
+ demo.launch()
requirements-full.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.19.2
2
+ fastapi>=0.110.0
3
+ uvicorn>=0.27.1
4
+ python-dotenv>=1.0.1
5
+ pydantic>=2.6.3
6
+ numpy>=1.26.4
7
+ torch>=2.2.1
8
+ transformers>=4.38.2
9
+ pillow>=10.2.0
10
+ python-multipart>=0.0.9
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.19.2
2
+ transformers>=4.38.2
3
+ torch>=2.2.1
4
+ numpy>=1.26.4
5
+ fastapi>=0.110.0
6
+ uvicorn>=0.27.1
7
+ pydantic>=2.6.3
8
+ pydantic-settings>=2.2.1
9
+ python-dotenv>=1.0.1
src/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """ML Models API Service."""
src/api/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """API endpoints for ML models."""
src/api/models.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, HTTPException
2
+ from typing import List, Dict
3
+ from models.registry import GradioRegistry
4
+
5
+ router = APIRouter()
6
+ registry = GradioRegistry()
7
+
8
+ @router.get("/models", response_model=List[Dict])
9
+ async def list_models():
10
+ """List all available models."""
11
+ return registry.list_models()
12
+
13
+ @router.get("/models/{model_id}")
14
+ async def get_model_info(model_id: str):
15
+ """Get information about a specific model."""
16
+ model_info = registry.get_model_info(model_id)
17
+ if not model_info:
18
+ raise HTTPException(status_code=404, detail="Model not found")
19
+ return model_info
20
+
21
+ @router.get("/models/{model_id}/status")
22
+ async def get_model_status(model_id: str):
23
+ """Get the current status of a model."""
24
+ model_info = registry.get_model_info(model_id)
25
+ if not model_info:
26
+ raise HTTPException(status_code=404, detail="Model not found")
27
+ return {"status": model_info.status}
28
+
29
+ @router.post("/models/{model_id}/load")
30
+ async def load_model(model_id: str):
31
+ """Load a model into memory."""
32
+ model = registry.get_model(model_id)
33
+ if not model:
34
+ raise HTTPException(status_code=404, detail="Model not found")
35
+
36
+ try:
37
+ model.load_model()
38
+ registry.update_model_status(model_id, "loaded")
39
+ return {"status": "loaded"}
40
+ except Exception as e:
41
+ registry.update_model_status(model_id, "error")
42
+ raise HTTPException(status_code=500, detail=str(e))
src/api/text_classification.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, HTTPException
2
+ from pydantic import BaseModel
3
+ from typing import Dict, Union, List
4
+ from models.text_classification import TextClassificationModel
5
+
6
+ router = APIRouter()
7
+ model = TextClassificationModel()
8
+
9
+ class TextInput(BaseModel):
10
+ text: str
11
+
12
+ class BatchTextInput(BaseModel):
13
+ texts: List[str]
14
+
15
+ class PredictionResponse(BaseModel):
16
+ label: str
17
+ confidence: float
18
+
19
+ class BatchPredictionResponse(BaseModel):
20
+ predictions: List[PredictionResponse]
21
+
22
+ @router.post("/predict", response_model=PredictionResponse)
23
+ async def predict(input_data: TextInput) -> Dict[str, Union[str, float]]:
24
+ """Make a prediction for a single text."""
25
+ try:
26
+ result = await model.predict(input_data.text)
27
+ return result
28
+ except Exception as e:
29
+ raise HTTPException(
30
+ status_code=500,
31
+ detail=f"Prediction failed: {str(e)}"
32
+ )
33
+
34
+ @router.post("/predict_batch", response_model=BatchPredictionResponse)
35
+ async def predict_batch(input_data: BatchTextInput) -> Dict[str, List[Dict[str, Union[str, float]]]]:
36
+ """Make predictions for multiple texts."""
37
+ try:
38
+ predictions = []
39
+ for text in input_data.texts:
40
+ result = await model.predict(text)
41
+ predictions.append(result)
42
+ return {"predictions": predictions}
43
+ except Exception as e:
44
+ raise HTTPException(
45
+ status_code=500,
46
+ detail=f"Batch prediction failed: {str(e)}"
47
+ )
48
+
49
+ @router.get("/info")
50
+ async def get_model_info():
51
+ """Get information about the text classification model."""
52
+ return model.get_info()
src/config.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic_settings import BaseSettings
2
+ from pydantic import Field
3
+
4
+
5
+ class Settings(BaseSettings):
6
+ """Application settings."""
7
+ app_name: str = "ML Models API"
8
+ debug: bool = Field(default=False, env="DEBUG")
9
+ host: str = Field(default="127.0.0.1", env="HOST")
10
+ port: int = Field(default=8000, env="PORT")
11
+
12
+ # Add model-specific configurations here
13
+ MODEL_CACHE_DIR: str = "model_cache"
14
+ MAX_BATCH_SIZE: int = 32
15
+
16
+ class Config:
17
+ env_file = ".env"
18
+ env_file_encoding = "utf-8"
19
+
20
+
21
+ def get_settings() -> Settings:
22
+ """Get application settings."""
23
+ return Settings()
src/main.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uvicorn
2
+ from fastapi import FastAPI
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ import gradio as gr
5
+ from config import get_settings
6
+ from models.text_classification import TextClassificationModel
7
+ from api.models import router as models_router, registry
8
+
9
+ app = FastAPI(
10
+ title=get_settings().app_name,
11
+ description="API for managing and running ML models",
12
+ version="1.0.0",
13
+ docs_url="/docs",
14
+ redoc_url="/redoc",
15
+ )
16
+
17
+ # Configure CORS
18
+ app.add_middleware(
19
+ CORSMiddleware,
20
+ allow_origins=["*"], # Modify this in production
21
+ allow_credentials=True,
22
+ allow_methods=["*"],
23
+ allow_headers=["*"],
24
+ )
25
+
26
+ # Register models
27
+ text_classifier = TextClassificationModel()
28
+ registry.register_model(
29
+ "text-classification",
30
+ text_classifier,
31
+ "/gradio/text-classification"
32
+ )
33
+
34
+ # Mount the models API router
35
+ app.include_router(
36
+ models_router,
37
+ prefix="/api/models",
38
+ tags=["models"]
39
+ )
40
+
41
+ # Mount Gradio interface
42
+ app = gr.mount_gradio_app(
43
+ app,
44
+ text_classifier.create_interface(),
45
+ path="/gradio/text-classification"
46
+ )
47
+
48
+ @app.get("/")
49
+ async def root():
50
+ """Root endpoint returning basic API information."""
51
+ return {
52
+ "name": get_settings().app_name,
53
+ "version": "1.0.0",
54
+ "status": "running"
55
+ }
56
+
57
+ if __name__ == "__main__":
58
+ # Initialize settings
59
+ settings = get_settings()
60
+
61
+ uvicorn.run(
62
+ "main:app",
63
+ host=settings.host,
64
+ port=settings.port,
65
+ reload=settings.debug
66
+ )
src/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """ML model implementations."""
src/models/base.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, Dict, Optional
3
+ import gradio as gr
4
+
5
+
6
+ class BaseModel(ABC):
7
+ """Base class for all ML/AI models."""
8
+
9
+ def __init__(self, name: str, description: str):
10
+ self.name = name
11
+ self.description = description
12
+ self._model: Optional[Any] = None
13
+ self._interface: Optional[gr.Interface] = None
14
+
15
+ @abstractmethod
16
+ def load_model(self) -> None:
17
+ """Load the model into memory."""
18
+ pass
19
+
20
+ @abstractmethod
21
+ def create_interface(self) -> gr.Interface:
22
+ """Create and return a Gradio interface for the model."""
23
+ pass
24
+
25
+ @abstractmethod
26
+ async def predict(self, *args, **kwargs) -> Any:
27
+ """Make predictions using the model."""
28
+ pass
29
+
30
+ def get_interface(self) -> gr.Interface:
31
+ """Get or create the Gradio interface."""
32
+ if self._interface is None:
33
+ self._interface = self.create_interface()
34
+ return self._interface
35
+
36
+ def get_info(self) -> Dict[str, str]:
37
+ """Get model information."""
38
+ return {
39
+ "name": self.name,
40
+ "description": self.description,
41
+ "status": "loaded" if self._model is not None else "unloaded"
42
+ }
src/models/registry.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional
2
+ import gradio as gr
3
+ from .base import BaseModel
4
+
5
+ class GradioModelInfo:
6
+ """Information about a Gradio model."""
7
+ def __init__(self,
8
+ model_id: str,
9
+ name: str,
10
+ description: str,
11
+ input_type: str,
12
+ output_type: List[str],
13
+ examples: List[List[str]],
14
+ api_path: str):
15
+ self.model_id = model_id
16
+ self.name = name
17
+ self.description = description
18
+ self.input_type = input_type
19
+ self.output_type = output_type
20
+ self.examples = examples
21
+ self.api_path = api_path
22
+ self.status = "unloaded"
23
+
24
+ class GradioRegistry:
25
+ """Registry for Gradio models."""
26
+
27
+ def __init__(self):
28
+ self._models: Dict[str, BaseModel] = {}
29
+ self._model_info: Dict[str, GradioModelInfo] = {}
30
+
31
+ def register_model(self,
32
+ model_id: str,
33
+ model: BaseModel,
34
+ api_path: str) -> None:
35
+ """Register a new Gradio model."""
36
+ self._models[model_id] = model
37
+
38
+ # Create interface to extract information
39
+ interface = model.create_interface()
40
+
41
+ # Store model information
42
+ self._model_info[model_id] = GradioModelInfo(
43
+ model_id=model_id,
44
+ name=model.name,
45
+ description=model.description,
46
+ input_type=interface.input_components[0].__class__.__name__,
47
+ output_type=[comp.__class__.__name__ for comp in interface.output_components],
48
+ examples=interface.examples or [],
49
+ api_path=api_path
50
+ )
51
+
52
+ def get_model(self, model_id: str) -> Optional[BaseModel]:
53
+ """Get a model by ID."""
54
+ return self._models.get(model_id)
55
+
56
+ def get_model_info(self, model_id: str) -> Optional[GradioModelInfo]:
57
+ """Get model information by ID."""
58
+ return self._model_info.get(model_id)
59
+
60
+ def list_models(self) -> List[Dict]:
61
+ """List all registered models."""
62
+ return [
63
+ {
64
+ "id": info.model_id,
65
+ "name": info.name,
66
+ "description": info.description,
67
+ "input_type": info.input_type,
68
+ "output_type": info.output_type,
69
+ "examples": info.examples,
70
+ "api_path": info.api_path,
71
+ "status": info.status
72
+ }
73
+ for info in self._model_info.values()
74
+ ]
75
+
76
+ def update_model_status(self, model_id: str, status: str) -> None:
77
+ """Update the status of a model."""
78
+ if model_id in self._model_info:
79
+ self._model_info[model_id].status = status
src/models/text_classification.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Union, Tuple
2
+ import gradio as gr
3
+ from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
4
+ import logging
5
+ from .base import BaseModel
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ class TextClassificationModel(BaseModel):
10
+ """Lightweight text classification model using tiny BERT."""
11
+
12
+ def __init__(self):
13
+ super().__init__(
14
+ name="Lightweight Text Classifier",
15
+ description="Fast text classification using a tiny BERT model (4.4MB)"
16
+ )
17
+ self.model_name = "prajjwal1/bert-tiny"
18
+ self._model = None
19
+
20
+ def load_model(self) -> None:
21
+ """Load the classification model."""
22
+ try:
23
+ logger.info(f"Loading model: {self.model_name}")
24
+
25
+ # Initialize model with binary classification
26
+ model = AutoModelForSequenceClassification.from_pretrained(
27
+ self.model_name,
28
+ num_labels=2
29
+ )
30
+ tokenizer = AutoTokenizer.from_pretrained(self.model_name)
31
+
32
+ self._model = pipeline(
33
+ "text-classification",
34
+ model=model,
35
+ tokenizer=tokenizer,
36
+ device=-1 # CPU, use device=0 for GPU
37
+ )
38
+
39
+ # Log model size
40
+ model_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024)
41
+ logger.info(f"Model loaded successfully. Size: {model_size_mb:.2f} MB")
42
+
43
+ except Exception as e:
44
+ logger.error(f"Error loading model: {str(e)}")
45
+ raise
46
+
47
+ async def predict(self, text: str) -> Dict[str, Union[str, float]]:
48
+ """Make a prediction using the model."""
49
+ try:
50
+ if self._model is None:
51
+ self.load_model()
52
+
53
+ logger.info(f"Processing text: {text[:50]}...")
54
+ result = self._model(text)[0]
55
+
56
+ # Map raw labels to sentiment
57
+ label_map = {
58
+ "LABEL_0": "NEGATIVE",
59
+ "LABEL_1": "POSITIVE"
60
+ }
61
+
62
+ prediction = {
63
+ "label": label_map.get(result["label"], result["label"]),
64
+ "confidence": float(result["score"])
65
+ }
66
+ logger.info(f"Prediction result: {prediction}")
67
+ return prediction
68
+
69
+ except Exception as e:
70
+ logger.error(f"Prediction error: {str(e)}")
71
+ raise
72
+
73
+ async def predict_for_interface(self, text: str) -> Tuple[str, float]:
74
+ """Make a prediction and return it in a format suitable for the Gradio interface."""
75
+ result = await self.predict(text)
76
+ return result["label"], result["confidence"]
77
+
78
+ def create_interface(self) -> gr.Interface:
79
+ """Create a Gradio interface for text classification."""
80
+ if self._model is None:
81
+ self.load_model()
82
+
83
+ examples = [
84
+ ["This movie was fantastic! I really enjoyed it."],
85
+ ["The service was terrible and the food was cold."],
86
+ ["It was an okay experience, nothing special."],
87
+ ["The weather is nice today!"],
88
+ ["I'm feeling sick and tired."]
89
+ ]
90
+
91
+ return gr.Interface(
92
+ fn=self.predict_for_interface, # Use the interface-specific prediction function
93
+ inputs=gr.Textbox(
94
+ lines=3,
95
+ placeholder="Enter text to classify...",
96
+ label="Input Text"
97
+ ),
98
+ outputs=[
99
+ gr.Label(label="Sentiment"),
100
+ gr.Number(label="Confidence", precision=4)
101
+ ],
102
+ title=self.name,
103
+ description=self.description + "\n\nThis model is also available via API!",
104
+ examples=examples,
105
+ api_name="predict"
106
+ )
test_api.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import time
3
+ import logging
4
+
5
+ logging.basicConfig(level=logging.INFO)
6
+ logger = logging.getLogger(__name__)
7
+
8
+ def test_api():
9
+ # Wait a bit for the model to load
10
+ logger.info("Waiting for the server to start...")
11
+ time.sleep(10)
12
+
13
+ base_url = "http://127.0.0.1:7860"
14
+
15
+ # Test single prediction
16
+ test_texts = [
17
+ "This is amazing! I love it!",
18
+ "This is terrible, I hate it.",
19
+ "It's okay, nothing special."
20
+ ]
21
+
22
+ logger.info("\nTesting single predictions:")
23
+ for text in test_texts:
24
+ try:
25
+ logger.info(f"\nTesting with text: {text}")
26
+ response = requests.post(
27
+ f"{base_url}/api/predict",
28
+ json={"text": text}
29
+ )
30
+
31
+ if response.status_code == 200:
32
+ result = response.json()
33
+ logger.info(f"Result: {result}")
34
+ else:
35
+ logger.error(f"Error: {response.status_code} - {response.text}")
36
+
37
+ except Exception as e:
38
+ logger.error(f"Request failed: {str(e)}")
39
+
40
+ time.sleep(1) # Small delay between requests
41
+
42
+ # Test batch prediction
43
+ logger.info("\nTesting batch prediction:")
44
+ try:
45
+ response = requests.post(
46
+ f"{base_url}/api/predict_batch",
47
+ json={"texts": test_texts}
48
+ )
49
+
50
+ if response.status_code == 200:
51
+ result = response.json()
52
+ logger.info(f"Batch results: {result}")
53
+ else:
54
+ logger.error(f"Batch Error: {response.status_code} - {response.text}")
55
+
56
+ except Exception as e:
57
+ logger.error(f"Batch request failed: {str(e)}")
58
+
59
+ if __name__ == "__main__":
60
+ test_api()
test_client.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+ import time
4
+ import logging
5
+ from typing import Dict, List, Any, Optional
6
+
7
+ logging.basicConfig(level=logging.INFO)
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class MLModelsClient:
11
+ """Client for interacting with the ML Models API."""
12
+
13
+ def __init__(self, base_url: str = "http://localhost:8000"):
14
+ self.base_url = base_url
15
+
16
+ def list_models(self) -> List[Dict]:
17
+ """List all available models."""
18
+ try:
19
+ logger.info("Fetching available models...")
20
+ response = requests.get(f"{self.base_url}/api/models")
21
+ response.raise_for_status()
22
+ models = response.json()
23
+ logger.info(f"Found {len(models)} models")
24
+ return models
25
+ except Exception as e:
26
+ logger.error(f"Error listing models: {str(e)}")
27
+ raise
28
+
29
+ def get_model_info(self, model_id: str) -> Dict:
30
+ """Get information about a specific model."""
31
+ try:
32
+ logger.info(f"Fetching info for model {model_id}...")
33
+ response = requests.get(f"{self.base_url}/api/models/{model_id}")
34
+ response.raise_for_status()
35
+ return response.json()
36
+ except Exception as e:
37
+ logger.error(f"Error getting model info: {str(e)}")
38
+ raise
39
+
40
+ def get_model_status(self, model_id: str) -> str:
41
+ """Get the current status of a model."""
42
+ try:
43
+ logger.info(f"Fetching status for model {model_id}...")
44
+ response = requests.get(f"{self.base_url}/api/models/{model_id}/status")
45
+ response.raise_for_status()
46
+ return response.json()["status"]
47
+ except Exception as e:
48
+ logger.error(f"Error getting model status: {str(e)}")
49
+ raise
50
+
51
+ def load_model(self, model_id: str) -> str:
52
+ """Load a model into memory."""
53
+ try:
54
+ logger.info(f"Loading model {model_id}...")
55
+ response = requests.post(f"{self.base_url}/api/models/{model_id}/load")
56
+ response.raise_for_status()
57
+ return response.json()["status"]
58
+ except Exception as e:
59
+ logger.error(f"Error loading model: {str(e)}")
60
+ raise
61
+
62
+ def predict(self, model_id: str, text: str) -> Dict:
63
+ """Make a prediction using a model."""
64
+ try:
65
+ logger.info(f"Making prediction with model {model_id}...")
66
+ model_info = self.get_model_info(model_id)
67
+ response = requests.post(
68
+ f"{self.base_url}{model_info.get('api_path')}/predict",
69
+ json={"text": text}
70
+ )
71
+ response.raise_for_status()
72
+ return response.json()
73
+ except Exception as e:
74
+ logger.error(f"Error making prediction: {str(e)}")
75
+ raise
76
+
77
+ def test_model_workflow():
78
+ """Test the complete model workflow."""
79
+ client = MLModelsClient()
80
+
81
+ try:
82
+ # 1. List available models
83
+ logger.info("\n1. Testing model listing...")
84
+ models = client.list_models()
85
+ for model in models:
86
+ logger.info(f"Found model: {json.dumps(model, indent=2)}")
87
+
88
+ if not models:
89
+ logger.error("No models found!")
90
+ return
91
+
92
+ # Use the first model for testing
93
+ model_id = models[0]["id"]
94
+
95
+ # 2. Get model information
96
+ logger.info(f"\n2. Testing model info retrieval for {model_id}...")
97
+ model_info = client.get_model_info(model_id)
98
+ logger.info(f"Model info: {json.dumps(model_info, indent=2)}")
99
+
100
+ # 3. Get model status
101
+ logger.info(f"\n3. Testing model status retrieval for {model_id}...")
102
+ status = client.get_model_status(model_id)
103
+ logger.info(f"Model status: {status}")
104
+
105
+ # 4. Load the model
106
+ logger.info(f"\n4. Testing model loading for {model_id}...")
107
+ load_status = client.load_model(model_id)
108
+ logger.info(f"Load status: {load_status}")
109
+
110
+ # 5. Test predictions
111
+ test_texts = [
112
+ "This is amazing! I really love it!",
113
+ "This is terrible, I hate it.",
114
+ "It's okay, nothing special."
115
+ ]
116
+
117
+ logger.info(f"\n5. Testing predictions for {model_id}...")
118
+ for text in test_texts:
119
+ logger.info(f"\nPredicting for text: {text}")
120
+ result = client.predict(model_id, text)
121
+ logger.info(f"Prediction: {json.dumps(result, indent=2)}")
122
+ time.sleep(1) # Small delay between predictions
123
+
124
+ except Exception as e:
125
+ logger.error(f"Test workflow failed: {str(e)}")
126
+ raise
127
+
128
+ if __name__ == "__main__":
129
+ logger.info("Starting model testing workflow...")
130
+ test_model_workflow()