farmax commited on
Commit
5724f7c
·
verified ·
1 Parent(s): f15d519

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -60,24 +60,29 @@ def initialize_database(document, chunk_size, chunk_overlap, progress=gr.Progres
60
  logger.info("Database initialized successfully.")
61
  return vectorstore, None, "Initialized" # Aggiunto None come secondo output
62
 
63
- def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress(), language="italian"):
64
  logger.info("Initializing LLM chain...")
65
-
66
  # Define the default LLMS based on the language
67
  if language == "italiano":
68
  default_llm = "google/gemma-7b-it"
69
  else:
70
  default_llm = "mistralai/Mistral-7B-Instruct-v0.2"
71
 
72
- # Create an instance of the LLM
73
  try:
74
- llm = LLM.from_pretrained(default_llm)
 
75
  except Exception as e:
76
  logger.error(f"Error initializing LLM: {e}")
77
  return None, "Failed to initialize LLM"
78
 
 
 
 
 
79
  qa_chain = ConversationalRetrievalChain.from_llm(
80
- llm=llm,
81
  retriever=vector_db.as_retriever(),
82
  chain_type="stuff",
83
  temperature=llm_temperature,
 
60
  logger.info("Database initialized successfully.")
61
  return vectorstore, None, "Initialized" # Aggiunto None come secondo output
62
 
63
+ def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress(), language="italiano"):
64
  logger.info("Initializing LLM chain...")
65
+
66
  # Define the default LLMS based on the language
67
  if language == "italiano":
68
  default_llm = "google/gemma-7b-it"
69
  else:
70
  default_llm = "mistralai/Mistral-7B-Instruct-v0.2"
71
 
72
+ # Try to load the tokenizer and model
73
  try:
74
+ tokenizer = AutoTokenizer.from_pretrained(default_llm)
75
+ model = AutoModelForCausalLM.from_pretrained(default_llm)
76
  except Exception as e:
77
  logger.error(f"Error initializing LLM: {e}")
78
  return None, "Failed to initialize LLM"
79
 
80
+ # Resize token embeddings if needed
81
+ if len(tokenizer) > model.config.max_position_embeddings:
82
+ model.resize_token_embeddings(len(tokenizer))
83
+
84
  qa_chain = ConversationalRetrievalChain.from_llm(
85
+ llm=model,
86
  retriever=vector_db.as_retriever(),
87
  chain_type="stuff",
88
  temperature=llm_temperature,