import streamlit as st from transformers import AutoTokenizer, AutoModelForCausalLM import torch import os from dotenv import load_dotenv # Load environment variables load_dotenv() HF_TOKEN = os.getenv("HF_TOKEN") # App title and description st.title("I am Your GrowBuddy 🌱") st.write("Let me help you start gardening. Let's grow together!") # Function to load model only once def load_model(): try: # If model and tokenizer are already in session state, return them if "tokenizer" in st.session_state and "model" in st.session_state: return st.session_state.tokenizer, st.session_state.model else: tokenizer = AutoTokenizer.from_pretrained("TheSheBots/UrbanGardening", use_auth_token=HF_TOKEN) model = AutoModelForCausalLM.from_pretrained("unsloth/gemma-2-2b-bnb-4bit", use_auth_token=HF_TOKEN) # Store the model and tokenizer in session state st.session_state.tokenizer = tokenizer st.session_state.model = model return tokenizer, model except Exception as e: st.error(f"Failed to load model: {e}") return None, None # Load model and tokenizer (cached) tokenizer, model = load_model() if not tokenizer or not model: st.stop() # Default to CPU, or use GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) # Initialize session state messages if "messages" not in st.session_state: st.session_state.messages = [ {"role": "assistant", "content": "Hello there! How can I help you with gardening today?"} ] # Display conversation history for message in st.session_state.messages: with st.chat_message(message["role"]): st.write(message["content"]) # Create a text area to display logs log_box = st.empty() # Function to generate response with debugging logs def generate_response(prompt): try: # Tokenize input prompt with dynamic padding and truncation log_box.text_area("Debugging Logs", "Tokenizing the prompt...", height=200) inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device) # Display tokenized inputs log_box.text_area("Debugging Logs", f"Tokenized inputs: {inputs['input_ids']}", height=200) # Generate output from model log_box.text_area("Debugging Logs", "Generating output...", height=200) outputs = model.generate(inputs["input_ids"], max_new_tokens=100, temperature=0.7, do_sample=True) # Display the raw output from the model log_box.text_area("Debugging Logs", f"Raw model output (tokens): {outputs}", height=200) # Decode and return response response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Display the final decoded response log_box.text_area("Debugging Logs", f"Decoded response: {response}", height=200) return response except Exception as e: st.error(f"Error during text generation: {e}") return "Sorry, I couldn't process your request." # User input field for gardening questions user_input = st.chat_input("Type your gardening question here:") if user_input: with st.chat_message("user"): st.write(user_input) with st.chat_message("assistant"): with st.spinner("Generating your answer..."): response = generate_response(user_input) st.write(response) # Update session state st.session_state.messages.append({"role": "user", "content": user_input}) st.session_state.messages.append({"role": "assistant", "content": response})