StevenChen16 commited on
Commit
52b36bc
·
verified ·
1 Parent(s): a2fadca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -8
app.py CHANGED
@@ -68,10 +68,41 @@ class RAGChatbot:
68
  Now, please guide me step by step to describe the legal issues I am facing, according to the above requirements.
69
  '''
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  @spaces.GPU
72
  def init_models(self):
73
  """Initialize the LLM and embedding models"""
 
 
 
 
 
 
 
 
 
74
  # LLM initialization
 
75
  self.llm_model_name = 'StevenChen16/llama3-8b-Lawyer'
76
  self.tokenizer = AutoTokenizer.from_pretrained(self.llm_model_name)
77
  self.model = AutoModelForCausalLM.from_pretrained(
@@ -82,28 +113,27 @@ class RAGChatbot:
82
  self.tokenizer.eos_token_id,
83
  self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
84
  ]
85
-
86
- # Embedding model initialization
87
- self.embeddings = HuggingFaceEmbeddings(
88
- model_name='intfloat/multilingual-e5-large-instruct',
89
- model_kwargs={'trust_remote_code': True}
90
- )
91
 
92
  def init_vector_store(self):
93
  """Load vector store from HuggingFace Hub"""
94
  try:
 
95
  # Download FAISS files from HuggingFace Hub
96
  repo_path = snapshot_download(
97
  repo_id="StevenChen16/laws.faiss",
98
  repo_type="model"
99
  )
100
 
 
101
  # Load the vector store from downloaded files
102
  self.vector_store = FAISS.load_local(
103
- repo_path,
104
- self.embeddings,
105
  allow_dangerous_deserialization=True
106
  )
 
 
107
  except Exception as e:
108
  raise RuntimeError(f"Failed to load vector store from HuggingFace Hub: {str(e)}")
109
 
 
68
  Now, please guide me step by step to describe the legal issues I am facing, according to the above requirements.
69
  '''
70
 
71
+ import gradio as gr
72
+ import os
73
+ import spaces
74
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
75
+ from threading import Thread
76
+ from langchain_community.vectorstores.faiss import FAISS
77
+ from langchain_huggingface import HuggingFaceEmbeddings
78
+ from huggingface_hub import snapshot_download
79
+
80
+ class RAGChatbot:
81
+ def __init__(self):
82
+ # First initialize models to create embeddings
83
+ self.init_models()
84
+ # Then initialize vector store which uses embeddings
85
+ self.init_vector_store()
86
+
87
+ self.background_prompt = '''
88
+ As an AI legal assistant, you are a highly trained expert in U.S. and Canadian law...
89
+ [rest of your existing background prompt]
90
+ '''
91
+
92
  @spaces.GPU
93
  def init_models(self):
94
  """Initialize the LLM and embedding models"""
95
+ print("Initializing models...")
96
+
97
+ # Embedding model initialization first
98
+ print("Loading embedding model...")
99
+ self.embeddings = HuggingFaceEmbeddings(
100
+ model_name='intfloat/multilingual-e5-large-instruct',
101
+ model_kwargs={'trust_remote_code': True}
102
+ )
103
+
104
  # LLM initialization
105
+ print("Loading LLM model...")
106
  self.llm_model_name = 'StevenChen16/llama3-8b-Lawyer'
107
  self.tokenizer = AutoTokenizer.from_pretrained(self.llm_model_name)
108
  self.model = AutoModelForCausalLM.from_pretrained(
 
113
  self.tokenizer.eos_token_id,
114
  self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
115
  ]
116
+ print("Models initialized successfully")
 
 
 
 
 
117
 
118
  def init_vector_store(self):
119
  """Load vector store from HuggingFace Hub"""
120
  try:
121
+ print("Downloading vector store from HuggingFace Hub...")
122
  # Download FAISS files from HuggingFace Hub
123
  repo_path = snapshot_download(
124
  repo_id="StevenChen16/laws.faiss",
125
  repo_type="model"
126
  )
127
 
128
+ print("Loading vector store...")
129
  # Load the vector store from downloaded files
130
  self.vector_store = FAISS.load_local(
131
+ folder_path=repo_path, # Specify the parameter name explicitly
132
+ embeddings=self.embeddings,
133
  allow_dangerous_deserialization=True
134
  )
135
+ print("Vector store loaded successfully")
136
+
137
  except Exception as e:
138
  raise RuntimeError(f"Failed to load vector store from HuggingFace Hub: {str(e)}")
139