File size: 2,239 Bytes
fc3790e
6ba7aeb
1c2fbc9
8896598
1c2fbc9
6ba7aeb
c3617df
 
 
 
1c2fbc9
6ba7aeb
 
 
 
 
 
 
 
 
 
c3617df
1c2fbc9
 
fc3790e
1c2fbc9
6ba7aeb
1c2fbc9
 
 
8896598
 
6ba7aeb
fc3790e
c3617df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc3790e
c3617df
af7f4e5
fc3790e
c3617df
 
 
 
33c61a3
c3617df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc3790e
 
 
c3617df
fc3790e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import gradio as gr
"""
import pathlib
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
"""
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import pandas as pd

## LLMの読み込み(Qwen2.5-3Bをsafetensorsで読み込み)
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_name = "Qwen/Qwen2.5-3B"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
llm = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.float16)
llm.eval()

"""
## LLMの読み込み
models_dir = pathlib.Path(__file__).parent / "models"
models_dir.mkdir(exist_ok=True)

model_path = hf_hub_download(
    repo_id="Mori-kamiyama/sarashina2-13b-r1",
    filename="model.gguf",
    local_dir=models_dir
)

llm = Llama(model_path=model_path)
"""

## 埋め込みモデルの読み込み
model = SentenceTransformer("BAAI/bge-m3")

# ドキュメントの読み込み
df = pd.read_csv("document.csv")
# "text"カラムをリストとして抽出
texts = df['text'].tolist()

# ベクトル化
doc_embeddings = model.encode(texts, normalize_embeddings=True)

# FAISSのセットアップ
dimension = doc_embeddings.shape[1]
index = faiss.IndexFlatIP(dimension)  # Cosine用にnormalize済ならこれ
index.add(np.array(doc_embeddings))

def generate_text(prompt):
    result = llm(search(prompt))
    return result['choices'][0]['text']

def search(query):
    query_embedding = model.encode([query], normalize_embeddings=True)

    # FAISSで検索
    top_k = 2
    D, I = index.search(np.array(query_embedding), top_k)

    retrieved_docs = []
    print("\n🔍 検索結果:")
    for idx in I[0]:
        doc_text = texts[idx]
        retrieved_docs.append(doc_text)
        print(f"→ {doc_text}")

    # RAG用のプロンプトを作成
    prompt = "以下の文書を参照して質問に答えてください。\n\n文書:\n"
    prompt += "\n".join(retrieved_docs)
    prompt += f"\n\n質問: {query}"

    return prompt

iface = gr.Interface(fn=generate_text,
                     inputs="text",
                     outputs="text",
                     title="sarashina-R13B-RAG")

iface.launch()