Spaces:
Sleeping
Sleeping
File size: 6,231 Bytes
88f5ec5 7b9bc7a 88f5ec5 ebfb19f 7b9bc7a ebfb19f 88f5ec5 7b9bc7a 88f5ec5 7b9bc7a 88f5ec5 7b9bc7a 88f5ec5 7b9bc7a 88f5ec5 7b9bc7a 88f5ec5 7b9bc7a 88f5ec5 7b9bc7a 88f5ec5 7b9bc7a 88f5ec5 7b9bc7a 88f5ec5 7b9bc7a 88f5ec5 7b9bc7a 88f5ec5 7b9bc7a 88f5ec5 7b9bc7a 88f5ec5 7b9bc7a 88f5ec5 7b9bc7a 88f5ec5 7b9bc7a 88f5ec5 7b9bc7a 88f5ec5 7b9bc7a 88f5ec5 7b9bc7a 88f5ec5 7b9bc7a 88f5ec5 7b9bc7a 88f5ec5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import streamlit as st
import os
from transformers import pipeline, AutoTokenizer # Added AutoTokenizer
import torch
# --- Set Page Config FIRST ---
st.set_page_config(layout="wide")
# --- Configuration ---
# MODEL_NAME = "AdaptLLM/finance-LLM" # Old model
MODEL_NAME = "WiroAI/WiroAI-Finance-Qwen-1.5B" # New smaller model
HF_TOKEN = os.environ.get("HF_TOKEN")
# --- Model Loading (Cached by Streamlit for efficiency) ---
@st.cache_resource
def load_resources():
"""Loads the tokenizer and the text generation pipeline."""
if not HF_TOKEN:
st.warning("HF_TOKEN secret not found. Ensure the model is public or add the token to secrets.")
try:
st.info(f"Loading tokenizer for {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN if HF_TOKEN else None)
st.success("Tokenizer loaded.")
# Determine device: Use GPU if available, otherwise CPU
# device_map="auto" might be problematic on CPU-only Spaces
# Start with device_map="auto", but fall back to explicit cpu if needed
device_map_setting = "auto"
# device = 0 if torch.cuda.is_available() else -1 # Alternative: explicit device
st.info(f"Loading model {MODEL_NAME}... (Using {device_map_setting}) This might take a while.")
# Use pipeline
generator = pipeline(
"text-generation",
model=MODEL_NAME,
tokenizer=tokenizer, # Pass loaded tokenizer
model_kwargs={"torch_dtype": torch.bfloat16}, # Use bfloat16 as per model card
device_map=device_map_setting,
# device=device # Use this if device_map causes issues
trust_remote_code=True
)
st.success(f"Model {MODEL_NAME} loaded successfully!")
return generator, tokenizer # Return both
except Exception as e:
st.error(f"Error loading model/tokenizer: {e}", icon="π₯")
st.error("Check memory limits, token access, or try removing device_map='auto'.")
st.stop()
# --- Load Resources ---
generator, tokenizer = load_resources()
# --- Streamlit App UI ---
st.title("π° FinBuddy Assistant")
st.caption(f"Model: {MODEL_NAME}")
if "messages" not in st.session_state:
# Add initial system message (as per model card example)
st.session_state.messages = [
{"role": "system", "content": "You are a finance chatbot developed by Wiro AI"}
]
# Display past chat messages (excluding system message)
for message in st.session_state.messages:
if message["role"] != "system": # Don't display system message
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Get user input
if prompt := st.chat_input("Ask a question about finance..."):
# Add user prompt to state and display
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
# Generate assistant response
with st.chat_message("assistant"):
message_placeholder = st.empty()
message_placeholder.markdown("Thinking...β³")
# --- Prepare prompt for the model (use message history) ---
# Use the messages stored in session state (includes system prompt)
messages_for_api = st.session_state.messages
# --- Define terminators as per model card ---
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|end_of_text|>") # Qwen uses <|end_of_text|> usually
]
# Handle potential errors if the specific token doesn't exist
terminators = [term for term in terminators if term is not None and not isinstance(term, list)] # Filter out None or lists if conversion fails
try:
# Generate response using the pipeline
outputs = generator(
messages_for_api, # Pass the list of messages
max_new_tokens=512,
eos_token_id=terminators,
pad_token_id=tokenizer.eos_token_id, # Use EOS for padding
do_sample=True,
temperature=0.7, # Adjusted slightly from example
top_p=0.95, # Added common param
# top_k=50 # Optional parameter
)
# --- Extract response ---
# The output format is a list containing a dictionary with 'generated_text'
# which itself is a list of message dictionaries.
if (outputs and
isinstance(outputs, list) and
len(outputs) > 0 and
isinstance(outputs[0], dict) and
'generated_text' in outputs[0] and
isinstance(outputs[0]['generated_text'], list) and
len(outputs[0]['generated_text']) > 0):
# Get the last message dictionary in the generated list (should be the assistant's reply)
last_message = outputs[0]['generated_text'][-1]
if isinstance(last_message, dict) and last_message.get('role') == 'assistant':
assistant_response = last_message.get('content', "").strip()
else:
# Fallback if format is unexpected - try getting last element's text if it's a string?
assistant_response = str(outputs[0]['generated_text'][-1]).strip()
if not assistant_response:
assistant_response = "I generated an empty response."
else:
print("Unexpected output format:", outputs) # Log for debugging
assistant_response = "Sorry, I couldn't parse the response format."
message_placeholder.markdown(assistant_response)
st.session_state.messages.append({"role": "assistant", "content": assistant_response})
except Exception as e:
error_message = f"Error during text generation: {e}"
st.error(error_message, icon="π₯")
message_placeholder.markdown("Sorry, an error occurred generating the response.")
st.session_state.messages.append({"role": "assistant", "content": f"[Error: {e}]"}) |