File size: 823 Bytes
2cec8f9
c68a924
6dd59cc
2cec8f9
 
c68a924
6dd59cc
2cec8f9
 
 
c68a924
 
2cec8f9
c68a924
 
 
 
 
2cec8f9
c68a924
2cec8f9
 
c68a924
 
 
2cec8f9
 
c68a924
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
import torch.nn as nn
from config import MLP
import gradio as gr

# Load the model
model = MLP()
model.load_state_dict(torch.load("pytorch_model.pth", map_location=torch.device("cpu")))
model.eval()

# Define a prediction function
def predict(inputs):
    with torch.no_grad():
        inputs = torch.tensor(inputs).float().unsqueeze(0)  # Add batch dimension
        output = model(inputs)
        if isinstance(output, torch.Tensor):
            return output.squeeze().tolist()
        return output  # fallback

# Create the Gradio interface
demo = gr.Interface(
    fn=predict,
    inputs=gr.Textbox(label="Enter comma-separated input values (e.g., 1.2, 3.4, 5.6)"),
    outputs=gr.Textbox(label="Model Output"),
    title="PyTorch MLP Classifier"
)

# Launch
if __name__ == "__main__":
    demo.launch()