ruslanmv commited on
Commit
b1744c8
·
verified ·
1 Parent(s): 40e0f8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -120
app.py CHANGED
@@ -1,51 +1,16 @@
1
- import os
2
  import gradio as gr
 
 
3
 
4
- # ------------------------------------------------------------------------------
5
- # Environment and Model/Client Initialization
6
- # ------------------------------------------------------------------------------
7
- # Try to import google.colab to decide whether to load a local model or use InferenceClient.
8
- try:
9
- from google.colab import userdata # In Colab, use local model inference.
10
- HF_TOKEN = userdata.get('HF_TOKEN')
11
- import torch
12
- from transformers import AutoTokenizer, AutoModelForCausalLM
13
-
14
- # Small performance tweak if your input sizes remain similar.
15
- torch.backends.cudnn.benchmark = True
16
-
17
- model_name = "HuggingFaceH4/zephyr-7b-beta"
18
- # Pass token if required for private models.
19
- model = AutoModelForCausalLM.from_pretrained(
20
- model_name,
21
- use_auth_token=HF_TOKEN,
22
- torch_dtype=torch.bfloat16,
23
- device_map="auto"
24
- )
25
- # Optionally compile the model for extra speed if using PyTorch 2.0+
26
- if hasattr(torch, "compile"):
27
- model = torch.compile(model)
28
-
29
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=HF_TOKEN)
30
- inference_mode = "local"
31
-
32
- except ImportError:
33
- # Not in Google Colab – use the Hugging Face InferenceClient.
34
- HF_TOKEN = os.getenv("HF_TOKEN")
35
- if not HF_TOKEN:
36
- raise ValueError("HF_TOKEN environment variable not set")
37
- from huggingface_hub import InferenceClient
38
- from transformers import AutoTokenizer
39
-
40
- model_name = "HuggingFaceH4/zephyr-7b-beta"
41
- tokenizer = AutoTokenizer.from_pretrained(model_name)
42
- # Pass the token to the client to avoid authentication errors.
43
- client = InferenceClient(model_name, token=HF_TOKEN)
44
- inference_mode = "client"
45
-
46
- # ------------------------------------------------------------------------------
47
- # SYSTEM PROMPT (PATIENT ROLE)
48
- # ------------------------------------------------------------------------------
49
  nvc_prompt_template = """You are now taking on the role of a single user (a “patient”) seeking support for various personal and emotional challenges.
50
  BEHAVIOR INSTRUCTIONS:
51
  - You will respond ONLY as this user/patient.
@@ -65,102 +30,111 @@ BEHAVIOR INSTRUCTIONS:
65
  - Keep your responses concise, aiming for a maximum of {max_response_words} words.
66
  Start the conversation by expressing your current feelings or challenges from the patient's point of view."""
67
 
68
- # ------------------------------------------------------------------------------
69
- # Utility Functions
70
- # ------------------------------------------------------------------------------
71
- def build_prompt(history: list[tuple[str, str]], system_message: str, message: str, max_response_words: int) -> str:
72
- """
73
- Build a text prompt that starts with the system message (with a max word limit),
74
- followed by the conversation history (with "Doctor:" and "Patient:" lines), and
75
- ends with a new "Doctor:" line prompting the patient to reply.
76
- """
77
- prompt = system_message.format(max_response_words=max_response_words) + "\n"
78
- for user_msg, assistant_msg in history:
79
- prompt += f"Doctor: {user_msg}\n"
80
- if assistant_msg:
81
- prompt += f"Patient: {assistant_msg}\n"
82
- prompt += f"Doctor: {message}\nPatient: "
83
- return prompt
84
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- def truncate_response(text: str, max_words: int) -> str:
87
- """
88
- Truncate the response text to the specified maximum number of words.
89
- """
90
  words = text.split()
91
  if len(words) > max_words:
92
- return " ".join(words[:max_words]) + "..."
93
  return text
94
 
95
- # ------------------------------------------------------------------------------
96
- # Response Function
97
- # ------------------------------------------------------------------------------
98
  def respond(
99
- message: str,
100
  history: list[tuple[str, str]],
101
- system_message: str,
102
- max_tokens: int,
103
- temperature: float,
104
- top_p: float,
105
- max_response_words: int,
106
  ):
107
- """
108
- Generate a response based on the built prompt.
109
- If running locally (in Colab), use the loaded model; otherwise, use InferenceClient.
110
- """
111
- prompt = build_prompt(history, system_message, message, max_response_words)
112
-
113
- if inference_mode == "local":
114
- # Tokenize the prompt and generate a response using the local model.
115
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
116
- output_ids = model.generate(
117
- input_ids,
118
- max_new_tokens=max_tokens,
119
- do_sample=True,
120
- temperature=temperature,
121
- top_p=top_p,
122
- )
123
- full_generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
124
- generated_response = full_generated_text[len(prompt):].strip()
125
- final_response = truncate_response(generated_response, max_response_words)
126
- return final_response
127
- else:
128
- # Use InferenceClient to generate a response.
129
- response = client.text_generation(
130
- prompt,
131
- max_new_tokens=max_tokens,
132
- do_sample=True,
 
 
 
 
 
133
  temperature=temperature,
134
  top_p=top_p,
135
- )
136
- full_generated_text = response[0]['generated_text']
137
- generated_response = full_generated_text[len(prompt):].strip()
138
- final_response = truncate_response(generated_response, max_response_words)
139
- return final_response
140
-
141
- # ------------------------------------------------------------------------------
142
- # Optional Initial Message and Gradio Interface
143
- # ------------------------------------------------------------------------------
 
 
 
144
  initial_user_message = (
145
- "I’m sorry youve been feeling overwhelmed. Could you tell me more "
146
- "about your arguments with your partner and how thats affecting you?"
 
 
147
  )
148
 
 
149
  demo = gr.ChatInterface(
150
  fn=respond,
151
  additional_inputs=[
152
  gr.Textbox(value=nvc_prompt_template, label="System message", visible=True),
153
- gr.Slider(minimum=1, maximum=2048, value=256, step=1, label="Max new tokens"),
154
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
155
  gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
156
- gr.Slider(minimum=10, maximum=200, value=100, step=10, label="Max response words"),
157
  ],
 
158
  title="Patient Interview Practice Chatbot",
159
- description=(
160
- "Simulate a patient interview. You (the user) act as the doctor, "
161
- "and the chatbot replies with the patient's perspective only."
162
- ),
163
  )
164
 
165
  if __name__ == "__main__":
166
- demo.launch(share=True)
 
 
1
  import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
+ from transformers import AutoTokenizer
4
 
5
+ # Import the tokenizer
6
+ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
7
+ client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
+
9
+ # Define a maximum context length (tokens). Check your model's documentation!
10
+ MAX_CONTEXT_LENGTH = 4096 # Example: Adjust based on your model
11
+ MAX_RESPONSE_WORDS = 100 # Define the maximum words for patient responses
12
+
13
+ ################################# SYSTEM PROMPT (PATIENT ROLE) #################################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  nvc_prompt_template = """You are now taking on the role of a single user (a “patient”) seeking support for various personal and emotional challenges.
15
  BEHAVIOR INSTRUCTIONS:
16
  - You will respond ONLY as this user/patient.
 
30
  - Keep your responses concise, aiming for a maximum of {max_response_words} words.
31
  Start the conversation by expressing your current feelings or challenges from the patient's point of view."""
32
 
33
+ def count_tokens(text: str) -> int:
34
+ """Counts the number of tokens in a given string."""
35
+ return len(tokenizer.encode(text))
36
+
37
+ def truncate_history(history: list[tuple[str, str]], system_message: str, max_length: int) -> list[tuple[str, str]]:
38
+ """Truncates the conversation history to fit within the maximum token limit."""
39
+ truncated_history = []
40
+ system_message_tokens = count_tokens(system_message)
41
+ current_length = system_message_tokens
 
 
 
 
 
 
 
42
 
43
+ # Iterate backwards through the history (newest to oldest)
44
+ for user_msg, assistant_msg in reversed(history):
45
+ user_tokens = count_tokens(user_msg) if user_msg else 0
46
+ assistant_tokens = count_tokens(assistant_msg) if assistant_msg else 0
47
+ turn_tokens = user_tokens + assistant_tokens
48
+ if current_length + turn_tokens <= max_length:
49
+ truncated_history.insert(0, (user_msg, assistant_msg)) # Add to the beginning
50
+ current_length += turn_tokens
51
+ else:
52
+ break # Stop adding turns if we exceed the limit
53
+ return truncated_history
54
 
55
+ def truncate_response_words(text: str, max_words: int) -> str:
56
+ """Truncates a text to a maximum number of words."""
 
 
57
  words = text.split()
58
  if len(words) > max_words:
59
+ return " ".join(words[:max_words]) + "..." # Add ellipsis to indicate truncation
60
  return text
61
 
62
+
 
 
63
  def respond(
64
+ message,
65
  history: list[tuple[str, str]],
66
+ system_message,
67
+ max_tokens,
68
+ temperature,
69
+ top_p,
70
+ max_response_words_param, # Pass max_response_words as parameter
71
  ):
72
+ """Responds to a user message, maintaining conversation history."""
73
+ # Use the system prompt that instructs the LLM to behave as the patient
74
+ formatted_system_message = system_message.format(max_response_words=max_response_words_param)
75
+
76
+ # Truncate history to fit within max tokens
77
+ truncated_history = truncate_history(
78
+ history,
79
+ formatted_system_message,
80
+ MAX_CONTEXT_LENGTH - max_tokens - 100 # Reserve some space
81
+ )
82
+
83
+ # Build the messages list with the system prompt first
84
+ messages = [{"role": "system", "content": formatted_system_message}]
85
+
86
+ # Replay truncated conversation
87
+ for user_msg, assistant_msg in truncated_history:
88
+ if user_msg:
89
+ messages.append({"role": "user", "content": f"<|user|>\n{user_msg}</s>"})
90
+ if assistant_msg:
91
+ messages.append({"role": "assistant", "content": f"<|assistant|>\n{assistant_msg}</s>"})
92
+
93
+ # Add the latest user query
94
+ messages.append({"role": "user", "content": f"<|user|>\n{message}</s>"})
95
+
96
+ response = ""
97
+ try:
98
+ # Generate response from the LLM, streaming tokens
99
+ for chunk in client.chat_completion(
100
+ messages,
101
+ max_tokens=max_tokens,
102
+ stream=True,
103
  temperature=temperature,
104
  top_p=top_p,
105
+ ):
106
+ token = chunk.choices[0].delta.content
107
+ response += token
108
+
109
+ truncated_response = truncate_response_words(response, max_response_words_param) # Truncate response to word limit
110
+ yield truncated_response
111
+
112
+ except Exception as e:
113
+ print(f"An error occurred: {e}")
114
+ yield "I'm sorry, I encountered an error. Please try again."
115
+
116
+ # OPTIONAL: An initial user message (the LLM "as user") if desired
117
  initial_user_message = (
118
+ "I really dont know where to begin… I feel overwhelmed lately. "
119
+ "My neighbors keep playing loud music, and Im arguing with my partner about money. "
120
+ "Also, two of my friends are fighting, and the group is drifting apart. "
121
+ "I just feel powerless."
122
  )
123
 
124
+ # --- Gradio Interface ---
125
  demo = gr.ChatInterface(
126
  fn=respond,
127
  additional_inputs=[
128
  gr.Textbox(value=nvc_prompt_template, label="System message", visible=True),
129
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
130
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
131
  gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
132
+ gr.Slider(minimum=10, maximum=200, value=MAX_RESPONSE_WORDS, step=10, label="Max response words"), # Slider for max words
133
  ],
134
+ # You can optionally set 'title' or 'description' to show some info in the UI:
135
  title="Patient Interview Practice Chatbot",
136
+ description="Practice medical interviews with a patient simulator. Ask questions and the patient will respond based on their defined persona and emotional challenges.",
 
 
 
137
  )
138
 
139
  if __name__ == "__main__":
140
+ demo.launch()