redfernstech commited on
Commit
0259037
·
verified ·
1 Parent(s): 195dc2b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -72
app.py CHANGED
@@ -236,189 +236,240 @@ from fastapi.templating import Jinja2Templates
236
  from simple_salesforce import Salesforce, SalesforceLogin
237
  from langchain_groq import ChatGroq
238
  from langchain_core.prompts import ChatPromptTemplate
239
- from llama_index.core import StorageContext, VectorStoreIndex, SimpleDirectoryReader, Settings, load_index_from_storage
 
240
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
241
 
242
  # Configure logging
243
  logging.basicConfig(level=logging.INFO)
244
  logger = logging.getLogger(__name__)
245
 
 
246
  class MessageRequest(BaseModel):
247
  message: str
248
 
 
249
  app = FastAPI()
250
 
 
251
  app.add_middleware(
252
  CORSMiddleware,
253
- allow_origins=["*"],
254
  allow_credentials=True,
255
  allow_methods=["*"],
256
  allow_headers=["*"],
257
  )
258
 
 
259
  app.mount("/static", StaticFiles(directory="static"), name="static")
260
  templates = Jinja2Templates(directory="static")
261
 
 
262
  required_env_vars = ["CHATGROQ_API_KEY", "username", "password", "security_token", "domain", "HF_TOKEN"]
263
  for var in required_env_vars:
264
  if not os.getenv(var):
265
- logger.error(f"Missing environment variable: {var}")
266
  raise ValueError(f"Environment variable {var} is not set")
267
 
268
- # LLM & Embedding Setup
269
  GROQ_API_KEY = os.getenv("CHATGROQ_API_KEY")
270
- llm = ChatGroq(model_name="llama3-8b-8192", api_key=GROQ_API_KEY, temperature=0.1, max_tokens=50)
 
 
 
 
 
 
 
 
 
 
 
 
271
  Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
272
 
273
- # Salesforce setup
 
 
 
 
 
 
274
  sf = None
275
  try:
276
  session_id, sf_instance = SalesforceLogin(
277
- username=os.getenv("username"),
278
- password=os.getenv("password"),
279
- security_token=os.getenv("security_token"),
280
- domain=os.getenv("domain")
281
  )
282
  sf = Salesforce(instance=sf_instance, session_id=session_id)
283
- logger.info("Salesforce connected.")
284
  except Exception as e:
285
- logger.warning(f"Salesforce connection failed: {e}")
286
 
 
287
  chat_history = []
288
  current_chat_history = []
289
  MAX_HISTORY_SIZE = 100
290
 
 
291
  PDF_DIRECTORY = "data"
292
  PERSIST_DIR = "db"
 
 
293
  os.makedirs(PDF_DIRECTORY, exist_ok=True)
294
  os.makedirs(PERSIST_DIR, exist_ok=True)
295
 
296
  def data_ingestion_from_directory():
 
297
  try:
298
  documents = SimpleDirectoryReader(PDF_DIRECTORY).load_data()
299
  if not documents:
300
  logger.warning("No documents found in PDF_DIRECTORY")
301
- return
302
  storage_context = StorageContext.from_defaults()
303
  index = VectorStoreIndex.from_documents(documents, storage_context=storage_context)
304
  index.storage_context.persist(persist_dir=PERSIST_DIR)
305
- logger.info("Data ingestion and embedding complete.")
 
306
  except Exception as e:
307
- logger.error(f"Data ingestion failed: {e}")
308
- raise HTTPException(status_code=500, detail="Data ingestion failed")
309
 
310
  def initialize():
 
311
  try:
312
- data_ingestion_from_directory()
 
313
  except Exception as e:
314
- logger.error(f"Initialization error: {e}")
315
- raise HTTPException(status_code=500, detail="Startup initialization failed")
316
 
 
317
  initialize()
318
 
319
  def handle_query(query: str) -> str:
 
 
320
  chat_context = ""
321
  for past_query, response in reversed(current_chat_history[-10:]):
322
- chat_context += f"User: {past_query}\nBot: {response}\n"
 
323
 
324
- # Load index
325
  try:
326
  storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
327
  index = load_index_from_storage(storage_context)
328
  query_engine = index.as_query_engine(similarity_top_k=2)
329
  retrieved = query_engine.query(query)
330
- doc_context = getattr(retrieved, 'response', "No relevant documents found.")
 
331
  except Exception as e:
332
- logger.error(f"Retrieval error: {e}")
333
- doc_context = "No relevant documents found."
334
 
335
- # Prompt template
336
  prompt_template = ChatPromptTemplate.from_messages([
337
  ("system", """
338
- You are a helpful and professional company chatbot.
339
- Answer user queries based on the provided document context and chat history.
340
- If you are unsure about the answer, politely respond with "I'm sorry, I don't know that yet."
341
 
342
- Document Context:
343
- {doc_context}
344
 
345
- Chat History:
346
- {chat_context}
347
 
348
- Question:
349
- {query}
350
- """)
351
  ])
352
  prompt = prompt_template.format(doc_context=doc_context, chat_context=chat_context, query=query)
353
 
 
354
  try:
355
  response = llm.invoke(prompt)
356
  response_text = response.content.strip()
357
- if "I'm sorry" not in response_text and len(response_text.strip()) < 3:
358
- response_text = "I'm sorry, I don't know that yet."
359
  except Exception as e:
360
- logger.error(f"Groq API Error: {e}")
361
- response_text = "I'm sorry, I don't know that yet."
362
 
 
363
  if len(current_chat_history) >= MAX_HISTORY_SIZE:
364
  current_chat_history.pop(0)
365
  current_chat_history.append((query, response_text))
366
-
367
  return response_text
368
 
369
  @app.get("/ch/{id}", response_class=HTMLResponse)
370
  async def load_chat(request: Request, id: str):
 
371
  return templates.TemplateResponse("index.html", {"request": request, "user_id": id})
372
 
373
  @app.post("/hist/")
374
  async def save_chat_history(history: dict):
 
375
  if not sf:
376
- return JSONResponse({"error": "Salesforce not connected"}, status_code=503)
 
377
 
378
  user_id = history.get('userId')
379
  if not user_id:
380
- return JSONResponse({"error": "userId missing"}, status_code=400)
 
381
 
382
- hist = '\n'.join([f"{entry['sender']}: {entry['message']}" for entry in history.get("history", [])])
383
- summary = "This is the chat summary: " + hist
384
 
385
  try:
386
- sf.Lead.update(user_id, {'Description': summary})
387
- return {"summary": summary, "message": "Chat history saved"}
 
388
  except Exception as e:
389
- return JSONResponse({"error": str(e)}, status_code=500)
 
390
 
391
  @app.post("/webhook")
392
  async def receive_form_data(request: Request):
 
393
  if not sf:
394
- return JSONResponse({"error": "Salesforce not connected"}, status_code=503)
 
395
 
396
  try:
397
  form_data = await request.json()
398
  except json.JSONDecodeError:
 
399
  return JSONResponse({"error": "Invalid JSON"}, status_code=400)
400
 
401
- first_name, last_name = split_name(form_data.get("name", ""))
402
- lead_data = {
403
- "FirstName": first_name,
404
- "LastName": last_name,
405
- "Company": form_data.get("company", ""),
406
- "Phone": form_data.get("phone", ""),
407
- "Email": form_data.get("email", ""),
408
- "Description": "Lead from website form"
409
  }
410
 
411
  try:
412
- result = sf.Lead.create(lead_data)
413
- return {"id": result.get("id")}
 
 
414
  except Exception as e:
415
- return JSONResponse({"error": str(e)}, status_code=500)
 
416
 
417
  @app.post("/chat/")
418
  async def chat(request: MessageRequest):
419
- message = request.message
 
 
 
 
420
  response = handle_query(message)
421
- chat_entry = {
422
  "sender": "User",
423
  "message": message,
424
  "response": response,
@@ -426,24 +477,28 @@ async def chat(request: MessageRequest):
426
  }
427
  if len(chat_history) >= MAX_HISTORY_SIZE:
428
  chat_history.pop(0)
429
- chat_history.append(chat_entry)
 
430
  return {"response": response}
431
 
432
  @app.get("/health")
433
  async def health_check():
 
434
  try:
435
  storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
436
- load_index_from_storage(storage_context)
437
- return {"status": "healthy"}
 
438
  except Exception as e:
 
439
  return {"status": "unhealthy", "error": str(e)}
440
 
441
  @app.get("/")
442
- def read_root():
443
- return {"message": "Welcome to the company chatbot API"}
444
-
445
- def split_name(full_name):
446
- parts = full_name.strip().split()
447
- if len(parts) == 1:
448
- return '', parts[0]
449
- return parts[0], ' '.join(parts[1:])
 
236
  from simple_salesforce import Salesforce, SalesforceLogin
237
  from langchain_groq import ChatGroq
238
  from langchain_core.prompts import ChatPromptTemplate
239
+ from llama_index.core import StorageContext, VectorStoreIndex, SimpleDirectoryReader, Settings
240
+ from llama_index.core import load_index_from_storage
241
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
242
 
243
  # Configure logging
244
  logging.basicConfig(level=logging.INFO)
245
  logger = logging.getLogger(__name__)
246
 
247
+ # Define Pydantic model for incoming request body
248
  class MessageRequest(BaseModel):
249
  message: str
250
 
251
+ # Initialize FastAPI app
252
  app = FastAPI()
253
 
254
+ # Allow CORS (restrict origins in production)
255
  app.add_middleware(
256
  CORSMiddleware,
257
+ allow_origins=["*"], # TODO: Restrict to specific origins in production
258
  allow_credentials=True,
259
  allow_methods=["*"],
260
  allow_headers=["*"],
261
  )
262
 
263
+ # Mount static files and templates
264
  app.mount("/static", StaticFiles(directory="static"), name="static")
265
  templates = Jinja2Templates(directory="static")
266
 
267
+ # Validate environment variables
268
  required_env_vars = ["CHATGROQ_API_KEY", "username", "password", "security_token", "domain", "HF_TOKEN"]
269
  for var in required_env_vars:
270
  if not os.getenv(var):
271
+ logger.error(f"Environment variable {var} is not set")
272
  raise ValueError(f"Environment variable {var} is not set")
273
 
274
+ # Initialize Groq model
275
  GROQ_API_KEY = os.getenv("CHATGROQ_API_KEY")
276
+ GROQ_MODEL = "llama3-8b-8192"
277
+ try:
278
+ llm = ChatGroq(
279
+ model_name=GROQ_MODEL,
280
+ api_key=GROQ_API_KEY,
281
+ temperature=0.1,
282
+ max_tokens=50
283
+ )
284
+ except Exception as e:
285
+ logger.error(f"Failed to initialize Groq model: {e}")
286
+ raise HTTPException(status_code=500, detail="Failed to initialize Groq model")
287
+
288
+ # Configure LlamaIndex settings
289
  Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
290
 
291
+ # Salesforce credentials
292
+ username = os.getenv("username")
293
+ password = os.getenv("password")
294
+ security_token = os.getenv("security_token")
295
+ domain = os.getenv("domain") # e.g., 'test' for sandbox
296
+
297
+ # Initialize Salesforce connection
298
  sf = None
299
  try:
300
  session_id, sf_instance = SalesforceLogin(
301
+ username=username, password=password, security_token=security_token, domain=domain
 
 
 
302
  )
303
  sf = Salesforce(instance=sf_instance, session_id=session_id)
304
+ logger.info("Salesforce connection established")
305
  except Exception as e:
306
+ logger.warning(f"Failed to connect to Salesforce: {e}. Continuing without Salesforce integration.")
307
 
308
+ # Chat history
309
  chat_history = []
310
  current_chat_history = []
311
  MAX_HISTORY_SIZE = 100
312
 
313
+ # Directories for data ingestion and storage
314
  PDF_DIRECTORY = "data"
315
  PERSIST_DIR = "db"
316
+
317
+ # Ensure directories exist
318
  os.makedirs(PDF_DIRECTORY, exist_ok=True)
319
  os.makedirs(PERSIST_DIR, exist_ok=True)
320
 
321
  def data_ingestion_from_directory():
322
+ """Ingest documents from PDF_DIRECTORY and store embeddings in PERSIST_DIR."""
323
  try:
324
  documents = SimpleDirectoryReader(PDF_DIRECTORY).load_data()
325
  if not documents:
326
  logger.warning("No documents found in PDF_DIRECTORY")
327
+ return False
328
  storage_context = StorageContext.from_defaults()
329
  index = VectorStoreIndex.from_documents(documents, storage_context=storage_context)
330
  index.storage_context.persist(persist_dir=PERSIST_DIR)
331
+ logger.info("Data ingestion and embedding storage completed successfully")
332
+ return True
333
  except Exception as e:
334
+ logger.error(f"Error during data ingestion: {e}")
335
+ raise HTTPException(status_code=500, detail=f"Data ingestion failed: {str(e)}")
336
 
337
  def initialize():
338
+ """Initialize the application by ingesting data and setting up embeddings."""
339
  try:
340
+ if not data_ingestion_from_directory():
341
+ logger.info("No documents to ingest, proceeding with empty index")
342
  except Exception as e:
343
+ logger.error(f"Initialization failed: {e}")
344
+ raise HTTPException(status_code=500, detail="Initialization failed")
345
 
346
+ # Run initialization
347
  initialize()
348
 
349
  def handle_query(query: str) -> str:
350
+ """Handle user query by retrieving relevant documents and querying Groq LLM."""
351
+ # Prepare context from chat history
352
  chat_context = ""
353
  for past_query, response in reversed(current_chat_history[-10:]):
354
+ if past_query.strip():
355
+ chat_context += f"User: {past_query}\nBot: {response}\n"
356
 
357
+ # Load vector index and retrieve relevant documents
358
  try:
359
  storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
360
  index = load_index_from_storage(storage_context)
361
  query_engine = index.as_query_engine(similarity_top_k=2)
362
  retrieved = query_engine.query(query)
363
+ doc_context = retrieved.response if hasattr(retrieved, 'response') else "No relevant information found."
364
+ logger.info(f"Retrieved context for query '{query}': {doc_context[:100]}...")
365
  except Exception as e:
366
+ logger.error(f"Error retrieving documents: {e}")
367
+ doc_context = "Failed to retrieve relevant information."
368
 
369
+ # Construct prompt for Redferns Tech chatbot
370
  prompt_template = ChatPromptTemplate.from_messages([
371
  ("system", """
372
+ You are Clara, a chatbot for Redferns Tech. Provide accurate, professional answers in 10-15 words.
373
+ Use the provided document context and chat history to inform your response.
374
+ If you don't know the answer, politely say: "I'm sorry, I don't have the information to answer that."
375
 
376
+ Document Context:
377
+ {doc_context}
378
 
379
+ Chat History:
380
+ {chat_context}
381
 
382
+ Question:
383
+ {query}
384
+ """),
385
  ])
386
  prompt = prompt_template.format(doc_context=doc_context, chat_context=chat_context, query=query)
387
 
388
+ # Query Groq model
389
  try:
390
  response = llm.invoke(prompt)
391
  response_text = response.content.strip()
392
+ if not response_text or response_text.lower() == "unknown":
393
+ response_text = "I'm sorry, I don't have the information to answer that."
394
  except Exception as e:
395
+ logger.error(f"Error querying Groq API: {e}")
396
+ response_text = "I'm sorry, I don't have the information to answer that."
397
 
398
+ # Update chat history
399
  if len(current_chat_history) >= MAX_HISTORY_SIZE:
400
  current_chat_history.pop(0)
401
  current_chat_history.append((query, response_text))
 
402
  return response_text
403
 
404
  @app.get("/ch/{id}", response_class=HTMLResponse)
405
  async def load_chat(request: Request, id: str):
406
+ """Serve the chat interface for a specific user ID."""
407
  return templates.TemplateResponse("index.html", {"request": request, "user_id": id})
408
 
409
  @app.post("/hist/")
410
  async def save_chat_history(history: dict):
411
+ """Save chat history to Salesforce."""
412
  if not sf:
413
+ logger.error("Salesforce integration is disabled")
414
+ return JSONResponse({"error": "Salesforce integration is unavailable"}, status_code=503)
415
 
416
  user_id = history.get('userId')
417
  if not user_id:
418
+ logger.error("userId is missing in history request")
419
+ return JSONResponse({"error": "userId is required"}, status_code=400)
420
 
421
+ hist = ''.join([f"{entry['sender']}: {entry['message']}\n" for entry in history['history']])
422
+ summary_prompt = f"Summarize user interests from this conversation:\n{hist}"
423
 
424
  try:
425
+ sf.Lead.update(user_id, {'Description': summary_prompt})
426
+ logger.info(f"Chat history updated for user {user_id}")
427
+ return {"summary": summary_prompt, "message": "Chat history saved"}
428
  except Exception as e:
429
+ logger.error(f"Failed to update lead: {e}")
430
+ return JSONResponse({"error": f"Failed to update lead: {str(e)}"}, status_code=500)
431
 
432
  @app.post("/webhook")
433
  async def receive_form_data(request: Request):
434
+ """Create a Salesforce lead from form data."""
435
  if not sf:
436
+ logger.error("Salesforce integration is disabled")
437
+ return JSONResponse({"error": "Salesforce integration is unavailable"}, status_code=503)
438
 
439
  try:
440
  form_data = await request.json()
441
  except json.JSONDecodeError:
442
+ logger.error("Invalid JSON in webhook request")
443
  return JSONResponse({"error": "Invalid JSON"}, status_code=400)
444
 
445
+ first_name, last_name = split_name(form_data.get('name', ''))
446
+ data = {
447
+ 'FirstName': first_name,
448
+ 'LastName': last_name,
449
+ 'Description': 'Lead created via webhook',
450
+ 'Company': form_data.get('company', ''),
451
+ 'Phone': form_data.get('phone', '').strip(),
452
+ 'Email': form_data.get('email', ''),
453
  }
454
 
455
  try:
456
+ result = sf.Lead.create(data)
457
+ unique_id = result['id']
458
+ logger.info(f"Lead created with ID {unique_id}")
459
+ return JSONResponse({"id": unique_id})
460
  except Exception as e:
461
+ logger.error(f"Failed to create lead: {e}")
462
+ return JSONResponse({"error": f"Failed to create lead: {str(e)}"}, status_code=500)
463
 
464
  @app.post("/chat/")
465
  async def chat(request: MessageRequest):
466
+ """Handle chat messages and return responses."""
467
+ message = request.message.strip()
468
+ if not message:
469
+ return JSONResponse({"error": "Message cannot be empty"}, status_code=400)
470
+
471
  response = handle_query(message)
472
+ message_data = {
473
  "sender": "User",
474
  "message": message,
475
  "response": response,
 
477
  }
478
  if len(chat_history) >= MAX_HISTORY_SIZE:
479
  chat_history.pop(0)
480
+ chat_history.append(message_data)
481
+ logger.info(f"Chat message processed: {message}")
482
  return {"response": response}
483
 
484
  @app.get("/health")
485
  async def health_check():
486
+ """Check the health of the application."""
487
  try:
488
  storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
489
+ index = load_index_from_storage(storage_context)
490
+ logger.info("Vector index loaded successfully")
491
+ return {"status": "healthy", "pdf_ingestion": "successful"}
492
  except Exception as e:
493
+ logger.error(f"Health check failed: {e}")
494
  return {"status": "unhealthy", "error": str(e)}
495
 
496
  @app.get("/")
497
+ async def read_root():
498
+ """Root endpoint for the API."""
499
+ return {"message": "Welcome to the Redferns Tech Chatbot API"}
500
+
501
+ def split_name(full_name: str) -> tuple:
502
+ """Split a full name into first and last names."""
503
+ words = full_name.strip().split()
504
+ return (words[0], ' '.join(words[1:])) if words else ('', '')