Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -447,62 +447,155 @@ class RAGPipeline:
|
|
447 |
placeholder.warning(message)
|
448 |
return message
|
449 |
|
450 |
-
def query_model(self, prompt: str) -> str:
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
|
476 |
-
|
477 |
-
|
478 |
-
"""Initialize the RAG pipeline once"""
|
479 |
try:
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
#
|
484 |
-
|
485 |
-
with st.spinner("Loading embeddings from Google Drive..."):
|
486 |
-
cache_data = load_from_drive(drive_file_id)
|
487 |
-
if cache_data is None:
|
488 |
-
st.error("Failed to load embeddings from Google Drive")
|
489 |
-
st.stop()
|
490 |
|
491 |
-
#
|
492 |
-
|
493 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
494 |
|
495 |
-
#
|
496 |
-
|
497 |
-
rag.retriever.store_embeddings(cache_data['embeddings'])
|
498 |
|
499 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
500 |
|
501 |
except Exception as e:
|
502 |
-
logging.error(f"
|
503 |
-
st.error(f"Failed to initialize the system: {str(e)}")
|
504 |
raise
|
505 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
506 |
# def main():
|
507 |
# try:
|
508 |
# # Environment check
|
|
|
447 |
placeholder.warning(message)
|
448 |
return message
|
449 |
|
450 |
+
# def query_model(self, prompt: str) -> str:
|
451 |
+
# """Query the local Llama model"""
|
452 |
+
# try:
|
453 |
+
# if self.llm is None:
|
454 |
+
# raise RuntimeError("Model not initialized")
|
455 |
|
456 |
+
# response = self.llm(
|
457 |
+
# prompt,
|
458 |
+
# max_tokens=512,
|
459 |
+
# temperature=0.4,
|
460 |
+
# top_p=0.95,
|
461 |
+
# echo=False,
|
462 |
+
# stop=["Question:", "Context:", "Guidelines:"], # Removed \n\n from stop tokens to allow paragraphs
|
463 |
+
# repeat_penalty=1.1 # Added to encourage more diverse text
|
464 |
+
# )
|
465 |
|
466 |
+
# if response and 'choices' in response and len(response['choices']) > 0:
|
467 |
+
# text = response['choices'][0].get('text', '').strip()
|
468 |
+
# return text
|
469 |
+
# else:
|
470 |
+
# raise ValueError("No valid response generated")
|
471 |
|
472 |
+
# except Exception as e:
|
473 |
+
# logging.error(f"Error in query_model: {str(e)}")
|
474 |
+
# raise
|
475 |
|
476 |
+
def query_model(self, prompt: str) -> str:
|
477 |
+
"""Query the local Llama model"""
|
|
|
478 |
try:
|
479 |
+
if self.llm is None:
|
480 |
+
raise RuntimeError("Model not initialized")
|
481 |
+
|
482 |
+
# Log the prompt for debugging
|
483 |
+
logging.info(f"Sending prompt to model...")
|
|
|
|
|
|
|
|
|
|
|
484 |
|
485 |
+
# Generate response with more explicit parameters
|
486 |
+
response = self.llm(
|
487 |
+
prompt,
|
488 |
+
max_tokens=512, # Maximum length of the response
|
489 |
+
temperature=0.7, # Slightly increased for more dynamic responses
|
490 |
+
top_p=0.95, # Nucleus sampling parameter
|
491 |
+
top_k=50, # Top-k sampling parameter
|
492 |
+
echo=False, # Don't include prompt in response
|
493 |
+
stop=["Question:", "Context:", "Guidelines:"], # Stop tokens
|
494 |
+
repeat_penalty=1.1, # Penalize repetition
|
495 |
+
presence_penalty=0.5, # Encourage topic diversity
|
496 |
+
frequency_penalty=0.5 # Discourage word repetition
|
497 |
+
)
|
498 |
|
499 |
+
# Log the raw response for debugging
|
500 |
+
logging.info(f"Raw model response: {response}")
|
|
|
501 |
|
502 |
+
if response and isinstance(response, dict) and 'choices' in response and response['choices']:
|
503 |
+
generated_text = response['choices'][0].get('text', '').strip()
|
504 |
+
if generated_text:
|
505 |
+
logging.info(f"Generated text: {generated_text[:100]}...") # Log first 100 chars
|
506 |
+
return generated_text
|
507 |
+
else:
|
508 |
+
logging.warning("Model returned empty response")
|
509 |
+
raise ValueError("Empty response from model")
|
510 |
+
else:
|
511 |
+
logging.warning(f"Unexpected response format: {response}")
|
512 |
+
raise ValueError("Invalid response format from model")
|
513 |
+
|
514 |
+
except Exception as e:
|
515 |
+
logging.error(f"Error in query_model: {str(e)}")
|
516 |
+
logging.error("Full error details: ", exc_info=True)
|
517 |
+
raise
|
518 |
+
|
519 |
+
def initialize_model(self):
|
520 |
+
"""Initialize the model with proper error handling and verification"""
|
521 |
+
try:
|
522 |
+
if not os.path.exists(self.model_path):
|
523 |
+
st.info("Downloading model... This may take a while.")
|
524 |
+
direct_url = "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_K_M.gguf"
|
525 |
+
download_file_with_progress(direct_url, self.model_path)
|
526 |
+
|
527 |
+
# Verify file exists and has content
|
528 |
+
if not os.path.exists(self.model_path):
|
529 |
+
raise FileNotFoundError(f"Model file {self.model_path} not found after download attempts")
|
530 |
+
|
531 |
+
if os.path.getsize(self.model_path) < 1000000: # Less than 1MB
|
532 |
+
os.remove(self.model_path)
|
533 |
+
raise ValueError("Downloaded model file is too small, likely corrupted")
|
534 |
+
|
535 |
+
# Updated model configuration
|
536 |
+
llm_config = {
|
537 |
+
"model_path": self.model_path,
|
538 |
+
"n_ctx": 4096, # Increased context window
|
539 |
+
"n_threads": 4,
|
540 |
+
"n_batch": 512,
|
541 |
+
"n_gpu_layers": 0,
|
542 |
+
"verbose": True, # Enable verbose mode for debugging
|
543 |
+
"use_mlock": False, # Disable memory locking
|
544 |
+
"last_n_tokens_size": 64, # Token window size for repeat penalty
|
545 |
+
"seed": -1 # Random seed for reproducibility
|
546 |
+
}
|
547 |
+
|
548 |
+
logging.info("Initializing Llama model...")
|
549 |
+
self.llm = Llama(**llm_config)
|
550 |
+
|
551 |
+
# Test the model
|
552 |
+
test_response = self.llm(
|
553 |
+
"Test response",
|
554 |
+
max_tokens=10,
|
555 |
+
temperature=0.7,
|
556 |
+
echo=False
|
557 |
+
)
|
558 |
+
|
559 |
+
if not test_response or 'choices' not in test_response:
|
560 |
+
raise RuntimeError("Model initialization test failed")
|
561 |
+
|
562 |
+
logging.info("Model initialized and tested successfully")
|
563 |
+
return self.llm
|
564 |
|
565 |
except Exception as e:
|
566 |
+
logging.error(f"Error initializing model: {str(e)}")
|
|
|
567 |
raise
|
568 |
|
569 |
+
# @st.cache_resource(show_spinner=False)
|
570 |
+
# def initialize_rag_pipeline():
|
571 |
+
# """Initialize the RAG pipeline once"""
|
572 |
+
# try:
|
573 |
+
# # Create necessary directories
|
574 |
+
# os.makedirs("ESPN_data", exist_ok=True)
|
575 |
+
|
576 |
+
# # Load embeddings from Drive
|
577 |
+
# drive_file_id = "1MuV63AE9o6zR9aBvdSDQOUextp71r2NN"
|
578 |
+
# with st.spinner("Loading embeddings from Google Drive..."):
|
579 |
+
# cache_data = load_from_drive(drive_file_id)
|
580 |
+
# if cache_data is None:
|
581 |
+
# st.error("Failed to load embeddings from Google Drive")
|
582 |
+
# st.stop()
|
583 |
+
|
584 |
+
# # Initialize pipeline
|
585 |
+
# data_folder = "ESPN_data"
|
586 |
+
# rag = RAGPipeline(data_folder)
|
587 |
+
|
588 |
+
# # Store embeddings
|
589 |
+
# rag.documents = cache_data['documents']
|
590 |
+
# rag.retriever.store_embeddings(cache_data['embeddings'])
|
591 |
+
|
592 |
+
# return rag
|
593 |
+
|
594 |
+
# except Exception as e:
|
595 |
+
# logging.error(f"Pipeline initialization error: {str(e)}")
|
596 |
+
# st.error(f"Failed to initialize the system: {str(e)}")
|
597 |
+
# raise
|
598 |
+
|
599 |
# def main():
|
600 |
# try:
|
601 |
# # Environment check
|