MedRAG_test / app.py
limyehji's picture
Update app.py
e2d169f verified
from fastapi import FastAPI, HTTPException
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
app = FastAPI()
# 1. MiniLM ๋ชจ๋ธ ๋กœ๋“œ (๋ฐ์ดํ„ฐ ๋ฒกํ„ฐํ™”)
embed_model = SentenceTransformer("all-MiniLM-L6-v2")
# 2. MedRAG ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
dataset = load_dataset("MedRAG/textbooks", split="train", streaming=True)
# 3. ๋ฐ์ดํ„ฐ ๋ณ€ํ™˜
texts = [entry["content"] for entry in dataset] # "content" ํ•„๋“œ ํ™œ์šฉ
# 4. ๋ฒกํ„ฐ ์ž„๋ฒ ๋”ฉ ์ƒ์„ฑ ๋ฐ FAISS์— ์ €์žฅ
vectors = embed_model.encode(texts)
dimension = vectors.shape[1] # ์ž„๋ฒ ๋”ฉ ์ฐจ์›
index = faiss.IndexFlatL2(dimension) # L2 ๊ฑฐ๋ฆฌ ๊ธฐ๋ฐ˜ FAISS ์ธ๋ฑ์Šค ์ƒ์„ฑ
index.add(np.array(vectors)) # FAISS์— ๋ฒกํ„ฐ ์ถ”๊ฐ€
# 5. ๊ฒ€์ƒ‰ API (GPTs์—์„œ ํ˜ธ์ถœ ๊ฐ€๋Šฅ)
@app.get("/search")
def search(query: str):
""" ์‚ฌ์šฉ์ž์˜ ์ฟผ๋ฆฌ๋ฅผ ๋ฒกํ„ฐ ๋ณ€ํ™˜ ํ›„, FAISS์—์„œ ๊ฒ€์ƒ‰ํ•˜์—ฌ ๊ด€๋ จ ๋ฌธ์„œ ๋ฐ˜ํ™˜ """
query_vector = embed_model.encode([query])
query_vector = np.array(query_vector, dtype=np.float32) # FAISS ํ˜ธํ™˜
_, I = index.search(query_vector, k=3) # FAISS๋กœ Top-3 ๊ฒ€์ƒ‰
results = [texts[i] for i in I[0]] # ๊ฒ€์ƒ‰๋œ ๋ฌธ์„œ ๋ฐ˜ํ™˜
return {"retrieved_docs": results}