File size: 1,244 Bytes
0273ec0
 
 
 
 
 
 
 
e2d169f
0273ec0
 
e2d169f
0273ec0
 
 
 
e2d169f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from fastapi import FastAPI, HTTPException
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from datasets import load_dataset

app = FastAPI()

# 1. MiniLM ๋ชจ๋ธ ๋กœ๋“œ (๋ฐ์ดํ„ฐ ๋ฒกํ„ฐํ™”)
embed_model = SentenceTransformer("all-MiniLM-L6-v2")

# 2. MedRAG ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
dataset = load_dataset("MedRAG/textbooks", split="train", streaming=True)

# 3. ๋ฐ์ดํ„ฐ ๋ณ€ํ™˜
texts = [entry["content"] for entry in dataset]  # "content" ํ•„๋“œ ํ™œ์šฉ

# 4. ๋ฒกํ„ฐ ์ž„๋ฒ ๋”ฉ ์ƒ์„ฑ ๋ฐ FAISS์— ์ €์žฅ
vectors = embed_model.encode(texts)
dimension = vectors.shape[1]  # ์ž„๋ฒ ๋”ฉ ์ฐจ์›
index = faiss.IndexFlatL2(dimension)  # L2 ๊ฑฐ๋ฆฌ ๊ธฐ๋ฐ˜ FAISS ์ธ๋ฑ์Šค ์ƒ์„ฑ
index.add(np.array(vectors))  # FAISS์— ๋ฒกํ„ฐ ์ถ”๊ฐ€

# 5. ๊ฒ€์ƒ‰ API (GPTs์—์„œ ํ˜ธ์ถœ ๊ฐ€๋Šฅ)
@app.get("/search")
def search(query: str):
    """ ์‚ฌ์šฉ์ž์˜ ์ฟผ๋ฆฌ๋ฅผ ๋ฒกํ„ฐ ๋ณ€ํ™˜ ํ›„, FAISS์—์„œ ๊ฒ€์ƒ‰ํ•˜์—ฌ ๊ด€๋ จ ๋ฌธ์„œ ๋ฐ˜ํ™˜ """
    query_vector = embed_model.encode([query])
    query_vector = np.array(query_vector, dtype=np.float32)  # FAISS ํ˜ธํ™˜

    _, I = index.search(query_vector, k=3)  # FAISS๋กœ Top-3 ๊ฒ€์ƒ‰
    results = [texts[i] for i in I[0]]  # ๊ฒ€์ƒ‰๋œ ๋ฌธ์„œ ๋ฐ˜ํ™˜

    return {"retrieved_docs": results}