File size: 1,697 Bytes
21db53c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from PIL import Image

from app.Services.transformers_service import TransformersService
from app.util.calculate_vectors_cosine import calculate_vectors_cosine
from ..assets import assets_path


class TestTransformersService:

    def setup_class(self):
        self.transformers_service = TransformersService()

    def test_get_image_vector(self):
        vector1 = self.transformers_service.get_image_vector(Image.open(assets_path / 'test_images/cat_0.jpg'))
        vector2 = self.transformers_service.get_image_vector(Image.open(assets_path / 'test_images/cat_1.jpg'))
        assert vector1.shape == (768,)
        assert vector2.shape == (768,)
        assert calculate_vectors_cosine(vector1, vector2) > 0.8

    def test_get_text_vector(self):
        vector1 = self.transformers_service.get_text_vector('1girl')
        vector2 = self.transformers_service.get_text_vector('girl, solo')
        assert vector1.shape == (768,)
        assert vector2.shape == (768,)
        assert calculate_vectors_cosine(vector1, vector2) > 0.8

    def test_get_bert_vector(self):
        vector1 = self.transformers_service.get_bert_vector('hi')
        vector2 = self.transformers_service.get_bert_vector('hello')
        assert vector1.shape == (768,)
        assert vector2.shape == (768,)
        assert calculate_vectors_cosine(vector1, vector2) > 0.8

    def test_get_bert_vector_long_text(self):
        vector1 = self.transformers_service.get_bert_vector('The quick brown fox jumps over the lazy dog ' * 100)
        vector2 = self.transformers_service.get_bert_vector('我可以吞下玻璃而不伤身体' * 100)
        assert vector1.shape == (768,)
        assert vector2.shape == (768,)