File size: 6,231 Bytes
88f5ec5
 
7b9bc7a
 
88f5ec5
ebfb19f
7b9bc7a
ebfb19f
88f5ec5
7b9bc7a
 
88f5ec5
 
 
7b9bc7a
 
 
88f5ec5
 
 
 
7b9bc7a
 
 
 
 
 
 
 
 
 
 
 
88f5ec5
 
 
7b9bc7a
 
 
 
 
88f5ec5
 
7b9bc7a
88f5ec5
7b9bc7a
 
 
88f5ec5
7b9bc7a
 
88f5ec5
 
 
7b9bc7a
88f5ec5
 
7b9bc7a
 
 
 
88f5ec5
7b9bc7a
88f5ec5
7b9bc7a
 
 
88f5ec5
7b9bc7a
88f5ec5
7b9bc7a
88f5ec5
 
 
 
 
 
7b9bc7a
 
 
 
 
 
88f5ec5
7b9bc7a
 
 
 
 
 
 
88f5ec5
 
 
 
7b9bc7a
 
 
 
 
 
 
 
88f5ec5
 
7b9bc7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88f5ec5
7b9bc7a
 
88f5ec5
 
7b9bc7a
 
88f5ec5
7b9bc7a
 
88f5ec5
 
 
 
 
 
 
7b9bc7a
88f5ec5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import streamlit as st
import os
from transformers import pipeline, AutoTokenizer # Added AutoTokenizer
import torch 

# --- Set Page Config FIRST --- 
st.set_page_config(layout="wide") 

# --- Configuration --- 
# MODEL_NAME = "AdaptLLM/finance-LLM" # Old model
MODEL_NAME = "WiroAI/WiroAI-Finance-Qwen-1.5B" # New smaller model
HF_TOKEN = os.environ.get("HF_TOKEN")

# --- Model Loading (Cached by Streamlit for efficiency) --- 
@st.cache_resource 
def load_resources():
    """Loads the tokenizer and the text generation pipeline."""
    if not HF_TOKEN:
        st.warning("HF_TOKEN secret not found. Ensure the model is public or add the token to secrets.")

    try:
        st.info(f"Loading tokenizer for {MODEL_NAME}...")
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN if HF_TOKEN else None)
        st.success("Tokenizer loaded.")

        # Determine device: Use GPU if available, otherwise CPU
        # device_map="auto" might be problematic on CPU-only Spaces
        # Start with device_map="auto", but fall back to explicit cpu if needed
        device_map_setting = "auto" 
        # device = 0 if torch.cuda.is_available() else -1 # Alternative: explicit device

        st.info(f"Loading model {MODEL_NAME}... (Using {device_map_setting}) This might take a while.")
        # Use pipeline
        generator = pipeline(
            "text-generation",
            model=MODEL_NAME,
            tokenizer=tokenizer, # Pass loaded tokenizer
            model_kwargs={"torch_dtype": torch.bfloat16}, # Use bfloat16 as per model card
            device_map=device_map_setting,
            # device=device # Use this if device_map causes issues
            trust_remote_code=True 
        )
        st.success(f"Model {MODEL_NAME} loaded successfully!")
        return generator, tokenizer # Return both
    except Exception as e:
        st.error(f"Error loading model/tokenizer: {e}", icon="πŸ”₯")
        st.error("Check memory limits, token access, or try removing device_map='auto'.")
        st.stop()

# --- Load Resources --- 
generator, tokenizer = load_resources()

# --- Streamlit App UI --- 
st.title("πŸ’° FinBuddy Assistant")
st.caption(f"Model: {MODEL_NAME}")

if "messages" not in st.session_state:
    # Add initial system message (as per model card example)
    st.session_state.messages = [
        {"role": "system", "content": "You are a finance chatbot developed by Wiro AI"}
    ]

# Display past chat messages (excluding system message)
for message in st.session_state.messages:
    if message["role"] != "system": # Don't display system message
        with st.chat_message(message["role"]):
            st.markdown(message["content"])

# Get user input
if prompt := st.chat_input("Ask a question about finance..."):
    # Add user prompt to state and display
    st.session_state.messages.append({"role": "user", "content": prompt})
    with st.chat_message("user"):
        st.markdown(prompt)

    # Generate assistant response
    with st.chat_message("assistant"):
        message_placeholder = st.empty()
        message_placeholder.markdown("Thinking...⏳")

        # --- Prepare prompt for the model (use message history) ---
        # Use the messages stored in session state (includes system prompt)
        messages_for_api = st.session_state.messages 

        # --- Define terminators as per model card ---
        terminators = [
            tokenizer.eos_token_id,
            tokenizer.convert_tokens_to_ids("<|end_of_text|>") # Qwen uses <|end_of_text|> usually
        ]
        # Handle potential errors if the specific token doesn't exist
        terminators = [term for term in terminators if term is not None and not isinstance(term, list)] # Filter out None or lists if conversion fails

        try:
            # Generate response using the pipeline
            outputs = generator(
                messages_for_api, # Pass the list of messages
                max_new_tokens=512,
                eos_token_id=terminators,
                pad_token_id=tokenizer.eos_token_id, # Use EOS for padding
                do_sample=True,
                temperature=0.7, # Adjusted slightly from example
                top_p=0.95, # Added common param
                # top_k=50 # Optional parameter
            )

            # --- Extract response --- 
            # The output format is a list containing a dictionary with 'generated_text'
            # which itself is a list of message dictionaries.
            if (outputs and 
                isinstance(outputs, list) and 
                len(outputs) > 0 and 
                isinstance(outputs[0], dict) and 
                'generated_text' in outputs[0] and
                isinstance(outputs[0]['generated_text'], list) and
                len(outputs[0]['generated_text']) > 0):
                
                # Get the last message dictionary in the generated list (should be the assistant's reply)
                last_message = outputs[0]['generated_text'][-1]
                if isinstance(last_message, dict) and last_message.get('role') == 'assistant':
                     assistant_response = last_message.get('content', "").strip()
                else:
                     # Fallback if format is unexpected - try getting last element's text if it's a string?
                     assistant_response = str(outputs[0]['generated_text'][-1]).strip() 

                if not assistant_response:
                   assistant_response = "I generated an empty response."
            
            else:
                print("Unexpected output format:", outputs) # Log for debugging
                assistant_response = "Sorry, I couldn't parse the response format."

            message_placeholder.markdown(assistant_response)
            st.session_state.messages.append({"role": "assistant", "content": assistant_response})

        except Exception as e:
            error_message = f"Error during text generation: {e}"
            st.error(error_message, icon="πŸ”₯")
            message_placeholder.markdown("Sorry, an error occurred generating the response.")
            st.session_state.messages.append({"role": "assistant", "content": f"[Error: {e}]"})