nishantgaurav23 commited on
Commit
c650a86
·
verified ·
1 Parent(s): 6665ca1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -44
app.py CHANGED
@@ -447,62 +447,155 @@ class RAGPipeline:
447
  placeholder.warning(message)
448
  return message
449
 
450
- def query_model(self, prompt: str) -> str:
451
- """Query the local Llama model"""
452
- try:
453
- if self.llm is None:
454
- raise RuntimeError("Model not initialized")
455
 
456
- response = self.llm(
457
- prompt,
458
- max_tokens=512,
459
- temperature=0.4,
460
- top_p=0.95,
461
- echo=False,
462
- stop=["Question:", "Context:", "Guidelines:"], # Removed \n\n from stop tokens to allow paragraphs
463
- repeat_penalty=1.1 # Added to encourage more diverse text
464
- )
465
 
466
- if response and 'choices' in response and len(response['choices']) > 0:
467
- text = response['choices'][0].get('text', '').strip()
468
- return text
469
- else:
470
- raise ValueError("No valid response generated")
471
 
472
- except Exception as e:
473
- logging.error(f"Error in query_model: {str(e)}")
474
- raise
475
 
476
- @st.cache_resource(show_spinner=False)
477
- def initialize_rag_pipeline():
478
- """Initialize the RAG pipeline once"""
479
  try:
480
- # Create necessary directories
481
- os.makedirs("ESPN_data", exist_ok=True)
482
-
483
- # Load embeddings from Drive
484
- drive_file_id = "1MuV63AE9o6zR9aBvdSDQOUextp71r2NN"
485
- with st.spinner("Loading embeddings from Google Drive..."):
486
- cache_data = load_from_drive(drive_file_id)
487
- if cache_data is None:
488
- st.error("Failed to load embeddings from Google Drive")
489
- st.stop()
490
 
491
- # Initialize pipeline
492
- data_folder = "ESPN_data"
493
- rag = RAGPipeline(data_folder)
 
 
 
 
 
 
 
 
 
 
494
 
495
- # Store embeddings
496
- rag.documents = cache_data['documents']
497
- rag.retriever.store_embeddings(cache_data['embeddings'])
498
 
499
- return rag
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
500
 
501
  except Exception as e:
502
- logging.error(f"Pipeline initialization error: {str(e)}")
503
- st.error(f"Failed to initialize the system: {str(e)}")
504
  raise
505
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
506
  # def main():
507
  # try:
508
  # # Environment check
 
447
  placeholder.warning(message)
448
  return message
449
 
450
+ # def query_model(self, prompt: str) -> str:
451
+ # """Query the local Llama model"""
452
+ # try:
453
+ # if self.llm is None:
454
+ # raise RuntimeError("Model not initialized")
455
 
456
+ # response = self.llm(
457
+ # prompt,
458
+ # max_tokens=512,
459
+ # temperature=0.4,
460
+ # top_p=0.95,
461
+ # echo=False,
462
+ # stop=["Question:", "Context:", "Guidelines:"], # Removed \n\n from stop tokens to allow paragraphs
463
+ # repeat_penalty=1.1 # Added to encourage more diverse text
464
+ # )
465
 
466
+ # if response and 'choices' in response and len(response['choices']) > 0:
467
+ # text = response['choices'][0].get('text', '').strip()
468
+ # return text
469
+ # else:
470
+ # raise ValueError("No valid response generated")
471
 
472
+ # except Exception as e:
473
+ # logging.error(f"Error in query_model: {str(e)}")
474
+ # raise
475
 
476
+ def query_model(self, prompt: str) -> str:
477
+ """Query the local Llama model"""
 
478
  try:
479
+ if self.llm is None:
480
+ raise RuntimeError("Model not initialized")
481
+
482
+ # Log the prompt for debugging
483
+ logging.info(f"Sending prompt to model...")
 
 
 
 
 
484
 
485
+ # Generate response with more explicit parameters
486
+ response = self.llm(
487
+ prompt,
488
+ max_tokens=512, # Maximum length of the response
489
+ temperature=0.7, # Slightly increased for more dynamic responses
490
+ top_p=0.95, # Nucleus sampling parameter
491
+ top_k=50, # Top-k sampling parameter
492
+ echo=False, # Don't include prompt in response
493
+ stop=["Question:", "Context:", "Guidelines:"], # Stop tokens
494
+ repeat_penalty=1.1, # Penalize repetition
495
+ presence_penalty=0.5, # Encourage topic diversity
496
+ frequency_penalty=0.5 # Discourage word repetition
497
+ )
498
 
499
+ # Log the raw response for debugging
500
+ logging.info(f"Raw model response: {response}")
 
501
 
502
+ if response and isinstance(response, dict) and 'choices' in response and response['choices']:
503
+ generated_text = response['choices'][0].get('text', '').strip()
504
+ if generated_text:
505
+ logging.info(f"Generated text: {generated_text[:100]}...") # Log first 100 chars
506
+ return generated_text
507
+ else:
508
+ logging.warning("Model returned empty response")
509
+ raise ValueError("Empty response from model")
510
+ else:
511
+ logging.warning(f"Unexpected response format: {response}")
512
+ raise ValueError("Invalid response format from model")
513
+
514
+ except Exception as e:
515
+ logging.error(f"Error in query_model: {str(e)}")
516
+ logging.error("Full error details: ", exc_info=True)
517
+ raise
518
+
519
+ def initialize_model(self):
520
+ """Initialize the model with proper error handling and verification"""
521
+ try:
522
+ if not os.path.exists(self.model_path):
523
+ st.info("Downloading model... This may take a while.")
524
+ direct_url = "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_K_M.gguf"
525
+ download_file_with_progress(direct_url, self.model_path)
526
+
527
+ # Verify file exists and has content
528
+ if not os.path.exists(self.model_path):
529
+ raise FileNotFoundError(f"Model file {self.model_path} not found after download attempts")
530
+
531
+ if os.path.getsize(self.model_path) < 1000000: # Less than 1MB
532
+ os.remove(self.model_path)
533
+ raise ValueError("Downloaded model file is too small, likely corrupted")
534
+
535
+ # Updated model configuration
536
+ llm_config = {
537
+ "model_path": self.model_path,
538
+ "n_ctx": 4096, # Increased context window
539
+ "n_threads": 4,
540
+ "n_batch": 512,
541
+ "n_gpu_layers": 0,
542
+ "verbose": True, # Enable verbose mode for debugging
543
+ "use_mlock": False, # Disable memory locking
544
+ "last_n_tokens_size": 64, # Token window size for repeat penalty
545
+ "seed": -1 # Random seed for reproducibility
546
+ }
547
+
548
+ logging.info("Initializing Llama model...")
549
+ self.llm = Llama(**llm_config)
550
+
551
+ # Test the model
552
+ test_response = self.llm(
553
+ "Test response",
554
+ max_tokens=10,
555
+ temperature=0.7,
556
+ echo=False
557
+ )
558
+
559
+ if not test_response or 'choices' not in test_response:
560
+ raise RuntimeError("Model initialization test failed")
561
+
562
+ logging.info("Model initialized and tested successfully")
563
+ return self.llm
564
 
565
  except Exception as e:
566
+ logging.error(f"Error initializing model: {str(e)}")
 
567
  raise
568
 
569
+ # @st.cache_resource(show_spinner=False)
570
+ # def initialize_rag_pipeline():
571
+ # """Initialize the RAG pipeline once"""
572
+ # try:
573
+ # # Create necessary directories
574
+ # os.makedirs("ESPN_data", exist_ok=True)
575
+
576
+ # # Load embeddings from Drive
577
+ # drive_file_id = "1MuV63AE9o6zR9aBvdSDQOUextp71r2NN"
578
+ # with st.spinner("Loading embeddings from Google Drive..."):
579
+ # cache_data = load_from_drive(drive_file_id)
580
+ # if cache_data is None:
581
+ # st.error("Failed to load embeddings from Google Drive")
582
+ # st.stop()
583
+
584
+ # # Initialize pipeline
585
+ # data_folder = "ESPN_data"
586
+ # rag = RAGPipeline(data_folder)
587
+
588
+ # # Store embeddings
589
+ # rag.documents = cache_data['documents']
590
+ # rag.retriever.store_embeddings(cache_data['embeddings'])
591
+
592
+ # return rag
593
+
594
+ # except Exception as e:
595
+ # logging.error(f"Pipeline initialization error: {str(e)}")
596
+ # st.error(f"Failed to initialize the system: {str(e)}")
597
+ # raise
598
+
599
  # def main():
600
  # try:
601
  # # Environment check