Spaces:
Runtime error
Runtime error
File size: 1,244 Bytes
0273ec0 e2d169f 0273ec0 e2d169f 0273ec0 e2d169f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
from fastapi import FastAPI, HTTPException
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
app = FastAPI()
# 1. MiniLM ๋ชจ๋ธ ๋ก๋ (๋ฐ์ดํฐ ๋ฒกํฐํ)
embed_model = SentenceTransformer("all-MiniLM-L6-v2")
# 2. MedRAG ๋ฐ์ดํฐ์
๋ก๋
dataset = load_dataset("MedRAG/textbooks", split="train", streaming=True)
# 3. ๋ฐ์ดํฐ ๋ณํ
texts = [entry["content"] for entry in dataset] # "content" ํ๋ ํ์ฉ
# 4. ๋ฒกํฐ ์๋ฒ ๋ฉ ์์ฑ ๋ฐ FAISS์ ์ ์ฅ
vectors = embed_model.encode(texts)
dimension = vectors.shape[1] # ์๋ฒ ๋ฉ ์ฐจ์
index = faiss.IndexFlatL2(dimension) # L2 ๊ฑฐ๋ฆฌ ๊ธฐ๋ฐ FAISS ์ธ๋ฑ์ค ์์ฑ
index.add(np.array(vectors)) # FAISS์ ๋ฒกํฐ ์ถ๊ฐ
# 5. ๊ฒ์ API (GPTs์์ ํธ์ถ ๊ฐ๋ฅ)
@app.get("/search")
def search(query: str):
""" ์ฌ์ฉ์์ ์ฟผ๋ฆฌ๋ฅผ ๋ฒกํฐ ๋ณํ ํ, FAISS์์ ๊ฒ์ํ์ฌ ๊ด๋ จ ๋ฌธ์ ๋ฐํ """
query_vector = embed_model.encode([query])
query_vector = np.array(query_vector, dtype=np.float32) # FAISS ํธํ
_, I = index.search(query_vector, k=3) # FAISS๋ก Top-3 ๊ฒ์
results = [texts[i] for i in I[0]] # ๊ฒ์๋ ๋ฌธ์ ๋ฐํ
return {"retrieved_docs": results}
|