File size: 2,527 Bytes
fc3790e
6ba7aeb
1c2fbc9
8896598
1c2fbc9
6ba7aeb
c3617df
 
 
 
1c2fbc9
6ba7aeb
 
 
 
 
 
 
 
 
 
c3617df
1c2fbc9
 
fc3790e
1c2fbc9
6ba7aeb
1c2fbc9
 
 
8896598
 
6ba7aeb
fc3790e
c3617df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc3790e
c44a239
 
 
 
 
 
fc3790e
c3617df
 
 
 
33c61a3
c3617df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc3790e
 
 
c3617df
fc3790e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import gradio as gr
"""
import pathlib
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
"""
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import pandas as pd

## LLMの読み込み(Qwen2.5-3Bをsafetensorsで読み込み)
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_name = "Qwen/Qwen2.5-3B"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
llm = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.float16)
llm.eval()

"""
## LLMの読み込み
models_dir = pathlib.Path(__file__).parent / "models"
models_dir.mkdir(exist_ok=True)

model_path = hf_hub_download(
    repo_id="Mori-kamiyama/sarashina2-13b-r1",
    filename="model.gguf",
    local_dir=models_dir
)

llm = Llama(model_path=model_path)
"""

## 埋め込みモデルの読み込み
model = SentenceTransformer("BAAI/bge-m3")

# ドキュメントの読み込み
df = pd.read_csv("document.csv")
# "text"カラムをリストとして抽出
texts = df['text'].tolist()

# ベクトル化
doc_embeddings = model.encode(texts, normalize_embeddings=True)

# FAISSのセットアップ
dimension = doc_embeddings.shape[1]
index = faiss.IndexFlatIP(dimension)  # Cosine用にnormalize済ならこれ
index.add(np.array(doc_embeddings))

def generate_text(prompt):
    full_prompt = search(prompt)
    input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids.cuda()  # または .to("cuda") / .to("mps") / .to("cpu") に変更
    with torch.no_grad():
        output_ids = llm.generate(input_ids, max_new_tokens=256)
    result_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return result_text

def search(query):
    query_embedding = model.encode([query], normalize_embeddings=True)

    # FAISSで検索
    top_k = 2
    D, I = index.search(np.array(query_embedding), top_k)

    retrieved_docs = []
    print("\n🔍 検索結果:")
    for idx in I[0]:
        doc_text = texts[idx]
        retrieved_docs.append(doc_text)
        print(f"→ {doc_text}")

    # RAG用のプロンプトを作成
    prompt = "以下の文書を参照して質問に答えてください。\n\n文書:\n"
    prompt += "\n".join(retrieved_docs)
    prompt += f"\n\n質問: {query}"

    return prompt

iface = gr.Interface(fn=generate_text,
                     inputs="text",
                     outputs="text",
                     title="sarashina-R13B-RAG")

iface.launch()