izhx Samoed commited on
Commit
c937797
·
verified ·
1 Parent(s): 0e7360a

Integrate sentence transformers (#9)

Browse files

- Base Integration with SentenceTransformers (2df56dc77ce0fbca90c338d38d2cffb4de4c9ea0)
- Update custom_st.py (f70ad8d07b0f89da5c767bc57d3453b8923938ac)
- Update README.md (27b4e411bf08eee10f4bd27941807530cce3099f)
- Update README.md (5b4def35a3fdbbe9c547cd2fce25896feb492d97)
- Update README.md (5ff08124217cda0a6a0a48ab93ada3aa2ac0a1c8)
- Update README.md (4f9008b5898196357c0cc9c767d767af1235cc32)


Co-authored-by: Solomatin Roman <Samoed@users.noreply.huggingface.co>

1_Pooling/config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "word_embedding_dimension": 1536,
3
+ "pooling_mode_cls_token": false,
4
+ "pooling_mode_mean_tokens": false,
5
+ "pooling_mode_max_tokens": false,
6
+ "pooling_mode_mean_sqrt_len_tokens": false,
7
+ "pooling_mode_weightedmean_tokens": false,
8
+ "pooling_mode_lasttoken": true,
9
+ "include_prompt": true
10
+ }
README.md CHANGED
@@ -3692,46 +3692,90 @@ The `GME` models support three types of input: **text**, **image**, and **image-
3692
  |[`gme-Qwen2-VL-7B`](https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-7B-Instruct) | 8.29B | 32768 | 3584 | 67.48 | 69.73 | 67.44 |
3693
 
3694
  ## Usage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3695
  **Use with custom code**
3696
 
3697
  ```python
3698
  # You can find the script gme_inference.py in https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct/blob/main/gme_inference.py
3699
  from gme_inference import GmeQwen2VL
3700
 
 
3701
  texts = [
3702
- "What kind of car is this?",
3703
- "The Tesla Cybertruck is a battery electric pickup truck built by Tesla, Inc. since 2023."
3704
  ]
3705
  images = [
3706
- 'https://en.wikipedia.org/wiki/File:Tesla_Cybertruck_damaged_window.jpg',
3707
- 'https://en.wikipedia.org/wiki/File:2024_Tesla_Cybertruck_Foundation_Series,_front_left_(Greenwich).jpg',
3708
  ]
3709
 
 
3710
  gme = GmeQwen2VL("Alibaba-NLP/gme-Qwen2-VL-2B-Instruct")
3711
 
3712
  # Single-modal embedding
3713
  e_text = gme.get_text_embeddings(texts=texts)
3714
  e_image = gme.get_image_embeddings(images=images)
3715
- print((e_text * e_image).sum(-1))
3716
- ## tensor([0.2281, 0.6001], dtype=torch.float16)
3717
 
3718
  # How to set embedding instruction
3719
- e_query = gme.get_text_embeddings(texts=texts, instruction='Find an image that matches the given text.')
3720
  # If is_query=False, we always use the default instruction.
3721
  e_corpus = gme.get_image_embeddings(images=images, is_query=False)
3722
- print((e_query * e_corpus).sum(-1))
3723
- ## tensor([0.2433, 0.7051], dtype=torch.float16)
3724
 
3725
  # Fused-modal embedding
3726
  e_fused = gme.get_fused_embeddings(texts=texts, images=images)
3727
- print((e_fused[0] * e_fused[1]).sum())
3728
- ## tensor(0.6108, dtype=torch.float16)
3729
-
3730
  ```
3731
 
3732
  ## Evaluation
3733
 
3734
- We validated the performance on our universal multimodal retrieval benchmark (**UMRB**) among others.
3735
 
3736
  | | | Single-modal | | Cross-modal | | | Fused-modal | | | | Avg. |
3737
  |--------------------|------|:------------:|:---------:|:-----------:|:-----------:|:---------:|:-----------:|:----------:|:----------:|:-----------:|:----------:|
 
3692
  |[`gme-Qwen2-VL-7B`](https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-7B-Instruct) | 8.29B | 32768 | 3584 | 67.48 | 69.73 | 67.44 |
3693
 
3694
  ## Usage
3695
+ **Use with sentence_transformers**
3696
+
3697
+ The `encode` function accept `str` or `dict` with key(s) in `{'text', 'image', 'prompt'}`.
3698
+
3699
+ **Do not pass `prompt` as the argument to `encode`**, pass as the input as a `dict` with a `prompt` key.
3700
+
3701
+ ```python
3702
+ from sentence_transformers import SentenceTransformer
3703
+
3704
+
3705
+ t2i_prompt = 'Find an image that matches the given text.'
3706
+ texts = [
3707
+ "The Tesla Cybertruck is a battery electric pickup truck built by Tesla, Inc. since 2023.",
3708
+ "Alibaba office.",
3709
+ ]
3710
+ images = [
3711
+ 'https://upload.wikimedia.org/wikipedia/commons/e/e9/Tesla_Cybertruck_damaged_window.jpg',
3712
+ 'https://upload.wikimedia.org/wikipedia/commons/e/e0/TaobaoCity_Alibaba_Xixi_Park.jpg',
3713
+ ]
3714
+
3715
+
3716
+ gme_st = SentenceTransformer("Alibaba-NLP/gme-Qwen2-VL-2B-Instruct")
3717
+
3718
+ # Single-modal embedding
3719
+ e_text = gme_st.encode(texts, convert_to_tensor=True)
3720
+ e_image = gme_st.encode([dict(image=i) for i in images], convert_to_tensor=True)
3721
+ print('Single-modal', (e_text @ e_image.T).tolist())
3722
+ ## Single-modal [[0.356201171875, 0.06536865234375], [0.041717529296875, 0.37890625]]
3723
+
3724
+ # How to set embedding instruction
3725
+ e_query = gme_st.encode([dict(text=t, prompt=t2i_prompt) for t in texts], convert_to_tensor=True)
3726
+ # If no prompt, we always use the default instruction.
3727
+ e_corpus = gme_st.encode([dict(image=i) for i in images], convert_to_tensor=True)
3728
+ print('Single-modal with instruction', (e_query @ e_corpus.T).tolist())
3729
+ ## Single-modal with instruction [[0.425537109375, 0.1158447265625], [0.049835205078125, 0.413818359375]]
3730
+
3731
+ # Fused-modal embedding
3732
+ e_fused = gme_st.encode([dict(text=t, image=i) for t, i in zip(texts, images)], convert_to_tensor=True)
3733
+ print('Fused-modal', (e_fused @ e_fused.T).tolist())
3734
+ ## Fused-modal [[0.99951171875, 0.0556640625], [0.0556640625, 0.99951171875]]
3735
+ ```
3736
+
3737
+
3738
  **Use with custom code**
3739
 
3740
  ```python
3741
  # You can find the script gme_inference.py in https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct/blob/main/gme_inference.py
3742
  from gme_inference import GmeQwen2VL
3743
 
3744
+ t2i_prompt = 'Find an image that matches the given text.'
3745
  texts = [
3746
+ "The Tesla Cybertruck is a battery electric pickup truck built by Tesla, Inc. since 2023.",
3747
+ "Alibaba office.",
3748
  ]
3749
  images = [
3750
+ 'https://upload.wikimedia.org/wikipedia/commons/e/e9/Tesla_Cybertruck_damaged_window.jpg',
3751
+ 'https://upload.wikimedia.org/wikipedia/commons/e/e0/TaobaoCity_Alibaba_Xixi_Park.jpg',
3752
  ]
3753
 
3754
+
3755
  gme = GmeQwen2VL("Alibaba-NLP/gme-Qwen2-VL-2B-Instruct")
3756
 
3757
  # Single-modal embedding
3758
  e_text = gme.get_text_embeddings(texts=texts)
3759
  e_image = gme.get_image_embeddings(images=images)
3760
+ print('Single-modal', (e_text @ e_image.T).tolist())
3761
+ ## [[0.359619140625, 0.0655517578125], [0.04180908203125, 0.374755859375]]
3762
 
3763
  # How to set embedding instruction
3764
+ e_query = gme.get_text_embeddings(texts=texts, instruction=t2i_prompt)
3765
  # If is_query=False, we always use the default instruction.
3766
  e_corpus = gme.get_image_embeddings(images=images, is_query=False)
3767
+ print('Single-modal with instruction', (e_query @ e_corpus.T).tolist())
3768
+ ## [[0.429931640625, 0.11505126953125], [0.049835205078125, 0.409423828125]]
3769
 
3770
  # Fused-modal embedding
3771
  e_fused = gme.get_fused_embeddings(texts=texts, images=images)
3772
+ print('Fused-modal', (e_fused @ e_fused.T).tolist())
3773
+ ## [[1.0, 0.05511474609375], [0.05511474609375, 1.0]]
 
3774
  ```
3775
 
3776
  ## Evaluation
3777
 
3778
+ We validated the performance on our universal multimodal retrieval benchmark (**UMRB**, see [Release UMRB](https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-7B-Instruct/discussions/2)) among others.
3779
 
3780
  | | | Single-modal | | Cross-modal | | | Fused-modal | | | | Avg. |
3781
  |--------------------|------|:------------:|:---------:|:-----------:|:-----------:|:---------:|:-----------:|:----------:|:----------:|:-----------:|:----------:|
config_sentence_transformers.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompts": {
3
+ "query": "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
4
+ },
5
+ "default_prompt_name": null,
6
+ "similarity_fn_name": null
7
+ }
custom_st.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ from typing import Any, Dict, Optional, List
3
+ import torch
4
+ from PIL import Image
5
+ from sentence_transformers.models import Transformer as BaseTransformer
6
+ from transformers import AutoModelForVision2Seq, AutoProcessor
7
+
8
+
9
+ class MultiModalTransformer(BaseTransformer):
10
+ def __init__(
11
+ self,
12
+ model_name_or_path: str,
13
+ cache_dir: Optional[str] = None,
14
+ tokenizer_args: Optional[Dict[str, Any]] = None,
15
+ min_image_tokens: int = 256,
16
+ max_image_tokens: int = 1280,
17
+ max_length: int = 1800,
18
+ **kwargs,
19
+ ):
20
+ super().__init__(model_name_or_path, **kwargs)
21
+ if tokenizer_args is None:
22
+ tokenizer_args = {}
23
+ tokenizer_args.pop("trust_remote_code", None)
24
+
25
+ # Initialize processor
26
+ min_pixels = min_image_tokens * 28 * 28
27
+ max_pixels = max_image_tokens * 28 * 28
28
+ self.processor = AutoProcessor.from_pretrained(
29
+ model_name_or_path, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs
30
+ )
31
+ self.processor.tokenizer.padding_side = 'right'
32
+ self.sep = ' '
33
+ self.max_length = max_length
34
+ self.normalize = True
35
+
36
+ def _load_model(
37
+ self,
38
+ model_name_or_path: str,
39
+ config,
40
+ cache_dir: str,
41
+ backend: str,
42
+ is_peft_model: bool,
43
+ **model_args,
44
+ ) -> None:
45
+ model_args.pop("trust_remote_code", None)
46
+ self.auto_model = AutoModelForVision2Seq.from_pretrained(
47
+ model_name_or_path, torch_dtype=torch.float16, **model_args
48
+ )
49
+
50
+ def forward(
51
+ self, features: Dict[str, torch.Tensor], **kwargs
52
+ ) -> Dict[str, torch.Tensor]:
53
+ if features.get("inputs_embeds", None) is None:
54
+ features["inputs_embeds"] = self.auto_model.base_model.embed_tokens(features["input_ids"])
55
+ if features.get("pixel_values", None) is not None:
56
+ features["pixel_values"] = features["pixel_values"].type(self.auto_model.visual.get_dtype())
57
+ image_embeds = self.auto_model.visual(
58
+ features["pixel_values"], grid_thw=features["image_grid_thw"]
59
+ )
60
+ image_mask = features["input_ids"] == self.auto_model.config.image_token_id
61
+ features["inputs_embeds"][image_mask] = image_embeds
62
+ # features.pop("pixel_values")
63
+ # features.pop("image_grid_thw")
64
+ # features.pop("input_ids")
65
+ inputs = {k: v for k, v in features.items() if k in 'position_ids,attention_mask,inputs_embeds'}
66
+ outputs = self.auto_model.model(
67
+ **inputs,
68
+ return_dict=True,
69
+ output_hidden_states=True,
70
+ # **kwargs
71
+ )
72
+ # pooling_mask = features["attention_mask"] if features.get("pooling_mask", None) is None else features["pooling_mask"]
73
+ # left_padding = (pooling_mask[:, -1].sum() == pooling_mask.shape[0]) # TODO
74
+ # if left_padding:
75
+ # embeddings = outputs.last_hidden_state
76
+ # else:
77
+ # sequence_lengths = pooling_mask.sum(dim=1) - 1
78
+ # embeddings = outputs.last_hidden_state[torch.arange(
79
+ # outputs.last_hidden_state.shape[0], device=outputs.last_hidden_state.device
80
+ # ), sequence_lengths]
81
+ features.update({"token_embeddings": outputs.last_hidden_state})
82
+ return features
83
+
84
+ def tokenize(self, texts: List[List[Dict[str, Any]]] | List[str]) -> Dict[str, torch.Tensor]:
85
+ default_instruction = 'You are a helpful assistant.'
86
+
87
+ all_texts, all_images = list(), list()
88
+ for item in texts:
89
+ if isinstance(item, str):
90
+ txt, img, inst = item, None, default_instruction
91
+ elif isinstance(item, dict):
92
+ txt = item.get('text', None)
93
+ img = item.get('image', None)
94
+ inst = item.get('prompt', default_instruction)
95
+ else:
96
+ raise RuntimeError(f'Input format not supported! {item=}')
97
+
98
+ input_str = ''
99
+ if img is None:
100
+ all_images = None # All examples in the same batch are consistent
101
+ # or will have ValueError: Could not make a flat list of images from xxxx
102
+ else:
103
+ input_str += '<|vision_start|><|image_pad|><|vision_end|>'
104
+ img = fetch_image(img)
105
+ all_images.append(img)
106
+ if txt is not None:
107
+ input_str += txt
108
+ msg = f'<|im_start|>system\n{inst}<|im_end|>\n<|im_start|>user\n{input_str}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>'
109
+ all_texts.append(msg)
110
+
111
+ inputs = self.processor(
112
+ text=all_texts,
113
+ images=all_images,
114
+ padding="longest",
115
+ truncation=True,
116
+ max_length=self.max_seq_length,
117
+ return_tensors='pt'
118
+ )
119
+ return inputs
120
+
121
+
122
+ ### Copied from qwen_vl_utils.vision_process.py
123
+ import base64
124
+ from io import BytesIO
125
+ import requests
126
+
127
+ IMAGE_FACTOR = 28
128
+ MIN_PIXELS = 4 * 28 * 28
129
+ MAX_PIXELS = 16384 * 28 * 28
130
+ MAX_RATIO = 200
131
+
132
+
133
+ def round_by_factor(number: int, factor: int) -> int:
134
+ """Returns the closest integer to 'number' that is divisible by 'factor'."""
135
+ return round(number / factor) * factor
136
+
137
+
138
+ def ceil_by_factor(number: int, factor: int) -> int:
139
+ """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
140
+ return math.ceil(number / factor) * factor
141
+
142
+
143
+ def floor_by_factor(number: int, factor: int) -> int:
144
+ """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
145
+ return math.floor(number / factor) * factor
146
+
147
+
148
+ def smart_resize(
149
+ height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
150
+ ) -> tuple[int, int]:
151
+ """
152
+ Rescales the image so that the following conditions are met:
153
+
154
+ 1. Both dimensions (height and width) are divisible by 'factor'.
155
+
156
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
157
+
158
+ 3. The aspect ratio of the image is maintained as closely as possible.
159
+ """
160
+ h_bar = max(factor, round_by_factor(height, factor))
161
+ w_bar = max(factor, round_by_factor(width, factor))
162
+ if h_bar * w_bar > max_pixels:
163
+ beta = math.sqrt((height * width) / max_pixels)
164
+ h_bar = floor_by_factor(height / beta, factor)
165
+ w_bar = floor_by_factor(width / beta, factor)
166
+ elif h_bar * w_bar < min_pixels:
167
+ beta = math.sqrt(min_pixels / (height * width))
168
+ h_bar = ceil_by_factor(height * beta, factor)
169
+ w_bar = ceil_by_factor(width * beta, factor)
170
+
171
+ if max(h_bar, w_bar) / min(h_bar, w_bar) > MAX_RATIO:
172
+ logging.warning(
173
+ f"Absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(h_bar, w_bar) / min(h_bar, w_bar)}"
174
+ )
175
+ if h_bar > w_bar:
176
+ h_bar = w_bar * MAX_RATIO
177
+ else:
178
+ w_bar = h_bar * MAX_RATIO
179
+ return h_bar, w_bar
180
+
181
+
182
+ def fetch_image(image: str | Image.Image, size_factor: int = IMAGE_FACTOR) -> Image.Image:
183
+ image_obj = None
184
+ if isinstance(image, Image.Image):
185
+ image_obj = image
186
+ elif image.startswith("http://") or image.startswith("https://"):
187
+ image_obj = Image.open(requests.get(image, stream=True).raw)
188
+ elif image.startswith("file://"):
189
+ image_obj = Image.open(image[7:])
190
+ elif image.startswith("data:image"):
191
+ if "base64," in image:
192
+ _, base64_data = image.split("base64,", 1)
193
+ data = base64.b64decode(base64_data)
194
+ image_obj = Image.open(BytesIO(data))
195
+ else:
196
+ image_obj = Image.open(image)
197
+ if image_obj is None:
198
+ raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
199
+ image = image_obj.convert("RGB")
200
+ ## resize
201
+ # if "resized_height" in ele and "resized_width" in ele:
202
+ # resized_height, resized_width = smart_resize(
203
+ # ele["resized_height"],
204
+ # ele["resized_width"],
205
+ # factor=size_factor,
206
+ # )
207
+ # else:
208
+ width, height = image.size
209
+ # min_pixels = ele.get("min_pixels", MIN_PIXELS)
210
+ # max_pixels = ele.get("max_pixels", MAX_PIXELS)
211
+ resized_height, resized_width = smart_resize(
212
+ height,
213
+ width,
214
+ factor=size_factor,
215
+ min_pixels=MIN_PIXELS,
216
+ max_pixels=MAX_PIXELS,
217
+ )
218
+ image = image.resize((resized_width, resized_height))
219
+
220
+ return image
221
+ ###
modules.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "custom_st.MultiModalTransformer"
7
+ },
8
+ {
9
+ "idx": 1,
10
+ "name": "1",
11
+ "path": "1_Pooling",
12
+ "type": "sentence_transformers.models.Pooling"
13
+ },
14
+ {
15
+ "idx": 2,
16
+ "name": "2",
17
+ "path": "2_Normalize",
18
+ "type": "sentence_transformers.models.Normalize"
19
+ }
20
+ ]