File size: 3,723 Bytes
372531f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import asyncio
from typing import List, Dict, Optional, Set

from ..context.compression import ContextCompressor, WrittenContentCompressor, VectorstoreCompressor
from ..actions.utils import stream_output


class ContextManager:
    """Manages context for the researcher agent."""

    def __init__(self, researcher):
        self.researcher = researcher

    async def get_similar_content_by_query(self, query, pages):
        if self.researcher.verbose:
            await stream_output(
                "logs",
                "fetching_query_content",
                f"πŸ“š Getting relevant content based on query: {query}...",
                self.researcher.websocket,
            )

        context_compressor = ContextCompressor(
            documents=pages, embeddings=self.researcher.memory.get_embeddings()
        )
        return await context_compressor.async_get_context(
            query=query, max_results=10, cost_callback=self.researcher.add_costs
        )
        
    async def get_similar_content_by_query_with_vectorstore(self, query, filter): 
        if self.researcher.verbose:
            await stream_output(
                "logs",
                "fetching_query_format",
                f" Getting relevant content based on query: {query}...",
                self.researcher.websocket,
                )
        vectorstore_compressor = VectorstoreCompressor(self.researcher.vector_store, filter)
        return await vectorstore_compressor.async_get_context(query=query, max_results=8)
    
    async def get_similar_written_contents_by_draft_section_titles(

        self,

        current_subtopic: str,

        draft_section_titles: List[str],

        written_contents: List[Dict],

        max_results: int = 10

    ) -> List[str]:
        all_queries = [current_subtopic] + draft_section_titles

        async def process_query(query: str) -> Set[str]:
            return set(await self.__get_similar_written_contents_by_query(query, written_contents))

        results = await asyncio.gather(*[process_query(query) for query in all_queries])
        relevant_contents = set().union(*results)
        relevant_contents = list(relevant_contents)[:max_results]

        if relevant_contents and self.researcher.verbose:
            prettier_contents = "\n".join(relevant_contents)
            await stream_output(
                "logs", "relevant_contents_context", f"πŸ“ƒ {prettier_contents}", self.researcher.websocket
            )

        return relevant_contents

    async def __get_similar_written_contents_by_query(self,

                                                      query: str,

                                                      written_contents: List[Dict],

                                                      similarity_threshold: float = 0.5,

                                                      max_results: int = 10

                                                      ) -> List[str]:
        if self.researcher.verbose:
            await stream_output(
                "logs",
                "fetching_relevant_written_content",
                f"πŸ”Ž Getting relevant written content based on query: {query}...",
                self.researcher.websocket,
            )

        written_content_compressor = WrittenContentCompressor(
            documents=written_contents,
            embeddings=self.researcher.memory.get_embeddings(),
            similarity_threshold=similarity_threshold
        )
        return await written_content_compressor.async_get_context(
            query=query, max_results=max_results, cost_callback=self.researcher.add_costs
        )