import base64 import pickle from typing import Any, Iterable, List, Optional, Tuple from omagent_core.memories.ltms.ltm_base import LTMBase from omagent_core.services.connectors.milvus import MilvusConnector from omagent_core.utils.registry import registry from pydantic import Field from pymilvus import (Collection, CollectionSchema, DataType, FieldSchema, utility) @registry.register_component() class VideoMilvusLTM(LTMBase): milvus_ltm_client: MilvusConnector storage_name: str = Field(default="default") dim: int = Field(default=128) def model_post_init(self, __context: Any) -> None: pass def _create_collection(self) -> None: # Check if collection exists if not self.milvus_ltm_client._client.has_collection(self.storage_name): index_params = self.milvus_ltm_client._client.prepare_index_params() # Define field schemas key_field = FieldSchema( name="key", dtype=DataType.VARCHAR, is_primary=True, max_length=256 ) value_field = FieldSchema( name="value", dtype=DataType.JSON, description="Json value" ) embedding_field = FieldSchema( name="embedding", dtype=DataType.FLOAT_VECTOR, description="Embedding vector", dim=self.dim, ) index_params = self.milvus_ltm_client._client.prepare_index_params() # Create collection schema schema = CollectionSchema( fields=[key_field, value_field, embedding_field], description="Key-Value storage with embeddings", ) for field in schema.fields: if ( field.dtype == DataType.FLOAT_VECTOR or field.dtype == DataType.BINARY_VECTOR ): index_params.add_index( field_name=field.name, index_name=field.name, index_type="FLAT", metric_type="COSINE", params={"nlist": 128}, ) self.milvus_ltm_client._client.create_collection( self.storage_name, schema=schema, index_params=index_params ) # Create index separately after collection creation print(f"Created storage {self.storage_name} successfully") def __getitem__(self, key: Any) -> Any: key_str = str(key) expr = f'key == "{key_str}"' res = self.milvus_ltm_client._client.query( self.storage_name, expr, output_fields=["value"] ) if res: value = res[0]["value"] # value_bytes = base64.b64decode(value_base64) # value = pickle.loads(value_bytes) return value else: raise KeyError(f"Key {key} not found") def __setitem__(self, key: Any, value: Any) -> None: self._create_collection() key_str = str(key) # Check if value is a dictionary containing 'value' and 'embedding' if isinstance(value, dict) and "value" in value and "embedding" in value: actual_value = value["value"] embedding = value["embedding"] else: raise ValueError( "When setting an item, value must be a dictionary containing 'value' and 'embedding' keys." ) # Serialize the actual value and encode it to base64 # value_bytes = pickle.dumps(actual_value) # value_base64 = base64.b64encode(value_bytes).decode('utf-8') # Ensure the embedding is provided if embedding is None: raise ValueError("An embedding vector must be provided.") # Check if the key exists and delete it if it does if key_str in self: self.__delitem__(key_str) # Prepare data for insertion (as a list of dictionaries) data = [ { "key": key_str, "value": actual_value, "embedding": embedding, } ] # Insert the new record self.milvus_ltm_client._client.insert( collection_name=self.storage_name, data=data ) def __delitem__(self, key: Any) -> None: key_str = str(key) if key_str in self: expr = f'key == "{key_str}"' self.milvus_ltm_client._client.delete(self.storage_name, expr) else: raise KeyError(f"Key {key} not found") def __contains__(self, key: Any) -> bool: key_str = str(key) expr = f'key == "{key_str}"' # Adjust the query call to match the expected signature res = self.milvus_ltm_client._client.query( self.storage_name, # Pass the collection name as the first argument filter=expr, output_fields=["key"], ) return len(res) > 0 """ def __len__(self) -> int: milvus_ltm.collection.flush() return self.collection.num_entities """ def __len__(self) -> int: expr = 'key != ""' # Expression to match all entities # self.milvus_ltm_client._client.load(refresh=True) results = self.milvus_ltm_client._client.query( self.storage_name, expr, output_fields=["key"], consistency_level="Strong" ) return len(results) def keys(self, limit=10) -> Iterable[Any]: expr = "" res = self.milvus_ltm_client._client.query( self.storage_name, expr, output_fields=["key"], limit=limit ) return (item["key"] for item in res) def values(self) -> Iterable[Any]: expr = 'key != ""' # Expression to match all active entities self.milvus_ltm_client._client.load(refresh=True) res = self.milvus_ltm_client._client.query( self.storage_name, expr, output_fields=["value"], consistency_level="Strong" ) for item in res: value_base64 = item["value"] value_bytes = base64.b64decode(value_base64) value = pickle.loads(value_bytes) yield value def items(self) -> Iterable[Tuple[Any, Any]]: expr = 'key != ""' res = self.milvus_ltm_client._client.query( self.storage_name, expr, output_fields=["key", "value"] ) for item in res: key = item["key"] value = item["value"] # value_bytes = base64.b64decode(value_base64) # value = pickle.loads(value_bytes) yield (key, value) def get(self, key: Any, default: Any = None) -> Any: try: return self[key] except KeyError: return default def clear(self) -> None: expr = ( 'key != ""' # This expression matches all records where 'key' is not empty ) self.milvus_ltm_client._client.delete(self.storage_name, filter=expr) def pop(self, key: Any, default: Any = None) -> Any: try: value = self[key] self.__delitem__(key) return value except KeyError: if default is not None: return default else: raise def update(self, other: Iterable[Tuple[Any, Any]]) -> None: for key, value in other: self[key] = value def get_by_vector( self, embedding: List[float], top_k: int = 10, threshold: float = 0.0, filter: str = "", ) -> List[Tuple[Any, Any, float]]: search_params = { "metric_type": "COSINE", "params": {"nprobe": 10, "range_filter": 1, "radius": threshold}, } results = self.milvus_ltm_client._client.search( self.storage_name, data=[embedding], anns_field="embedding", search_params=search_params, limit=top_k, output_fields=["key", "value"], consistency_level="Strong", filter=filter, ) items = [] for match in results[0]: key = match.get("entity").get("key") value = match.get("entity").get("value") items.append((key, value)) return items