File size: 1,883 Bytes
3506b46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import gradio as gr
import os
import keras_nlp
from transformers import AutoModelForCausalLM

# Set Kaggle API credentials using values from environment variables
os.environ["KAGGLE_USERNAME"] = os.environ.get("KAGGLE_USERNAME")
os.environ["KAGGLE_KEY"] = os.environ.get("KAGGLE_KEY")


# Load LoRA weights if you have them
LoRA_weights_path = "fined-tuned-model.lora.h5"
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_lm.backbone.enable_lora(rank=4)  # Enable LoRA with rank 4
gemma_lm.preprocessor.sequence_length = 512  # Limit sequence length
gemma_lm.backbone.load_lora_weights(LoRA_weights_path)  # Load LoRA weights

# Define the response generation function
def generate_response(message):
    # Create a prompt template
    template = "Instruction:\n{instruction}\n\nResponse:\n{response}"

    # Create the prompt with the current message
    prompt = template.format(instruction=message, response="")
    print("Prompt:\n", prompt)

    # Generate response from the model
    response = gemma_lm.generate(prompt, max_length=256)
    # Only keep the generated response
    response = response.split("Response:")[-1].strip()

    print("Generated Response:\n", response)

    # Extract and return the generated response text
    return response  # Adjust this if your model's output structure differs

# Create the Gradio chat interface
interface = gr.Interface(
    fn=generate_response,  # Function that generates responses
    inputs=gr.Textbox(placeholder="Hello, I am Sage, your mental health advisor", lines=2, scale=7),
    outputs=gr.Textbox(),
    title="Sage, your Mental Health Advisor",
#     description="Chat with Sage, your mental health advisor.",
#     live=True
)
proxy_prefix = os.environ.get("PROXY_PREFIX")
# Launch the Gradio app
interface.launch(server_name="0.0.0.0", server_port=8080, root_path=proxy_prefix, share=True)