File size: 1,555 Bytes
781bf2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73aff9a
781bf2a
 
 
 
73aff9a
 
 
781bf2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import os
import torch
import gradio as gr
from model import SmolLM
from huggingface_hub import hf_hub_download

hf_token = os.environ.get("HF_TOKEN")
repo_id = "ZivK/smollm2-end-of-sentence"
model_options = {
    "Word-level Model": "word_model.ckpt",
    "Token-level Model": "token_model.ckpt"
}
models = {}
for model_name, filename in model_options.items():
    print(f"Loading {model_name} ...")
    checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename, token=hf_token)
    models[model_name] = SmolLM.load_from_checkpoint(checkpoint_path)
    models[model_name].eval()


def classify_sentence(sentence, model_choice):
    model = models[model_choice]
    inputs = model.tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, use_fast=True)
    with torch.no_grad():
        logits = model(inputs)
    confidence = torch.sigmoid(logits).item() * 100
    confidence_to_display = confidence if confidence > 50.0 else 100 - confidence
    label = "Complete" if confidence > 50.0 else "Incomplete"

    return f"{label} Sentence\nConfidence: {confidence_to_display:.2f}"


# Create the Gradio interface
interface = gr.Interface(
    fn=classify_sentence,
    inputs=[
        gr.Textbox(lines=1, placeholder="Enter your sentence here..."),
        gr.Dropdown(choices=list(model_options.keys()), label="Select Model")
    ],
    outputs="text",
    title="Complete Sentence Classifier",
    description="## Enter a sentence to determine if it's complete or if it might be cut off"
)

# Launch the demo
interface.launch()