File size: 3,817 Bytes
db24a4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from frame import Frame 
from helper import OBJECT_MAP, get_hypernym_path

from numpy.polynomial import polynomial
from nltk.corpus import wordnet as wn
import json
import os

class NodeFrame:
    def __init__(self, frame: Frame, p_list: list[float]) -> None:
        self.frame = frame
        self.p_list = p_list
        self.p_total = self.calculate_p_total(p_list)
        self.p_exactly = self.calculate_p_exactly(p_list)
        
    def calculate_p_total(self, p_list: list[float]) -> float:
        return sum(p_list)

    def calculate_p_exactly(self, p_list: list[float]) -> list[float]:
        result = [1]
        p_list = [[1 - p, p] for p in p_list]
        for p in p_list:
            result = polynomial.polymul(result, p)
        return list(result) 
    
    def p_of(self, amount: int) -> float:
        if amount < len(self.p_exactly):
            return self.p_exactly[amount]
        else:
            return self.p_exactly[-1] * (0.1 ** (amount - len(self.p_exactly) + 1))
        
    def serialize(self) -> dict:
        return {
            'frame': self.frame.serialize(),
            'p_list': self.p_list,
        }

class Node:
    def __init__(self, node_frames: list[NodeFrame]) -> None:
        self.node_frames = node_frames
        self.children = {}

class Trie:
    def __init__(self) -> None:
        self.root = Node([])
        
    def insert(self, node_frame: NodeFrame, path: list[str]) -> None:
        node = self.root
        for word in path:
            if word not in node.children:
                node.children[word] = Node([])
            node = node.children[word]
        node.node_frames.append(node_frame)
        
    def search(self, path: list[str]) -> list[NodeFrame]:
        node = self.root
        for word in path:
            if word not in node.children:
                return []
            node = node.children[word]
        return self.search_all_children(node)
    
    def search_all_children(self, node: Node) -> list[NodeFrame]:
        result = []
        if len(node.node_frames) > 0:
            result.extend(node.node_frames)
        for child in node.children.values():
            result.extend(self.search_all_children(child))
        return result
            
    def load_from_dir(self, dir: str) -> None:
        for path, _, files in os.walk(dir):
            for file in files:
                if file.endswith('.json'):
                    data = json.load(open(os.path.join(path, file)))
                    video = file[:-5]
                    for frame_name, frame_data in data.items():
                        for object, p_list in frame_data.items():
                            hypernym_path = get_hypernym_path(object)
                            self.insert(NodeFrame(Frame(video=video, frame_name=frame_name), p_list), hypernym_path)
    
    def save_to_cache(self, cache_path: str) -> None:
        json.dump(self.serialize(), open(cache_path, 'w'))
        
    def load_from_cache(self, cache_path: str) -> None:
        self.deserialize(json.load(open(cache_path)))
        
    def serialize(self) -> dict:
        output = {}
        def dfs(node: Node, path: list[str]) -> None:
            if len(node.node_frames) > 0:
                output['/'.join(path)] = [node_frame.serialize() for node_frame in node.node_frames]
            for word, child in node.children.items():
                dfs(child, path + [word])
        dfs(self.root, [])
        return output
        
    def deserialize(self, input):
        for path, node_frames in input.items():
            path = path.split('/')
            for node_frame in node_frames:
                self.insert(NodeFrame(Frame(id=node_frame['frame']['id']), node_frame['p_list']), path)