File size: 2,092 Bytes
432cd4a
f43b68f
3f3da62
432cd4a
abe2d0f
 
e929713
b84cd4b
abe2d0f
 
b84cd4b
 
 
abe2d0f
b84cd4b
e929713
 
 
 
b84cd4b
 
 
 
 
 
 
 
 
3c7c10f
b84cd4b
14ddf0d
 
8ca2a6e
 
14ddf0d
 
 
b84cd4b
 
 
14ddf0d
fb27a1f
b84cd4b
 
 
14ddf0d
 
b84cd4b
3f3da62
432cd4a
 
 
e929713
 
 
 
b84cd4b
 
 
 
 
 
 
 
 
432cd4a
 
 
 
b84cd4b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import torch

model_id = "thrishala/mental_health_chatbot"

try:
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        load_in_8bit=True,
        device_map="auto",
        torch_dtype=torch.float16
    )
    pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
except Exception as e:
    print(f"Error loading model: {e}")
    exit()

def respond(
    message,
    history,
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    # Construct the prompt with clear separation
    prompt = f"{system_message}\n"
    for user_msg, bot_msg in history:
        prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n"
    prompt += f"User: {message}\nAssistant:"
    
    try:
        response = pipe(
            prompt,
            max_new_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            eos_token_id=tokenizer.eos_token_id,  # Use EOS token to stop generation
        )[0]["generated_text"]
        
        # Extract only the new assistant response after the last Assistant: in the prompt
        bot_response = response[len(prompt):].split("User:")[0].strip()  # Take text after prompt and before next User
        yield bot_response
    except Exception as e:
        print(f"Error during generation: {e}")
        yield "An error occurred."

demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(
            value="You are a friendly and helpful mental health chatbot.",
            label="System message",
        ),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
)

if __name__ == "__main__":
    demo.launch()