ft-bitvla-bitsiglipL-224px-libero_object-bf16 / bitvla_for_action_prediction.py
lxsy's picture
Upload folder using huggingface_hub
6b4df6d verified
from transformers import LlavaForConditionalGeneration,PretrainedConfig
from configuration_bit_vla import Bitvla_Config
import numpy as np
import torch
from prismatic.vla.constants import (
ACTION_DIM,
ACTION_PROPRIO_NORMALIZATION_TYPE,
NUM_ACTIONS_CHUNK,
NormalizationType,
)
from typing import Optional, Dict, Any,List,Tuple
from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
from prismatic.training.train_utils import (
get_current_action_mask,
get_next_actions_mask,
)
class BitVLAForActionPrediction(LlavaForConditionalGeneration):
config_class: PretrainedConfig = Bitvla_Config
def __init__(self, config) -> None:
super().__init__(config)
self.norm_stats = config.norm_stats
# Compute action bins
self.bins = np.linspace(-1, 1, config.n_action_bins)
self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0
self.vocab_size = self.config.vocab_size
def set_constant(self, image_token_idx, proprio_pad_idx, ignore_idx, action_token_begin_idx, stop_index):
self.image_token_idx = image_token_idx
self.proprio_pad_idx = proprio_pad_idx
self.action_token_begin_idx = action_token_begin_idx
self.stop_index = stop_index
self.ignore_idx = ignore_idx
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_projector_features: Optional[bool] = None,
return_dict: Optional[bool] = None,
proprio=None,
proprio_projector=None,
cache_position: Optional[torch.LongTensor] = None,
vision_feature_layer=None,
vision_feature_select_strategy=None,
) -> Tuple[int, LlavaCausalLMOutputWithPast]:
"""Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance."""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
output_projector_features = output_projector_features if output_projector_features is not None else False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Respect `use_cache` only if not training (even if `gradient_checkpointing` is off)
use_cache = use_cache and not self.training
batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] # type: ignore
# === Handle Multimodal Forward ===
if (input_ids.shape[0] == pixel_values.shape[0]) or (inputs_embeds.shape[0] == pixel_values.shape[0]):
assert past_key_values is None, "Unexpected key `past_key_values` provided during multimodal forward!"
# Get input embeddings
inputs_embeds = self.get_input_embeddings()(input_ids) # (B, seq_len, D)
# change the vision padding to the real vision tokens
if pixel_values is not None:
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
# pixel_values: b,num_images,c,h,w
# for each image, we do self.get_image_features
# then we concat the features of all images
# pixel_values: (b,num_images,c,h,w) --> (b*num_images,c,h,w)
b, num_images, c, h, w = pixel_values.shape
pixel_values = pixel_values.view(-1, c, h, w) # (b*num_images,c,h,w)
image_embeds = self.get_image_features(
pixel_values = pixel_values,
vision_feature_layer = vision_feature_layer,
vision_feature_select_strategy = vision_feature_select_strategy,
)
# image_features: (b*num_images,seq_len,patch_size) --> (b*num_images*seq_len,patch_size)
image_embeds = image_embeds.view(-1,image_embeds.shape[-1])
n_image_tokens = (input_ids == self.image_token_idx).sum().item()
n_image_features = image_embeds.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
mask = input_ids == self.image_token_idx
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
image_mask = mask_expanded.to(inputs_embeds.device)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
# change the proprio padding to the real proprio tokens
if proprio_projector is not None and proprio is not None:
# proprio: (bsz, proprio_dim) or (propro_dim,)
proprio = proprio.reshape(batch_size, -1) # (bsz, proprio_dim)
proprio_features = proprio_projector(proprio) # (bsz, llm_dim)
proprio_features = proprio_features.unsqueeze(dim=1) # (bsz, 1, llm_dim)
#(bsz, 1, llm_dim) --> (bsz*1, llm_dim)
proprio_features = proprio_features.view(-1, proprio_features.shape[-1])
n_proprio_tokens = (input_ids == self.proprio_pad_idx).sum().item()
n_proprio_features = proprio_features.shape[0]
if n_proprio_tokens != n_proprio_features:
raise ValueError(
f"Proprio features and proprio tokens do not match: tokens: {n_proprio_tokens}, features {n_proprio_features}"
)
mask = input_ids == self.proprio_pad_idx
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
proprio_mask = mask_expanded.to(inputs_embeds.device)
proprio_features = proprio_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(proprio_mask, proprio_features)
# Extract action masks
# Action tokens are those in labels that are not ignore, not newline, and not end-of-sequence tokens
all_actions_mask = (labels != self.ignore_idx) & (labels != self.stop_index)
# Replace the embeddings of the action tokens with zeros
# (Later on, the positional embeddings will be added to them)
all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
inputs_embeds = inputs_embeds * ~all_actions_mask
outputs = LlavaForConditionalGeneration.forward(
self,
input_ids = None,
attention_mask=attention_mask,
position_ids=None,
pixel_values=None,
labels=labels,
inputs_embeds=inputs_embeds,
past_key_values=None,
use_cache=None,
output_attentions=False,
output_hidden_states=True,
return_dict=True,
)
# === Otherwise =>> Assume Invalid! ===
elif (input_ids.shape[0] != pixel_values.shape[0]) or (inputs_embeds.shape[0] != pixel_values.shape[0]):
raise ValueError("Non-homogenous batch of (text, image) input -- forward() does not support mixed batches!")
else:
raise ValueError(
"Invalid PrismaticForConditionalGeneration `forward()` call with provided arguments:\n"
f"=> `input_ids` = {input_ids is not None}\n"
f"=> `attention_mask` = {attention_mask is not None}\n"
f"=> `pixel_values` = {pixel_values is not None}\n"
f"=> `labels` = {labels is not None}\n"
f"=> `input_embeds` = {inputs_embeds is not None}\n"
f"=> `past_key_values` = {past_key_values is not None}\n"
f"=> `use_cache` = {use_cache}"
)
return outputs
def _prepare_input_for_action_prediction(self, input_ids, attention_mask):
"""Prepares input for action prediction by adding necessary tokens"""
# Add (ACTION_DIM * NUM_ACTIONS_CHUNK) placeholder tokens to input_ids to simulate action tokens
placeholder_action_token_ids = (
torch.ones((input_ids.shape[0], ACTION_DIM * NUM_ACTIONS_CHUNK)).to(input_ids.device).to(input_ids.dtype)
)
input_ids = torch.cat([input_ids, placeholder_action_token_ids], dim=-1)
# Add stop token to sequence (needed in non-causal bi-directional self-attention, as it appears at train time)
stop_token_id = torch.ones((input_ids.shape[0], 1)).to(input_ids.device).to(input_ids.dtype) * self.stop_index
input_ids = torch.cat([input_ids, stop_token_id], dim=-1)
# Extend the attention mask to fit the new shape of input
# Note: Only batch size == 1 supported right now
mask_extension = (
torch.ones((attention_mask.shape[0], input_ids.shape[-1] - attention_mask.shape[-1]))
.to(attention_mask.device)
.to(attention_mask.dtype)
)
attention_mask = torch.cat([attention_mask, mask_extension], dim=-1)
return input_ids, attention_mask
def _prepare_labels_for_action_prediction(self, labels, input_ids):
"""Creates labels tensor for action prediction if not provided"""
# Extend labels tensor with fake action labels
ARBITRARY_ACTION_TOKEN_IDX = self.action_token_begin_idx + 1
labels_extension = (
torch.ones((labels.shape[0], input_ids.shape[-1] - labels.shape[-1])).to(labels.device).to(labels.dtype)
* ARBITRARY_ACTION_TOKEN_IDX
)
labels = torch.cat([labels, labels_extension], dim=-1)
# Replace last label token with stop token
labels[:, -1] = self.stop_index
return labels
def _process_action_masks(self, labels):
"""Helper to get action masks from labels"""
current_action_mask = get_current_action_mask(labels,ignore_index=self.ignore_idx,action_token_begin_idx=self.action_token_begin_idx)
next_actions_mask = get_next_actions_mask(labels,ignore_index=self.ignore_idx,action_token_begin_idx=self.action_token_begin_idx)
all_actions_mask = current_action_mask | next_actions_mask # (B, seq_len)
return all_actions_mask
def _unnormalize_actions(self, normalized_actions, unnorm_key=None):
"""Unnormalize actions using dataset statistics"""
action_norm_stats = self.get_action_stats(unnorm_key)
if ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS:
mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["min"], dtype=bool))
action_high, action_low = np.array(action_norm_stats["max"]), np.array(action_norm_stats["min"])
elif ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS_Q99:
mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool))
action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"])
else:
raise ValueError("Unsupported action/proprio normalization type detected!")
actions = np.where(
mask,
0.5 * (normalized_actions + 1) * (action_high - action_low + 1e-8) + action_low,
normalized_actions,
)
return actions
def _regression_or_discrete_prediction(
self,
input_ids,
input_embeddings,
all_actions_mask,
attention_mask,
labels,
action_head=None,
pixel_values = None,
):
"""Run L1 regression-based continuous action prediction or discrete action tokens prediction."""
# Zero out action token embeddings
all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
input_embeddings = input_embeddings * ~all_actions_mask
llava_output = LlavaForConditionalGeneration.forward(
self,
input_ids = None,
attention_mask=attention_mask,
position_ids=None,
pixel_values=None,
labels=None,
inputs_embeds=input_embeddings,
past_key_values=None,
use_cache=None,
output_attentions=False,
output_hidden_states=True,
return_dict=True,
)
all_actions_mask = self._process_action_masks(labels[:,1:])
# Extract hidden states for action tokens
last_hidden_states = llava_output.hidden_states[-1] # (B, seq_len, D)
last_hidden_states = last_hidden_states[:, : -1, :] # (B, act_chunk_len, D)
# Use the action mask to extract the hidden states of the actions
actions_hidden_states = last_hidden_states[all_actions_mask.squeeze(-1)].unsqueeze(0) # (B, act_chunk_len, D)
# Handle different prediction methods
if action_head is not None:
# L1 regression prediction
normalized_actions = action_head.predict_action(actions_hidden_states)
normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
normalized_actions = normalized_actions.float().cpu().detach().numpy()
else:
# Discrete token-based prediction
predicted_action_token_ids = (
llava_output.logits[all_actions_mask.squeeze(-1)].unsqueeze(0)
.argmax(dim=2)
.cpu()
.numpy()
)
# FIXME: We do not support discrete action prediction right now
# It seems that vocab_size here is not correct. This should be the dimension of the logit layer, which is actually larger than the vocab_size in the tokenizer. What we actually need here is the vocab_size from the tokenizer.
discretized_actions = self.vocab_size - predicted_action_token_ids
discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1)
normalized_actions = self.bin_centers[discretized_actions]
normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
return normalized_actions, actions_hidden_states
def predict_action(
self,
input_ids: Optional[torch.LongTensor] = None,
unnorm_key: Optional[str] = None,
proprio=None,
proprio_projector=None,
action_head=None,
vision_feature_layer=None,
vision_feature_select_strategy=None,
**kwargs: str,
) -> np.ndarray:
"""Predict actions from input sequence, with options for different prediction methods.
Args:
input_ids: Input token ids
unnorm_key: Key for unnormalization statistics
proprio: Proprioceptive features
proprio_projector: Projector for proprioceptive features
action_head: Optional head for L1 regression prediction
**kwargs: Additional arguments including pixel_values and attention_mask
Returns:
Tuple of (unnormalized_actions, action_hidden_states)
"""
pixel_values = kwargs["pixel_values"]
attention_mask = kwargs["attention_mask"]
# Create fake labels tensor (needed for action mask)
labels = input_ids.clone()
labels[:] = self.ignore_idx
# Prepare inputs by adding necessary tokens
input_ids, attention_mask = self._prepare_input_for_action_prediction(input_ids, attention_mask)
# Update labels tensor for action mask computation later
labels = self._prepare_labels_for_action_prediction(labels, input_ids)
# Get input embeddings and action masks
input_embeddings = self.get_input_embeddings()(input_ids)
all_actions_mask = self._process_action_masks(labels)
# vision tokens
if pixel_values is not None:
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
# pixel_values: b,num_images,c,h,w
# for each image, we do self.get_image_features
# then we concat the features of all images
# pixel_values: (b,num_images,c,h,w) --> (b*num_images,c,h,w)
b, num_images, c, h, w = pixel_values.shape
pixel_values = pixel_values.view(-1, c, h, w) # (b*num_images,c,h,w)
image_embeds = self.get_image_features(
pixel_values = pixel_values,
vision_feature_layer = vision_feature_layer,
vision_feature_select_strategy = vision_feature_select_strategy,
)
# image_features: (b*num_images,seq_len,patch_size) --> (b*num_images*seq_len,patch_size)
image_embeds = image_embeds.view(-1,image_embeds.shape[-1])
n_image_tokens = (input_ids == self.image_token_idx).sum().item()
n_image_features = image_embeds.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
mask = input_ids == self.image_token_idx
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(input_embeddings)
image_mask = mask_expanded.to(input_embeddings.device)
image_embeds = image_embeds.to(input_embeddings.device, input_embeddings.dtype)
input_embeddings = input_embeddings.masked_scatter(image_mask, image_embeds)
# Add proprioceptive features if provided
use_proprio = proprio_projector is not None and proprio is not None
if use_proprio:
batch_size = input_ids.shape[0] if input_ids is not None else input_embeddings.shape[0] # type: ignore
proprio = torch.Tensor(proprio).to(input_embeddings.device, dtype=input_embeddings.dtype)
if proprio_projector is not None and proprio is not None:
# proprio: (bsz, proprio_dim) or (propro_dim,)
proprio = proprio.reshape(batch_size, -1) # (bsz, proprio_dim)
proprio_features = proprio_projector(proprio) # (bsz, llm_dim)
proprio_features = proprio_features.unsqueeze(dim=1) # (bsz, 1, llm_dim)
#(bsz, 1, llm_dim) --> (bsz*1, llm_dim)
proprio_features = proprio_features.view(-1, proprio_features.shape[-1])
n_proprio_tokens = (input_ids == self.proprio_pad_idx).sum().item()
n_proprio_features = proprio_features.shape[0]
if n_proprio_tokens != n_proprio_features:
raise ValueError(
f"Proprio features and proprio tokens do not match: tokens: {n_proprio_tokens}, features {n_proprio_features}"
)
mask = input_ids == self.proprio_pad_idx
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(input_embeddings)
proprio_mask = mask_expanded.to(input_embeddings.device)
proprio_features = proprio_features.to(input_embeddings.device, input_embeddings.dtype)
input_embeddings = input_embeddings.masked_scatter(proprio_mask, proprio_features)
# Run regression or discrete token-based prediction
normalized_actions, actions_hidden_states = self._regression_or_discrete_prediction(
input_ids,
input_embeddings,
all_actions_mask,
attention_mask,
labels,
action_head,
pixel_values,
)
# Unnormalize predicted actions
actions = self._unnormalize_actions(normalized_actions, unnorm_key)
return actions, actions_hidden_states
@staticmethod
def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str:
"""Validate and resolve the unnormalization key for action statistics"""
if unnorm_key is None:
assert len(norm_stats) == 1, (
f"Your model was trained on more than one dataset, "
f"please pass a `unnorm_key` from the following options to choose the statistics "
f"used for un-normalizing actions: {norm_stats.keys()}"
)
unnorm_key = next(iter(norm_stats.keys()))
assert unnorm_key in norm_stats, (
f"The `unnorm_key` you chose is not in the set of available dataset statistics, "
f"please choose from: {norm_stats.keys()}"
)
return unnorm_key
def get_action_dim(self, unnorm_key: Optional[str] = None) -> int:
"""Get the dimensionality of the policy's action space."""
unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
return len(self.norm_stats[unnorm_key]["action"]["min"])
def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]:
"""Get all the logged statistics for the given dataset."""
unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
return self.norm_stats[unnorm_key]["action"]