ZivK's picture
Added torch no_grad
73aff9a
raw
history blame
1.56 kB
import os
import torch
import gradio as gr
from model import SmolLM
from huggingface_hub import hf_hub_download
hf_token = os.environ.get("HF_TOKEN")
repo_id = "ZivK/smollm2-end-of-sentence"
model_options = {
"Word-level Model": "word_model.ckpt",
"Token-level Model": "token_model.ckpt"
}
models = {}
for model_name, filename in model_options.items():
print(f"Loading {model_name} ...")
checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename, token=hf_token)
models[model_name] = SmolLM.load_from_checkpoint(checkpoint_path)
models[model_name].eval()
def classify_sentence(sentence, model_choice):
model = models[model_choice]
inputs = model.tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, use_fast=True)
with torch.no_grad():
logits = model(inputs)
confidence = torch.sigmoid(logits).item() * 100
confidence_to_display = confidence if confidence > 50.0 else 100 - confidence
label = "Complete" if confidence > 50.0 else "Incomplete"
return f"{label} Sentence\nConfidence: {confidence_to_display:.2f}"
# Create the Gradio interface
interface = gr.Interface(
fn=classify_sentence,
inputs=[
gr.Textbox(lines=1, placeholder="Enter your sentence here..."),
gr.Dropdown(choices=list(model_options.keys()), label="Select Model")
],
outputs="text",
title="Complete Sentence Classifier",
description="## Enter a sentence to determine if it's complete or if it might be cut off"
)
# Launch the demo
interface.launch()