File size: 3,698 Bytes
507c938
 
 
4096d75
37145a8
507c938
37145a8
61c90c9
 
 
 
 
507c938
dbd33b2
185fa42
 
 
 
 
 
 
 
 
 
 
 
dbd33b2
 
 
37145a8
 
 
 
 
 
507c938
37145a8
507c938
37145a8
 
 
 
 
 
 
507c938
37145a8
 
507c938
 
dbd33b2
185fa42
 
 
 
507c938
 
 
 
 
 
37145a8
 
 
 
 
 
507c938
 
 
 
 
 
37145a8
507c938
37145a8
 
 
507c938
37145a8
507c938
 
 
dbd33b2
507c938
37145a8
 
 
 
 
 
 
 
507c938
 
 
37145a8
507c938
 
37145a8
 
 
 
 
 
 
 
507c938
 
 
37145a8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
from dotenv import load_dotenv
import logging
import sys
from transformers import pipeline

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    stream=sys.stdout
)
logger = logging.getLogger(__name__)

# Define the RAG prompt template
RAG_PROMPT_TEMPLATE = """
You are an AI assistant analyzing YouTube video transcripts. Your task is to answer questions based on the provided transcript context.

Context from transcript:
{context}

User Question: {question}

Please provide a clear, concise answer based only on the information given in the context. If the context doesn't contain enough information to fully answer the question, acknowledge this in your response.
""".strip()

class RAGSystem:
    def __init__(self, data_processor):
        self.data_processor = data_processor
        self.model = pipeline(
            "text-generation",
            model="google/flan-t5-base",  # Using a smaller model suitable for Spaces
            device=-1  # Use CPU
        )
        logger.info("Initialized RAG system with flan-t5-base model")

    def generate(self, prompt):
        try:
            response = self.model(
                prompt,
                max_length=512,
                min_length=64,
                num_return_sequences=1
            )[0]['generated_text']
            return response
        except Exception as e:
            logger.error(f"Error generating response: {e}")
            return None

    def get_prompt(self, user_query, relevant_docs):
        context = "\n".join([doc['content'] for doc in relevant_docs])
        return RAG_PROMPT_TEMPLATE.format(
            context=context,
            question=user_query
        )

    def query(self, user_query, search_method='hybrid', index_name=None):
        try:
            if not index_name:
                raise ValueError("No index name provided. Please select a video and ensure it has been processed.")

            relevant_docs = self.data_processor.search(
                user_query, 
                num_results=3, 
                method=search_method, 
                index_name=index_name
            )
            
            if not relevant_docs:
                logger.warning("No relevant documents found for the query.")
                return "I couldn't find any relevant information to answer your query.", ""

            prompt = self.get_prompt(user_query, relevant_docs)
            answer = self.generate(prompt)
            
            if not answer:
                return "I encountered an error generating the response.", prompt
                
            return answer, prompt
            
        except Exception as e:
            logger.error(f"An error occurred in the RAG system: {e}")
            return f"An error occurred: {str(e)}", ""
        
    def rewrite_cot(self, query):
        prompt = f"""
        Think through this step by step:
        1. Original query: {query}
        2. What are the key components of this query?
        3. How can we break this down into a clearer question?
        
        Rewritten query:
        """
        response = self.generate(prompt)
        if response:
            return response, prompt
        return query, prompt

    def rewrite_react(self, query):
        prompt = f"""
        Let's approach this step-by-step:
        1. Question: {query}
        2. What information do we need?
        3. What's the best way to structure this query?
        
        Rewritten query:
        """
        response = self.generate(prompt)
        if response:
            return response, prompt
        return query, prompt