Ligeng-Zhu commited on
Commit
d48af03
·
verified ·
1 Parent(s): c3082d0

Upload files with `vila-upload`.

Browse files
Files changed (1) hide show
  1. modeling_vila.py +7 -6
modeling_vila.py CHANGED
@@ -1082,7 +1082,8 @@ class VILAForCasualLM(VILAPretrainedModel):
1082
 
1083
  return outputs
1084
 
1085
- @torch.inference_mode()
 
1086
  def generate(
1087
  self,
1088
  input_ids: Optional[torch.FloatTensor] = None,
@@ -1096,14 +1097,14 @@ class VILAForCasualLM(VILAPretrainedModel):
1096
  input_tokens: <image> describe the image
1097
  media: [Tensor(1, 3, 384, 384), ]
1098
  ----------->
1099
- input_tokens: 36000 001 002 003 004
1100
  input_emds: <media emd> 001 002 003 004
1101
  """
1102
  # NOTE: hard code to move to GPU
1103
- input_ids = input_ids.cuda()
1104
- media = {k: [v.cuda() for v in media[k]] for k in media}
1105
- if attention_mask is not None:
1106
- attention_mask = attention_mask.cuda()
1107
 
1108
  inputs_embeds, _, attention_mask = self._embed(input_ids, media, media_config, None, attention_mask)
1109
  output_ids = self.llm.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generation_kwargs)
 
1082
 
1083
  return outputs
1084
 
1085
+ # TODO(ligeng): check how qwen implements this function
1086
+ # @torch.inference_mode()
1087
  def generate(
1088
  self,
1089
  input_ids: Optional[torch.FloatTensor] = None,
 
1097
  input_tokens: <image> describe the image
1098
  media: [Tensor(1, 3, 384, 384), ]
1099
  ----------->
1100
+ input_tokens: 36000 001 002 003 004
1101
  input_emds: <media emd> 001 002 003 004
1102
  """
1103
  # NOTE: hard code to move to GPU
1104
+ # input_ids = input_ids.cuda()
1105
+ # media = {k: [v.cuda() if v is not None for v in media[k]] for k in media}
1106
+ # if attention_mask is not None:
1107
+ # attention_mask = attention_mask.cuda()
1108
 
1109
  inputs_embeds, _, attention_mask = self._embed(input_ids, media, media_config, None, attention_mask)
1110
  output_ids = self.llm.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generation_kwargs)