amiguel's picture
Update app.py
76e96ca verified
raw
history blame contribute delete
5.18 kB
# πŸš€ DigiTwin Streamlit App – Cleaned and Fixed for Fully Merged Models
import streamlit as st
import torch
import os
import time
from threading import Thread
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
# --- Hugging Face Token from Secrets ---
HF_TOKEN = st.secrets["HF_TOKEN"]
# --- Streamlit Page Configuration ---
st.set_page_config(
page_title="DigiTwin - ValLabs",
page_icon="πŸš€",
layout="centered"
)
# --- Logo and App Title ---
#st.image("assets/valonylabz_logo.png", width=150) # Replace with your logo
st.title("πŸš€ DigiTwin - ValLabs πŸš€")
# --- Model Options (All Fully Merged Models) ---
MODEL_OPTIONS = {
"GM Qwen (merged)": "amiguel/GM_Qwen1.8B_Finetune",
"GM Mistral (merged)": "amiguel/GM_finetune"
}
# --- Sidebar: Model Selection ---
with st.sidebar:
st.header("🧠 Choose Model")
selected_model_name = st.selectbox("Select model", list(MODEL_OPTIONS.keys()), index=0)
selected_model = MODEL_OPTIONS[selected_model_name]
# --- Session State for Messages & Reload ---
if "messages" not in st.session_state:
st.session_state.messages = []
if "model_name" not in st.session_state or st.session_state.model_name != selected_model_name:
st.session_state.model = None
st.session_state.model_name = selected_model_name
# --- DigiTwin System Prompt ---
SYSTEM_PROMPT = (
"You are DigiTwin, the digital twin of Ataliba, an inspection engineer with over 17 years of "
"experience in mechanical integrity, reliability, piping, and asset management. "
"Be precise, practical, and technical. Provide advice aligned with industry best practices."
)
# --- Load Fully Merged Model ---
@st.cache_resource
def load_model(model_path):
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
token=HF_TOKEN
)
return model, tokenizer
# --- Load Model on Selection ---
if st.session_state.model is None:
with st.spinner(f"Loading model: {selected_model_name}..."):
st.session_state.model, st.session_state.tokenizer = load_model(selected_model)
model = st.session_state.model
tokenizer = st.session_state.tokenizer
# --- Build Prompt with ChatML Format ---
def build_prompt(messages):
full_prompt = f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n"
for msg in messages:
role = msg["role"]
full_prompt += f"<|im_start|>{role}\n{msg['content']}<|im_end|>\n"
full_prompt += "<|im_start|>assistant\n"
return full_prompt
# --- Generate Assistant Response (Streaming) ---
def generate_response(prompt_text, model, tokenizer):
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)
generation_kwargs = {
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"max_new_tokens": 1024,
"temperature": 0.7,
"top_p": 0.9,
"repetition_penalty": 1.1,
"do_sample": True,
"streamer": streamer
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
return streamer
# --- Avatars ---
USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png"
BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg"
# --- Display Chat History ---
for message in st.session_state.messages:
avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR
with st.chat_message(message["role"], avatar=avatar):
st.markdown(message["content"])
# --- Chat Input Handling ---
if prompt := st.chat_input("Ask your inspection or reliability question..."):
with st.chat_message("user", avatar=USER_AVATAR):
st.markdown(prompt)
st.session_state.messages.append({"role": "user", "content": prompt})
full_prompt = build_prompt(st.session_state.messages)
with st.chat_message("assistant", avatar=BOT_AVATAR):
start_time = time.time()
streamer = generate_response(full_prompt, model, tokenizer)
response_container = st.empty()
full_response = ""
for chunk in streamer:
full_response += chunk
response_container.markdown(full_response + "β–Œ", unsafe_allow_html=True)
end_time = time.time()
input_tokens = len(tokenizer(full_prompt)["input_ids"])
output_tokens = len(tokenizer(full_response)["input_ids"])
speed = output_tokens / (end_time - start_time)
st.caption(
f"πŸ”‘ Input Tokens: {input_tokens} | Output Tokens: {output_tokens} | "
f"πŸ•’ Speed: {speed:.1f} tokens/sec"
)
response_container.markdown(full_response)
st.session_state.messages.append({"role": "assistant", "content": full_response})