Avinash109 commited on
Commit
3e64474
·
verified ·
1 Parent(s): 891b2fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -32
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
 
4
 
5
  # Set Streamlit page configuration
6
  st.set_page_config(
@@ -13,80 +14,123 @@ st.set_page_config(
13
  st.title("💬 Qwen2.5-Coder Chat Interface")
14
 
15
  # Initialize session state for messages (store conversation history)
16
- if 'messages' not in st.session_state:
17
- st.session_state['messages'] = []
18
 
19
  # Load the model and tokenizer
20
  @st.cache_resource
21
  def load_model():
22
  model_name = "Qwen/Qwen2.5-Coder-32B-Instruct" # Replace with the correct model path
23
  tokenizer = AutoTokenizer.from_pretrained(model_name)
24
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
 
 
 
 
 
25
  return tokenizer, model
26
 
27
- # Load tokenizer and model
28
- with st.spinner("Loading model... This may take a while..."):
29
- tokenizer, model = load_model()
 
 
 
 
30
 
31
  # Function to generate model response
32
- def generate_response(user_input, max_tokens=150, temperature=0.7, top_p=0.9):
33
- # Tokenize the user input
34
- inputs = tokenizer.encode(user_input, return_tensors="pt").to(model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  # Generate a response
37
  with torch.no_grad():
38
  outputs = model.generate(
39
  inputs,
40
- max_length=max_tokens,
41
  temperature=temperature,
42
  top_p=top_p,
43
  do_sample=True,
44
- num_return_sequences=1
 
45
  )
46
 
47
  # Decode the response
48
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
49
 
50
- # Return the response without the input prompt
51
- return response[len(user_input):].strip()
 
52
 
53
  # Layout: Two columns for the main chat and sidebar
54
  chat_col, sidebar_col = st.columns([4, 1])
55
 
56
  with chat_col:
57
- # Display chat messages
58
- for message in st.session_state['messages']:
59
- if message['role'] == 'user':
60
- st.markdown(f"**You:** {message['content']}")
61
- else:
62
- st.markdown(f"**Qwen2.5-Coder:** {message['content']}")
 
 
 
63
 
64
  # Input area for user message
65
  with st.form(key='chat_form', clear_on_submit=True):
66
  user_input = st.text_area("You:", height=100)
67
  submit_button = st.form_submit_button(label='Send')
68
 
69
- if submit_button and user_input:
 
70
  # Append the user's message to the chat history
71
- st.session_state['messages'].append({'role': 'user', 'content': user_input})
72
 
73
  # Generate and append the model's response
74
- with st.spinner("Qwen2.5-Coder is typing..."):
75
- response = generate_response(user_input)
76
-
77
- # Append the model's response to the chat history
78
- st.session_state['messages'].append({'role': 'assistant', 'content': response})
79
-
80
- # Rerun the app to display new messages
81
- st.experimental_rerun()
 
 
 
 
 
 
82
 
83
  with sidebar_col:
84
  st.sidebar.header("Settings")
85
  max_tokens = st.sidebar.slider(
86
  "Maximum Tokens",
87
- min_value=512,
88
  max_value=4096,
89
- value=150,
90
  step=256,
91
  help="Set the maximum number of tokens for the model's response."
92
  )
@@ -111,4 +155,4 @@ with sidebar_col:
111
 
112
  if st.sidebar.button("Clear Chat"):
113
  st.session_state['messages'] = []
114
- st.experimental_rerun()
 
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
+ import datetime
5
 
6
  # Set Streamlit page configuration
7
  st.set_page_config(
 
14
  st.title("💬 Qwen2.5-Coder Chat Interface")
15
 
16
  # Initialize session state for messages (store conversation history)
17
+ st.session_state.setdefault('messages', [])
 
18
 
19
  # Load the model and tokenizer
20
  @st.cache_resource
21
  def load_model():
22
  model_name = "Qwen/Qwen2.5-Coder-32B-Instruct" # Replace with the correct model path
23
  tokenizer = AutoTokenizer.from_pretrained(model_name)
24
+ model = AutoModelForCausalLM.from_pretrained(
25
+ model_name,
26
+ torch_dtype=torch.float16,
27
+ device_map="auto",
28
+ load_in_8bit=True # Optional: Use if supported for reduced memory usage
29
+ )
30
  return tokenizer, model
31
 
32
+ # Load tokenizer and model with error handling
33
+ try:
34
+ with st.spinner("Loading model... This may take a while..."):
35
+ tokenizer, model = load_model()
36
+ except Exception as e:
37
+ st.error(f"Error loading model: {e}")
38
+ st.stop()
39
 
40
  # Function to generate model response
41
+ def generate_response(messages, tokenizer, model, max_tokens=150, temperature=0.7, top_p=0.9):
42
+ """
43
+ Generates a response from the model based on the conversation history.
44
+
45
+ Args:
46
+ messages (list): List of message dictionaries containing 'role' and 'content'.
47
+ tokenizer: The tokenizer instance.
48
+ model: The language model instance.
49
+ max_tokens (int): Maximum number of tokens for the response.
50
+ temperature (float): Sampling temperature.
51
+ top_p (float): Nucleus sampling probability.
52
+
53
+ Returns:
54
+ str: The generated response text.
55
+ """
56
+ # Concatenate all previous messages
57
+ conversation = ""
58
+ for message in messages:
59
+ role = "You" if message['role'] == 'user' else "Qwen2.5-Coder"
60
+ conversation += f"**{role}:** {message['content']}\n"
61
+
62
+ # Append the latest user input
63
+ conversation += f"**You:** {messages[-1]['content']}\n**Qwen2.5-Coder:**"
64
+
65
+ # Tokenize the conversation
66
+ inputs = tokenizer.encode(conversation, return_tensors="pt").to(model.device)
67
 
68
  # Generate a response
69
  with torch.no_grad():
70
  outputs = model.generate(
71
  inputs,
72
+ max_length=inputs.shape[1] + max_tokens,
73
  temperature=temperature,
74
  top_p=top_p,
75
  do_sample=True,
76
+ num_return_sequences=1,
77
+ pad_token_id=tokenizer.eos_token_id
78
  )
79
 
80
  # Decode the response
81
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
82
 
83
+ # Extract the generated response after the conversation
84
+ generated_response = response.split("Qwen2.5-Coder:")[-1].strip()
85
+ return generated_response
86
 
87
  # Layout: Two columns for the main chat and sidebar
88
  chat_col, sidebar_col = st.columns([4, 1])
89
 
90
  with chat_col:
91
+ st.markdown("### Chat")
92
+ chat_container = st.container()
93
+ with chat_container:
94
+ for message in st.session_state['messages']:
95
+ time = message.get('timestamp', '')
96
+ if message['role'] == 'user':
97
+ st.markdown(f"**You:** {message['content']} _({time})_")
98
+ else:
99
+ st.markdown(f"**Qwen2.5-Coder:** {message['content']} _({time})_")
100
 
101
  # Input area for user message
102
  with st.form(key='chat_form', clear_on_submit=True):
103
  user_input = st.text_area("You:", height=100)
104
  submit_button = st.form_submit_button(label='Send')
105
 
106
+ if submit_button and user_input.strip():
107
+ timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
108
  # Append the user's message to the chat history
109
+ st.session_state['messages'].append({'role': 'user', 'content': user_input, 'timestamp': timestamp})
110
 
111
  # Generate and append the model's response
112
+ try:
113
+ with st.spinner("Qwen2.5-Coder is typing..."):
114
+ response = generate_response(
115
+ st.session_state['messages'],
116
+ tokenizer,
117
+ model,
118
+ max_tokens=max_tokens,
119
+ temperature=temperature,
120
+ top_p=top_p
121
+ )
122
+ timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
123
+ st.session_state['messages'].append({'role': 'assistant', 'content': response, 'timestamp': timestamp})
124
+ except Exception as e:
125
+ st.error(f"Error generating response: {e}")
126
 
127
  with sidebar_col:
128
  st.sidebar.header("Settings")
129
  max_tokens = st.sidebar.slider(
130
  "Maximum Tokens",
131
+ min_value=50,
132
  max_value=4096,
133
+ value=512,
134
  step=256,
135
  help="Set the maximum number of tokens for the model's response."
136
  )
 
155
 
156
  if st.sidebar.button("Clear Chat"):
157
  st.session_state['messages'] = []
158
+ st.success("Chat history cleared.")