Ani07-05 commited on
Commit
7b9bc7a
Β·
1 Parent(s): 92cd1f8

Switch to WiroAI-Finance-Qwen-1.5B model

Browse files
Files changed (1) hide show
  1. app.py +85 -67
app.py CHANGED
@@ -1,121 +1,139 @@
1
  import streamlit as st
2
  import os
3
- from transformers import pipeline
4
- import torch # PyTorch is commonly used by transformers
5
 
6
  # --- Set Page Config FIRST ---
7
- st.set_page_config(layout="wide") # Use wider layout
8
 
9
  # --- Configuration ---
10
- MODEL_NAME = "AdaptLLM/finance-LLM"
11
- # Attempt to get token from secrets, handle case where it might not be set yet
12
  HF_TOKEN = os.environ.get("HF_TOKEN")
13
 
14
  # --- Model Loading (Cached by Streamlit for efficiency) ---
15
- @st.cache_resource # Cache the pipeline object
16
- def load_text_generation_pipeline():
17
- """Loads the text generation pipeline."""
18
  if not HF_TOKEN:
19
  st.warning("HF_TOKEN secret not found. Ensure the model is public or add the token to secrets.")
20
- # Decide if you want to stop or proceed cautiously
21
- # st.stop() # Uncomment this line to halt execution if token is strictly required
22
 
23
  try:
24
- # Determine device: Use GPU (cuda:0) if available, otherwise CPU (-1)
25
- # Free Spaces typically only have CPU, so device will likely be -1
26
- device = 0 if torch.cuda.is_available() else -1
27
-
28
- st.info(f"Loading model {MODEL_NAME}... This might take a while on the first run.")
29
- # Use pipeline for easier text generation
 
 
 
 
 
 
30
  generator = pipeline(
31
  "text-generation",
32
  model=MODEL_NAME,
33
- tokenizer=MODEL_NAME,
34
- torch_dtype=torch.float16,
35
- device=device,
36
- trust_remote_code=True
 
37
  )
38
  st.success(f"Model {MODEL_NAME} loaded successfully!")
39
- return generator
40
  except Exception as e:
41
- st.error(f"Error loading model pipeline: {e}", icon="πŸ”₯")
42
- st.error("This could be due to memory limits on the free tier, missing token for a private model, or other issues.")
43
- st.stop() # Stop the app if the model fails to load
44
 
45
- # --- Load the Model Pipeline ---
46
- generator = load_text_generation_pipeline()
47
 
48
  # --- Streamlit App UI ---
49
  st.title("πŸ’° FinBuddy Assistant")
50
- st.caption("Your AI-powered financial planning assistant (Text Chat - v1)")
51
 
52
- # Initialize chat history in session state if it doesn't exist
53
  if "messages" not in st.session_state:
54
- st.session_state.messages = []
 
 
 
55
 
56
- # Display past chat messages
57
  for message in st.session_state.messages:
58
- with st.chat_message(message["role"]):
59
- st.markdown(message["content"]) # Display content as markdown
 
60
 
61
- # Get user input using chat_input
62
  if prompt := st.chat_input("Ask a question about finance..."):
63
- # Add user message to session state and display it
64
  st.session_state.messages.append({"role": "user", "content": prompt})
65
  with st.chat_message("user"):
66
  st.markdown(prompt)
67
 
68
  # Generate assistant response
69
  with st.chat_message("assistant"):
70
- message_placeholder = st.empty() # Create placeholder for streaming/final response
71
- message_placeholder.markdown("Thinking...⏳") # Initial thinking message
 
 
 
 
72
 
73
- # --- Prepare prompt for the model ---
74
- # Simple approach: just use the latest user prompt.
75
- # TODO: Improve this later to include conversation history for better context.
76
- prompt_for_model = prompt
 
 
 
77
 
78
  try:
79
  # Generate response using the pipeline
80
  outputs = generator(
81
- prompt_for_model,
82
- max_new_tokens=512, # Limit the length of the response
83
- num_return_sequences=1,
84
- eos_token_id=generator.tokenizer.eos_token_id,
85
- pad_token_id=generator.tokenizer.eos_token_id # Helps prevent warnings/issues
 
 
 
86
  )
87
 
88
- if outputs and len(outputs) > 0 and 'generated_text' in outputs[0]:
89
- # Extract the generated text
90
- full_response = outputs[0]['generated_text']
91
-
92
- # --- Attempt to clean the response ---
93
- # The pipeline often returns the prompt + response. Try to remove the prompt part.
94
- if full_response.startswith(prompt_for_model):
95
- assistant_response = full_response[len(prompt_for_model):].strip()
96
- # Sometimes models add their own role prefix
97
- if assistant_response.lower().startswith("assistant:"):
98
- assistant_response = assistant_response[len("assistant:"):].strip()
99
- elif assistant_response.lower().startswith("response:"):
100
- assistant_response = assistant_response[len("response:"):].strip()
 
 
101
  else:
102
- assistant_response = full_response # Fallback if prompt isn't found at start
 
103
 
104
- # Handle cases where the response might be empty after cleaning
105
  if not assistant_response:
106
- assistant_response = "I received your message, but I don't have a further response right now."
107
-
108
  else:
109
- assistant_response = "Sorry, I couldn't generate a response."
 
110
 
111
- # Display the final response
112
  message_placeholder.markdown(assistant_response)
113
- # Add the final assistant response to session state
114
  st.session_state.messages.append({"role": "assistant", "content": assistant_response})
115
 
116
  except Exception as e:
117
  error_message = f"Error during text generation: {e}"
118
  st.error(error_message, icon="πŸ”₯")
119
- message_placeholder.markdown("Sorry, an error occurred while generating the response.")
120
- # Add error indication to history
121
  st.session_state.messages.append({"role": "assistant", "content": f"[Error: {e}]"})
 
1
  import streamlit as st
2
  import os
3
+ from transformers import pipeline, AutoTokenizer # Added AutoTokenizer
4
+ import torch
5
 
6
  # --- Set Page Config FIRST ---
7
+ st.set_page_config(layout="wide")
8
 
9
  # --- Configuration ---
10
+ # MODEL_NAME = "AdaptLLM/finance-LLM" # Old model
11
+ MODEL_NAME = "WiroAI/WiroAI-Finance-Qwen-1.5B" # New smaller model
12
  HF_TOKEN = os.environ.get("HF_TOKEN")
13
 
14
  # --- Model Loading (Cached by Streamlit for efficiency) ---
15
+ @st.cache_resource
16
+ def load_resources():
17
+ """Loads the tokenizer and the text generation pipeline."""
18
  if not HF_TOKEN:
19
  st.warning("HF_TOKEN secret not found. Ensure the model is public or add the token to secrets.")
 
 
20
 
21
  try:
22
+ st.info(f"Loading tokenizer for {MODEL_NAME}...")
23
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN if HF_TOKEN else None)
24
+ st.success("Tokenizer loaded.")
25
+
26
+ # Determine device: Use GPU if available, otherwise CPU
27
+ # device_map="auto" might be problematic on CPU-only Spaces
28
+ # Start with device_map="auto", but fall back to explicit cpu if needed
29
+ device_map_setting = "auto"
30
+ # device = 0 if torch.cuda.is_available() else -1 # Alternative: explicit device
31
+
32
+ st.info(f"Loading model {MODEL_NAME}... (Using {device_map_setting}) This might take a while.")
33
+ # Use pipeline
34
  generator = pipeline(
35
  "text-generation",
36
  model=MODEL_NAME,
37
+ tokenizer=tokenizer, # Pass loaded tokenizer
38
+ model_kwargs={"torch_dtype": torch.bfloat16}, # Use bfloat16 as per model card
39
+ device_map=device_map_setting,
40
+ # device=device # Use this if device_map causes issues
41
+ trust_remote_code=True
42
  )
43
  st.success(f"Model {MODEL_NAME} loaded successfully!")
44
+ return generator, tokenizer # Return both
45
  except Exception as e:
46
+ st.error(f"Error loading model/tokenizer: {e}", icon="πŸ”₯")
47
+ st.error("Check memory limits, token access, or try removing device_map='auto'.")
48
+ st.stop()
49
 
50
+ # --- Load Resources ---
51
+ generator, tokenizer = load_resources()
52
 
53
  # --- Streamlit App UI ---
54
  st.title("πŸ’° FinBuddy Assistant")
55
+ st.caption(f"Model: {MODEL_NAME}")
56
 
 
57
  if "messages" not in st.session_state:
58
+ # Add initial system message (as per model card example)
59
+ st.session_state.messages = [
60
+ {"role": "system", "content": "You are a finance chatbot developed by Wiro AI"}
61
+ ]
62
 
63
+ # Display past chat messages (excluding system message)
64
  for message in st.session_state.messages:
65
+ if message["role"] != "system": # Don't display system message
66
+ with st.chat_message(message["role"]):
67
+ st.markdown(message["content"])
68
 
69
+ # Get user input
70
  if prompt := st.chat_input("Ask a question about finance..."):
71
+ # Add user prompt to state and display
72
  st.session_state.messages.append({"role": "user", "content": prompt})
73
  with st.chat_message("user"):
74
  st.markdown(prompt)
75
 
76
  # Generate assistant response
77
  with st.chat_message("assistant"):
78
+ message_placeholder = st.empty()
79
+ message_placeholder.markdown("Thinking...⏳")
80
+
81
+ # --- Prepare prompt for the model (use message history) ---
82
+ # Use the messages stored in session state (includes system prompt)
83
+ messages_for_api = st.session_state.messages
84
 
85
+ # --- Define terminators as per model card ---
86
+ terminators = [
87
+ tokenizer.eos_token_id,
88
+ tokenizer.convert_tokens_to_ids("<|end_of_text|>") # Qwen uses <|end_of_text|> usually
89
+ ]
90
+ # Handle potential errors if the specific token doesn't exist
91
+ terminators = [term for term in terminators if term is not None and not isinstance(term, list)] # Filter out None or lists if conversion fails
92
 
93
  try:
94
  # Generate response using the pipeline
95
  outputs = generator(
96
+ messages_for_api, # Pass the list of messages
97
+ max_new_tokens=512,
98
+ eos_token_id=terminators,
99
+ pad_token_id=tokenizer.eos_token_id, # Use EOS for padding
100
+ do_sample=True,
101
+ temperature=0.7, # Adjusted slightly from example
102
+ top_p=0.95, # Added common param
103
+ # top_k=50 # Optional parameter
104
  )
105
 
106
+ # --- Extract response ---
107
+ # The output format is a list containing a dictionary with 'generated_text'
108
+ # which itself is a list of message dictionaries.
109
+ if (outputs and
110
+ isinstance(outputs, list) and
111
+ len(outputs) > 0 and
112
+ isinstance(outputs[0], dict) and
113
+ 'generated_text' in outputs[0] and
114
+ isinstance(outputs[0]['generated_text'], list) and
115
+ len(outputs[0]['generated_text']) > 0):
116
+
117
+ # Get the last message dictionary in the generated list (should be the assistant's reply)
118
+ last_message = outputs[0]['generated_text'][-1]
119
+ if isinstance(last_message, dict) and last_message.get('role') == 'assistant':
120
+ assistant_response = last_message.get('content', "").strip()
121
  else:
122
+ # Fallback if format is unexpected - try getting last element's text if it's a string?
123
+ assistant_response = str(outputs[0]['generated_text'][-1]).strip()
124
 
 
125
  if not assistant_response:
126
+ assistant_response = "I generated an empty response."
127
+
128
  else:
129
+ print("Unexpected output format:", outputs) # Log for debugging
130
+ assistant_response = "Sorry, I couldn't parse the response format."
131
 
 
132
  message_placeholder.markdown(assistant_response)
 
133
  st.session_state.messages.append({"role": "assistant", "content": assistant_response})
134
 
135
  except Exception as e:
136
  error_message = f"Error during text generation: {e}"
137
  st.error(error_message, icon="πŸ”₯")
138
+ message_placeholder.markdown("Sorry, an error occurred generating the response.")
 
139
  st.session_state.messages.append({"role": "assistant", "content": f"[Error: {e}]"})