Spaces:
Sleeping
Sleeping
File size: 1,964 Bytes
911b23c e62ae46 911b23c 7e55c71 2eff9ee 7e55c71 2eff9ee 7e55c71 911b23c 2eff9ee 911b23c e62ae46 911b23c e62ae46 911b23c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda" if torch.cuda.is_available() else "cpu" # Automatically detect GPU or CPU
model_name = "tanusrich/Mental_Health_Chatbot"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16, # Reduce memory usage
device_map="cpu", # Automatically assigns to GPU if available
low_cpu_mem_usage=True,
max_memory={0: "3.5GiB", "cpu": "12GiB"}, # Optimize CPU memory
offload_folder=None
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
'''
model_save_path = "./model"
# Save model
model.save_pretrained(model_save_path)
# Save tokenizer
tokenizer.save_pretrained(model_save_path)'''
def generate_response(user_input):
inputs = tokenizer(user_input, return_tensors="pt").to("cpu")
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=150,
temperature=0.7,
top_k=50,
top_p=0.9,
repetition_penalty=1.2,
pad_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(output[0], skip_special_tokens=True)
# Extract only chatbot's latest response
chatbot_response = response.split("Chatbot:")[-1].strip()
# Update conversation history
conversation_history += chatbot_response + "\n"
return chatbot_response
# Continuous conversation loop
'''while True:
user_input = input("You: ") # Take user input
if user_input.lower() in ["exit", "quit", "stop"]:
print("Chatbot: Goodbye!")
break
response = generate_response(user_input)
print("Chatbot:", response)'''
# Initialize the ChatInterface
chatbot = gr.ChatInterface(fn=generate_response, title="Mental Health Chatbot")
chatbot.launch()
'''
# Example
user_input = "I'm feeling suicidal."
response = generate_response(user_input)
print("Chatbot: ", response)
'''
|