|
import os |
|
import torch |
|
import gradio as gr |
|
from model import SmolLM |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
device = "mps" if torch.backends.mps.is_available() else "cpu" |
|
hf_token = os.environ.get("HF_TOKEN") |
|
repo_id = "ZivK/smollm2-end-of-sentence" |
|
model_options = { |
|
"Word-level Model": "word_model.ckpt", |
|
"Token-level Model": "token_model.ckpt" |
|
} |
|
label_map = {0: "Incomplete", 1: "Complete"} |
|
models = {} |
|
for model_name, filename in model_options.items(): |
|
print(f"Loading {model_name} ...") |
|
checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename, token=hf_token) |
|
models[model_name] = SmolLM.load_from_checkpoint(checkpoint_path).to(device) |
|
models[model_name].eval() |
|
|
|
|
|
def classify_sentence(sentence, model_choice): |
|
model = models[model_choice] |
|
inputs = model.tokenizer(sentence, return_tensors="pt", padding=True, truncation=True).to(device) |
|
with torch.no_grad(): |
|
logits = model(inputs) |
|
confidence = torch.sigmoid(logits).item() * 100 |
|
predicted_class = 1 if confidence > 50.0 else 0 |
|
return label_map[predicted_class], confidence |
|
|
|
|
|
def chatbot_reply(history, user_input, model_choice): |
|
classification, confidence = classify_sentence(user_input, model_choice) |
|
|
|
if classification == "Incomplete": |
|
bot_message = "It looks like you may have stopped mid-sentence. Please finish your thought! Confidence: " + \ |
|
f"{(100.0-confidence):.2f}" |
|
else: |
|
bot_message = f"Thank you for sharing a complete sentence! Confidence: {confidence:.2f}" |
|
|
|
|
|
history.append((user_input, bot_message)) |
|
return history, "" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
"## Sentence Completeness Chatbot\nType a message and see if the model thinks it’s complete or incomplete!") |
|
gr.Markdown("#### [Click here to view the model on Hugging Face](https://huggingface.co/ZivK/smollm2-end-of-sentence)") |
|
|
|
|
|
chatbot = gr.Chatbot(label="Chat with Me!") |
|
state = gr.State([]) |
|
|
|
with gr.Row(): |
|
user_input = gr.Textbox(show_label=False, placeholder="Type your sentence here...") |
|
submit_btn = gr.Button("Submit") |
|
with gr.Row(): |
|
model_input = gr.Dropdown(choices=list(model_options.keys()), label="Select Model") |
|
|
|
|
|
submit_btn.click(fn=chatbot_reply, |
|
inputs=[state, user_input, model_input], |
|
outputs=[chatbot, user_input]) |
|
|
|
|
|
user_input.submit(fn=chatbot_reply, |
|
inputs=[state, user_input, model_input], |
|
outputs=[chatbot, user_input]) |
|
|
|
|
|
demo.launch() |
|
|