from typing import Optional, List import dspy import copy import dspy.evaluate from pydantic import BaseModel from dotenv import load_dotenv import os from dspy.teleprompt import BootstrapFewShotWithRandomSearch load_dotenv() class Agent(dspy.Module): """ Base Agent Module """ def __init__( self, model: Optional[str] | None = "llama3", client: Optional[str] | None = "ollama", max_tokens: Optional[int] | None = 4096, temperature: Optional[float] | None = 0.5, ) -> None: """ Initialising Agent Module Args: model: str -> default = llama3 client: str -> default = ollama max_tokens: int -> default = 4096 temperature: float -> default = 0.5 tools: List[Tool] -> default = None """ self.model = dspy.GROQ( model="llama3-8b-8192", temperature=temperature, api_key=os.getenv("GROQ_API_KEY"), max_tokens=max_tokens, frequency_penalty=1.5, presence_penalty=1.5, ) dspy.settings.configure( lm=self.model, max_tokens = max_tokens, temperature = temperature ) def __deepcopy__(self, memo): new_instance = self.__class__.__new__(self.__class__) memo[id(self)] = new_instance for k, v in self.__dict__.items(): if k != 'model': setattr(new_instance, k, copy.deepcopy(v, memo)) new_instance.model = self.model return new_instance class OutputFormat(BaseModel): expand: Optional[str] topic: str class Conversation(BaseModel): role: str content: str class Memory(BaseModel): conversations: List[Conversation] class BaseSignature(dspy.Signature): """ You are an expert in expanding the user question and generating suitable tags for the question. Follow the exact instructions given: 1. Expand with only single question. 2. Try to keep the actual content in the expand question. Example: User question: What is math ?, expand: What is mathematics ? 3. Tags should be 2-level hierarchy topics. Eg - India - Politics, Sports- Football. Tags should be as specific as possible. If it is a general question topic: GENERAL 4. Do not give the reference of the previous question in the expanded question. 5. If there is no expanded version of the user question, then give it as expand = "None" 6. If there is a general question asked, do not expand the question, just give it as expand="None" 7. topic can not be "None" 8. Use the provided memory to understand context and provide more relevant expansions and topics. """ query: str = dspy.InputField(prefix = "Question: ") memory: Memory = dspy.InputField(prefix = "Previous conversations: ", desc="This is a list of previous conversations.") output: OutputFormat = dspy.OutputField(desc='''Expanded user question and tags are generated as output. Respond with a single JSON object. JSON Schema: {"properties": {"expand": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Expand"}, "topic": {"title": "Topic", "type": "string"}}, "required": ["expand", "topic"], "title": "OutputFormat", "type": "object"}''') class OutputAgent(Agent): """ Multi-output Agent Module. Inherited from Agent Module """ def __init__(self, model: str | None = "llama3", client: str | None = "ollama", max_tokens: int | None = 8192) -> None: super().__init__( model = model, client = client, max_tokens = max_tokens ) def __call__(self, query: str, memory: List[dict]) -> dspy.Prediction: """ This function expands the user question and generates the tags for the user question. Args: query: str -> The current user query memory: List[dict] -> List of previous conversations Returns: dspy.Prediction: Expanded question and topic """ # Convert the memory list to the Memory model conversations = [Conversation(role=m["role"], content=m["content"]) for m in memory] memory_model = Memory(conversations=conversations) # modules outputGenerator = dspy.TypedPredictor(BaseSignature) # infer try: output = outputGenerator(query=query, memory=memory_model) return output except Exception as e: print("Retrying...", e) return self.__call__(query=query, memory=memory) # This function can be called from app.py to get the expanded question and topic def get_expanded_query_and_topic(query: str, conversation_context: List[dict]): agent = OutputAgent() result = agent(query, conversation_context) return result.output