user commited on
Commit
b3b4e83
·
1 Parent(s): 5412ab7
Files changed (1) hide show
  1. app.py +9 -3
app.py CHANGED
@@ -67,8 +67,13 @@ def load_and_process_text(file_path):
67
 
68
  @st.cache_data
69
  def create_embeddings(chunks, _embedding_model):
70
- tokenizer = AutoTokenizer.from_pretrained(_embedding_model)
71
- model = AutoModel.from_pretrained(_embedding_model)
 
 
 
 
 
72
 
73
  embeddings = []
74
  for chunk in chunks:
@@ -76,7 +81,8 @@ def create_embeddings(chunks, _embedding_model):
76
  with torch.no_grad():
77
  outputs = model(**inputs)
78
  embeddings.append(outputs.last_hidden_state.mean(dim=1).squeeze().numpy())
79
- return np.array(embeddings)
 
80
 
81
  @st.cache_resource
82
  def create_faiss_index(embeddings):
 
67
 
68
  @st.cache_data
69
  def create_embeddings(chunks, _embedding_model):
70
+ if isinstance(_embedding_model, str):
71
+ tokenizer = AutoTokenizer.from_pretrained(_embedding_model)
72
+ model = AutoModel.from_pretrained(_embedding_model)
73
+ else:
74
+ # Assume _embedding_model is already a model instance
75
+ model = _embedding_model
76
+ tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path)
77
 
78
  embeddings = []
79
  for chunk in chunks:
 
81
  with torch.no_grad():
82
  outputs = model(**inputs)
83
  embeddings.append(outputs.last_hidden_state.mean(dim=1).squeeze().numpy())
84
+
85
+ return np.vstack(embeddings)
86
 
87
  @st.cache_resource
88
  def create_faiss_index(embeddings):