File size: 21,074 Bytes
6ee9a4f
 
 
 
 
 
 
 
 
 
 
2c40901
 
 
 
 
6ee9a4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c40901
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
from typing import Optional, Tuple, Union
from collections import namedtuple
from einops import rearrange

import math
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F
from torch.utils.checkpoint import checkpoint

from .configuration_muddpythia import MUDDPythiaConfig
#try:
#    from .configuration_muddpythia import MUDDPythiaConfig
#except:
#    from configuration_muddpythia import MUDDPythiaConfig

from transformers.modeling_utils import PreTrainedModel


class KVCache(nn.Module):
    def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.float16):
        super().__init__()
        self.seq_length = max_seq_length
        cache_shape = (max_batch_size, n_heads, self.seq_length, head_dim)
        self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
        self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))

    def update(self, input_pos, k_val, v_val):
        # input_pos: [S], k_val: [B, H, S, D]
        assert input_pos.shape[0] == k_val.shape[2]
        k_out = self.k_cache
        v_out = self.v_cache
        k_out[:, :, input_pos] = k_val
        v_out[:, :, input_pos] = v_val
        return k_out, v_out

class LayerCache(nn.Module):
    def __init__(self, max_batch_size, num_layers, model_dim, dtype=torch.float16):
        super().__init__()
        cache_shape = (num_layers+1, max_batch_size, 1, model_dim) # LBTD
        self.register_buffer('layer_cache', torch.zeros(cache_shape, dtype=dtype))
    
    def update(self, x, lidx):
        self.layer_cache[lidx] = x
        return self.layer_cache[:lidx+1]

class MultiwayDynamicDenseBlock(nn.Module):
    def __init__(self, config: MUDDPythiaConfig, lidx: int, last_layer=False) -> None:
        super().__init__()
        self.norm = RMSnormNoscale(epsilon=config.norm_eps)
        self.C = len(config.dense_type) if not last_layer else 1
        self.lidx = lidx
        l = lidx + 2
        hid_dim, out_dim = l * self.C, l * self.C
        if last_layer and config.expand_last: hid_dim *= 4  
        if config.round64: hid_dim = (hid_dim// 64 +1) * 64 
        self.w1 = nn.Linear(config.dim, hid_dim, bias=False)
        self.act = nn.GELU() 
        self.w2 = nn.Linear(hid_dim, out_dim, bias=False)
    
    def forward(self, x: Tensor) -> Tensor:
        x = self.norm(x) 
        dw = self.w2(self.act(self.w1(x))) # BTD->BTL
        dw = rearrange(dw, 'B T (C L) -> C B T L', C=self.C)
        return dw
    
    def layer_mix(self, hids, dw)-> Tensor:
        x = tuple([sum(dw[cidx,:,:,j,None] * hids[j] for j in range(self.lidx+2)) for cidx in range(self.C)]) # BTL, LBTD-> BTD
        return x

class MUDDPythia(PreTrainedModel):
    config_class=MUDDPythiaConfig
    def __init__(self, config: MUDDPythiaConfig) -> None:
        super().__init__(config)
        self.config = config
        self.use_gradient_checkpointing = config.use_gradient_checkpointing 
        self.is_training = config.is_training
        self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
        self.layers = nn.ModuleList(TransformerBlock(config, lidx) for lidx in range(config.n_layer))
        self.norm = nn.LayerNorm(config.dim, eps=config.norm_eps)
        self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
        C = len(self.config.dense_type)
        self.dense_bs = nn.ParameterList([nn.Parameter(data=torch.randn(C if lidx != config.n_layer-1 else 1, lidx+2)) for lidx in range(config.n_layer)])

        self.layer_cache = None
        self.use_layer_cache = False if self.is_training else self.config.use_layer_cache
        self.stack_hidden = self.config.stack_hidden
        self.dynamic = self.config.dynamic_dense
        self.dense = self.config.dense
        if self.dynamic:
            self.dynamic_dense = nn.ModuleList([MultiwayDynamicDenseBlock(config, lidx, last_layer=lidx==config.n_layer-1) for lidx in range(config.n_layer)])

        self.rotary_ndims = int(config.head_dim * config.rotary_pct)
        self.freqs_cis: Optional[Tensor] = None
        self.mask_cache: Optional[Tensor] = None
        self.max_batch_size = -1
        self.max_seq_length = -1
    
    def tie_weights(self): # placeholder
        return 

    def setup_caches(self, max_batch_size, max_seq_length, dtype=torch.float16):
        if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
            return
        head_dim = self.config.dim // self.config.n_head
        max_seq_length = find_multiple(max_seq_length, 8)
        self.max_seq_length = max_seq_length
        self.max_batch_size = max_batch_size
        if not self.config.is_training:
            if self.use_layer_cache:
                self.layer_cache = LayerCache(max_batch_size, self.config.n_layer, self.config.dim, dtype=dtype)
            for b in self.layers:
                b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype=dtype)

        self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.rotary_ndims, self.config.rope_base).to(self.tok_embeddings.weight.device)
        self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool, device=self.tok_embeddings.weight.device))

    def generate(self, input_ids, num_tokens_to_generate=10, compiled_decode_one_token=None):
        batch_size, seq_length = input_ids.shape
        input_pos = torch.arange(seq_length, device=self.device)
        generated_ids = torch.zeros(batch_size, seq_length + num_tokens_to_generate, dtype=torch.int, device=self.device)
        generated_ids[:, :seq_length] = input_ids.to(self.device).to(torch.int)
        logits = self.forward(input_ids, input_pos=input_pos,return_tensor=True)
        _next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
        next_token = torch.zeros(self.max_batch_size, 1, device=self.device, dtype=torch.int)
        next_token[:batch_size] = _next_token
        generated_ids[:, seq_length] = next_token[:batch_size, 0]
        input_pos = torch.tensor([seq_length], device=self.device)
        for _ in range(1, num_tokens_to_generate):
            if compiled_decode_one_token is not None:
                next_token = compiled_decode_one_token(self, next_token.clone(), input_pos)
            else:
                next_token = self.decode_one_token(next_token.clone(), input_pos)
            generated_ids[:, input_pos+1] = next_token.int()[:batch_size]
            input_pos += 1
        return generated_ids
    
    def decode_one_token(self, cur_token, input_pos):
        logits = self.forward(
            cur_token,
            input_pos=input_pos,
            return_tensor=True
        )
        new_token = torch.argmax(logits[:, -1], dim=-1)[:,None]
        return new_token

    def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None, return_tensor=False) -> Tensor:
        assert self.freqs_cis is not None, "Caches must be initialized first"
        if input_pos is None:
            input_pos = torch.arange(idx.shape[-1], device=idx.device, dtype=torch.int)
        mask = self.causal_mask[None, None, input_pos]
        freqs_cis = self.freqs_cis[input_pos]
        x = self.tok_embeddings(idx)
        _, seqlen, _ = x.shape
        use_layer_cache = self.use_layer_cache and seqlen == 1
        if use_layer_cache:
            self.layer_cache.update(x, 0)
        else:
            hiddens = [x]
        for i, layer in enumerate(self.layers):
            if self.use_gradient_checkpointing:
                x = checkpoint(layer, x, input_pos, freqs_cis, mask)
            else:
                x = layer(x, input_pos, freqs_cis, mask)
            if use_layer_cache:
                _hidden = self.layer_cache.update(x, i+1) # LBTD
            else:
                hiddens.append(x)
                _hidden = hiddens if not self.stack_hidden else hiddens
            if self.dynamic and self.dense:
                dw = self.dynamic_dense[i](x) # BTD -> CBTL
                dw = dw + self.dense_bs[i][:,None,None,:] # CBTL
                if self.stack_hidden:
                    x = torch.einsum('LBTD, CBTL -> CBTD', _hidden, dw)
                else:
                    x = self.dynamic_dense[i].layer_mix(_hidden, dw)
        if self.config.dense_type == 'qkvr' and self.config.dense and self.config.dynamic_dense:
            x = x[0]
        x = self.norm(x)
        logits = self.output(x)
        if return_tensor: 
            return logits
        else:
            CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
            return CausalLMOutput(logits=logits)

class TransformerBlock(nn.Module):
    def __init__(self, config: MUDDPythiaConfig, lidx) -> None:
        super().__init__()
        self.lidx = lidx
        self.config = config
        self.attention = Attention(config, lidx)

        self.feed_forward = FeedForward(config, lidx)
        self.ffn_norm = nn.LayerNorm(config.dim, eps=config.norm_eps)
        self.use_parallel_residual = config.use_parallel_residual

        if self.config.sepln and self.lidx > 0:
            self.attention_norms = torch.nn.ModuleList([ nn.LayerNorm(config.dim, eps=config.norm_eps) for _ in range(3)])
        else:
            self.attention_norm = nn.LayerNorm(config.dim, eps=config.norm_eps)

    def forward(self, x: Union[Tuple[Tensor], Tensor], input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
        if self.config.dense_type == 'l' or self.lidx == 0 or not self.config.dense:
            res = x
            normed_x = self.attention_norm(x)
        elif self.config.dense_type == 'qkvr':
            res = x[-1] # for mlp
            if self.config.stack_hidden or not self.config.sepln:
                normed_x = self.attention_norm(x[:3])
            else:
                normed_x = tuple([norm_fn(_x) for norm_fn, _x in zip(self.attention_norms, x[:3])])
        h = res + self.attention(normed_x, freqs_cis, mask, input_pos)
        out = h + self.feed_forward(self.ffn_norm(res if self.use_parallel_residual else h))
        return out

class Attention(nn.Module):
    def __init__(self, config: MUDDPythiaConfig, lidx):
        super().__init__()
        assert config.dim % config.n_head == 0

        total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
        # key, query, value projections for all heads, but in a batch
        self.config = config
        if self.config.dense_type == 'l' or not self.config.dense:
            self.wqkv = nn.Linear(config.dim, total_head_dim, bias=True)
        elif self.config.dense_type == 'qkvr':
            self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=True)
            self.wk = nn.Linear(config.dim, config.n_local_heads * config.head_dim, bias=True)
            self.wv = nn.Linear(config.dim, config.n_local_heads * config.head_dim, bias=True)

        self.wo = nn.Linear(config.dim, config.dim, bias=True)
        self.lidx = lidx
        self.kv_cache = None

        self.n_head = config.n_head
        self.head_dim = config.head_dim
        self.scale_factor = 1 / math.sqrt(self.head_dim)
        self.n_local_heads = config.n_local_heads
        self.dim = config.dim
        self.use_qk_norm = config.use_qk_norm
        if self.use_qk_norm:
            self.q_norm = RMSNorm(self.head_dim, config.norm_eps)
            self.k_norm = RMSNorm(self.head_dim, config.norm_eps)

        self.rotary_ndims = int(self.head_dim * config.rotary_pct)

        self._register_load_state_dict_pre_hook(self.load_hook)

    def load_hook(self, state_dict, prefix, *args):
        if prefix + "wq.weight" in state_dict and (self.config.dense_type == 'l' or not self.config.dense):
            wq = state_dict.pop(prefix + "wq.weight")
            wk = state_dict.pop(prefix + "wk.weight")
            wv = state_dict.pop(prefix + "wv.weight")
            state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])

    def forward(self, x: Union[Tuple[Tensor], Tensor], freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
        if self.lidx == 0 or self.config.dense_type == 'l' or not self.config.dense:
            bsz, seqlen, _ = x.shape
        else:
            if self.config.stack_hidden:
                C, bsz, seqlen, _ = x.shape
            else:
                C, (bsz, seqlen, _) = len(x), x[0].shape
        kv_size = self.n_local_heads * self.head_dim

        if self.config.dense_type == 'l' or not self.config.dense:
            q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)

            q = q.view(bsz, seqlen, self.n_head, self.head_dim)
            k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
            v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        elif self.config.dense_type == 'qkvr':
            if self.lidx == 0:
                xq, xk, xv = x, x, x
            else:
                xq, xk, xv = x[0], x[1], x[2]
            q = self.wq(xq).view(bsz, seqlen, self.n_head, self.head_dim)
            k = self.wk(xk).view(bsz, seqlen, self.n_local_heads, self.head_dim)
            v = self.wv(xv).view(bsz, seqlen, self.n_local_heads, self.head_dim)

        if self.use_qk_norm:
            q, k = self.q_norm(q), self.k_norm(k)

        if self.rotary_ndims == self.head_dim:
            q = apply_rotary_emb(q, freqs_cis) #BTND
            k = apply_rotary_emb(k, freqs_cis)
        else:
            q_rot = q[..., : self.rotary_ndims]
            q_pass = q[..., self.rotary_ndims :]
            k_rot = k[..., : self.rotary_ndims]
            k_pass = k[..., self.rotary_ndims :]
            q_rot = apply_rotary_emb(q_rot, freqs_cis, mode='half') #BTND
            k_rot = apply_rotary_emb(k_rot, freqs_cis, mode='half')
            q = torch.cat((q_rot, q_pass), dim=-1)
            k = torch.cat((k_rot, k_pass), dim=-1)

        q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))

        if self.kv_cache is not None:
            if seqlen == 1:
                k, v = self.kv_cache.update(input_pos, k, v)
            else:
                _, _ = self.kv_cache.update(input_pos, k, v)

        if seqlen == 1: # one-token generation
            k_mask = mask[:,:,:,:self.kv_cache.seq_length]
        else:# prefill
            k_mask = mask[:,:,:,:k.shape[-2]] 

        logits = q @ k.transpose(-2, -1) * self.scale_factor 
        min_value = torch.finfo(torch.float16).min
        logits = torch.where(k_mask, logits, min_value)
        probs = logits.softmax(-1)
        y = probs @ v
        y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
        y = self.wo(y)
        return y

class RMSnormNoscale(nn.Module):
    
    def __init__(self, epsilon=1e-6, dim=-1):
        super().__init__()
        self.dim = dim 
        self.epsilon = epsilon

    def forward(self, inputs):
        var = inputs.pow(2).mean(dim=self.dim, keepdim=True)
        normed_inputs = inputs * torch.rsqrt(var + self.epsilon)
        return normed_inputs 

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)

    def forward(self, x: Tensor) -> Tensor:
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

class FeedForward(nn.Module):
    def __init__(self, config: MUDDPythiaConfig, lidx, round128=True, scale_with_layer=True) -> None:
        super().__init__()
        hid_dim = config.intermediate_size
        if scale_with_layer:
            hid_dim = hid_dim * (lidx/(config.n_layer -1) +0.5)
        if round128:
            hid_dim = round(hid_dim / 128) * 128
        self.w1 = nn.Linear(config.dim, hid_dim, bias=config.use_linear_bias)
        self.w2 = nn.Linear(hid_dim, config.dim, bias=config.use_linear_bias)

    def forward(self, x: Tensor) -> Tensor:
        return self.w2(F.gelu(self.w1(x)))

def find_multiple(n: int, k: int) -> int:
    if n % k == 0:
        return n
    return n + k - (n % k)

def precompute_freqs_cis(
    seq_len: int, n_elem: int, base: int = 10000
) -> Tensor:
    freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
    t = torch.arange(seq_len, device=freqs.device)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
    return cache.to(dtype=torch.float16)

def apply_rotary_emb(x: Tensor, freqs_cis: Tensor, mode='half') -> Tensor:
    if mode == 'half':
        xshaped = x.float().reshape(*x.shape[:-1], 2,-1).transpose(-1,-2) 
    elif mode == 'alternative':
        xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
    freqs_cis = freqs_cis.view(-1, xshaped.size(1), 1, xshaped.size(3), 2)
    x_out2 = torch.stack(
        [
            xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
            xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
        ],
        -1,
    )
    x_out2 = x_out2.flatten(3)
    return x_out2.type_as(x)

def match_weight_muddpythia(model, w, strict=False, pythia=True):
    if pythia:
        map_dict={'wq':'query', 'wk':'key', 'wv':'value', 'wo':'post', 'w1':'ffn_layer1', 'w2': 'ffn_layer2',
                'weight': 'w', 'bias': 'b'}
    else:
        map_dict={'wq':'query', 'wk':'key', 'wv':'value', 'wo':'post', 'w1': 'ffn_layer1_gate', 'w3': 'ffn_layer1', 'w2': 'ffn_layer2',
            'weight': 'w', 'bias': 'b'}
    ln_dict={'weight':'scale','bias':'bias'}
    E, H, D = model.config.dim, model.config.n_head, model.config.head_dim
    N = model.config.vocab_size
    state_dict = {}
    for k, v in model.named_parameters():
        if k == 'tok_embeddings.weight':
            v = w['state.mdl_vars.params.lm.embedding_lookup.emb_var'][:N,:]
        elif k == 'norm.weight':
            v = w['state.mdl_vars.params.lm.final_ln.scale']
        elif k == 'norm.bias':
            v = w['state.mdl_vars.params.lm.final_ln.bias']
        elif k == 'output.weight':
            v = w['state.mdl_vars.params.lm.softmax.logits_ffn.linear.w'].T[:N,:]  # E,N -> N,E
        elif 'dense_bs' in k: # static dense w
            lidx = int(k.split('.')[-1])
            v = w[f'state.mdl_vars.params.lm.transformer.dense_conn_{lidx}']
        elif 'dynamic_dense' in k:
            lidx = int(k.split('.')[1])
            widx = int(k.split('.')[2][-1]) # 1 or 2 in w1, w2
            v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.dynamic_dense_conn{widx}_{lidx}'].T
        else:
            assert 'layers' in k
            lidx = int(k.split('.')[1])
            if '.attention.' in k:
                _, _, _, ptype, wtype = k.split('.')
                if ptype in ['wq', 'wk', 'wv', 'wo']:
                    v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.self_attention.{map_dict.get(ptype, ptype)}.{map_dict.get(wtype, wtype)}']#.reshape(E,E)
                    if wtype == 'weight':
                        v = v.reshape(E,E)
                    elif wtype == 'bias':
                        v = v.reshape(E)
                    if ptype != 'wo' and wtype == 'weight':
                        v = v.T
                elif ptype in ['q_norm', 'k_norm']:
                    v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.self_attention.{map_dict.get(ptype, ptype)}.scale']
            elif 'feed_forward' in k:
                ptype = k.split('.')[3] # w1, w3,w2
                wtype = k.split('.')[4] # weight or bias
                if wtype=='weight':
                    v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.ff_layer.{map_dict[ptype]}.linear.{map_dict[wtype]}']
                    v = v.T
                elif wtype=='bias':
                    v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.ff_layer.{map_dict[ptype]}.bias.{map_dict[wtype]}']
            elif 'ffn_norm' in k: # mlp layernorm
                wtype = k.split('.')[-1] # weight or bias
                v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.ff_layer.layer_norm.{ln_dict[wtype]}']
            elif 'attention_norm' in k: # attention layernorm
                wtype = k.split('.')[-1] # weight or bias
                if 'attention_norms' in k:
                    ln_idx = int(k.split('.')[3])
                    v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.layer_norms_{ln_idx}.{ln_dict[wtype]}']
                else:
                    v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.layer_norm.{ln_dict[wtype]}']
        if pythia and 'weight' in k and 'norm' in k and 'q_norm' not in k and 'k_norm' not in k:
            v = v+1
        state_dict[k] = torch.tensor(v)
    model.load_state_dict(state_dict, strict=strict)
    return model