Spaces:
Sleeping
Sleeping
# π DigiTwin Streamlit App β Cleaned and Fixed for Fully Merged Models | |
import streamlit as st | |
import torch | |
import os | |
import time | |
from threading import Thread | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
# --- Hugging Face Token from Secrets --- | |
HF_TOKEN = st.secrets["HF_TOKEN"] | |
# --- Streamlit Page Configuration --- | |
st.set_page_config( | |
page_title="DigiTwin - ValLabs", | |
page_icon="π", | |
layout="centered" | |
) | |
# --- Logo and App Title --- | |
#st.image("assets/valonylabz_logo.png", width=150) # Replace with your logo | |
st.title("π DigiTwin - ValLabs π") | |
# --- Model Options (All Fully Merged Models) --- | |
MODEL_OPTIONS = { | |
"GM Qwen (merged)": "amiguel/GM_Qwen1.8B_Finetune", | |
"GM Mistral (merged)": "amiguel/GM_finetune" | |
} | |
# --- Sidebar: Model Selection --- | |
with st.sidebar: | |
st.header("π§ Choose Model") | |
selected_model_name = st.selectbox("Select model", list(MODEL_OPTIONS.keys()), index=0) | |
selected_model = MODEL_OPTIONS[selected_model_name] | |
# --- Session State for Messages & Reload --- | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
if "model_name" not in st.session_state or st.session_state.model_name != selected_model_name: | |
st.session_state.model = None | |
st.session_state.model_name = selected_model_name | |
# --- DigiTwin System Prompt --- | |
SYSTEM_PROMPT = ( | |
"You are DigiTwin, the digital twin of Ataliba, an inspection engineer with over 17 years of " | |
"experience in mechanical integrity, reliability, piping, and asset management. " | |
"Be precise, practical, and technical. Provide advice aligned with industry best practices." | |
) | |
# --- Load Fully Merged Model --- | |
def load_model(model_path): | |
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, token=HF_TOKEN) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
device_map="auto", | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True, | |
token=HF_TOKEN | |
) | |
return model, tokenizer | |
# --- Load Model on Selection --- | |
if st.session_state.model is None: | |
with st.spinner(f"Loading model: {selected_model_name}..."): | |
st.session_state.model, st.session_state.tokenizer = load_model(selected_model) | |
model = st.session_state.model | |
tokenizer = st.session_state.tokenizer | |
# --- Build Prompt with ChatML Format --- | |
def build_prompt(messages): | |
full_prompt = f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n" | |
for msg in messages: | |
role = msg["role"] | |
full_prompt += f"<|im_start|>{role}\n{msg['content']}<|im_end|>\n" | |
full_prompt += "<|im_start|>assistant\n" | |
return full_prompt | |
# --- Generate Assistant Response (Streaming) --- | |
def generate_response(prompt_text, model, tokenizer): | |
streamer = TextIteratorStreamer( | |
tokenizer, | |
skip_prompt=True, | |
skip_special_tokens=True | |
) | |
inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device) | |
generation_kwargs = { | |
"input_ids": inputs["input_ids"], | |
"attention_mask": inputs["attention_mask"], | |
"max_new_tokens": 1024, | |
"temperature": 0.7, | |
"top_p": 0.9, | |
"repetition_penalty": 1.1, | |
"do_sample": True, | |
"streamer": streamer | |
} | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
return streamer | |
# --- Avatars --- | |
USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png" | |
BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg" | |
# --- Display Chat History --- | |
for message in st.session_state.messages: | |
avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR | |
with st.chat_message(message["role"], avatar=avatar): | |
st.markdown(message["content"]) | |
# --- Chat Input Handling --- | |
if prompt := st.chat_input("Ask your inspection or reliability question..."): | |
with st.chat_message("user", avatar=USER_AVATAR): | |
st.markdown(prompt) | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
full_prompt = build_prompt(st.session_state.messages) | |
with st.chat_message("assistant", avatar=BOT_AVATAR): | |
start_time = time.time() | |
streamer = generate_response(full_prompt, model, tokenizer) | |
response_container = st.empty() | |
full_response = "" | |
for chunk in streamer: | |
full_response += chunk | |
response_container.markdown(full_response + "β", unsafe_allow_html=True) | |
end_time = time.time() | |
input_tokens = len(tokenizer(full_prompt)["input_ids"]) | |
output_tokens = len(tokenizer(full_response)["input_ids"]) | |
speed = output_tokens / (end_time - start_time) | |
st.caption( | |
f"π Input Tokens: {input_tokens} | Output Tokens: {output_tokens} | " | |
f"π Speed: {speed:.1f} tokens/sec" | |
) | |
response_container.markdown(full_response) | |
st.session_state.messages.append({"role": "assistant", "content": full_response}) | |