izhx commited on
Commit
e54cb53
·
verified ·
1 Parent(s): 8a1422a
Files changed (4) hide show
  1. README.md +12 -0
  2. config.json +6 -3
  3. custom_st.py +1 -1
  4. modeling_gme_qwen2vl.py +40 -16
README.md CHANGED
@@ -3698,7 +3698,19 @@ The `GME` models support three types of input: **text**, **image**, and **image-
3698
 
3699
  **Transformers**
3700
 
 
 
3701
  ```python
 
 
 
 
 
 
 
 
 
 
3702
  t2i_prompt = 'Find an image that matches the given text.'
3703
  texts = [
3704
  "The Tesla Cybertruck is a battery electric pickup truck built by Tesla, Inc. since 2023.",
 
3698
 
3699
  **Transformers**
3700
 
3701
+ The remote code has some issues with `transformers>=4.52.0`, please downgrade or use `sentence_transformers`
3702
+
3703
  ```python
3704
+ from transformers import AutoModel
3705
+ from transformers.utils.versions import require_version
3706
+
3707
+
3708
+ require_version(
3709
+ "transformers<4.52.0",
3710
+ "The remote code has some issues with transformers>=4.52.0, please downgrade: pip install transformers==4.51.3"
3711
+ )
3712
+
3713
+
3714
  t2i_prompt = 'Find an image that matches the given text.'
3715
  texts = [
3716
  "The Tesla Cybertruck is a battery electric pickup truck built by Tesla, Inc. since 2023.",
config.json CHANGED
@@ -1,9 +1,12 @@
1
  {
2
  "_name_or_path": "Alibaba-NLP/gme-Qwen2-VL-7B-Instruct",
3
- "architectures": ["GmeQwen2VLForVision2Seq"],
 
 
 
4
  "auto_map": {
5
- "AutoModel": "modeling_gme_qwen2vl.GmeQwen2VLForVision2Seq",
6
- "AutoConfig": "modeling_gme_qwen2vl.GmeQwen2VLConfig"
7
  },
8
  "attention_dropout": 0.0,
9
  "bos_token_id": 151643,
 
1
  {
2
  "_name_or_path": "Alibaba-NLP/gme-Qwen2-VL-7B-Instruct",
3
+ "architectures": [
4
+ "Qwen2VLForConditionalGeneration",
5
+ "GmeQwen2VL"
6
+ ],
7
  "auto_map": {
8
+ "AutoConfig": "modeling_gme_qwen2vl.GmeQwen2VLConfig",
9
+ "AutoModel": "modeling_gme_qwen2vl.GmeQwen2VL"
10
  },
11
  "attention_dropout": 0.0,
12
  "bos_token_id": 151643,
custom_st.py CHANGED
@@ -51,7 +51,7 @@ class MultiModalTransformer(BaseTransformer):
51
  self, features: Dict[str, torch.Tensor], **kwargs
52
  ) -> Dict[str, torch.Tensor]:
53
  if features.get("inputs_embeds", None) is None:
54
- features["inputs_embeds"] = self.auto_model.base_model.embed_tokens(features["input_ids"])
55
  if features.get("pixel_values", None) is not None:
56
  features["pixel_values"] = features["pixel_values"].type(self.auto_model.visual.get_dtype())
57
  image_embeds = self.auto_model.visual(
 
51
  self, features: Dict[str, torch.Tensor], **kwargs
52
  ) -> Dict[str, torch.Tensor]:
53
  if features.get("inputs_embeds", None) is None:
54
+ features["inputs_embeds"] = self.auto_model.base_model.get_input_embeddings()(features["input_ids"])
55
  if features.get("pixel_values", None) is not None:
56
  features["pixel_values"] = features["pixel_values"].type(self.auto_model.visual.get_dtype())
57
  image_embeds = self.auto_model.visual(
modeling_gme_qwen2vl.py CHANGED
@@ -12,16 +12,25 @@ import torch
12
  from PIL import Image
13
  from torch.utils.data import DataLoader
14
  from tqdm.autonotebook import tqdm
15
- from transformers import (
16
- AutoProcessor,
17
- PreTrainedModel,
18
  Qwen2VLConfig,
19
  Qwen2VLForConditionalGeneration,
 
 
 
 
 
 
 
 
20
  )
21
- import os
22
 
23
 
24
  class GmeQwen2VLConfig(Qwen2VLConfig):
 
 
25
  def __init__(
26
  self,
27
  min_image_tokens: int = 256,
@@ -35,14 +44,25 @@ class GmeQwen2VLConfig(Qwen2VLConfig):
35
  self.max_length = max_length
36
 
37
 
38
- class GmeQwen2VLForVision2Seq(PreTrainedModel):
39
  config_class = GmeQwen2VLConfig
40
- base_model_prefix: str = "base"
 
 
 
 
 
 
 
 
41
 
42
  def __init__(self, config: GmeQwen2VLConfig, **kwargs: Any) -> None:
43
  super().__init__(config)
44
- self.base = Qwen2VLForConditionalGeneration.from_pretrained(config._name_or_path)
45
- self.base.tie_weights() # It's important to produce same outputs.
 
 
 
46
 
47
  min_pixels: int = config.min_image_tokens * 28 * 28
48
  max_pixels: int = config.max_image_tokens * 28 * 28
@@ -55,6 +75,9 @@ class GmeQwen2VLForVision2Seq(PreTrainedModel):
55
  self.default_instruction: str = "You are a helpful assistant."
56
  self.sep: str = " "
57
 
 
 
 
58
  def forward(
59
  self,
60
  input_ids: Optional[torch.LongTensor] = None,
@@ -70,21 +93,21 @@ class GmeQwen2VLForVision2Seq(PreTrainedModel):
70
  **kwargs
71
  ) -> torch.Tensor:
72
  if inputs_embeds is None:
73
- inputs_embeds = self.base.model.embed_tokens(input_ids)
74
  if pixel_values is not None:
75
- pixel_values = pixel_values.type(self.base.visual.get_dtype())
76
- image_embeds = self.base.visual(pixel_values, grid_thw=image_grid_thw).to(inputs_embeds.device)
77
- image_mask = input_ids == self.base.config.image_token_id
78
  inputs_embeds[image_mask] = image_embeds
79
  # if pixel_values_videos is not None:
80
- # pixel_values_videos = pixel_values_videos.type(self.base.visual.get_dtype())
81
- # video_embeds = self.base.visual(pixel_values_videos, grid_thw=video_grid_thw).to(inputs_embeds.device)
82
- # video_mask = input_ids == self.base.config.video_token_id
83
  # inputs_embeds[video_mask] = video_embeds
84
  if attention_mask is not None:
85
  attention_mask = attention_mask.to(inputs_embeds.device)
86
 
87
- outputs = self.base.model(
88
  input_ids=None,
89
  position_ids=position_ids,
90
  attention_mask=attention_mask,
@@ -311,3 +334,4 @@ def fetch_image(image: str | Image.Image, size_factor: int = IMAGE_FACTOR) -> Im
311
 
312
  return image
313
  ###
 
 
12
  from PIL import Image
13
  from torch.utils.data import DataLoader
14
  from tqdm.autonotebook import tqdm
15
+ from transformers import AutoProcessor, PreTrainedModel
16
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import (
17
+ Qwen2VisionTransformerPretrainedModel,
18
  Qwen2VLConfig,
19
  Qwen2VLForConditionalGeneration,
20
+ Qwen2VLModel,
21
+ )
22
+ from transformers.utils.versions import require_version
23
+
24
+
25
+ require_version(
26
+ "transformers<4.52.0",
27
+ "This code has some issues with transformers>=4.52.0, please downgrade: pip install transformers==4.51.3"
28
  )
 
29
 
30
 
31
  class GmeQwen2VLConfig(Qwen2VLConfig):
32
+ # model_type = ''
33
+
34
  def __init__(
35
  self,
36
  min_image_tokens: int = 256,
 
44
  self.max_length = max_length
45
 
46
 
47
+ class GmeQwen2VL(PreTrainedModel):
48
  config_class = GmeQwen2VLConfig
49
+ base_model_prefix = "model"
50
+ supports_gradient_checkpointing = True
51
+ _no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2VLVisionBlock"]
52
+ # _skip_keys_device_placement = "past_key_values"
53
+ _supports_flash_attn_2 = True
54
+ _supports_sdpa = True
55
+ # _supports_cache_class = True
56
+ _supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions`
57
+ # _tied_weights_keys = ["lm_head.weight"]
58
 
59
  def __init__(self, config: GmeQwen2VLConfig, **kwargs: Any) -> None:
60
  super().__init__(config)
61
+ self.visual = Qwen2VisionTransformerPretrainedModel._from_config(config.vision_config)
62
+ self.model = Qwen2VLModel(config)
63
+ self.vocab_size = config.vocab_size
64
+ # self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
65
+ self.rope_deltas = None # cache rope_deltas here
66
 
67
  min_pixels: int = config.min_image_tokens * 28 * 28
68
  max_pixels: int = config.max_image_tokens * 28 * 28
 
75
  self.default_instruction: str = "You are a helpful assistant."
76
  self.sep: str = " "
77
 
78
+ # Initialize weights and apply final processing
79
+ self.post_init()
80
+
81
  def forward(
82
  self,
83
  input_ids: Optional[torch.LongTensor] = None,
 
93
  **kwargs
94
  ) -> torch.Tensor:
95
  if inputs_embeds is None:
96
+ inputs_embeds = self.model.get_input_embeddings()(input_ids)
97
  if pixel_values is not None:
98
+ pixel_values = pixel_values.type(self.visual.get_dtype())
99
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).to(inputs_embeds.device)
100
+ image_mask = input_ids == self.config.image_token_id
101
  inputs_embeds[image_mask] = image_embeds
102
  # if pixel_values_videos is not None:
103
+ # pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
104
+ # video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw).to(inputs_embeds.device)
105
+ # video_mask = input_ids == self.config.video_token_id
106
  # inputs_embeds[video_mask] = video_embeds
107
  if attention_mask is not None:
108
  attention_mask = attention_mask.to(inputs_embeds.device)
109
 
110
+ outputs = self.model(
111
  input_ids=None,
112
  position_ids=position_ids,
113
  attention_mask=attention_mask,
 
334
 
335
  return image
336
  ###
337
+