File size: 3,281 Bytes
32b542e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''

import torch.nn.init as init
import torch
from torch import nn
import torch.distributed as dist




from .gate import TopKGate
import copy
import typing

from .experts import FusedExperts as Experts


class TaskMoE(torch.nn.Module):
    def __init__(self,
                 hidden_size,
                 expert,
                 num_experts=1,
                 k=1,
                 capacity_factor=1.,
                 eval_capacity_factor=1.,
                 min_capacity=4,
                 noisy_gate_policy: typing.Optional[str] = None,
                 drop_tokens: bool = True,
                 use_rts=True,
                 use_tutel: bool = False,
                 cfg=None):
        """Initialize an MoE layer.

        Arguments:
            hidden_size (int): the hidden dimension of the model, importantly this is also the input and output dimension.

            expert (torch.nn.Module): the torch module that defines the expert (e.g., MLP, torch.linear).

            num_experts (int, optional): default=1, the total number of experts per layer.

            k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.

            capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.

            eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.

            min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.

            noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample' or 'None'.

            drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to infinite capacity).

            use_rts (bool, optional): default=True, whether to use Random Token Selection.

            use_tutel (bool, optional): default=False, whether to use Tutel optimizations (if installed).
        """

        super().__init__()


        self.num_experts = num_experts

        if isinstance(expert, nn.Linear):
            self.expert_type = 'linear'
        elif isinstance(expert, nn.MultiheadAttention):
            self.expert_type = 'attention'
        else:
            raise NotImplementedError('please check expert type')

        experts = Experts(expert, cfg, num_experts)

        self.gate = TopKGate(hidden_size,
                             num_experts,
                             k,
                             noisy_gate_policy,
                             cfg,
                             moe_type=self.expert_type)


        self.experts = experts



    def forward(self, hidden_states, gate_decision=None, **kwargs):
        """ MoE forward
        Arguments:
            hidden_states (Tensor): input to the layer
        Returns:
            A tuple including output
            * output (Tensor): output of the model
        """


        if  gate_decision is not None:
            top_indices, gates = gate_decision
        else:
            top_indices, gates = self.gate(hidden_states, **kwargs)

        expert_output = self.experts(hidden_states, top_indices, gates, **kwargs)

        return expert_output, [top_indices, gates]