Spaces:
Sleeping
Sleeping
File size: 2,052 Bytes
432cd4a f43b68f 3f3da62 432cd4a abe2d0f e929713 abe2d0f f43b68f 3c7c10f f43b68f 3c7c10f f43b68f 3c7c10f c6784b6 3c7c10f abe2d0f e929713 372767f 3c7c10f 8ca2a6e 14ddf0d 8ca2a6e 14ddf0d 8ca2a6e 14ddf0d c6784b6 3c7c10f 14ddf0d ff2cb04 14ddf0d ff2cb04 14ddf0d ff2cb04 3f3da62 432cd4a e929713 b7c5b78 432cd4a 3c7c10f 432cd4a 8ca2a6e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import torch
model_id = "thrishala/mental_health_chatbot"
try:
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="cpu",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
max_memory={"cpu": "15GB"},
offload_folder="offload",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.model_max_length = 256 # Set maximum length
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
torch_dtype=torch.float16,
num_return_sequences=1,
do_sample=False,
truncation=True,
max_new_tokens=128
)
except Exception as e:
print(f"Error loading model: {e}")
exit()
def respond(message, history, system_message, max_tokens):
prompt = f"{system_message}\n"
for user_msg, bot_msg in reversed(history): # Reversed to append messages correctly
prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n"
prompt += f"User: {message}\nAssistant:"
try:
response = pipe(
prompt,
max_new_tokens=max_tokens,
do_sample=False,
pad_token_id=tokenizer.eos_token_id
)[0]["generated_text"]
bot_response = response.split("Assistant:")[-1].strip()
yield [message, bot_response] # Yield a list: [user_message, bot_response]
except Exception as e:
print(f"Error during generation: {e}")
yield [message, "An error occurred during generation."]
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(
value="You are a friendly and helpful mental health chatbot.",
label="System message",
),
gr.Slider(minimum=1, maximum=128, value=128, step=1, label="Max new tokens"),
],
chatbot=gr.Chatbot(type="messages"), # Updated to new format
)
if __name__ == "__main__":
demo.launch()
|