farmax commited on
Commit
f15d519
·
verified ·
1 Parent(s): a23ee22

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -17
app.py CHANGED
@@ -2,14 +2,9 @@ from langchain_huggingface import HuggingFaceEmbeddings
2
  import gradio as gr
3
  import os
4
  from googletrans import Translator
5
- # import requests
6
- # from dotenv import load_dotenv
7
- # import numpy as np
8
  from langchain_community.vectorstores import Chroma
9
  from langchain_community.document_loaders import UnstructuredPDFLoader, PyPDFLoader
10
  from langchain.text_splitter import CharacterTextSplitter
11
- # from langchain.chains import RetrievalQAWithSourcesChain
12
- # from langchain.chains import load_qa_with_sources_from_chain_type
13
  from langchain.chains import ConversationalRetrievalChain
14
  from langchain.schema import Document
15
  from langchain.memory import ConversationBufferMemory
@@ -18,7 +13,8 @@ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
18
  from langchain.llms.base import LLM
19
  from typing import List, Dict, Any, Optional
20
  from pydantic import BaseModel
21
- # from tqdm import tqdm
 
22
  import torch
23
  import logging
24
 
@@ -66,24 +62,24 @@ def initialize_database(document, chunk_size, chunk_overlap, progress=gr.Progres
66
 
67
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress(), language="italian"):
68
  logger.info("Initializing LLM chain...")
69
- llm_name = list_llm[llm_option]
70
- print("llm_name: ",llm_name)
71
 
72
- if language == "italian":
 
73
  default_llm = "google/gemma-7b-it"
74
  else:
75
  default_llm = "mistralai/Mistral-7B-Instruct-v0.2"
76
-
77
- if llm_name != default_llm:
78
- print(f"Using default LLM {default_llm} for {language}")
79
- llm_name = default_llm
80
-
 
 
 
81
  qa_chain = ConversationalRetrievalChain.from_llm(
82
- llm=llm_name,
83
  retriever=vector_db.as_retriever(),
84
  chain_type="stuff",
85
- # memory=memory,
86
- return_source_documents=True,
87
  temperature=llm_temperature,
88
  verbose=False,
89
  )
 
2
  import gradio as gr
3
  import os
4
  from googletrans import Translator
 
 
 
5
  from langchain_community.vectorstores import Chroma
6
  from langchain_community.document_loaders import UnstructuredPDFLoader, PyPDFLoader
7
  from langchain.text_splitter import CharacterTextSplitter
 
 
8
  from langchain.chains import ConversationalRetrievalChain
9
  from langchain.schema import Document
10
  from langchain.memory import ConversationBufferMemory
 
13
  from langchain.llms.base import LLM
14
  from typing import List, Dict, Any, Optional
15
  from pydantic import BaseModel
16
+ from langchain.llms.base import LLM
17
+ from transformers import AutoTokenizer, AutoModelForCausalLM
18
  import torch
19
  import logging
20
 
 
62
 
63
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress(), language="italian"):
64
  logger.info("Initializing LLM chain...")
 
 
65
 
66
+ # Define the default LLMS based on the language
67
+ if language == "italiano":
68
  default_llm = "google/gemma-7b-it"
69
  else:
70
  default_llm = "mistralai/Mistral-7B-Instruct-v0.2"
71
+
72
+ # Create an instance of the LLM
73
+ try:
74
+ llm = LLM.from_pretrained(default_llm)
75
+ except Exception as e:
76
+ logger.error(f"Error initializing LLM: {e}")
77
+ return None, "Failed to initialize LLM"
78
+
79
  qa_chain = ConversationalRetrievalChain.from_llm(
80
+ llm=llm,
81
  retriever=vector_db.as_retriever(),
82
  chain_type="stuff",
 
 
83
  temperature=llm_temperature,
84
  verbose=False,
85
  )