zhengyun21 commited on
Commit
afb7521
·
verified ·
1 Parent(s): c3f0bb5

Upload folder using huggingface_hub

Browse files
config.json ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "../../output_ft/nv_4_l512",
3
+ "add_eos": true,
4
+ "add_pad_token": true,
5
+ "architectures": [
6
+ "NVEmbedModel"
7
+ ],
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_nvembed.NVEmbedConfig",
10
+ "AutoModel": "modeling_nvembed.NVEmbedModel"
11
+ },
12
+ "hidden_size": 4096,
13
+ "is_mask_instruction": true,
14
+ "latent_attention_config": {
15
+ "model_type": "latent_attention"
16
+ },
17
+ "mask_type": "b",
18
+ "model_type": "nvembed",
19
+ "padding_side": "right",
20
+ "text_config": {
21
+ "_name_or_path": "nvidia/NV-Embed-v2",
22
+ "add_cross_attention": false,
23
+ "architectures": [
24
+ "MistralModel"
25
+ ],
26
+ "attention_dropout": 0.0,
27
+ "bad_words_ids": null,
28
+ "begin_suppress_tokens": null,
29
+ "bos_token_id": 1,
30
+ "chunk_size_feed_forward": 0,
31
+ "cross_attention_hidden_size": null,
32
+ "decoder_start_token_id": null,
33
+ "diversity_penalty": 0.0,
34
+ "do_sample": false,
35
+ "early_stopping": false,
36
+ "encoder_no_repeat_ngram_size": 0,
37
+ "eos_token_id": 2,
38
+ "exponential_decay_length_penalty": null,
39
+ "finetuning_task": null,
40
+ "forced_bos_token_id": null,
41
+ "forced_eos_token_id": null,
42
+ "hidden_act": "silu",
43
+ "hidden_size": 4096,
44
+ "id2label": {
45
+ "0": "LABEL_0",
46
+ "1": "LABEL_1"
47
+ },
48
+ "initializer_range": 0.02,
49
+ "intermediate_size": 14336,
50
+ "is_decoder": false,
51
+ "is_encoder_decoder": false,
52
+ "label2id": {
53
+ "LABEL_0": 0,
54
+ "LABEL_1": 1
55
+ },
56
+ "length_penalty": 1.0,
57
+ "max_length": 20,
58
+ "max_position_embeddings": 32768,
59
+ "min_length": 0,
60
+ "model_type": "bidir_mistral",
61
+ "no_repeat_ngram_size": 0,
62
+ "num_attention_heads": 32,
63
+ "num_beam_groups": 1,
64
+ "num_beams": 1,
65
+ "num_hidden_layers": 32,
66
+ "num_key_value_heads": 8,
67
+ "num_return_sequences": 1,
68
+ "output_attentions": false,
69
+ "output_hidden_states": false,
70
+ "output_scores": false,
71
+ "pad_token_id": null,
72
+ "prefix": null,
73
+ "problem_type": null,
74
+ "pruned_heads": {},
75
+ "remove_invalid_values": false,
76
+ "repetition_penalty": 1.0,
77
+ "return_dict": true,
78
+ "return_dict_in_generate": false,
79
+ "rms_norm_eps": 1e-05,
80
+ "rope_theta": 10000.0,
81
+ "sep_token_id": null,
82
+ "sliding_window": 4096,
83
+ "suppress_tokens": null,
84
+ "task_specific_params": null,
85
+ "temperature": 1.0,
86
+ "tf_legacy_loss": false,
87
+ "tie_encoder_decoder": false,
88
+ "tie_word_embeddings": false,
89
+ "tokenizer_class": null,
90
+ "top_k": 50,
91
+ "top_p": 1.0,
92
+ "torch_dtype": "float32",
93
+ "torchscript": false,
94
+ "typical_p": 1.0,
95
+ "use_bfloat16": false,
96
+ "use_cache": true,
97
+ "vocab_size": 32000
98
+ },
99
+ "torch_dtype": "float32",
100
+ "transformers_version": "4.41.0"
101
+ }
configuration_nvembed.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Literal
3
+ from transformers import AutoConfig
4
+ from transformers.configuration_utils import PretrainedConfig
5
+ from transformers.models.auto import CONFIG_MAPPING
6
+ from transformers.models.mistral import MistralConfig
7
+
8
+ NVEMBED_TYPE = "nvembed"
9
+ LATENT_ATTENTION_TYPE = "latent_attention"
10
+ BIDIR_MISTRAL_TYPE = "bidir_mistral"
11
+
12
+ class NVEmbedConfig(PretrainedConfig):
13
+ model_type = "nvembed"
14
+ is_composition = False
15
+
16
+ def __init__(
17
+ self,
18
+ latent_attention_config=None,
19
+ text_config=None,
20
+ padding_side: Literal["right", "left"]="right",
21
+ add_pad_token: bool=True,
22
+ is_mask_instruction: bool = True,
23
+ add_eos: bool=True,
24
+ mask_type: str="b",
25
+ **kwargs,
26
+ ):
27
+ if isinstance(latent_attention_config, dict):
28
+ latent_attention_config["model_type"] = (
29
+ latent_attention_config["model_type"] if "model_type" in latent_attention_config else LATENT_ATTENTION_TYPE
30
+ )
31
+ latent_attention_config = CONFIG_MAPPING[latent_attention_config["model_type"]](**latent_attention_config)
32
+ elif latent_attention_config is None:
33
+ latent_attention_config = CONFIG_MAPPING[LATENT_ATTENTION_TYPE]()
34
+
35
+ self.latent_attention_config = latent_attention_config
36
+
37
+ if isinstance(text_config, dict):
38
+ text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
39
+ text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
40
+ elif text_config is None:
41
+ text_config = None
42
+
43
+ self.text_config = text_config
44
+ self.padding_side = padding_side
45
+ self.is_mask_instruction = is_mask_instruction
46
+ self.add_pad_token = add_pad_token
47
+ self.add_eos = add_eos
48
+ self.mask_type = mask_type
49
+ if "hidden_size" in kwargs:
50
+ self.hidden_size = kwargs["hidden_size"]
51
+ else:
52
+ self.hidden_size = 4096
53
+
54
+ super().__init__(**kwargs)
55
+
56
+
57
+ class LatentAttentionConfig(PretrainedConfig):
58
+ model_type = LATENT_ATTENTION_TYPE
59
+ is_composition = False
60
+ _name_or_path = "latent_attention"
61
+
62
+ def __init__(
63
+ self,
64
+ num_latents_value: int=512,
65
+ num_cross_heads: int=8,
66
+ output_normalize: bool=True,
67
+ hidden_dim: int=4096,
68
+ latent_dim: int=4096,
69
+ cross_dim_head: int=4096,
70
+ **kwargs,
71
+ ):
72
+ self.num_latents_value = num_latents_value
73
+ self.num_cross_heads = num_cross_heads
74
+ self.output_normalize = output_normalize
75
+ self.hidden_dim = hidden_dim
76
+ self.latent_dim = latent_dim
77
+ self.cross_dim_head = cross_dim_head
78
+
79
+ super().__init__(**kwargs)
80
+
81
+
82
+ class BidirectionalMistralConfig(MistralConfig):
83
+ model_type = BIDIR_MISTRAL_TYPE
84
+ keys_to_ignore_at_inference = ["past_key_values"]
85
+
86
+ AutoConfig.register(NVEMBED_TYPE, NVEmbedConfig)
87
+ AutoConfig.register(LATENT_ATTENTION_TYPE, LatentAttentionConfig)
88
+ AutoConfig.register(BIDIR_MISTRAL_TYPE, BidirectionalMistralConfig)
89
+
90
+ NVEmbedConfig.register_for_auto_class()
91
+ LatentAttentionConfig.register_for_auto_class()
92
+ BidirectionalMistralConfig.register_for_auto_class()
data_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:232af92e4a73796f60edeade96320093a2290e2264a7a2b0d3e4b0f5eec4c060
3
+ size 1000
model-00001-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c6fa030d0c9dd85bb107655f316ff0ab66551690615c554c68fabecd43e22fca
3
+ size 4995698456
model-00002-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:55aa8d1c62376700abb304fc66ff69e6000b0c4975056fd9ad14c79357d0b960
3
+ size 4999813600
model-00003-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d23852c39e57fb7daa4c3fdaa5f1d35da035cdee46f7afbce612e44b91caf7cb
3
+ size 4999813624
model-00004-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0033197f8a25fb2c31b5030156bdf0ae8546e0879af2805a068f86c52a9e9c79
3
+ size 4832007968
model-00005-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52bb48e04cfa4df62ca412edb2bad23437de511da766a736a16f396b23fc911b
3
+ size 4999813656
model-00006-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4fa3b2be383b214d4d8b9834113d98e7de2085a1b3cc445a73c046822a1797e2
3
+ size 4999813656
model-00007-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2ee1638913495e91391ff6ecfdc5cc92872c2d22017d59202ae1e100297322b
3
+ size 1577142096
model.safetensors.index.json ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 31404064768
4
+ },
5
+ "weight_map": {
6
+ "embedding_model.embed_tokens.weight": "model-00001-of-00007.safetensors",
7
+ "embedding_model.layers.0.input_layernorm.weight": "model-00001-of-00007.safetensors",
8
+ "embedding_model.layers.0.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
9
+ "embedding_model.layers.0.mlp.gate_proj.weight": "model-00001-of-00007.safetensors",
10
+ "embedding_model.layers.0.mlp.up_proj.weight": "model-00001-of-00007.safetensors",
11
+ "embedding_model.layers.0.post_attention_layernorm.weight": "model-00001-of-00007.safetensors",
12
+ "embedding_model.layers.0.self_attn.k_proj.weight": "model-00001-of-00007.safetensors",
13
+ "embedding_model.layers.0.self_attn.o_proj.weight": "model-00001-of-00007.safetensors",
14
+ "embedding_model.layers.0.self_attn.q_proj.weight": "model-00001-of-00007.safetensors",
15
+ "embedding_model.layers.0.self_attn.v_proj.weight": "model-00001-of-00007.safetensors",
16
+ "embedding_model.layers.1.input_layernorm.weight": "model-00002-of-00007.safetensors",
17
+ "embedding_model.layers.1.mlp.down_proj.weight": "model-00002-of-00007.safetensors",
18
+ "embedding_model.layers.1.mlp.gate_proj.weight": "model-00001-of-00007.safetensors",
19
+ "embedding_model.layers.1.mlp.up_proj.weight": "model-00001-of-00007.safetensors",
20
+ "embedding_model.layers.1.post_attention_layernorm.weight": "model-00002-of-00007.safetensors",
21
+ "embedding_model.layers.1.self_attn.k_proj.weight": "model-00001-of-00007.safetensors",
22
+ "embedding_model.layers.1.self_attn.o_proj.weight": "model-00001-of-00007.safetensors",
23
+ "embedding_model.layers.1.self_attn.q_proj.weight": "model-00001-of-00007.safetensors",
24
+ "embedding_model.layers.1.self_attn.v_proj.weight": "model-00001-of-00007.safetensors",
25
+ "embedding_model.layers.10.input_layernorm.weight": "model-00003-of-00007.safetensors",
26
+ "embedding_model.layers.10.mlp.down_proj.weight": "model-00003-of-00007.safetensors",
27
+ "embedding_model.layers.10.mlp.gate_proj.weight": "model-00003-of-00007.safetensors",
28
+ "embedding_model.layers.10.mlp.up_proj.weight": "model-00003-of-00007.safetensors",
29
+ "embedding_model.layers.10.post_attention_layernorm.weight": "model-00003-of-00007.safetensors",
30
+ "embedding_model.layers.10.self_attn.k_proj.weight": "model-00003-of-00007.safetensors",
31
+ "embedding_model.layers.10.self_attn.o_proj.weight": "model-00003-of-00007.safetensors",
32
+ "embedding_model.layers.10.self_attn.q_proj.weight": "model-00003-of-00007.safetensors",
33
+ "embedding_model.layers.10.self_attn.v_proj.weight": "model-00003-of-00007.safetensors",
34
+ "embedding_model.layers.11.input_layernorm.weight": "model-00003-of-00007.safetensors",
35
+ "embedding_model.layers.11.mlp.down_proj.weight": "model-00003-of-00007.safetensors",
36
+ "embedding_model.layers.11.mlp.gate_proj.weight": "model-00003-of-00007.safetensors",
37
+ "embedding_model.layers.11.mlp.up_proj.weight": "model-00003-of-00007.safetensors",
38
+ "embedding_model.layers.11.post_attention_layernorm.weight": "model-00003-of-00007.safetensors",
39
+ "embedding_model.layers.11.self_attn.k_proj.weight": "model-00003-of-00007.safetensors",
40
+ "embedding_model.layers.11.self_attn.o_proj.weight": "model-00003-of-00007.safetensors",
41
+ "embedding_model.layers.11.self_attn.q_proj.weight": "model-00003-of-00007.safetensors",
42
+ "embedding_model.layers.11.self_attn.v_proj.weight": "model-00003-of-00007.safetensors",
43
+ "embedding_model.layers.12.input_layernorm.weight": "model-00003-of-00007.safetensors",
44
+ "embedding_model.layers.12.mlp.down_proj.weight": "model-00003-of-00007.safetensors",
45
+ "embedding_model.layers.12.mlp.gate_proj.weight": "model-00003-of-00007.safetensors",
46
+ "embedding_model.layers.12.mlp.up_proj.weight": "model-00003-of-00007.safetensors",
47
+ "embedding_model.layers.12.post_attention_layernorm.weight": "model-00003-of-00007.safetensors",
48
+ "embedding_model.layers.12.self_attn.k_proj.weight": "model-00003-of-00007.safetensors",
49
+ "embedding_model.layers.12.self_attn.o_proj.weight": "model-00003-of-00007.safetensors",
50
+ "embedding_model.layers.12.self_attn.q_proj.weight": "model-00003-of-00007.safetensors",
51
+ "embedding_model.layers.12.self_attn.v_proj.weight": "model-00003-of-00007.safetensors",
52
+ "embedding_model.layers.13.input_layernorm.weight": "model-00004-of-00007.safetensors",
53
+ "embedding_model.layers.13.mlp.down_proj.weight": "model-00004-of-00007.safetensors",
54
+ "embedding_model.layers.13.mlp.gate_proj.weight": "model-00004-of-00007.safetensors",
55
+ "embedding_model.layers.13.mlp.up_proj.weight": "model-00004-of-00007.safetensors",
56
+ "embedding_model.layers.13.post_attention_layernorm.weight": "model-00004-of-00007.safetensors",
57
+ "embedding_model.layers.13.self_attn.k_proj.weight": "model-00003-of-00007.safetensors",
58
+ "embedding_model.layers.13.self_attn.o_proj.weight": "model-00003-of-00007.safetensors",
59
+ "embedding_model.layers.13.self_attn.q_proj.weight": "model-00003-of-00007.safetensors",
60
+ "embedding_model.layers.13.self_attn.v_proj.weight": "model-00003-of-00007.safetensors",
61
+ "embedding_model.layers.14.input_layernorm.weight": "model-00004-of-00007.safetensors",
62
+ "embedding_model.layers.14.mlp.down_proj.weight": "model-00004-of-00007.safetensors",
63
+ "embedding_model.layers.14.mlp.gate_proj.weight": "model-00004-of-00007.safetensors",
64
+ "embedding_model.layers.14.mlp.up_proj.weight": "model-00004-of-00007.safetensors",
65
+ "embedding_model.layers.14.post_attention_layernorm.weight": "model-00004-of-00007.safetensors",
66
+ "embedding_model.layers.14.self_attn.k_proj.weight": "model-00004-of-00007.safetensors",
67
+ "embedding_model.layers.14.self_attn.o_proj.weight": "model-00004-of-00007.safetensors",
68
+ "embedding_model.layers.14.self_attn.q_proj.weight": "model-00004-of-00007.safetensors",
69
+ "embedding_model.layers.14.self_attn.v_proj.weight": "model-00004-of-00007.safetensors",
70
+ "embedding_model.layers.15.input_layernorm.weight": "model-00004-of-00007.safetensors",
71
+ "embedding_model.layers.15.mlp.down_proj.weight": "model-00004-of-00007.safetensors",
72
+ "embedding_model.layers.15.mlp.gate_proj.weight": "model-00004-of-00007.safetensors",
73
+ "embedding_model.layers.15.mlp.up_proj.weight": "model-00004-of-00007.safetensors",
74
+ "embedding_model.layers.15.post_attention_layernorm.weight": "model-00004-of-00007.safetensors",
75
+ "embedding_model.layers.15.self_attn.k_proj.weight": "model-00004-of-00007.safetensors",
76
+ "embedding_model.layers.15.self_attn.o_proj.weight": "model-00004-of-00007.safetensors",
77
+ "embedding_model.layers.15.self_attn.q_proj.weight": "model-00004-of-00007.safetensors",
78
+ "embedding_model.layers.15.self_attn.v_proj.weight": "model-00004-of-00007.safetensors",
79
+ "embedding_model.layers.16.input_layernorm.weight": "model-00004-of-00007.safetensors",
80
+ "embedding_model.layers.16.mlp.down_proj.weight": "model-00004-of-00007.safetensors",
81
+ "embedding_model.layers.16.mlp.gate_proj.weight": "model-00004-of-00007.safetensors",
82
+ "embedding_model.layers.16.mlp.up_proj.weight": "model-00004-of-00007.safetensors",
83
+ "embedding_model.layers.16.post_attention_layernorm.weight": "model-00004-of-00007.safetensors",
84
+ "embedding_model.layers.16.self_attn.k_proj.weight": "model-00004-of-00007.safetensors",
85
+ "embedding_model.layers.16.self_attn.o_proj.weight": "model-00004-of-00007.safetensors",
86
+ "embedding_model.layers.16.self_attn.q_proj.weight": "model-00004-of-00007.safetensors",
87
+ "embedding_model.layers.16.self_attn.v_proj.weight": "model-00004-of-00007.safetensors",
88
+ "embedding_model.layers.17.input_layernorm.weight": "model-00004-of-00007.safetensors",
89
+ "embedding_model.layers.17.mlp.down_proj.weight": "model-00004-of-00007.safetensors",
90
+ "embedding_model.layers.17.mlp.gate_proj.weight": "model-00004-of-00007.safetensors",
91
+ "embedding_model.layers.17.mlp.up_proj.weight": "model-00004-of-00007.safetensors",
92
+ "embedding_model.layers.17.post_attention_layernorm.weight": "model-00004-of-00007.safetensors",
93
+ "embedding_model.layers.17.self_attn.k_proj.weight": "model-00004-of-00007.safetensors",
94
+ "embedding_model.layers.17.self_attn.o_proj.weight": "model-00004-of-00007.safetensors",
95
+ "embedding_model.layers.17.self_attn.q_proj.weight": "model-00004-of-00007.safetensors",
96
+ "embedding_model.layers.17.self_attn.v_proj.weight": "model-00004-of-00007.safetensors",
97
+ "embedding_model.layers.18.input_layernorm.weight": "model-00005-of-00007.safetensors",
98
+ "embedding_model.layers.18.mlp.down_proj.weight": "model-00005-of-00007.safetensors",
99
+ "embedding_model.layers.18.mlp.gate_proj.weight": "model-00004-of-00007.safetensors",
100
+ "embedding_model.layers.18.mlp.up_proj.weight": "model-00004-of-00007.safetensors",
101
+ "embedding_model.layers.18.post_attention_layernorm.weight": "model-00005-of-00007.safetensors",
102
+ "embedding_model.layers.18.self_attn.k_proj.weight": "model-00004-of-00007.safetensors",
103
+ "embedding_model.layers.18.self_attn.o_proj.weight": "model-00004-of-00007.safetensors",
104
+ "embedding_model.layers.18.self_attn.q_proj.weight": "model-00004-of-00007.safetensors",
105
+ "embedding_model.layers.18.self_attn.v_proj.weight": "model-00004-of-00007.safetensors",
106
+ "embedding_model.layers.19.input_layernorm.weight": "model-00005-of-00007.safetensors",
107
+ "embedding_model.layers.19.mlp.down_proj.weight": "model-00005-of-00007.safetensors",
108
+ "embedding_model.layers.19.mlp.gate_proj.weight": "model-00005-of-00007.safetensors",
109
+ "embedding_model.layers.19.mlp.up_proj.weight": "model-00005-of-00007.safetensors",
110
+ "embedding_model.layers.19.post_attention_layernorm.weight": "model-00005-of-00007.safetensors",
111
+ "embedding_model.layers.19.self_attn.k_proj.weight": "model-00005-of-00007.safetensors",
112
+ "embedding_model.layers.19.self_attn.o_proj.weight": "model-00005-of-00007.safetensors",
113
+ "embedding_model.layers.19.self_attn.q_proj.weight": "model-00005-of-00007.safetensors",
114
+ "embedding_model.layers.19.self_attn.v_proj.weight": "model-00005-of-00007.safetensors",
115
+ "embedding_model.layers.2.input_layernorm.weight": "model-00002-of-00007.safetensors",
116
+ "embedding_model.layers.2.mlp.down_proj.weight": "model-00002-of-00007.safetensors",
117
+ "embedding_model.layers.2.mlp.gate_proj.weight": "model-00002-of-00007.safetensors",
118
+ "embedding_model.layers.2.mlp.up_proj.weight": "model-00002-of-00007.safetensors",
119
+ "embedding_model.layers.2.post_attention_layernorm.weight": "model-00002-of-00007.safetensors",
120
+ "embedding_model.layers.2.self_attn.k_proj.weight": "model-00002-of-00007.safetensors",
121
+ "embedding_model.layers.2.self_attn.o_proj.weight": "model-00002-of-00007.safetensors",
122
+ "embedding_model.layers.2.self_attn.q_proj.weight": "model-00002-of-00007.safetensors",
123
+ "embedding_model.layers.2.self_attn.v_proj.weight": "model-00002-of-00007.safetensors",
124
+ "embedding_model.layers.20.input_layernorm.weight": "model-00005-of-00007.safetensors",
125
+ "embedding_model.layers.20.mlp.down_proj.weight": "model-00005-of-00007.safetensors",
126
+ "embedding_model.layers.20.mlp.gate_proj.weight": "model-00005-of-00007.safetensors",
127
+ "embedding_model.layers.20.mlp.up_proj.weight": "model-00005-of-00007.safetensors",
128
+ "embedding_model.layers.20.post_attention_layernorm.weight": "model-00005-of-00007.safetensors",
129
+ "embedding_model.layers.20.self_attn.k_proj.weight": "model-00005-of-00007.safetensors",
130
+ "embedding_model.layers.20.self_attn.o_proj.weight": "model-00005-of-00007.safetensors",
131
+ "embedding_model.layers.20.self_attn.q_proj.weight": "model-00005-of-00007.safetensors",
132
+ "embedding_model.layers.20.self_attn.v_proj.weight": "model-00005-of-00007.safetensors",
133
+ "embedding_model.layers.21.input_layernorm.weight": "model-00005-of-00007.safetensors",
134
+ "embedding_model.layers.21.mlp.down_proj.weight": "model-00005-of-00007.safetensors",
135
+ "embedding_model.layers.21.mlp.gate_proj.weight": "model-00005-of-00007.safetensors",
136
+ "embedding_model.layers.21.mlp.up_proj.weight": "model-00005-of-00007.safetensors",
137
+ "embedding_model.layers.21.post_attention_layernorm.weight": "model-00005-of-00007.safetensors",
138
+ "embedding_model.layers.21.self_attn.k_proj.weight": "model-00005-of-00007.safetensors",
139
+ "embedding_model.layers.21.self_attn.o_proj.weight": "model-00005-of-00007.safetensors",
140
+ "embedding_model.layers.21.self_attn.q_proj.weight": "model-00005-of-00007.safetensors",
141
+ "embedding_model.layers.21.self_attn.v_proj.weight": "model-00005-of-00007.safetensors",
142
+ "embedding_model.layers.22.input_layernorm.weight": "model-00005-of-00007.safetensors",
143
+ "embedding_model.layers.22.mlp.down_proj.weight": "model-00005-of-00007.safetensors",
144
+ "embedding_model.layers.22.mlp.gate_proj.weight": "model-00005-of-00007.safetensors",
145
+ "embedding_model.layers.22.mlp.up_proj.weight": "model-00005-of-00007.safetensors",
146
+ "embedding_model.layers.22.post_attention_layernorm.weight": "model-00005-of-00007.safetensors",
147
+ "embedding_model.layers.22.self_attn.k_proj.weight": "model-00005-of-00007.safetensors",
148
+ "embedding_model.layers.22.self_attn.o_proj.weight": "model-00005-of-00007.safetensors",
149
+ "embedding_model.layers.22.self_attn.q_proj.weight": "model-00005-of-00007.safetensors",
150
+ "embedding_model.layers.22.self_attn.v_proj.weight": "model-00005-of-00007.safetensors",
151
+ "embedding_model.layers.23.input_layernorm.weight": "model-00005-of-00007.safetensors",
152
+ "embedding_model.layers.23.mlp.down_proj.weight": "model-00005-of-00007.safetensors",
153
+ "embedding_model.layers.23.mlp.gate_proj.weight": "model-00005-of-00007.safetensors",
154
+ "embedding_model.layers.23.mlp.up_proj.weight": "model-00005-of-00007.safetensors",
155
+ "embedding_model.layers.23.post_attention_layernorm.weight": "model-00005-of-00007.safetensors",
156
+ "embedding_model.layers.23.self_attn.k_proj.weight": "model-00005-of-00007.safetensors",
157
+ "embedding_model.layers.23.self_attn.o_proj.weight": "model-00005-of-00007.safetensors",
158
+ "embedding_model.layers.23.self_attn.q_proj.weight": "model-00005-of-00007.safetensors",
159
+ "embedding_model.layers.23.self_attn.v_proj.weight": "model-00005-of-00007.safetensors",
160
+ "embedding_model.layers.24.input_layernorm.weight": "model-00006-of-00007.safetensors",
161
+ "embedding_model.layers.24.mlp.down_proj.weight": "model-00006-of-00007.safetensors",
162
+ "embedding_model.layers.24.mlp.gate_proj.weight": "model-00005-of-00007.safetensors",
163
+ "embedding_model.layers.24.mlp.up_proj.weight": "model-00006-of-00007.safetensors",
164
+ "embedding_model.layers.24.post_attention_layernorm.weight": "model-00006-of-00007.safetensors",
165
+ "embedding_model.layers.24.self_attn.k_proj.weight": "model-00005-of-00007.safetensors",
166
+ "embedding_model.layers.24.self_attn.o_proj.weight": "model-00005-of-00007.safetensors",
167
+ "embedding_model.layers.24.self_attn.q_proj.weight": "model-00005-of-00007.safetensors",
168
+ "embedding_model.layers.24.self_attn.v_proj.weight": "model-00005-of-00007.safetensors",
169
+ "embedding_model.layers.25.input_layernorm.weight": "model-00006-of-00007.safetensors",
170
+ "embedding_model.layers.25.mlp.down_proj.weight": "model-00006-of-00007.safetensors",
171
+ "embedding_model.layers.25.mlp.gate_proj.weight": "model-00006-of-00007.safetensors",
172
+ "embedding_model.layers.25.mlp.up_proj.weight": "model-00006-of-00007.safetensors",
173
+ "embedding_model.layers.25.post_attention_layernorm.weight": "model-00006-of-00007.safetensors",
174
+ "embedding_model.layers.25.self_attn.k_proj.weight": "model-00006-of-00007.safetensors",
175
+ "embedding_model.layers.25.self_attn.o_proj.weight": "model-00006-of-00007.safetensors",
176
+ "embedding_model.layers.25.self_attn.q_proj.weight": "model-00006-of-00007.safetensors",
177
+ "embedding_model.layers.25.self_attn.v_proj.weight": "model-00006-of-00007.safetensors",
178
+ "embedding_model.layers.26.input_layernorm.weight": "model-00006-of-00007.safetensors",
179
+ "embedding_model.layers.26.mlp.down_proj.weight": "model-00006-of-00007.safetensors",
180
+ "embedding_model.layers.26.mlp.gate_proj.weight": "model-00006-of-00007.safetensors",
181
+ "embedding_model.layers.26.mlp.up_proj.weight": "model-00006-of-00007.safetensors",
182
+ "embedding_model.layers.26.post_attention_layernorm.weight": "model-00006-of-00007.safetensors",
183
+ "embedding_model.layers.26.self_attn.k_proj.weight": "model-00006-of-00007.safetensors",
184
+ "embedding_model.layers.26.self_attn.o_proj.weight": "model-00006-of-00007.safetensors",
185
+ "embedding_model.layers.26.self_attn.q_proj.weight": "model-00006-of-00007.safetensors",
186
+ "embedding_model.layers.26.self_attn.v_proj.weight": "model-00006-of-00007.safetensors",
187
+ "embedding_model.layers.27.input_layernorm.weight": "model-00006-of-00007.safetensors",
188
+ "embedding_model.layers.27.mlp.down_proj.weight": "model-00006-of-00007.safetensors",
189
+ "embedding_model.layers.27.mlp.gate_proj.weight": "model-00006-of-00007.safetensors",
190
+ "embedding_model.layers.27.mlp.up_proj.weight": "model-00006-of-00007.safetensors",
191
+ "embedding_model.layers.27.post_attention_layernorm.weight": "model-00006-of-00007.safetensors",
192
+ "embedding_model.layers.27.self_attn.k_proj.weight": "model-00006-of-00007.safetensors",
193
+ "embedding_model.layers.27.self_attn.o_proj.weight": "model-00006-of-00007.safetensors",
194
+ "embedding_model.layers.27.self_attn.q_proj.weight": "model-00006-of-00007.safetensors",
195
+ "embedding_model.layers.27.self_attn.v_proj.weight": "model-00006-of-00007.safetensors",
196
+ "embedding_model.layers.28.input_layernorm.weight": "model-00006-of-00007.safetensors",
197
+ "embedding_model.layers.28.mlp.down_proj.weight": "model-00006-of-00007.safetensors",
198
+ "embedding_model.layers.28.mlp.gate_proj.weight": "model-00006-of-00007.safetensors",
199
+ "embedding_model.layers.28.mlp.up_proj.weight": "model-00006-of-00007.safetensors",
200
+ "embedding_model.layers.28.post_attention_layernorm.weight": "model-00006-of-00007.safetensors",
201
+ "embedding_model.layers.28.self_attn.k_proj.weight": "model-00006-of-00007.safetensors",
202
+ "embedding_model.layers.28.self_attn.o_proj.weight": "model-00006-of-00007.safetensors",
203
+ "embedding_model.layers.28.self_attn.q_proj.weight": "model-00006-of-00007.safetensors",
204
+ "embedding_model.layers.28.self_attn.v_proj.weight": "model-00006-of-00007.safetensors",
205
+ "embedding_model.layers.29.input_layernorm.weight": "model-00006-of-00007.safetensors",
206
+ "embedding_model.layers.29.mlp.down_proj.weight": "model-00006-of-00007.safetensors",
207
+ "embedding_model.layers.29.mlp.gate_proj.weight": "model-00006-of-00007.safetensors",
208
+ "embedding_model.layers.29.mlp.up_proj.weight": "model-00006-of-00007.safetensors",
209
+ "embedding_model.layers.29.post_attention_layernorm.weight": "model-00006-of-00007.safetensors",
210
+ "embedding_model.layers.29.self_attn.k_proj.weight": "model-00006-of-00007.safetensors",
211
+ "embedding_model.layers.29.self_attn.o_proj.weight": "model-00006-of-00007.safetensors",
212
+ "embedding_model.layers.29.self_attn.q_proj.weight": "model-00006-of-00007.safetensors",
213
+ "embedding_model.layers.29.self_attn.v_proj.weight": "model-00006-of-00007.safetensors",
214
+ "embedding_model.layers.3.input_layernorm.weight": "model-00002-of-00007.safetensors",
215
+ "embedding_model.layers.3.mlp.down_proj.weight": "model-00002-of-00007.safetensors",
216
+ "embedding_model.layers.3.mlp.gate_proj.weight": "model-00002-of-00007.safetensors",
217
+ "embedding_model.layers.3.mlp.up_proj.weight": "model-00002-of-00007.safetensors",
218
+ "embedding_model.layers.3.post_attention_layernorm.weight": "model-00002-of-00007.safetensors",
219
+ "embedding_model.layers.3.self_attn.k_proj.weight": "model-00002-of-00007.safetensors",
220
+ "embedding_model.layers.3.self_attn.o_proj.weight": "model-00002-of-00007.safetensors",
221
+ "embedding_model.layers.3.self_attn.q_proj.weight": "model-00002-of-00007.safetensors",
222
+ "embedding_model.layers.3.self_attn.v_proj.weight": "model-00002-of-00007.safetensors",
223
+ "embedding_model.layers.30.input_layernorm.weight": "model-00007-of-00007.safetensors",
224
+ "embedding_model.layers.30.mlp.down_proj.weight": "model-00007-of-00007.safetensors",
225
+ "embedding_model.layers.30.mlp.gate_proj.weight": "model-00007-of-00007.safetensors",
226
+ "embedding_model.layers.30.mlp.up_proj.weight": "model-00007-of-00007.safetensors",
227
+ "embedding_model.layers.30.post_attention_layernorm.weight": "model-00007-of-00007.safetensors",
228
+ "embedding_model.layers.30.self_attn.k_proj.weight": "model-00006-of-00007.safetensors",
229
+ "embedding_model.layers.30.self_attn.o_proj.weight": "model-00006-of-00007.safetensors",
230
+ "embedding_model.layers.30.self_attn.q_proj.weight": "model-00006-of-00007.safetensors",
231
+ "embedding_model.layers.30.self_attn.v_proj.weight": "model-00006-of-00007.safetensors",
232
+ "embedding_model.layers.31.input_layernorm.weight": "model-00007-of-00007.safetensors",
233
+ "embedding_model.layers.31.mlp.down_proj.weight": "model-00007-of-00007.safetensors",
234
+ "embedding_model.layers.31.mlp.gate_proj.weight": "model-00007-of-00007.safetensors",
235
+ "embedding_model.layers.31.mlp.up_proj.weight": "model-00007-of-00007.safetensors",
236
+ "embedding_model.layers.31.post_attention_layernorm.weight": "model-00007-of-00007.safetensors",
237
+ "embedding_model.layers.31.self_attn.k_proj.weight": "model-00007-of-00007.safetensors",
238
+ "embedding_model.layers.31.self_attn.o_proj.weight": "model-00007-of-00007.safetensors",
239
+ "embedding_model.layers.31.self_attn.q_proj.weight": "model-00007-of-00007.safetensors",
240
+ "embedding_model.layers.31.self_attn.v_proj.weight": "model-00007-of-00007.safetensors",
241
+ "embedding_model.layers.4.input_layernorm.weight": "model-00002-of-00007.safetensors",
242
+ "embedding_model.layers.4.mlp.down_proj.weight": "model-00002-of-00007.safetensors",
243
+ "embedding_model.layers.4.mlp.gate_proj.weight": "model-00002-of-00007.safetensors",
244
+ "embedding_model.layers.4.mlp.up_proj.weight": "model-00002-of-00007.safetensors",
245
+ "embedding_model.layers.4.post_attention_layernorm.weight": "model-00002-of-00007.safetensors",
246
+ "embedding_model.layers.4.self_attn.k_proj.weight": "model-00002-of-00007.safetensors",
247
+ "embedding_model.layers.4.self_attn.o_proj.weight": "model-00002-of-00007.safetensors",
248
+ "embedding_model.layers.4.self_attn.q_proj.weight": "model-00002-of-00007.safetensors",
249
+ "embedding_model.layers.4.self_attn.v_proj.weight": "model-00002-of-00007.safetensors",
250
+ "embedding_model.layers.5.input_layernorm.weight": "model-00002-of-00007.safetensors",
251
+ "embedding_model.layers.5.mlp.down_proj.weight": "model-00002-of-00007.safetensors",
252
+ "embedding_model.layers.5.mlp.gate_proj.weight": "model-00002-of-00007.safetensors",
253
+ "embedding_model.layers.5.mlp.up_proj.weight": "model-00002-of-00007.safetensors",
254
+ "embedding_model.layers.5.post_attention_layernorm.weight": "model-00002-of-00007.safetensors",
255
+ "embedding_model.layers.5.self_attn.k_proj.weight": "model-00002-of-00007.safetensors",
256
+ "embedding_model.layers.5.self_attn.o_proj.weight": "model-00002-of-00007.safetensors",
257
+ "embedding_model.layers.5.self_attn.q_proj.weight": "model-00002-of-00007.safetensors",
258
+ "embedding_model.layers.5.self_attn.v_proj.weight": "model-00002-of-00007.safetensors",
259
+ "embedding_model.layers.6.input_layernorm.weight": "model-00002-of-00007.safetensors",
260
+ "embedding_model.layers.6.mlp.down_proj.weight": "model-00002-of-00007.safetensors",
261
+ "embedding_model.layers.6.mlp.gate_proj.weight": "model-00002-of-00007.safetensors",
262
+ "embedding_model.layers.6.mlp.up_proj.weight": "model-00002-of-00007.safetensors",
263
+ "embedding_model.layers.6.post_attention_layernorm.weight": "model-00002-of-00007.safetensors",
264
+ "embedding_model.layers.6.self_attn.k_proj.weight": "model-00002-of-00007.safetensors",
265
+ "embedding_model.layers.6.self_attn.o_proj.weight": "model-00002-of-00007.safetensors",
266
+ "embedding_model.layers.6.self_attn.q_proj.weight": "model-00002-of-00007.safetensors",
267
+ "embedding_model.layers.6.self_attn.v_proj.weight": "model-00002-of-00007.safetensors",
268
+ "embedding_model.layers.7.input_layernorm.weight": "model-00003-of-00007.safetensors",
269
+ "embedding_model.layers.7.mlp.down_proj.weight": "model-00003-of-00007.safetensors",
270
+ "embedding_model.layers.7.mlp.gate_proj.weight": "model-00002-of-00007.safetensors",
271
+ "embedding_model.layers.7.mlp.up_proj.weight": "model-00003-of-00007.safetensors",
272
+ "embedding_model.layers.7.post_attention_layernorm.weight": "model-00003-of-00007.safetensors",
273
+ "embedding_model.layers.7.self_attn.k_proj.weight": "model-00002-of-00007.safetensors",
274
+ "embedding_model.layers.7.self_attn.o_proj.weight": "model-00002-of-00007.safetensors",
275
+ "embedding_model.layers.7.self_attn.q_proj.weight": "model-00002-of-00007.safetensors",
276
+ "embedding_model.layers.7.self_attn.v_proj.weight": "model-00002-of-00007.safetensors",
277
+ "embedding_model.layers.8.input_layernorm.weight": "model-00003-of-00007.safetensors",
278
+ "embedding_model.layers.8.mlp.down_proj.weight": "model-00003-of-00007.safetensors",
279
+ "embedding_model.layers.8.mlp.gate_proj.weight": "model-00003-of-00007.safetensors",
280
+ "embedding_model.layers.8.mlp.up_proj.weight": "model-00003-of-00007.safetensors",
281
+ "embedding_model.layers.8.post_attention_layernorm.weight": "model-00003-of-00007.safetensors",
282
+ "embedding_model.layers.8.self_attn.k_proj.weight": "model-00003-of-00007.safetensors",
283
+ "embedding_model.layers.8.self_attn.o_proj.weight": "model-00003-of-00007.safetensors",
284
+ "embedding_model.layers.8.self_attn.q_proj.weight": "model-00003-of-00007.safetensors",
285
+ "embedding_model.layers.8.self_attn.v_proj.weight": "model-00003-of-00007.safetensors",
286
+ "embedding_model.layers.9.input_layernorm.weight": "model-00003-of-00007.safetensors",
287
+ "embedding_model.layers.9.mlp.down_proj.weight": "model-00003-of-00007.safetensors",
288
+ "embedding_model.layers.9.mlp.gate_proj.weight": "model-00003-of-00007.safetensors",
289
+ "embedding_model.layers.9.mlp.up_proj.weight": "model-00003-of-00007.safetensors",
290
+ "embedding_model.layers.9.post_attention_layernorm.weight": "model-00003-of-00007.safetensors",
291
+ "embedding_model.layers.9.self_attn.k_proj.weight": "model-00003-of-00007.safetensors",
292
+ "embedding_model.layers.9.self_attn.o_proj.weight": "model-00003-of-00007.safetensors",
293
+ "embedding_model.layers.9.self_attn.q_proj.weight": "model-00003-of-00007.safetensors",
294
+ "embedding_model.layers.9.self_attn.v_proj.weight": "model-00003-of-00007.safetensors",
295
+ "embedding_model.norm.weight": "model-00007-of-00007.safetensors",
296
+ "latent_attention_model.cross_attend_blocks.0.fn.to_kv.weight": "model-00001-of-00007.safetensors",
297
+ "latent_attention_model.cross_attend_blocks.0.fn.to_out.weight": "model-00001-of-00007.safetensors",
298
+ "latent_attention_model.cross_attend_blocks.0.fn.to_q.weight": "model-00001-of-00007.safetensors",
299
+ "latent_attention_model.cross_attend_blocks.0.norm.bias": "model-00001-of-00007.safetensors",
300
+ "latent_attention_model.cross_attend_blocks.0.norm.weight": "model-00001-of-00007.safetensors",
301
+ "latent_attention_model.cross_attend_blocks.0.norm_context.bias": "model-00001-of-00007.safetensors",
302
+ "latent_attention_model.cross_attend_blocks.0.norm_context.weight": "model-00001-of-00007.safetensors",
303
+ "latent_attention_model.cross_attend_blocks.1.fn.net.0.bias": "model-00001-of-00007.safetensors",
304
+ "latent_attention_model.cross_attend_blocks.1.fn.net.0.weight": "model-00001-of-00007.safetensors",
305
+ "latent_attention_model.cross_attend_blocks.1.fn.net.2.bias": "model-00001-of-00007.safetensors",
306
+ "latent_attention_model.cross_attend_blocks.1.fn.net.2.weight": "model-00001-of-00007.safetensors",
307
+ "latent_attention_model.cross_attend_blocks.1.norm.bias": "model-00001-of-00007.safetensors",
308
+ "latent_attention_model.cross_attend_blocks.1.norm.weight": "model-00001-of-00007.safetensors",
309
+ "latent_attention_model.latents": "model-00001-of-00007.safetensors"
310
+ }
311
+ }
model_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef1b9796736d41a9e119be11274e67fd39c6b7d4ffbbc3208265509578f675a6
3
+ size 1068
modeling_nvembed.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union, Dict, Mapping, Optional, Tuple, TypedDict
2
+ import torch
3
+ import os
4
+ import json
5
+ import numpy as np
6
+ from functools import partial
7
+ from contextlib import nullcontext
8
+ from transformers import AutoModel, PreTrainedTokenizerFast, BatchEncoding, DataCollatorWithPadding
9
+ from transformers.modeling_utils import PreTrainedModel
10
+ from transformers.models.auto import AutoTokenizer
11
+ from transformers.models.mistral.modeling_mistral import MISTRAL_INPUTS_DOCSTRING
12
+ from transformers.modeling_outputs import BaseModelOutputWithPast
13
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
14
+ from transformers import MistralModel, MistralConfig
15
+ from transformers.cache_utils import Cache, DynamicCache
16
+ from transformers.utils import (
17
+ add_start_docstrings_to_model_forward,
18
+ logging,
19
+ )
20
+ from einops import rearrange, repeat
21
+ from tqdm.auto import tqdm
22
+ from datasets import Dataset
23
+ from torch.utils.data import DataLoader
24
+ from .configuration_nvembed import NVEmbedConfig, LatentAttentionConfig, BidirectionalMistralConfig
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+ class NVEmbedFeatures(TypedDict):
29
+ input_dict: torch.Tensor
30
+ attention_mask: torch.Tensor
31
+ pool_mask: torch.Tensor
32
+
33
+ class BidirectionalMistralModel(MistralModel):
34
+ config_class = BidirectionalMistralConfig
35
+
36
+ def __init__(self, config: MistralConfig):
37
+ super().__init__(config)
38
+ for layer in self.layers:
39
+ layer.self_attn.is_causal = False
40
+ self._attn_implementation = "eager"
41
+
42
+ @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
43
+ def forward(
44
+ self,
45
+ input_ids: torch.LongTensor = None,
46
+ attention_mask: Optional[torch.Tensor] = None,
47
+ position_ids: Optional[torch.LongTensor] = None,
48
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
49
+ inputs_embeds: Optional[torch.FloatTensor] = None,
50
+ use_cache: Optional[bool] = None,
51
+ output_attentions: Optional[bool] = None,
52
+ output_hidden_states: Optional[bool] = None,
53
+ return_dict: Optional[bool] = None,
54
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
55
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
56
+ output_hidden_states = (
57
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
58
+ )
59
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
60
+
61
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
62
+
63
+ # retrieve input_ids and inputs_embeds
64
+ if input_ids is not None and inputs_embeds is not None:
65
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
66
+ elif input_ids is not None:
67
+ batch_size, seq_length = input_ids.shape
68
+ elif inputs_embeds is not None:
69
+ batch_size, seq_length, _ = inputs_embeds.shape
70
+ else:
71
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
72
+
73
+ if self.gradient_checkpointing and self.training:
74
+ if use_cache:
75
+ logger.warning_once(
76
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
77
+ )
78
+ use_cache = False
79
+
80
+ past_key_values_length = 0
81
+
82
+ if use_cache:
83
+ use_legacy_cache = not isinstance(past_key_values, Cache)
84
+ if use_legacy_cache:
85
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
86
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
87
+
88
+ if position_ids is None:
89
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
90
+ position_ids = torch.arange(
91
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
92
+ )
93
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
94
+ else:
95
+ position_ids = position_ids.view(-1, seq_length).long()
96
+
97
+ if inputs_embeds is None:
98
+ inputs_embeds = self.embed_tokens(input_ids)
99
+
100
+ if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
101
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
102
+ if is_padding_right:
103
+ raise ValueError(
104
+ "You are attempting to perform batched generation with padding_side='right'"
105
+ " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
106
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
107
+ )
108
+
109
+ if self._attn_implementation == "flash_attention_2":
110
+ # 2d mask is passed through the layers
111
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
112
+ elif self._attn_implementation == "sdpa" and not output_attentions:
113
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
114
+ # the manual implementation that requires a 4D causal mask in all cases.
115
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(
116
+ attention_mask, inputs_embeds.dtype
117
+ )
118
+ else:
119
+ # 4d mask is passed through the layers
120
+ attention_mask = _prepare_4d_attention_mask(
121
+ attention_mask, inputs_embeds.dtype,
122
+ )
123
+
124
+ hidden_states = inputs_embeds
125
+
126
+ # decoder layers
127
+ all_hidden_states = () if output_hidden_states else None
128
+ all_self_attns = () if output_attentions else None
129
+ next_decoder_cache = None
130
+
131
+ for decoder_layer in self.layers:
132
+ if output_hidden_states:
133
+ all_hidden_states += (hidden_states,)
134
+
135
+ if self.gradient_checkpointing and self.training:
136
+ layer_outputs = self._gradient_checkpointing_func(
137
+ decoder_layer.__call__,
138
+ hidden_states,
139
+ attention_mask,
140
+ position_ids,
141
+ past_key_values,
142
+ output_attentions,
143
+ use_cache,
144
+ )
145
+ else:
146
+ layer_outputs = decoder_layer(
147
+ hidden_states,
148
+ attention_mask=attention_mask,
149
+ position_ids=position_ids,
150
+ past_key_value=past_key_values,
151
+ output_attentions=output_attentions,
152
+ use_cache=use_cache,
153
+ )
154
+
155
+ hidden_states = layer_outputs[0]
156
+
157
+ if use_cache:
158
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
159
+
160
+ if output_attentions:
161
+ all_self_attns += (layer_outputs[1],)
162
+
163
+ hidden_states = self.norm(hidden_states)
164
+
165
+ # add hidden states from the last decoder layer
166
+ if output_hidden_states:
167
+ all_hidden_states += (hidden_states,)
168
+
169
+ next_cache = None
170
+ if use_cache:
171
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
172
+
173
+ if not return_dict:
174
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
175
+ return BaseModelOutputWithPast(
176
+ last_hidden_state=hidden_states,
177
+ past_key_values=next_cache,
178
+ hidden_states=all_hidden_states,
179
+ attentions=all_self_attns,
180
+ )
181
+
182
+ def _move_to_device(maybe_tensor, device: torch.device):
183
+ if torch.is_tensor(maybe_tensor):
184
+ return maybe_tensor.to(device, non_blocking=device.type == "cuda")
185
+ elif isinstance(maybe_tensor, dict):
186
+ return {key: _move_to_device(value, device) for key, value in maybe_tensor.items()}
187
+ elif isinstance(maybe_tensor, list):
188
+ return [_move_to_device(x, device) for x in maybe_tensor]
189
+ elif isinstance(maybe_tensor, tuple):
190
+ return tuple([_move_to_device(x, device) for x in maybe_tensor])
191
+ elif isinstance(maybe_tensor, Mapping):
192
+ return type(maybe_tensor)({k: _move_to_device(v, device) for k, v in maybe_tensor.items()})
193
+ else:
194
+ return maybe_tensor
195
+
196
+ def move_to_device(sample, device: torch.device):
197
+ if device.type == "cpu":
198
+ return sample
199
+
200
+ if len(sample) == 0:
201
+ return {}
202
+ return _move_to_device(sample, device)
203
+
204
+
205
+ def input_transform_func(
206
+ tokenizer: PreTrainedTokenizerFast,
207
+ examples: Dict[str, List],
208
+ always_add_eos: bool,
209
+ max_length: int,
210
+ instruction: str,
211
+ ) -> BatchEncoding:
212
+ if always_add_eos:
213
+ examples['input_texts'] = [instruction + input_example + tokenizer.eos_token for input_example in examples['input_texts']]
214
+ batch_dict = tokenizer(
215
+ examples['input_texts'],
216
+ max_length=max_length,
217
+ padding=True,
218
+ return_token_type_ids=False,
219
+ return_tensors="pt",
220
+ truncation=True)
221
+ return batch_dict
222
+
223
+
224
+ class PreNorm(torch.nn.Module):
225
+ def __init__(self, dim, fn, context_dim = None):
226
+ super().__init__()
227
+ self.fn = fn
228
+ self.norm = torch.nn.LayerNorm(dim)
229
+ self.norm_context = torch.nn.LayerNorm(context_dim) if exists(context_dim) else None
230
+
231
+ def forward(self, x, **kwargs):
232
+ x = self.norm(x)
233
+ if exists(self.norm_context):
234
+ context = kwargs['context']
235
+ normed_context = self.norm_context(context)
236
+ kwargs.update(context = normed_context)
237
+ return self.fn(x, **kwargs)
238
+
239
+ class GEGLU(torch.nn.Module):
240
+ def forward(self, x):
241
+ x, gates = x.chunk(2, dim = -1)
242
+ return x * torch.nn.functional.gelu(gates)
243
+
244
+ class FeedForward(torch.nn.Module):
245
+ def __init__(self, dim, mult = 4):
246
+ super().__init__()
247
+ self.net = torch.nn.Sequential(torch.nn.Linear(dim, dim * mult * 2),
248
+ GEGLU(),
249
+ torch.nn.Linear(dim * mult, dim))
250
+
251
+ def forward(self, x):
252
+ return self.net(x)
253
+
254
+ def exists(val):
255
+ return val is not None
256
+
257
+ def default(val, d):
258
+ return val if exists(val) else d
259
+
260
+
261
+ class Attention(torch.nn.Module):
262
+ def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64):
263
+ super().__init__()
264
+ inner_dim = dim_head * heads
265
+ context_dim = default(context_dim, query_dim)
266
+ self.scale = dim_head ** -0.5
267
+ self.heads = heads
268
+
269
+ self.to_q = torch.nn.Linear(query_dim, inner_dim, bias = False)
270
+ self.to_kv = torch.nn.Linear(context_dim, inner_dim * 2, bias = False)
271
+ self.to_out = torch.nn.Linear(inner_dim, query_dim, bias = False)
272
+
273
+ def forward(self, x, context = None, mask = None):
274
+ h = self.heads
275
+ q = self.to_q(x)
276
+ context = default(context, x)
277
+ k, v = self.to_kv(context).chunk(2, dim = -1)
278
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
279
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=True):
280
+ out = torch.nn.functional.scaled_dot_product_attention(q, k, v)
281
+ out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
282
+ return self.to_out(out)
283
+
284
+
285
+ class LatentAttentionModel(PreTrainedModel):
286
+ config_class = LatentAttentionConfig
287
+
288
+ def __init__(self, config: LatentAttentionConfig):
289
+ super().__init__(config)
290
+ ## cross-attention block
291
+ num_latents, latent_dim, cross_heads, cross_dim_head = config.num_latents_value, config.latent_dim, config.num_cross_heads, config.cross_dim_head
292
+ dim = config.hidden_dim
293
+ # init latent_attention and latents
294
+ self.cross_attend_blocks = torch.nn.ModuleList([
295
+ PreNorm(latent_dim, Attention(latent_dim, dim, heads = cross_heads, dim_head = cross_dim_head),
296
+ context_dim = dim),
297
+ PreNorm(latent_dim, FeedForward(latent_dim)),
298
+ ])
299
+ self.output_normalize = config.output_normalize
300
+ self.register_parameter("latents", torch.nn.Parameter(torch.randn(num_latents, latent_dim)))
301
+
302
+ def forward(self, hiddens, attention_mask: torch.Tensor=None):
303
+ ## cross-attention block
304
+ cross_attn, cross_ff = self.cross_attend_blocks
305
+ b, *_, device = *hiddens.shape, hiddens.device
306
+ x = repeat(self.latents, 'n d -> b n d', b = b)
307
+ hiddens = cross_attn(hiddens, context = x, mask = None) + hiddens
308
+ hiddens = cross_ff(hiddens) + hiddens
309
+ if attention_mask !=None:
310
+ s = torch.sum(hiddens * attention_mask.unsqueeze(-1).float(), dim=1)
311
+ d = attention_mask.sum(dim=1, keepdim=True).float()
312
+ hiddens = s / d
313
+ if self.output_normalize:
314
+ hiddens = torch.nn.functional.normalize(hiddens, p=2, dim=-1)
315
+ return hiddens
316
+
317
+ class NVEmbedModel(PreTrainedModel):
318
+ config_class = NVEmbedConfig
319
+ _no_split_modules = ["MistralDecoderLayer", "LatentAttentionModel"]
320
+
321
+ def __init__(self, config: NVEmbedConfig):
322
+ super().__init__(config)
323
+ self.latent_attention_model = AutoModel.from_config(config.latent_attention_config)
324
+ self.embedding_model = AutoModel.from_config(
325
+ config.text_config,
326
+ ) if config.text_config is not None else None
327
+ self.tokenizer = AutoTokenizer.from_pretrained(config.text_config._name_or_path) if config.text_config is not None else None
328
+ self.padding_side = config.padding_side
329
+ self.is_mask_instruction = config.is_mask_instruction
330
+ self.add_eos = config.add_eos
331
+ self.mask_type = config.mask_type
332
+ if config.add_pad_token and self.tokenizer is not None:
333
+ self.add_pad_token()
334
+
335
+ def add_pad_token(self):
336
+ self.tokenizer.pad_token = self.tokenizer.eos_token
337
+ self.tokenizer.padding_side = self.padding_side
338
+
339
+ def prepare_kwargs_from_batch(self, batch_dict: dict, instruction_lens: int, device: torch.device):
340
+ batch_dict = move_to_device(batch_dict, device)
341
+ attention_mask = batch_dict['attention_mask'].clone() if 'attention_mask' in batch_dict else None
342
+ if (attention_mask is not None and
343
+ self.padding_side == "right" and
344
+ self.is_mask_instruction == True and
345
+ instruction_lens > 0):
346
+ # Mask out the instruction tokens for mean-pooling
347
+ attention_mask[:, :instruction_lens] = 0
348
+ features: NVEmbedFeatures = {
349
+ 'input_ids': torch.tensor(batch_dict.get('input_ids').to(batch_dict.get('input_ids')).long()),
350
+ 'attention_mask': batch_dict['attention_mask'],
351
+ 'pool_mask': attention_mask,
352
+ }
353
+ return features
354
+
355
+ @torch.no_grad()
356
+ def _do_encode(self,
357
+ prompts: List[str],
358
+ batch_size: int=1,
359
+ instruction: str="",
360
+ max_length: int=4096,
361
+ num_workers: int=32,
362
+ **kwargs
363
+ ) -> Union[np.ndarray, torch.FloatTensor]:
364
+ dataset: Dataset = Dataset.from_dict({'input_texts': prompts})
365
+ dataset.set_transform(partial(input_transform_func,
366
+ self.tokenizer,
367
+ always_add_eos=True,
368
+ max_length=max_length,
369
+ instruction=instruction))
370
+
371
+ data_collator = DataCollatorWithPadding(self.tokenizer)
372
+ data_loader = DataLoader(
373
+ dataset,
374
+ batch_size=batch_size,
375
+ shuffle=False,
376
+ drop_last=False,
377
+ num_workers=num_workers,
378
+ collate_fn=data_collator,
379
+ pin_memory=True)
380
+
381
+ if self.padding_side == "right" and self.is_mask_instruction == True and len(instruction) > 0:
382
+ instruction_lens = len(self.tokenizer.tokenize(instruction))
383
+ else:
384
+ instruction_lens = 0
385
+
386
+ encoded_embeds = []
387
+ device = next(self.embedding_model.parameters()).device
388
+ for batch_dict in tqdm(data_loader, desc='encoding', mininterval=10):
389
+ features = self.prepare_kwargs_from_batch(batch_dict, instruction_lens, device=device)
390
+ embeds=self(**features)["sentence_embeddings"].squeeze(1)
391
+ encoded_embeds.append(embeds)
392
+ encoded_embeds = torch.cat(encoded_embeds, axis=0)
393
+ if "return_numpy" in kwargs and kwargs.get("return_numpy"):
394
+ encoded_embeds = encoded_embeds.cpu().detach().numpy()
395
+ return encoded_embeds
396
+
397
+ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, pool_mask: Optional[torch.Tensor]=None, return_dict: bool=True):
398
+ autocast_ctx = torch.autocast if torch.cuda.is_available() else nullcontext
399
+ with autocast_ctx("cuda"):
400
+ ## decoder only layer
401
+ outputs = self.embedding_model(
402
+ input_ids=input_ids,
403
+ attention_mask=attention_mask,
404
+ )
405
+ ## latent attention layer
406
+ embeds = self.latent_attention_model(
407
+ outputs.last_hidden_state,
408
+ pool_mask,
409
+ )
410
+ if not return_dict:
411
+ return (embeds,)
412
+ return {"sentence_embeddings": embeds}
413
+
414
+
415
+ @torch.no_grad()
416
+ def encode(self, prompts: List[str], instruction: str="", max_length: int=4096, **kwargs):
417
+ if self.padding_side == "right" and self.is_mask_instruction == True and len(instruction) > 0:
418
+ instruction_lens = len(self.tokenizer.tokenize(instruction))
419
+ else:
420
+ instruction_lens = 0
421
+
422
+ device = next(self.embedding_model.parameters()).device
423
+ batch_dict = input_transform_func(self.tokenizer,
424
+ {"input_texts": [prompt for prompt in prompts]},
425
+ always_add_eos=True,
426
+ max_length=max_length,
427
+ instruction=instruction)
428
+
429
+ features: NVEmbedFeatures = self.prepare_kwargs_from_batch(batch_dict, instruction_lens, device=device)
430
+ return self(**features)["sentence_embeddings"].squeeze(1)
431
+
432
+
433
+ ## AutoModel Register
434
+ AutoModel.register(NVEmbedConfig, NVEmbedModel)
435
+ AutoModel.register(LatentAttentionConfig, LatentAttentionModel)
436
+ AutoModel.register(BidirectionalMistralConfig, BidirectionalMistralModel)
437
+
438
+ ## Register for auto class
439
+ NVEmbedModel.register_for_auto_class("AutoModel")
440
+ LatentAttentionModel.register_for_auto_class("AutoModel")
441
+ BidirectionalMistralModel.register_for_auto_class("AutoModel")
special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "</s>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<unk>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dadfd56d766715c61d2ef780a525ab43b8e6da4de6865bda3d95fdef5e134055
3
+ size 493443
tokenizer_config.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<unk>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ }
29
+ },
30
+ "additional_special_tokens": [],
31
+ "bos_token": "<s>",
32
+ "clean_up_tokenization_spaces": false,
33
+ "eos_token": "</s>",
34
+ "model_max_length": 1000000000000000019884624838656,
35
+ "pad_token": "</s>",
36
+ "sp_model_kwargs": {},
37
+ "spaces_between_special_tokens": false,
38
+ "tokenizer_class": "LlamaTokenizer",
39
+ "unk_token": "<unk>",
40
+ "use_default_system_prompt": false
41
+ }
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6671918340699bcf27aab69a7decf1909736dd1b9d933ddf9347a7168ac167c3
3
+ size 5112