|
import torch |
|
import torch.nn as nn |
|
from config import Net |
|
import gradio as gr |
|
|
|
|
|
model = Net() |
|
model.load_state_dict(torch.load("pytorch_model.pth", map_location=torch.device("cpu"))) |
|
model.eval() |
|
|
|
|
|
def predict(inputs): |
|
with torch.no_grad(): |
|
inputs = torch.tensor(inputs).float().unsqueeze(0) |
|
output = model(inputs) |
|
if isinstance(output, torch.Tensor): |
|
return output.squeeze().tolist() |
|
return output |
|
|
|
|
|
demo = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Textbox(label="Enter comma-separated input values (e.g., 1.2, 3.4, 5.6)"), |
|
outputs=gr.Textbox(label="Model Output"), |
|
title="PyTorch MLP Classifier" |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|