File size: 2,052 Bytes
432cd4a
f43b68f
3f3da62
432cd4a
abe2d0f
 
e929713
abe2d0f
 
f43b68f
3c7c10f
 
 
 
f43b68f
 
 
3c7c10f
f43b68f
 
 
 
 
 
3c7c10f
c6784b6
3c7c10f
 
abe2d0f
e929713
 
 
 
 
372767f
3c7c10f
8ca2a6e
 
14ddf0d
8ca2a6e
14ddf0d
8ca2a6e
 
14ddf0d
 
 
c6784b6
3c7c10f
14ddf0d
ff2cb04
14ddf0d
ff2cb04
 
 
14ddf0d
 
ff2cb04
3f3da62
432cd4a
 
 
e929713
 
 
 
b7c5b78
432cd4a
3c7c10f
432cd4a
 
 
8ca2a6e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import torch

model_id = "thrishala/mental_health_chatbot"

try:
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="cpu",
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        max_memory={"cpu": "15GB"},
        offload_folder="offload",
    )
    
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.model_max_length = 256  # Set maximum length
    
    pipe = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        torch_dtype=torch.float16,
        num_return_sequences=1,
        do_sample=False,
        truncation=True,
        max_new_tokens=128
    )

except Exception as e:
    print(f"Error loading model: {e}")
    exit()

def respond(message, history, system_message, max_tokens):
    prompt = f"{system_message}\n"
    
    for user_msg, bot_msg in reversed(history):  # Reversed to append messages correctly
        prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n"
        
    prompt += f"User: {message}\nAssistant:"
    
    try:
        response = pipe(
            prompt,
            max_new_tokens=max_tokens,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id
        )[0]["generated_text"]

        bot_response = response.split("Assistant:")[-1].strip()

        yield [message, bot_response]  # Yield a list: [user_message, bot_response]

    except Exception as e:
        print(f"Error during generation: {e}")
        yield [message, "An error occurred during generation."]

demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(
            value="You are a friendly and helpful mental health chatbot.",
            label="System message",
        ),
        gr.Slider(minimum=1, maximum=128, value=128, step=1, label="Max new tokens"),
    ],
    chatbot=gr.Chatbot(type="messages"),  # Updated to new format
)

if __name__ == "__main__":
    demo.launch()