File size: 4,833 Bytes
0b80ea1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e37271
0b80ea1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10c7a75
0b80ea1
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# video_rag_routes.py

import os
import uuid
from fastapi import APIRouter, HTTPException, UploadFile, File
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.chains import ConversationalRetrievalChain
from langchain_core.prompts import ChatPromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_groq import ChatGroq
from google import genai
from google.genai import types

router = APIRouter()

# β€”β€”β€” Helpers β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”

def init_google_client():
    api_key = os.getenv("GOOGLE_API_KEY", "")
    if not api_key:
        raise ValueError("GOOGLE_API_KEY must be set")
    return genai.Client(api_key=api_key)

def get_llm():
    api_key = os.getenv("CHATGROQ_API_KEY", "")
    if not api_key:
        raise ValueError("CHATGROQ_API_KEY must be set")
    return ChatGroq(
        model="llama-3.3-70b-versatile",
        temperature=0,
        max_tokens=1024,
        api_key=api_key,
    )

def get_embeddings():
    return HuggingFaceEmbeddings(
        model_name="BAAI/bge-small-en",
        model_kwargs={"device": "cpu"},
        encode_kwargs={"normalize_embeddings": True},
    )

# Simple prompt template for RAG
quiz_prompt = """
You are an assistant specialized in answering questions based on the provided context.
If the context does not contain the answer, reply β€œI don't know.”
Context:
{context}

Question:
{question}

Answer:
"""
chat_prompt = ChatPromptTemplate.from_messages([
    ("system", quiz_prompt),
    ("human", "{question}"),
])

def create_chain(retriever):
    return ConversationalRetrievalChain.from_llm(
        llm=get_llm(),
        retriever=retriever,
        return_source_documents=True,
        chain_type="stuff",
        combine_docs_chain_kwargs={"prompt": chat_prompt},
        verbose=False,
    )

# In-memory session store
sessions: dict[str, dict] = {}

def process_transcription(text: str) -> str:
    # split β†’ embed β†’ index β†’ store retriever & empty history
    splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=20)
    chunks = splitter.split_text(text)
    vs = FAISS.from_texts(chunks, get_embeddings())
    retr = vs.as_retriever(search_kwargs={"k": 3})
    sid = str(uuid.uuid4())
    sessions[sid] = {"retriever": retr, "history": []}
    return sid

# β€”β€”β€” Endpoints β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”

class URLIn(BaseModel):
    youtube_url: str

@router.post("/transcribe_video")
async def transcribe_url(body: URLIn):
    client = init_google_client()
    try:
        resp = client.models.generate_content(
            model="models/gemini-2.0-flash",
            contents=types.Content(parts=[
                types.Part(text="Transcribe the video"),
                types.Part(file_data=types.FileData(file_uri=body.youtube_url))
            ])
        )
        txt = resp.candidates[0].content.parts[0].text
        sid = process_transcription(txt)
        return {"session_id": sid}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@router.post("/upload_video")
async def upload_file(
    file: UploadFile = File(...),
    prompt: str = "Transcribe the video",
):
    data = await file.read()
    client = init_google_client()
    try:
        resp = client.models.generate_content(
            model="models/gemini-2.0-flash",
            contents=types.Content(parts=[
                types.Part(text=prompt),
                types.Part(inline_data=types.Blob(data=data, mime_type=file.content_type))
            ])
        )
        txt = resp.candidates[0].content.parts[0].text
        sid = process_transcription(txt)
        return {"session_id": sid}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

class QueryIn(BaseModel):
    session_id: str
    query: str

@router.post("/vid_query")
async def query_rag(body: QueryIn):
    sess = sessions.get(body.session_id)
    if not sess:
        raise HTTPException(status_code=404, detail="Session not found")
    chain = create_chain(sess["retriever"])
    result = chain.invoke({
        "question": body.query,
        "chat_history": sess["history"]
    })
    answer = result.get("answer", "I don't know.")
    # update history
    sess["history"].append((body.query, answer))
    # collect source snippets
    docs = result.get("source_documents") or []
    srcs = [getattr(d, "page_content", str(d)) for d in docs]
    return {"answer": answer, "source_documents": srcs}