File size: 2,876 Bytes
781bf2a
 
 
 
 
 
1e6d148
 
781bf2a
 
 
 
 
 
1e6d148
781bf2a
 
 
 
1e6d148
73aff9a
781bf2a
 
 
 
1e6d148
73aff9a
 
781bf2a
1e6d148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34662fc
1e6d148
 
 
 
781bf2a
1e6d148
 
 
 
 
781bf2a
1e6d148
 
 
 
781bf2a
1e6d148
 
 
 
781bf2a
 
1e6d148
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import torch
import gradio as gr
from model import SmolLM
from huggingface_hub import hf_hub_download


device = "mps" if torch.backends.mps.is_available() else "cpu"
hf_token = os.environ.get("HF_TOKEN")
repo_id = "ZivK/smollm2-end-of-sentence"
model_options = {
    "Word-level Model": "word_model.ckpt",
    "Token-level Model": "token_model.ckpt"
}
label_map = {0: "Incomplete", 1: "Complete"}
models = {}
for model_name, filename in model_options.items():
    print(f"Loading {model_name} ...")
    checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename, token=hf_token)
    models[model_name] = SmolLM.load_from_checkpoint(checkpoint_path).to(device)
    models[model_name].eval()


def classify_sentence(sentence, model_choice):
    model = models[model_choice]
    inputs = model.tokenizer(sentence, return_tensors="pt", padding=True, truncation=True).to(device)
    with torch.no_grad():
        logits = model(inputs)
    confidence = torch.sigmoid(logits).item() * 100
    predicted_class = 1 if confidence > 50.0 else 0
    return label_map[predicted_class], confidence


def chatbot_reply(history, user_input, model_choice):
    classification, confidence = classify_sentence(user_input, model_choice)

    if classification == "Incomplete":
        bot_message = "It looks like you may have stopped mid-sentence. Please finish your thought! Confidence: " + \
                        f"{(100.0-confidence):.2f}"
    else:
        bot_message = f"Thank you for sharing a complete sentence!  Confidence: {confidence:.2f}"

    # Append the user message and bot response to the conversation history
    history.append((user_input, bot_message))
    return history, ""


with gr.Blocks() as demo:
    gr.Markdown(
        "## Sentence Completeness Chatbot\nType a message and see if the model thinks it’s complete or incomplete!")
    gr.Markdown("#### [Click here to view the model on Hugging Face](https://huggingface.co/ZivK/smollm2-end-of-sentence)")

    # 3. Create a stateful Chatbot plus an input textbox
    chatbot = gr.Chatbot(label="Chat with Me!")
    state = gr.State([])  # This will store the conversation history

    with gr.Row():
        user_input = gr.Textbox(show_label=False, placeholder="Type your sentence here...")
        submit_btn = gr.Button("Submit")
    with gr.Row():
        model_input = gr.Dropdown(choices=list(model_options.keys()), label="Select Model")

    # 4. Bind the chatbot function
    submit_btn.click(fn=chatbot_reply,
                     inputs=[state, user_input, model_input],
                     outputs=[chatbot, user_input])

    # We also want pressing Enter to do the same as clicking submit
    user_input.submit(fn=chatbot_reply,
                      inputs=[state, user_input, model_input],
                      outputs=[chatbot, user_input])

# Launch the demo
demo.launch()