from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import declared_attr from pgvector.sqlalchemy import Vector from sqlalchemy import Column from datetime import datetime from util import snake_case import uuid as uuid_pkg from sqlmodel import ( UniqueConstraint, create_engine, Relationship, SQLModel, Session, select, Field, ) from typing import ( Optional, Union, List, Dict, Any ) from config import ( LLM_DEFAULT_DISTANCE_STRATEGY, VECTOR_EMBEDDINGS_COUNT, LLM_MAX_OUTPUT_TOKENS, DISTANCE_STRATEGIES, LLM_MIN_NODE_LIMIT, PGVECTOR_ADD_INDEX, ENTITY_STATUS, CHANNEL_TYPE, LLM_MODELS, DB_USER, SU_DSN, logger, ) # ========== # Base model # ========== class BaseModel(SQLModel): @declared_attr def __tablename__(cls) -> str: return snake_case(cls.__name__) @classmethod def by_uuid(self, _uuid: uuid_pkg.UUID): with Session(get_engine()) as session: q = select(self).where(self.uuid == _uuid) org = session.exec(q).first() return org if org else None def update(self, o: Union[SQLModel, dict] = None): if not o: raise ValueError("Must provide a model or dict to update values") o = o if isinstance(o, dict) else o.dict(exclude_unset=True) for key, value in o.items(): setattr(self, key, value) # save and commit to database with Session(get_engine()) as session: session.add(self) session.commit() session.refresh(self) def delete(self): with Session(get_engine()) as session: self.status = ENTITY_STATUS.DELETED self.updated_at = datetime.utcnow() session.add(self) session.commit() session.refresh(self) @classmethod def create(self, o: Union[SQLModel, dict] = None): if not o: raise ValueError("Must provide a model or dict to update values") with Session(get_engine()) as session: obj = self.from_orm(o) if isinstance(o, SQLModel) else self(**o) session.add(obj) session.commit() session.refresh(obj) return obj # ============ # Organization # ============ class Organization(BaseModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) uuid: Optional[uuid_pkg.UUID] = Field( unique=True, default_factory=uuid_pkg.uuid4 ) # UUID for the organization display_name: Optional[str] = Field( default="Untitled Organization 😊", index=True ) # display name of the organization namespace: str = Field( unique=True, index=True ) # unique organization namespace for URLs, etc. bot_url: Optional[str] = Field(default=None) # URL for the bot status: Optional[ENTITY_STATUS] = Field(default=ENTITY_STATUS.ACTIVE.value) created_at: Optional[datetime] = Field(default_factory=datetime.now) updated_at: Optional[datetime] = Field(default_factory=datetime.now) # ------------- # Relationships # ------------- projects: Optional[List["Project"]] = Relationship(back_populates="organization") documents: Optional[List["Document"]] = Relationship(back_populates="organization") @property def project_count(self) -> int: return len(self.projects) @property def document_count(self) -> int: return len(self.documents) def __repr__(self): return f"" class OrganizationCreate(SQLModel): display_name: Optional[str] namespace: Optional[str] bot_url: Optional[str] class OrganizationRead(SQLModel): id: int uuid: uuid_pkg.UUID display_name: str namespace: Optional[str] bot_url: Optional[str] created_at: datetime updated_at: datetime class OrganizationUpdate(SQLModel): display_name: Optional[str] namespace: Optional[str] bot_url: Optional[str] # =============== # User (customer) # =============== class User(BaseModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) identifier: str = Field(default=None, unique=True, index=True) identifier_type: Optional[CHANNEL_TYPE] = Field(default=None) uuid: Optional[uuid_pkg.UUID] = Field(unique=True, default_factory=uuid_pkg.uuid4) first_name: Optional[str] = Field(default=None) last_name: Optional[str] = Field(default=None) email: Optional[str] = Field(default=None) phone: Optional[str] = Field(default=None) dob: Optional[datetime] = Field(default=None) device_fingerprint: Optional[str] = Field(default=None) created_at: Optional[datetime] = Field(default_factory=datetime.now) updated_at: Optional[datetime] = Field(default_factory=datetime.now) # ------------- # Relationships # ------------- chat_sessions: Optional[List["ChatSession"]] = Relationship(back_populates="user") @property def chat_session_count(self) -> int: return len(self.chat_sessions) __table_args__ = ( UniqueConstraint("identifier", "identifier_type", name="unq_id_idtype"), ) def __repr__(self): return f"" class UserCreate(SQLModel): identifier: str identifier_type: CHANNEL_TYPE device_fingerprint: Optional[str] first_name: Optional[str] last_name: Optional[str] email: Optional[str] phone: Optional[str] dob: Optional[datetime] class UserReadList(SQLModel): id: int identifier: Optional[str] identifier_type: Optional[CHANNEL_TYPE] uuid: uuid_pkg.UUID device_fingerprint: Optional[str] first_name: Optional[str] last_name: Optional[str] email: Optional[str] phone: Optional[str] dob: Optional[datetime] chat_session_count: int created_at: datetime updated_at: datetime class UserUpdate(SQLModel): device_fingerprint: Optional[str] device_fingerprint: Optional[str] first_name: Optional[str] last_name: Optional[str] email: Optional[str] phone: Optional[str] dob: Optional[datetime] # ======= # Project # ======= class Project(BaseModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) uuid: Optional[uuid_pkg.UUID] = Field(unique=True, default_factory=uuid_pkg.uuid4) organization_id: int = Field(default=None, foreign_key="organization.id") display_name: str = Field(default="📝 Untitled Project") status: Optional[ENTITY_STATUS] = Field(default=ENTITY_STATUS.ACTIVE.value) created_at: Optional[datetime] = Field(default_factory=datetime.now) updated_at: Optional[datetime] = Field(default_factory=datetime.now) # ------------- # Relationships # ------------- organization: Optional["Organization"] = Relationship(back_populates="projects") documents: Optional[List["Document"]] = Relationship(back_populates="project") chat_sessions: Optional[List["ChatSession"]] = Relationship( back_populates="project" ) @property def document_count(self) -> int: return len(self.documents) def __repr__(self): return f"" class ProjectCreate(SQLModel): display_name: Optional[str] class ProjectReadListOrganization(SQLModel): uuid: uuid_pkg.UUID display_name: str namespace: Optional[str] document_count: int class ProjectUpdate(SQLModel): display_name: Optional[str] status: Optional[ENTITY_STATUS] # ========= # Documents # ========= class Document(BaseModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) uuid: Optional[uuid_pkg.UUID] = Field(unique=True, default_factory=uuid_pkg.uuid4) organization_id: int = Field(default=None, foreign_key="organization.id") project_id: int = Field(default=None, foreign_key="project.id") display_name: str = Field(default="Untitled Document 😊") url: str = Field(default="") data: Optional[bytes] = Field(default=None) hash: str = Field(default=None) version: Optional[int] = Field(default=1) status: Optional[ENTITY_STATUS] = Field(default=ENTITY_STATUS.ACTIVE.value) created_at: Optional[datetime] = Field(default_factory=datetime.now) updated_at: Optional[datetime] = Field(default_factory=datetime.now) # ------------- # Relationships # ------------- nodes: Optional[List["Node"]] = Relationship(back_populates="document") organization: Optional["Organization"] = Relationship(back_populates="documents") project: Optional["Project"] = Relationship(back_populates="documents") @property def node_count(self) -> int: return len(self.nodes) __table_args__ = (UniqueConstraint("uuid", "hash", name="unq_org_document"),) def __repr__(self): return f"" class ProjectRead(SQLModel): id: int uuid: uuid_pkg.UUID organization: Organization document_count: int documents: Optional[List[Document]] = None display_name: str created_at: datetime updated_at: datetime class DocumentCreate(SQLModel): project: Project display_name: Optional[str] url: Optional[str] version: Optional[str] data: Optional[bytes] hash: Optional[str] class DocumentUpdate(SQLModel): status: Optional[ENTITY_STATUS] # ============== # Document Nodes # ============== class Node(BaseModel, table=True): class Config: arbitrary_types_allowed = True id: Optional[int] = Field(default=None, primary_key=True) document_id: int = Field(default=None, foreign_key="document.id") uuid: Optional[uuid_pkg.UUID] = Field(unique=True, default_factory=uuid_pkg.uuid4) embeddings: Optional[List[float]] = Field( sa_column=Column(Vector(VECTOR_EMBEDDINGS_COUNT)) ) meta: Optional[Dict] = Field(default=None, sa_column=Column(JSONB)) token_count: Optional[int] = Field(default=None) text: str = Field(default=None, nullable=False) status: Optional[ENTITY_STATUS] = Field(default=ENTITY_STATUS.ACTIVE.value) created_at: Optional[datetime] = Field(default_factory=datetime.now) updated_at: Optional[datetime] = Field(default_factory=datetime.now) # ------------- # Relationships # ------------- document: Optional["Document"] = Relationship(back_populates="nodes") def __repr__(self): return f"" class NodeCreate(SQLModel): document: Document embeddings: List[float] token_count: Optional[int] text: str status: Optional[ENTITY_STATUS] class NodeRead(SQLModel): id: int document: Document embeddings: Optional[List[float]] token_count: Optional[int] text: str created_at: datetime class DocumentReadNodeList(SQLModel): id: int uuid: uuid_pkg.UUID display_name: str node_count: int class NodeReadResult(SQLModel): id: int token_count: Optional[int] text: str meta: Optional[Dict] class ProjectReadListDocumentList(SQLModel): uuid: uuid_pkg.UUID display_name: str node_count: Optional[int] class ProjectReadList(SQLModel): id: int # organization: ProjectReadListOrganization documents: Optional[List[DocumentReadNodeList]] document_count: int uuid: uuid_pkg.UUID display_name: str created_at: datetime updated_at: datetime class NodeReadList(SQLModel): id: int document: DocumentReadNodeList embeddings: Optional[List[float]] token_count: Optional[int] text: str created_at: datetime class NodeUpdate(SQLModel): status: Optional[ENTITY_STATUS] = Field(default=ENTITY_STATUS.ACTIVE.value) class NodeReadListDocumentRead(SQLModel): uuid: uuid_pkg.UUID token_count: Optional[int] created_at: datetime class DocumentReadList(SQLModel): id: int uuid: uuid_pkg.UUID display_name: str version: int nodes: Optional[List[NodeReadListDocumentRead]] = None node_count: int hash: str created_at: datetime updated_at: datetime # ============ # Chat Session # ============ class ChatSession(BaseModel, table=True): class Config: arbitrary_types_allowed = True id: Optional[int] = Field(default=None, primary_key=True) session_id: Optional[uuid_pkg.UUID] = Field( index=True, default_factory=uuid_pkg.uuid4 ) user_id: int = Field(default=None, foreign_key="user.id") project_id: int = Field(default=None, foreign_key="project.id") channel: CHANNEL_TYPE = Field(default=CHANNEL_TYPE.TELEGRAM) user_message: str = Field(default=None) token_count: Optional[int] = Field(default=None) embeddings: Optional[List[float]] = Field( sa_column=Column(Vector(VECTOR_EMBEDDINGS_COUNT)) ) response: Optional[str] = Field(default=None) meta: Optional[Dict] = Field(default=None, sa_column=Column(JSONB)) created_at: datetime = Field(default_factory=datetime.now) # ------------- # Relationships # ------------- user: Optional["User"] = Relationship(back_populates="chat_sessions") project: Optional["Project"] = Relationship(back_populates="chat_sessions") def __repr__(self): return f"" class ChatSessionCreatePost(SQLModel): project_id: Optional[str] = "" organization_id: Optional[str] = "pepe" channel: Optional[CHANNEL_TYPE] = CHANNEL_TYPE.TELEGRAM query: Optional[str] = "What is the weather like in London right now?" identifier: Optional[str] = "@username" distance_strategy: Optional[str] = LLM_DEFAULT_DISTANCE_STRATEGY max_output_tokens: Optional[int] = LLM_MAX_OUTPUT_TOKENS node_limit: Optional[int] = LLM_MIN_NODE_LIMIT model: Optional[str] = LLM_MODELS.GPT_35_TURBO session_id: Optional[str] = "" class ChatSessionCreate(SQLModel): channel: CHANNEL_TYPE token_count: Optional[int] user_message: str embeddings: List[float] response: Optional[str] class ChatSessionRead(SQLModel): id: int user: User project: Optional[ProjectReadListDocumentList] token_count: Optional[int] channel: CHANNEL_TYPE user_message: str embeddings: List[float] response: Optional[str] meta: Optional[dict] created_at: datetime = Field(default_factory=datetime.now) class ChatSessionResponse(SQLModel): meta: Optional[dict] response: Optional[str] user_message: Optional[str] class ProjectReadChatSessionRead(SQLModel): id: int token_count: Optional[int] channel: CHANNEL_TYPE created_at: datetime = Field(default_factory=datetime.now) class ChatSessionReadUserRead(SQLModel): id: int project: Optional[ProjectReadListDocumentList] token_count: Optional[int] channel: CHANNEL_TYPE user_message: str response: Optional[str] created_at: datetime = Field(default_factory=datetime.now) class UserRead(SQLModel): id: int identifier: Optional[str] identifier_type: Optional[CHANNEL_TYPE] uuid: uuid_pkg.UUID language: Optional[str] device_fingerprint: Optional[str] first_name: Optional[str] last_name: Optional[str] email: Optional[str] phone: Optional[str] dob: Optional[datetime] chat_session_count: int chat_sessions: Optional[List[ChatSessionReadUserRead]] created_at: datetime updated_at: datetime class DocumentReadProjectRead(SQLModel): uuid: uuid_pkg.UUID display_name: str namespace: Optional[str] document_count: int class DocumentRead(SQLModel): id: int uuid: uuid_pkg.UUID project: DocumentReadProjectRead organization: OrganizationRead display_name: str node_count: int url: Optional[str] version: int data: bytes hash: str created_at: datetime updated_at: datetime class WebhookCreate(SQLModel): update_id: str message: Dict[str, Any] class WebhookResponse(SQLModel): update_id: str message: Dict[str, Any] # ================== # Database functions # ================== def get_engine(dsn: str = SU_DSN): return create_engine(dsn) def get_session(): with Session(get_engine()) as session: yield session def create_db(): logger.info("...Enabling pgvector and creating database tables") enable_vector() BaseModel.metadata.create_all(get_engine(dsn=SU_DSN)) create_user_permissions() create_vector_index() def create_user_permissions(): session = Session(get_engine(dsn=SU_DSN)) # grant access to entire database and all tables to user DB_USER query = f"GRANT SELECT, INSERT, UPDATE, DELETE ON ALL TABLES IN SCHEMA public TO {DB_USER};" session.execute(query) session.commit() session.close() def drop_db(): BaseModel.metadata.drop_all(get_engine(dsn=SU_DSN)) def create_vector_index(): # ------------------------------------- # Let's add an index for the embeddings # ------------------------------------- if PGVECTOR_ADD_INDEX is True: session = Session(get_engine(dsn=SU_DSN)) for strategy in DISTANCE_STRATEGIES: session.execute(strategy[3]) session.commit() def enable_vector(): session = Session(get_engine(dsn=SU_DSN)) query = "CREATE EXTENSION IF NOT EXISTS vector;" session.execute(query) session.commit() add_vector_distance_fn(session) session.close() def add_vector_distance_fn(session: Session): for strategy in DISTANCE_STRATEGIES: strategy_name = strategy[1] strategy_distance_str = strategy[2] query = f"""create or replace function match_node_{strategy_name} ( query_embeddings vector({VECTOR_EMBEDDINGS_COUNT}), match_threshold float, match_count int ) returns table ( uuid uuid, text varchar, similarity float ) language plpgsql as $$ begin return query select node.uuid, node.text, 1 - (node.embeddings {strategy_distance_str} query_embeddings) as similarity from node where 1 - (node.embeddings {strategy_distance_str} query_embeddings) > match_threshold order by similarity desc limit match_count; end; $$;""" session.execute(query) session.commit() session.close() if __name__ == "__main__": create_db()