|
|
|
import streamlit as st |
|
import torch |
|
import os |
|
import time |
|
from threading import Thread |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer |
|
|
|
|
|
HF_TOKEN = st.secrets["HF_TOKEN"] |
|
|
|
|
|
st.set_page_config(page_title="DigiTwin - TinyLLaMA", page_icon="π¦", layout="centered") |
|
|
|
|
|
|
|
st.title("π DigiTwin - TinyLLaMA ChatML π") |
|
|
|
|
|
MODEL_ID = "amiguel/TinyLLaMA-110M-general-knowledge" |
|
|
|
|
|
SYSTEM_PROMPT = ( |
|
"You are DigiTwin, the digital twin of Ataliba, an inspection engineer with over 17 years " |
|
"of experience in mechanical integrity, reliability, piping, and asset management. " |
|
"Be precise, practical, and technical. Provide advice aligned with industry best practices." |
|
) |
|
|
|
|
|
USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png" |
|
BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg" |
|
|
|
|
|
@st.cache_resource |
|
def load_tinyllama_model(): |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
MODEL_ID, |
|
trust_remote_code=True, |
|
use_fast=False, |
|
token=HF_TOKEN |
|
) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_ID, |
|
device_map="auto", |
|
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, |
|
trust_remote_code=True, |
|
token=HF_TOKEN |
|
) |
|
return model, tokenizer |
|
|
|
model, tokenizer = load_tinyllama_model() |
|
|
|
|
|
def build_chatml_prompt(messages): |
|
prompt = f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n" |
|
for msg in messages: |
|
role = msg["role"] |
|
prompt += f"<|im_start|>{role}\n{msg['content']}<|im_end|>\n" |
|
prompt += "<|im_start|>assistant\n" |
|
return prompt |
|
|
|
|
|
def generate_response(prompt_text): |
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device) |
|
|
|
generation_kwargs = { |
|
"input_ids": inputs["input_ids"], |
|
"attention_mask": inputs["attention_mask"], |
|
"max_new_tokens": 1024, |
|
"temperature": 0.7, |
|
"top_p": 0.9, |
|
"repetition_penalty": 1.1, |
|
"do_sample": True, |
|
"streamer": streamer, |
|
} |
|
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
return streamer |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
|
|
for msg in st.session_state.messages: |
|
avatar = USER_AVATAR if msg["role"] == "user" else BOT_AVATAR |
|
with st.chat_message(msg["role"], avatar=avatar): |
|
st.markdown(msg["content"]) |
|
|
|
|
|
if prompt := st.chat_input("Ask DigiTwin about inspection, piping, or reliability..."): |
|
|
|
with st.chat_message("user", avatar=USER_AVATAR): |
|
st.markdown(prompt) |
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
|
|
|
|
prompt_text = build_chatml_prompt(st.session_state.messages) |
|
|
|
with st.chat_message("assistant", avatar=BOT_AVATAR): |
|
start = time.time() |
|
streamer = generate_response(prompt_text) |
|
|
|
response_area = st.empty() |
|
full_response = "" |
|
|
|
for chunk in streamer: |
|
full_response += chunk.replace("<|im_end|>", "").strip() + " " |
|
response_area.markdown(full_response + "β", unsafe_allow_html=True) |
|
|
|
end = time.time() |
|
input_tokens = len(tokenizer(prompt_text)["input_ids"]) |
|
output_tokens = len(tokenizer(full_response)["input_ids"]) |
|
speed = output_tokens / (end - start) |
|
|
|
st.caption( |
|
f"π Input Tokens: {input_tokens} | Output Tokens: {output_tokens} | " |
|
f"π Speed: {speed:.1f} tokens/sec" |
|
) |
|
|
|
response_area.markdown(full_response.strip()) |
|
st.session_state.messages.append({"role": "assistant", "content": full_response.strip()}) |
|
|