|
import os |
|
import torch |
|
import gradio as gr |
|
from model import SmolLM |
|
from huggingface_hub import hf_hub_download |
|
|
|
hf_token = os.environ.get("HF_TOKEN") |
|
repo_id = "ZivK/smollm2-end-of-sentence" |
|
model_options = { |
|
"Word-level Model": "word_model.ckpt", |
|
"Token-level Model": "token_model.ckpt" |
|
} |
|
models = {} |
|
for model_name, filename in model_options.items(): |
|
print(f"Loading {model_name} ...") |
|
checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename, token=hf_token) |
|
models[model_name] = SmolLM.load_from_checkpoint(checkpoint_path) |
|
models[model_name].eval() |
|
|
|
|
|
def classify_sentence(sentence, model_choice): |
|
model = models[model_choice] |
|
inputs = model.tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, use_fast=True) |
|
with torch.no_grad(): |
|
logits = model(inputs) |
|
confidence = torch.sigmoid(logits).item() * 100 |
|
confidence_to_display = confidence if confidence > 50.0 else 100 - confidence |
|
label = "Complete" if confidence > 50.0 else "Incomplete" |
|
|
|
return f"{label} Sentence\nConfidence: {confidence_to_display:.2f}" |
|
|
|
|
|
|
|
interface = gr.Interface( |
|
fn=classify_sentence, |
|
inputs=[ |
|
gr.Textbox(lines=1, placeholder="Enter your sentence here..."), |
|
gr.Dropdown(choices=list(model_options.keys()), label="Select Model") |
|
], |
|
outputs="text", |
|
title="Complete Sentence Classifier", |
|
description="## Enter a sentence to determine if it's complete or if it might be cut off" |
|
) |
|
|
|
|
|
interface.launch() |
|
|