Spaces:
Sleeping
Sleeping
File size: 2,092 Bytes
432cd4a f43b68f 3f3da62 432cd4a abe2d0f e929713 b84cd4b abe2d0f b84cd4b abe2d0f b84cd4b e929713 b84cd4b 3c7c10f b84cd4b 14ddf0d 8ca2a6e 14ddf0d b84cd4b 14ddf0d fb27a1f b84cd4b 14ddf0d b84cd4b 3f3da62 432cd4a e929713 b84cd4b 432cd4a b84cd4b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import torch
model_id = "thrishala/mental_health_chatbot"
try:
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_8bit=True,
device_map="auto",
torch_dtype=torch.float16
)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
except Exception as e:
print(f"Error loading model: {e}")
exit()
def respond(
message,
history,
system_message,
max_tokens,
temperature,
top_p,
):
# Construct the prompt with clear separation
prompt = f"{system_message}\n"
for user_msg, bot_msg in history:
prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n"
prompt += f"User: {message}\nAssistant:"
try:
response = pipe(
prompt,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
eos_token_id=tokenizer.eos_token_id, # Use EOS token to stop generation
)[0]["generated_text"]
# Extract only the new assistant response after the last Assistant: in the prompt
bot_response = response[len(prompt):].split("User:")[0].strip() # Take text after prompt and before next User
yield bot_response
except Exception as e:
print(f"Error during generation: {e}")
yield "An error occurred."
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(
value="You are a friendly and helpful mental health chatbot.",
label="System message",
),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
)
if __name__ == "__main__":
demo.launch() |