Spaces:
Sleeping
Sleeping
File size: 3,744 Bytes
d16e9aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
"""
Tests for the RAG engine module.
"""
import pytest
from unittest.mock import Mock, patch
from src.rag_engine import RAGEngine
@pytest.fixture
def mock_azure_client():
"""Create a mock Azure OpenAI client."""
with patch('openai.AzureOpenAI') as mock_client:
yield mock_client
@pytest.fixture
def mock_chroma_client():
"""Create a mock Chroma client."""
with patch('chromadb.Client') as mock_client:
yield mock_client
@pytest.fixture
def rag_engine(mock_azure_client, mock_chroma_client):
"""Create a RAG engine instance with mocked dependencies."""
return RAGEngine("test-deployment")
def test_create_embeddings(rag_engine, mock_azure_client):
"""Test embedding creation."""
# Setup mock response
mock_response = Mock()
mock_response.data = [
Mock(embedding=[0.1, 0.2, 0.3]),
Mock(embedding=[0.4, 0.5, 0.6])
]
rag_engine.client.embeddings.create.return_value = mock_response
# Test
texts = ["Text 1", "Text 2"]
embeddings = rag_engine.create_embeddings(texts)
# Verify
assert len(embeddings) == 2
assert all(isinstance(emb, list) for emb in embeddings)
assert len(embeddings[0]) == 3 # Embedding dimension
def test_initialize_vector_store(rag_engine):
"""Test vector store initialization."""
rag_engine.initialize_vector_store("test_collection")
# Verify the collection was created
assert rag_engine.collection is not None
def test_add_documents(rag_engine):
"""Test adding documents to vector store."""
# Setup
rag_engine.initialize_vector_store("test_collection")
texts = ["Document 1", "Document 2"]
metadata = [{"source": "test1"}, {"source": "test2"}]
# Create mock embeddings
with patch.object(rag_engine, 'create_embeddings') as mock_create_embeddings:
mock_create_embeddings.return_value = [[0.1, 0.2], [0.3, 0.4]]
# Test
rag_engine.add_documents(texts, metadata)
# Verify
mock_create_embeddings.assert_called_once_with(texts)
assert rag_engine.collection.add.called
def test_query(rag_engine):
"""Test querying the RAG engine."""
# Setup
rag_engine.initialize_vector_store("test_collection")
# Mock embeddings creation
with patch.object(rag_engine, 'create_embeddings') as mock_create_embeddings:
mock_create_embeddings.return_value = [[0.1, 0.2]]
# Mock vector store query
mock_results = {
'documents': [["Relevant document 1", "Relevant document 2"]],
'distances': [[0.1, 0.2]]
}
rag_engine.collection.query.return_value = mock_results
# Mock chat completion
mock_response = Mock()
mock_response.choices = [Mock(message=Mock(content="Test answer"))]
rag_engine.client.chat.completions.create.return_value = mock_response
# Test
result = rag_engine.query("Test question")
# Verify
assert isinstance(result, dict)
assert "answer" in result
assert "context" in result
assert "source_documents" in result
assert result["answer"] == "Test answer"
def test_error_handling(rag_engine):
"""Test error handling in RAG engine."""
# Test error in embeddings creation
rag_engine.client.embeddings.create.side_effect = Exception("API Error")
with pytest.raises(Exception):
rag_engine.create_embeddings(["Test"])
# Test error in vector store initialization
rag_engine.chroma_client.get_or_create_collection.side_effect = Exception("DB Error")
with pytest.raises(Exception):
rag_engine.initialize_vector_store("test")
|