File size: 4,503 Bytes
d23e46f
3aafe68
330fc4f
699a812
d23e46f
05e25f7
 
eaba60c
d23e46f
 
1836de9
d23e46f
 
b5b8672
d23e46f
 
 
 
 
 
 
 
 
 
 
 
 
b5b8672
d23e46f
a454488
 
 
d23e46f
05e25f7
d23e46f
 
 
 
 
 
 
05e25f7
d23e46f
05e25f7
d23e46f
 
05e25f7
330fc4f
05e25f7
 
d23e46f
b5b8672
d23e46f
 
 
 
 
 
 
 
b5b8672
d23e46f
 
 
 
05e25f7
ac19c17
e65b516
 
ac19c17
 
 
 
 
d23e46f
ac19c17
05e25f7
 
 
ac19c17
 
d23e46f
 
 
1836de9
d23e46f
 
 
 
 
1836de9
d23e46f
 
 
a454488
0373f3c
330fc4f
27b07a6
d23e46f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# πŸš€ DigiTwin - TinyLLaMA ChatML Inference App
import streamlit as st
import torch
import os
import time
from threading import Thread
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer

# --- Hugging Face Token (use Streamlit secrets for secure deployment) ---
HF_TOKEN = st.secrets["HF_TOKEN"]

# --- Streamlit Page Configuration ---
st.set_page_config(page_title="DigiTwin - TinyLLaMA", page_icon="πŸ¦™", layout="centered")

# --- App Logo and Title ---
# st.image("assets/valonylabz_logo.png", width=160)  # Optional: Add your logo
st.title("πŸš€ DigiTwin - TinyLLaMA ChatML πŸš€")

# --- Model Path ---
MODEL_ID = "amiguel/TinyLLaMA-110M-general-knowledge"

# --- System Prompt (ChatML format) ---
SYSTEM_PROMPT = (
    "You are DigiTwin, the digital twin of Ataliba, an inspection engineer with over 17 years "
    "of experience in mechanical integrity, reliability, piping, and asset management. "
    "Be precise, practical, and technical. Provide advice aligned with industry best practices."
)

# --- Avatars ---
USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png"
BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg"

# --- Load Model & Tokenizer ---
@st.cache_resource
def load_tinyllama_model():
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_ID,
        trust_remote_code=True,
        use_fast=False,  # SentencePiece tokenizer
        token=HF_TOKEN
    )
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        device_map="auto",
        torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
        trust_remote_code=True,
        token=HF_TOKEN
    )
    return model, tokenizer

model, tokenizer = load_tinyllama_model()

# --- Build ChatML Prompt ---
def build_chatml_prompt(messages):
    prompt = f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n"
    for msg in messages:
        role = msg["role"]
        prompt += f"<|im_start|>{role}\n{msg['content']}<|im_end|>\n"
    prompt += "<|im_start|>assistant\n"
    return prompt

# --- Generate Response with Streaming ---
def generate_response(prompt_text):
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)

    generation_kwargs = {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
        "max_new_tokens": 1024,
        "temperature": 0.7,
        "top_p": 0.9,
        "repetition_penalty": 1.1,
        "do_sample": True,
        "streamer": streamer,
    }

    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    return streamer

# --- Session State ---
if "messages" not in st.session_state:
    st.session_state.messages = []

# --- Display Chat History ---
for msg in st.session_state.messages:
    avatar = USER_AVATAR if msg["role"] == "user" else BOT_AVATAR
    with st.chat_message(msg["role"], avatar=avatar):
        st.markdown(msg["content"])

# --- Handle Chat Input ---
if prompt := st.chat_input("Ask DigiTwin about inspection, piping, or reliability..."):
    # Show user message
    with st.chat_message("user", avatar=USER_AVATAR):
        st.markdown(prompt)
    st.session_state.messages.append({"role": "user", "content": prompt})

    # Build ChatML-formatted prompt
    prompt_text = build_chatml_prompt(st.session_state.messages)

    with st.chat_message("assistant", avatar=BOT_AVATAR):
        start = time.time()
        streamer = generate_response(prompt_text)

        response_area = st.empty()
        full_response = ""

        for chunk in streamer:
            full_response += chunk.replace("<|im_end|>", "").strip() + " "
            response_area.markdown(full_response + "β–Œ", unsafe_allow_html=True)

        end = time.time()
        input_tokens = len(tokenizer(prompt_text)["input_ids"])
        output_tokens = len(tokenizer(full_response)["input_ids"])
        speed = output_tokens / (end - start)

        st.caption(
            f"πŸ”‘ Input Tokens: {input_tokens} | Output Tokens: {output_tokens} | "
            f"πŸ•’ Speed: {speed:.1f} tokens/sec"
        )

        response_area.markdown(full_response.strip())
        st.session_state.messages.append({"role": "assistant", "content": full_response.strip()})