from fastapi import FastAPI, HTTPException import faiss import numpy as np from sentence_transformers import SentenceTransformer from datasets import load_dataset app = FastAPI() # 1. MiniLM 모델 로드 (데이터 벡터화) embed_model = SentenceTransformer("all-MiniLM-L6-v2") # 2. MedRAG 데이터셋 로드 dataset = load_dataset("MedRAG/textbooks", split="train", streaming=True) # 3. 데이터 변환 texts = [entry["content"] for entry in dataset] # "content" 필드 활용 # 4. 벡터 임베딩 생성 및 FAISS에 저장 vectors = embed_model.encode(texts) dimension = vectors.shape[1] # 임베딩 차원 index = faiss.IndexFlatL2(dimension) # L2 거리 기반 FAISS 인덱스 생성 index.add(np.array(vectors)) # FAISS에 벡터 추가 # 5. 검색 API (GPTs에서 호출 가능) @app.get("/search") def search(query: str): """ 사용자의 쿼리를 벡터 변환 후, FAISS에서 검색하여 관련 문서 반환 """ query_vector = embed_model.encode([query]) query_vector = np.array(query_vector, dtype=np.float32) # FAISS 호환 _, I = index.search(query_vector, k=3) # FAISS로 Top-3 검색 results = [texts[i] for i in I[0]] # 검색된 문서 반환 return {"retrieved_docs": results}