File size: 2,214 Bytes
432cd4a
f43b68f
3f3da62
432cd4a
abe2d0f
 
e929713
abe2d0f
 
f43b68f
3c7c10f
 
 
 
f43b68f
 
 
3c7c10f
f43b68f
 
 
 
 
 
3c7c10f
c6784b6
3c7c10f
 
abe2d0f
e929713
 
 
 
 
3c7c10f
 
14ddf0d
 
 
3c7c10f
14ddf0d
 
 
 
 
 
c6784b6
3c7c10f
14ddf0d
3c7c10f
14ddf0d
 
 
 
3c7c10f
3f3da62
432cd4a
 
 
e929713
 
 
 
b7c5b78
432cd4a
 
3c7c10f
432cd4a
 
3c7c10f
432cd4a
 
 
e929713
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import torch

model_id = "thrishala/mental_health_chatbot"

try:
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="cpu",
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        max_memory={"cpu": "15GB"},
        offload_folder="offload",
    )
    
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.model_max_length = 256  # Set maximum length
    
    pipe = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        torch_dtype=torch.float16,
        num_return_sequences=1,
        do_sample=False,
        truncation=True,
        max_new_tokens=128
    )

except Exception as e:
    print(f"Error loading model: {e}")
    exit()

def respond(message, history, system_message, max_tokens, temperature, top_p):
    prompt = f"{system_message}\n"
    for user_msg, bot_msg in history:
        prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n"
    prompt += f"User: {message}\nAssistant:"

    try:
        response = pipe(
            prompt,
            max_new_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id
        )[0]["generated_text"]
        
        bot_response = response.split("Assistant:")[-1].strip()
        yield bot_response
    except Exception as e:
        print(f"Error during generation: {e}")
        yield "An error occurred during generation."

demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(
            value="You are a friendly and helpful mental health chatbot.",
            label="System message",
        ),
        gr.Slider(minimum=1, maximum=128, value=128, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)",
        ),
    ],
    chatbot=gr.Chatbot(type="messages"),  # Updated to new format
)

if __name__ == "__main__":
    demo.launch()