import streamlit as st import os from transformers import pipeline, AutoTokenizer # Added AutoTokenizer import torch # --- Set Page Config FIRST --- st.set_page_config(layout="wide") # --- Configuration --- # MODEL_NAME = "AdaptLLM/finance-LLM" # Old model MODEL_NAME = "WiroAI/WiroAI-Finance-Qwen-1.5B" # New smaller model HF_TOKEN = os.environ.get("HF_TOKEN") # --- Model Loading (Cached by Streamlit for efficiency) --- @st.cache_resource def load_resources(): """Loads the tokenizer and the text generation pipeline.""" if not HF_TOKEN: st.warning("HF_TOKEN secret not found. Ensure the model is public or add the token to secrets.") try: st.info(f"Loading tokenizer for {MODEL_NAME}...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN if HF_TOKEN else None) st.success("Tokenizer loaded.") # Determine device: Use GPU if available, otherwise CPU # device_map="auto" might be problematic on CPU-only Spaces # Start with device_map="auto", but fall back to explicit cpu if needed device_map_setting = "auto" # device = 0 if torch.cuda.is_available() else -1 # Alternative: explicit device st.info(f"Loading model {MODEL_NAME}... (Using {device_map_setting}) This might take a while.") # Use pipeline generator = pipeline( "text-generation", model=MODEL_NAME, tokenizer=tokenizer, # Pass loaded tokenizer model_kwargs={"torch_dtype": torch.bfloat16}, # Use bfloat16 as per model card device_map=device_map_setting, # device=device # Use this if device_map causes issues trust_remote_code=True ) st.success(f"Model {MODEL_NAME} loaded successfully!") return generator, tokenizer # Return both except Exception as e: st.error(f"Error loading model/tokenizer: {e}", icon="🔥") st.error("Check memory limits, token access, or try removing device_map='auto'.") st.stop() # --- Load Resources --- generator, tokenizer = load_resources() # --- Streamlit App UI --- st.title("💰 FinBuddy Assistant") st.caption(f"Model: {MODEL_NAME}") if "messages" not in st.session_state: # Add initial system message (as per model card example) st.session_state.messages = [ {"role": "system", "content": "You are a finance chatbot developed by Wiro AI"} ] # Display past chat messages (excluding system message) for message in st.session_state.messages: if message["role"] != "system": # Don't display system message with st.chat_message(message["role"]): st.markdown(message["content"]) # Get user input if prompt := st.chat_input("Ask a question about finance..."): # Add user prompt to state and display st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) # Generate assistant response with st.chat_message("assistant"): message_placeholder = st.empty() message_placeholder.markdown("Thinking...⏳") # --- Prepare prompt for the model (use message history) --- # Use the messages stored in session state (includes system prompt) messages_for_api = st.session_state.messages # --- Define terminators as per model card --- terminators = [ tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|end_of_text|>") # Qwen uses <|end_of_text|> usually ] # Handle potential errors if the specific token doesn't exist terminators = [term for term in terminators if term is not None and not isinstance(term, list)] # Filter out None or lists if conversion fails try: # Generate response using the pipeline outputs = generator( messages_for_api, # Pass the list of messages max_new_tokens=512, eos_token_id=terminators, pad_token_id=tokenizer.eos_token_id, # Use EOS for padding do_sample=True, temperature=0.7, # Adjusted slightly from example top_p=0.95, # Added common param # top_k=50 # Optional parameter ) # --- Extract response --- # The output format is a list containing a dictionary with 'generated_text' # which itself is a list of message dictionaries. if (outputs and isinstance(outputs, list) and len(outputs) > 0 and isinstance(outputs[0], dict) and 'generated_text' in outputs[0] and isinstance(outputs[0]['generated_text'], list) and len(outputs[0]['generated_text']) > 0): # Get the last message dictionary in the generated list (should be the assistant's reply) last_message = outputs[0]['generated_text'][-1] if isinstance(last_message, dict) and last_message.get('role') == 'assistant': assistant_response = last_message.get('content', "").strip() else: # Fallback if format is unexpected - try getting last element's text if it's a string? assistant_response = str(outputs[0]['generated_text'][-1]).strip() if not assistant_response: assistant_response = "I generated an empty response." else: print("Unexpected output format:", outputs) # Log for debugging assistant_response = "Sorry, I couldn't parse the response format." message_placeholder.markdown(assistant_response) st.session_state.messages.append({"role": "assistant", "content": assistant_response}) except Exception as e: error_message = f"Error during text generation: {e}" st.error(error_message, icon="🔥") message_placeholder.markdown("Sorry, an error occurred generating the response.") st.session_state.messages.append({"role": "assistant", "content": f"[Error: {e}]"})