File size: 4,372 Bytes
05b45a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from unittest.mock import AsyncMock, MagicMock, patch

import numpy as np
import pytest
import torch

from api.src.services.tts_service import TTSService


@pytest.fixture
def mock_managers():
    """Mock model and voice managers."""

    async def _mock_managers():
        model_manager = AsyncMock()
        model_manager.get_backend.return_value = MagicMock()

        voice_manager = AsyncMock()
        voice_manager.get_voice_path.return_value = "/path/to/voice.pt"
        voice_manager.list_voices.return_value = ["voice1", "voice2"]

        with (
            patch("api.src.services.tts_service.get_model_manager") as mock_get_model,
            patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice,
        ):
            mock_get_model.return_value = model_manager
            mock_get_voice.return_value = voice_manager
            return model_manager, voice_manager

    return _mock_managers()


@pytest.fixture
def tts_service(mock_managers):
    """Create TTSService instance with mocked dependencies."""

    async def _create_service():
        return await TTSService.create("test_output")

    return _create_service()


@pytest.mark.asyncio
async def test_service_creation():
    """Test service creation and initialization."""
    model_manager = AsyncMock()
    voice_manager = AsyncMock()

    with (
        patch("api.src.services.tts_service.get_model_manager") as mock_get_model,
        patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice,
    ):
        mock_get_model.return_value = model_manager
        mock_get_voice.return_value = voice_manager

        service = await TTSService.create("test_output")
        assert service.output_dir == "test_output"
        assert service.model_manager is model_manager
        assert service._voice_manager is voice_manager


@pytest.mark.asyncio
async def test_get_voice_path_single():
    """Test getting path for single voice."""
    model_manager = AsyncMock()
    voice_manager = AsyncMock()
    voice_manager.get_voice_path.return_value = "/path/to/voice1.pt"

    with (
        patch("api.src.services.tts_service.get_model_manager") as mock_get_model,
        patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice,
    ):
        mock_get_model.return_value = model_manager
        mock_get_voice.return_value = voice_manager

        service = await TTSService.create("test_output")
        name, path = await service._get_voices_path("voice1")
        assert name == "voice1"
        assert path == "/path/to/voice1.pt"
        voice_manager.get_voice_path.assert_called_once_with("voice1")


@pytest.mark.asyncio
async def test_get_voice_path_combined():
    """Test getting path for combined voices."""
    model_manager = AsyncMock()
    voice_manager = AsyncMock()
    voice_manager.get_voice_path.return_value = "/path/to/voice.pt"

    with (
        patch("api.src.services.tts_service.get_model_manager") as mock_get_model,
        patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice,
        patch("torch.load") as mock_load,
        patch("torch.save") as mock_save,
        patch("tempfile.gettempdir") as mock_temp,
    ):
        mock_get_model.return_value = model_manager
        mock_get_voice.return_value = voice_manager
        mock_temp.return_value = "/tmp"
        mock_load.return_value = torch.ones(10)

        service = await TTSService.create("test_output")
        name, path = await service._get_voices_path("voice1+voice2")
        assert name == "voice1+voice2"
        assert path.endswith("voice1+voice2.pt")
        mock_save.assert_called_once()


@pytest.mark.asyncio
async def test_list_voices():
    """Test listing available voices."""
    model_manager = AsyncMock()
    voice_manager = AsyncMock()
    voice_manager.list_voices.return_value = ["voice1", "voice2"]

    with (
        patch("api.src.services.tts_service.get_model_manager") as mock_get_model,
        patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice,
    ):
        mock_get_model.return_value = model_manager
        mock_get_voice.return_value = voice_manager

        service = await TTSService.create("test_output")
        voices = await service.list_voices()
        assert voices == ["voice1", "voice2"]
        voice_manager.list_voices.assert_called_once()