krisha06 commited on
Commit
23b09ce
·
verified ·
1 Parent(s): 8f45aef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -60
app.py CHANGED
@@ -1,71 +1,181 @@
1
  import streamlit as st
2
- from datasets import load_dataset
3
- from sentence_transformers import SentenceTransformer
4
  import chromadb
 
 
 
 
 
5
 
6
- # Load dataset
7
- # Load dataset
8
  def load_recipes():
9
  try:
10
- dataset = load_dataset("mbien/recipe_nlg", split="train", trust_remote_code=True)
11
- print(" Dataset loaded successfully!")
12
- return dataset
 
 
 
 
13
  except Exception as e:
14
- print(f" Error loading dataset: {e}")
15
- return None
16
 
17
  recipes_df = load_recipes()
18
 
19
- if recipes_df is None:
20
- st.error("❌ Failed to load dataset! Check internet or dataset availability.")
21
- st.stop() # Stops Streamlit from running further if the dataset isn't loaded
22
-
23
- # Load embedding model
24
  @st.cache_resource
25
  def load_embedding_model():
26
- return SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
27
-
28
- embed_model = load_embedding_model()
29
-
30
- # Initialize ChromaDB
31
- chroma_client = chromadb.PersistentClient(path="./chroma_db") # Saves vectors
32
- recipe_collection = chroma_client.get_or_create_collection(name="recipes")
33
-
34
- # Ensure recipes_df is iterable
35
- if isinstance(recipes_df, list) or isinstance(recipes_df, dict):
36
- if recipe_collection.count() == 0:
37
- st.info("Indexing recipes... This will take a few minutes.")
38
- for i, recipe in enumerate(recipes_df):
39
- title = recipe.get("title", "Unknown Title") # Handle missing keys
40
- ingredients = ", ".join(recipe.get("ingredients", []))
41
- instructions = recipe.get("instructions", "No instructions available")
42
-
43
- embedding = embed_model.encode(title).tolist()
44
- recipe_collection.add(
45
- ids=[str(i)],
46
- embeddings=[embedding],
47
- metadatas=[{"title": title, "ingredients": ingredients, "index": i}],
48
- )
49
- else:
50
- st.error("❌ Dataset is not in the correct format!")
51
-
52
- # UI
53
- st.title("🍽️ AI Recipe Finder with ChromaDB RAG")
54
- query = st.text_input("🔍 Search for a recipe (e.g., pasta, cake)")
55
-
56
- if query:
57
- query_embedding = embed_model.encode(query).tolist()
58
- results = recipe_collection.query(
59
- query_embeddings=[query_embedding], n_results=5
60
- )
61
-
62
- st.subheader("🔎 Most relevant recipes:")
63
- for result in results["metadatas"][0]:
64
- index = result["index"]
65
- recipe = recipes_df[index]
66
- st.write(f"**🍴 {recipe.get('title', 'No title available')}**")
67
- st.write(f"**Ingredients:** {', '.join(recipe.get('ingredients', []))}")
68
- st.write(f"**Instructions:** {recipe.get('instructions', 'No instructions available')}")
69
- st.write("---")
70
- else:
71
- st.info("Type a recipe name to find similar recipes.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import pandas as pd
 
3
  import chromadb
4
+ from sentence_transformers import SentenceTransformer
5
+ from transformers import pipeline, AutoModelForQuestionAnswering, AutoTokenizer
6
+ from PIL import Image
7
+ from io import BytesIO
8
+ import requests
9
 
10
+ # --- 1. Load Recipes Dataset ---
11
+ @st.cache_data
12
  def load_recipes():
13
  try:
14
+ recipes_df = pd.read_csv("recipes.csv")
15
+ recipes_df = recipes_df.rename(columns={"recipe_name": "title", "directions": "instructions"})
16
+ recipes_df = recipes_df[['title', 'ingredients', 'instructions', 'img_src']]
17
+ recipes_df.fillna("", inplace=True)
18
+ recipes_df["ingredients"] = recipes_df["ingredients"].str.lower().str.replace(r'[^\w\s]', '', regex=True)
19
+ recipes_df["combined_text"] = recipes_df["title"] + " " + recipes_df["ingredients"]
20
+ return recipes_df
21
  except Exception as e:
22
+ st.error(f" Error loading recipes: {e}")
23
+ return pd.DataFrame()
24
 
25
  recipes_df = load_recipes()
26
 
27
+ # --- 2. Load SentenceTransformer Model ---
 
 
 
 
28
  @st.cache_resource
29
  def load_embedding_model():
30
+ return SentenceTransformer("all-mpnet-base-v2")
31
+
32
+ embedding_model = load_embedding_model()
33
+
34
+ # --- 3. Initialize ChromaDB ---
35
+ chroma_client = chromadb.PersistentClient(path="./chroma_db")
36
+ collection = chroma_client.get_or_create_collection(name="recipe_collection")
37
+
38
+ # --- 4. Generate & Store Embeddings ---
39
+ def get_sentence_transformer_embeddings(text):
40
+ return embedding_model.encode(text).tolist()
41
+
42
+ try:
43
+ existing_data = collection.get()
44
+ existing_ids = set(existing_data["ids"]) if existing_data and "ids" in existing_data else set()
45
+ except Exception as e:
46
+ st.error(f"⚠ ChromaDB Error: {e}")
47
+ existing_ids = set()
48
+
49
+ for index, row in recipes_df.iterrows():
50
+ recipe_id = str(index)
51
+ if recipe_id in existing_ids:
52
+ continue
53
+ embedding = get_sentence_transformer_embeddings(row["combined_text"])
54
+ if embedding:
55
+ collection.add(embeddings=[embedding], documents=[row["combined_text"]], ids=[recipe_id])
56
+
57
+ # --- 5. Retrieve Similar Recipes ---
58
+ def retrieve_recipes(query, top_k=3):
59
+ query_embedding = get_sentence_transformer_embeddings(query)
60
+ results = collection.query(query_embeddings=[query_embedding], n_results=top_k)
61
+
62
+ if results and "ids" in results and results["ids"]: # Check existence before accessing
63
+ recipe_indices = [int(id) for id in results["ids"][0] if id.isdigit()]
64
+ return recipes_df.iloc[recipe_indices] if recipe_indices else None
65
+ return None
66
+
67
+ # --- 6. Load a Compatible LLM for Q&A ---
68
+ @st.cache_resource
69
+ def load_llm_model():
70
+ tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2") # Better Q&A model
71
+ model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
72
+ return pipeline("question-answering", model=model, tokenizer=tokenizer)
73
+
74
+ llm_model = load_llm_model()
75
+
76
+ # --- 5. Answer Greeting and Handle Q&A Queries ---
77
+ def answer_question(query, context=""):
78
+ # Handle greetings or non-informational queries
79
+ greetings = ["hi", "hello", "hii", "hey", "greetings", "how are you", "what's up", "how's it going"]
80
+ if query.lower().strip() in greetings:
81
+ return "Hello! How can I assist you today? Feel free to ask about recipes or any other questions."
82
+
83
+ # If not a greeting, check if it is a valid Q&A query
84
+ if query.lower().strip() not in greetings:
85
+ # Use the QA model for other questions
86
+ response = qa_model(question=query, context=context)
87
+
88
+ # Check if the response from the model is valid
89
+ if response and "answer" in response and response["answer"].strip():
90
+ return response["answer"]
91
+ else:
92
+ return "I'm sorry, I couldn't generate a response for your query."
93
+
94
+ return None
95
+
96
+ # --- 6. Classify Query Type (Q&A or Recipe Search) ---
97
+ @st.cache_resource
98
+ def load_classifier():
99
+ return pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
100
+
101
+ classifier = load_classifier()
102
+
103
+ def classify_query(query):
104
+ # Keywords that may indicate a recipe-related query
105
+ recipe_keywords = ["make", "cook", "bake", "recipe", "prepare"]
106
+
107
+ # Check if query contains common recipe-related keywords
108
+ if any(keyword in query.lower() for keyword in recipe_keywords):
109
+ return "Recipe Search"
110
+
111
+ labels = ["Q&A", "Recipe Search"]
112
+ result = classifier(query, labels)
113
+ return result["labels"][0]
114
+
115
+ # --- 8. Display Image Function ---
116
+ def display_image(image_url, recipe_name):
117
+ try:
118
+ if not isinstance(image_url, str) or not image_url.startswith("http"):
119
+ raise ValueError("Invalid or missing image URL")
120
+ response = requests.get(image_url, timeout=5)
121
+ response.raise_for_status()
122
+ image = Image.open(BytesIO(response.content))
123
+ st.image(image, caption=recipe_name, use_container_width=True)
124
+ except requests.exceptions.RequestException as e:
125
+ st.warning(f"⚠ Image fetch error: {e}")
126
+ placeholder_url = "https://via.placeholder.com/300?text=No+Image"
127
+ st.image(placeholder_url, caption=recipe_name, use_container_width=True)
128
+
129
+ # --- Streamlit UI ---
130
+ st.title("🍽️ AI Recipe & Q&A Assistant")
131
+
132
+ # Unique key for the main user query input
133
+ user_query = st.text_input("Enter your question or recipe search query:", "", key="main_query_input")
134
+
135
+ # Use session state to store the retrieved recipe
136
+ if "retrieved_recipes" not in st.session_state:
137
+ st.session_state["retrieved_recipes"] = None
138
+
139
+ if st.button("Ask AI"):
140
+ if user_query:
141
+ # Handle greetings and other specific queries with answer_question
142
+ response = answer_question(user_query)
143
+
144
+ if response and "Hello!" in response:
145
+ st.subheader("🤖 AI Answer:")
146
+ st.write(response)
147
+ else:
148
+ # Classify the query if not a greeting
149
+ intent = classify_query(user_query)
150
+
151
+ if intent == "Q&A":
152
+ st.subheader("🤖 AI Answer:")
153
+ context = "You can add specific context here, or leave it empty."
154
+ response = answer_question(user_query, context)
155
+ st.write(response)
156
+
157
+ elif intent == "Recipe Search":
158
+ retrieved_recipes = retrieve_recipes(user_query)
159
+ if retrieved_recipes is not None and not retrieved_recipes.empty:
160
+ st.session_state["retrieved_recipes"] = retrieved_recipes # Store retrieved recipes in session state
161
+ st.subheader("🍴 Found Recipes:")
162
+ for index, recipe in retrieved_recipes.iterrows():
163
+ st.markdown(f"### {recipe['title']}")
164
+ st.write(f"**Ingredients:** {recipe['ingredients']}")
165
+ st.write(f"**Instructions:** {recipe['instructions']}")
166
+ display_image(recipe.get('img_src', ''), recipe['title'])
167
+
168
+ # Unique key for each follow-up question input
169
+ follow_up_query = st.text_input(
170
+ "Any modifications or follow-up questions about this recipe?",
171
+ key=f"follow_up_query_{index}"
172
+ )
173
+
174
+ if st.button(f"Submit Follow-up for {recipe['title']}", key=f"submit_follow_up_{index}"):
175
+ # Handle follow-up query
176
+ response = handle_follow_up(follow_up_query, recipe)
177
+ st.write(response)
178
+ else:
179
+ st.warning("⚠️ No relevant recipes found.")
180
+ else:
181
+ st.warning("❌ Unable to classify the query.")