Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, HTTPException | |
import faiss | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
from datasets import load_dataset | |
app = FastAPI() | |
# 1. MiniLM ๋ชจ๋ธ ๋ก๋ (๋ฐ์ดํฐ ๋ฒกํฐํ) | |
embed_model = SentenceTransformer("all-MiniLM-L6-v2") | |
# 2. MedRAG ๋ฐ์ดํฐ์ ๋ก๋ | |
dataset = load_dataset("MedRAG/textbooks", split="train", streaming=True) | |
# 3. ๋ฐ์ดํฐ ๋ณํ | |
texts = [entry["content"] for entry in dataset] # "content" ํ๋ ํ์ฉ | |
# 4. ๋ฒกํฐ ์๋ฒ ๋ฉ ์์ฑ ๋ฐ FAISS์ ์ ์ฅ | |
vectors = embed_model.encode(texts) | |
dimension = vectors.shape[1] # ์๋ฒ ๋ฉ ์ฐจ์ | |
index = faiss.IndexFlatL2(dimension) # L2 ๊ฑฐ๋ฆฌ ๊ธฐ๋ฐ FAISS ์ธ๋ฑ์ค ์์ฑ | |
index.add(np.array(vectors)) # FAISS์ ๋ฒกํฐ ์ถ๊ฐ | |
# 5. ๊ฒ์ API (GPTs์์ ํธ์ถ ๊ฐ๋ฅ) | |
def search(query: str): | |
""" ์ฌ์ฉ์์ ์ฟผ๋ฆฌ๋ฅผ ๋ฒกํฐ ๋ณํ ํ, FAISS์์ ๊ฒ์ํ์ฌ ๊ด๋ จ ๋ฌธ์ ๋ฐํ """ | |
query_vector = embed_model.encode([query]) | |
query_vector = np.array(query_vector, dtype=np.float32) # FAISS ํธํ | |
_, I = index.search(query_vector, k=3) # FAISS๋ก Top-3 ๊ฒ์ | |
results = [texts[i] for i in I[0]] # ๊ฒ์๋ ๋ฌธ์ ๋ฐํ | |
return {"retrieved_docs": results} | |