izhx commited on
Commit
6021c8f
·
verified ·
1 Parent(s): 8a1422a

Fix remote code

Browse files
Files changed (1) hide show
  1. custom_st.py +1 -1
custom_st.py CHANGED
@@ -51,7 +51,7 @@ class MultiModalTransformer(BaseTransformer):
51
  self, features: Dict[str, torch.Tensor], **kwargs
52
  ) -> Dict[str, torch.Tensor]:
53
  if features.get("inputs_embeds", None) is None:
54
- features["inputs_embeds"] = self.auto_model.base_model.embed_tokens(features["input_ids"])
55
  if features.get("pixel_values", None) is not None:
56
  features["pixel_values"] = features["pixel_values"].type(self.auto_model.visual.get_dtype())
57
  image_embeds = self.auto_model.visual(
 
51
  self, features: Dict[str, torch.Tensor], **kwargs
52
  ) -> Dict[str, torch.Tensor]:
53
  if features.get("inputs_embeds", None) is None:
54
+ features["inputs_embeds"] = self.auto_model.base_model.get_input_embeddings()(features["input_ids"])
55
  if features.get("pixel_values", None) is not None:
56
  features["pixel_values"] = features["pixel_values"].type(self.auto_model.visual.get_dtype())
57
  image_embeds = self.auto_model.visual(