Mattral commited on
Commit
c51711e
·
verified ·
1 Parent(s): 58dd7d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -14
app.py CHANGED
@@ -4,17 +4,19 @@ import torch
4
  import os
5
  from dotenv import load_dotenv
6
 
 
7
  load_dotenv()
8
-
9
  api_key = os.getenv("api_key")
 
10
  # App title and description
11
  st.title("I am Your GrowBuddy 🌱")
12
  st.write("Let me help you start gardening. Let's grow together!")
13
 
 
14
  def load_model():
15
  try:
16
- tokenizer = AutoTokenizer.from_pretrained("KhunPop/Gardening", use_auth_token=api_key)
17
- model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", use_auth_token=api_key)
18
  return tokenizer, model
19
  except Exception as e:
20
  st.error(f"Failed to load model: {e}")
@@ -30,46 +32,45 @@ if not tokenizer or not model:
30
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
  model = model.to(device)
32
 
33
- # Initialize session state messages if not already initialized
34
  if "messages" not in st.session_state:
35
  st.session_state.messages = [
36
  {"role": "assistant", "content": "Hello there! How can I help you with gardening today?"}
37
  ]
38
 
39
- # Display the conversation history
40
  for message in st.session_state.messages:
41
  with st.chat_message(message["role"]):
42
  st.write(message["content"])
43
 
 
44
  def generate_response(prompt):
45
  try:
46
- # Tokenize the input prompt
47
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device)
48
 
49
- # Ensure the model is generating properly (without a target)
50
- outputs = model.generate(inputs["input_ids"], max_new_tokens=150, temperature=0.7, do_sample=True)
51
 
52
- # Decode the output to text
53
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
54
  return response
55
  except Exception as e:
56
  st.error(f"Error during text generation: {e}")
57
  return "Sorry, I couldn't process your request."
58
 
59
- # User input field for asking questions
60
  user_input = st.chat_input("Type your gardening question here:")
61
 
62
  if user_input:
63
- # Display user message
64
  with st.chat_message("user"):
65
  st.write(user_input)
66
 
67
- # Generate and display assistant's response
68
  with st.chat_message("assistant"):
69
- with st.spinner("I'm gonna tell you..."):
70
  response = generate_response(user_input)
71
  st.write(response)
72
 
73
- # Update session state with the new conversation
74
  st.session_state.messages.append({"role": "user", "content": user_input})
75
  st.session_state.messages.append({"role": "assistant", "content": response})
 
4
  import os
5
  from dotenv import load_dotenv
6
 
7
+ # Load environment variables
8
  load_dotenv()
 
9
  api_key = os.getenv("api_key")
10
+
11
  # App title and description
12
  st.title("I am Your GrowBuddy 🌱")
13
  st.write("Let me help you start gardening. Let's grow together!")
14
 
15
+ # Function to load model
16
  def load_model():
17
  try:
18
+ tokenizer = AutoTokenizer.from_pretrained("KhunPop/Gardening", use_auth_token=api_key)
19
+ model = AutoModelForCausalLM.from_pretrained("QuantFactory/leniachat-gemma-2b-v0-GGUF", use_auth_token=api_key)
20
  return tokenizer, model
21
  except Exception as e:
22
  st.error(f"Failed to load model: {e}")
 
32
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
  model = model.to(device)
34
 
35
+ # Initialize session state messages
36
  if "messages" not in st.session_state:
37
  st.session_state.messages = [
38
  {"role": "assistant", "content": "Hello there! How can I help you with gardening today?"}
39
  ]
40
 
41
+ # Display conversation history
42
  for message in st.session_state.messages:
43
  with st.chat_message(message["role"]):
44
  st.write(message["content"])
45
 
46
+ # Function to generate response
47
  def generate_response(prompt):
48
  try:
49
+ # Tokenize input prompt with dynamic padding and truncation
50
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device)
51
 
52
+ # Generate output from model
53
+ outputs = model.generate(inputs["input_ids"], max_new_tokens=100, temperature=0.7, do_sample=True)
54
 
55
+ # Decode and return response
56
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
57
  return response
58
  except Exception as e:
59
  st.error(f"Error during text generation: {e}")
60
  return "Sorry, I couldn't process your request."
61
 
62
+ # User input field for gardening questions
63
  user_input = st.chat_input("Type your gardening question here:")
64
 
65
  if user_input:
 
66
  with st.chat_message("user"):
67
  st.write(user_input)
68
 
 
69
  with st.chat_message("assistant"):
70
+ with st.spinner("Generating your answer..."):
71
  response = generate_response(user_input)
72
  st.write(response)
73
 
74
+ # Update session state
75
  st.session_state.messages.append({"role": "user", "content": user_input})
76
  st.session_state.messages.append({"role": "assistant", "content": response})