|
from transformers import LlavaForConditionalGeneration,PretrainedConfig |
|
from configuration_bit_vla import Bitvla_Config |
|
import numpy as np |
|
import torch |
|
from prismatic.vla.constants import ( |
|
ACTION_DIM, |
|
ACTION_PROPRIO_NORMALIZATION_TYPE, |
|
NUM_ACTIONS_CHUNK, |
|
NormalizationType, |
|
) |
|
from typing import Optional, Dict, Any,List,Tuple |
|
|
|
from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast |
|
|
|
from prismatic.training.train_utils import ( |
|
get_current_action_mask, |
|
get_next_actions_mask, |
|
) |
|
|
|
|
|
class BitVLAForActionPrediction(LlavaForConditionalGeneration): |
|
config_class: PretrainedConfig = Bitvla_Config |
|
|
|
def __init__(self, config) -> None: |
|
super().__init__(config) |
|
self.norm_stats = config.norm_stats |
|
|
|
|
|
self.bins = np.linspace(-1, 1, config.n_action_bins) |
|
self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0 |
|
|
|
self.vocab_size = self.config.vocab_size |
|
|
|
def set_constant(self, image_token_idx, proprio_pad_idx, ignore_idx, action_token_begin_idx, stop_index): |
|
self.image_token_idx = image_token_idx |
|
self.proprio_pad_idx = proprio_pad_idx |
|
self.action_token_begin_idx = action_token_begin_idx |
|
self.stop_index = stop_index |
|
self.ignore_idx = ignore_idx |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
output_projector_features: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
proprio=None, |
|
proprio_projector=None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
vision_feature_layer=None, |
|
vision_feature_select_strategy=None, |
|
) -> Tuple[int, LlavaCausalLMOutputWithPast]: |
|
"""Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance.""" |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
output_projector_features = output_projector_features if output_projector_features is not None else False |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
use_cache = use_cache and not self.training |
|
|
|
batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] |
|
|
|
|
|
if (input_ids.shape[0] == pixel_values.shape[0]) or (inputs_embeds.shape[0] == pixel_values.shape[0]): |
|
assert past_key_values is None, "Unexpected key `past_key_values` provided during multimodal forward!" |
|
|
|
|
|
inputs_embeds = self.get_input_embeddings()(input_ids) |
|
|
|
|
|
if pixel_values is not None: |
|
vision_feature_layer = ( |
|
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer |
|
) |
|
vision_feature_select_strategy = ( |
|
vision_feature_select_strategy |
|
if vision_feature_select_strategy is not None |
|
else self.config.vision_feature_select_strategy |
|
) |
|
|
|
|
|
|
|
|
|
b, num_images, c, h, w = pixel_values.shape |
|
pixel_values = pixel_values.view(-1, c, h, w) |
|
image_embeds = self.get_image_features( |
|
pixel_values = pixel_values, |
|
vision_feature_layer = vision_feature_layer, |
|
vision_feature_select_strategy = vision_feature_select_strategy, |
|
) |
|
|
|
|
|
image_embeds = image_embeds.view(-1,image_embeds.shape[-1]) |
|
n_image_tokens = (input_ids == self.image_token_idx).sum().item() |
|
n_image_features = image_embeds.shape[0] |
|
if n_image_tokens != n_image_features: |
|
raise ValueError( |
|
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" |
|
) |
|
|
|
mask = input_ids == self.image_token_idx |
|
mask_unsqueezed = mask.unsqueeze(-1) |
|
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) |
|
image_mask = mask_expanded.to(inputs_embeds.device) |
|
|
|
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
|
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) |
|
|
|
|
|
|
|
if proprio_projector is not None and proprio is not None: |
|
|
|
proprio = proprio.reshape(batch_size, -1) |
|
proprio_features = proprio_projector(proprio) |
|
proprio_features = proprio_features.unsqueeze(dim=1) |
|
|
|
proprio_features = proprio_features.view(-1, proprio_features.shape[-1]) |
|
n_proprio_tokens = (input_ids == self.proprio_pad_idx).sum().item() |
|
n_proprio_features = proprio_features.shape[0] |
|
if n_proprio_tokens != n_proprio_features: |
|
raise ValueError( |
|
f"Proprio features and proprio tokens do not match: tokens: {n_proprio_tokens}, features {n_proprio_features}" |
|
) |
|
|
|
mask = input_ids == self.proprio_pad_idx |
|
mask_unsqueezed = mask.unsqueeze(-1) |
|
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) |
|
proprio_mask = mask_expanded.to(inputs_embeds.device) |
|
|
|
proprio_features = proprio_features.to(inputs_embeds.device, inputs_embeds.dtype) |
|
inputs_embeds = inputs_embeds.masked_scatter(proprio_mask, proprio_features) |
|
|
|
|
|
|
|
|
|
all_actions_mask = (labels != self.ignore_idx) & (labels != self.stop_index) |
|
|
|
|
|
|
|
all_actions_mask = all_actions_mask.unsqueeze(-1) |
|
inputs_embeds = inputs_embeds * ~all_actions_mask |
|
outputs = LlavaForConditionalGeneration.forward( |
|
self, |
|
input_ids = None, |
|
attention_mask=attention_mask, |
|
position_ids=None, |
|
pixel_values=None, |
|
labels=labels, |
|
inputs_embeds=inputs_embeds, |
|
past_key_values=None, |
|
use_cache=None, |
|
output_attentions=False, |
|
output_hidden_states=True, |
|
return_dict=True, |
|
) |
|
|
|
elif (input_ids.shape[0] != pixel_values.shape[0]) or (inputs_embeds.shape[0] != pixel_values.shape[0]): |
|
raise ValueError("Non-homogenous batch of (text, image) input -- forward() does not support mixed batches!") |
|
|
|
else: |
|
raise ValueError( |
|
"Invalid PrismaticForConditionalGeneration `forward()` call with provided arguments:\n" |
|
f"=> `input_ids` = {input_ids is not None}\n" |
|
f"=> `attention_mask` = {attention_mask is not None}\n" |
|
f"=> `pixel_values` = {pixel_values is not None}\n" |
|
f"=> `labels` = {labels is not None}\n" |
|
f"=> `input_embeds` = {inputs_embeds is not None}\n" |
|
f"=> `past_key_values` = {past_key_values is not None}\n" |
|
f"=> `use_cache` = {use_cache}" |
|
) |
|
|
|
return outputs |
|
|
|
def _prepare_input_for_action_prediction(self, input_ids, attention_mask): |
|
"""Prepares input for action prediction by adding necessary tokens""" |
|
|
|
placeholder_action_token_ids = ( |
|
torch.ones((input_ids.shape[0], ACTION_DIM * NUM_ACTIONS_CHUNK)).to(input_ids.device).to(input_ids.dtype) |
|
) |
|
input_ids = torch.cat([input_ids, placeholder_action_token_ids], dim=-1) |
|
|
|
|
|
stop_token_id = torch.ones((input_ids.shape[0], 1)).to(input_ids.device).to(input_ids.dtype) * self.stop_index |
|
input_ids = torch.cat([input_ids, stop_token_id], dim=-1) |
|
|
|
|
|
|
|
mask_extension = ( |
|
torch.ones((attention_mask.shape[0], input_ids.shape[-1] - attention_mask.shape[-1])) |
|
.to(attention_mask.device) |
|
.to(attention_mask.dtype) |
|
) |
|
attention_mask = torch.cat([attention_mask, mask_extension], dim=-1) |
|
|
|
return input_ids, attention_mask |
|
|
|
def _prepare_labels_for_action_prediction(self, labels, input_ids): |
|
"""Creates labels tensor for action prediction if not provided""" |
|
|
|
ARBITRARY_ACTION_TOKEN_IDX = self.action_token_begin_idx + 1 |
|
labels_extension = ( |
|
torch.ones((labels.shape[0], input_ids.shape[-1] - labels.shape[-1])).to(labels.device).to(labels.dtype) |
|
* ARBITRARY_ACTION_TOKEN_IDX |
|
) |
|
labels = torch.cat([labels, labels_extension], dim=-1) |
|
|
|
|
|
labels[:, -1] = self.stop_index |
|
|
|
return labels |
|
|
|
def _process_action_masks(self, labels): |
|
"""Helper to get action masks from labels""" |
|
current_action_mask = get_current_action_mask(labels,ignore_index=self.ignore_idx,action_token_begin_idx=self.action_token_begin_idx) |
|
next_actions_mask = get_next_actions_mask(labels,ignore_index=self.ignore_idx,action_token_begin_idx=self.action_token_begin_idx) |
|
all_actions_mask = current_action_mask | next_actions_mask |
|
return all_actions_mask |
|
|
|
def _unnormalize_actions(self, normalized_actions, unnorm_key=None): |
|
"""Unnormalize actions using dataset statistics""" |
|
action_norm_stats = self.get_action_stats(unnorm_key) |
|
|
|
if ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS: |
|
mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["min"], dtype=bool)) |
|
action_high, action_low = np.array(action_norm_stats["max"]), np.array(action_norm_stats["min"]) |
|
elif ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS_Q99: |
|
mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool)) |
|
action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"]) |
|
else: |
|
raise ValueError("Unsupported action/proprio normalization type detected!") |
|
|
|
actions = np.where( |
|
mask, |
|
0.5 * (normalized_actions + 1) * (action_high - action_low + 1e-8) + action_low, |
|
normalized_actions, |
|
) |
|
|
|
return actions |
|
|
|
def _regression_or_discrete_prediction( |
|
self, |
|
input_ids, |
|
input_embeddings, |
|
all_actions_mask, |
|
attention_mask, |
|
labels, |
|
action_head=None, |
|
pixel_values = None, |
|
): |
|
"""Run L1 regression-based continuous action prediction or discrete action tokens prediction.""" |
|
|
|
all_actions_mask = all_actions_mask.unsqueeze(-1) |
|
input_embeddings = input_embeddings * ~all_actions_mask |
|
|
|
llava_output = LlavaForConditionalGeneration.forward( |
|
self, |
|
input_ids = None, |
|
attention_mask=attention_mask, |
|
position_ids=None, |
|
pixel_values=None, |
|
labels=None, |
|
inputs_embeds=input_embeddings, |
|
past_key_values=None, |
|
use_cache=None, |
|
output_attentions=False, |
|
output_hidden_states=True, |
|
return_dict=True, |
|
) |
|
all_actions_mask = self._process_action_masks(labels[:,1:]) |
|
|
|
last_hidden_states = llava_output.hidden_states[-1] |
|
last_hidden_states = last_hidden_states[:, : -1, :] |
|
|
|
actions_hidden_states = last_hidden_states[all_actions_mask.squeeze(-1)].unsqueeze(0) |
|
|
|
|
|
if action_head is not None: |
|
|
|
normalized_actions = action_head.predict_action(actions_hidden_states) |
|
normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM) |
|
normalized_actions = normalized_actions.float().cpu().detach().numpy() |
|
else: |
|
|
|
predicted_action_token_ids = ( |
|
llava_output.logits[all_actions_mask.squeeze(-1)].unsqueeze(0) |
|
.argmax(dim=2) |
|
.cpu() |
|
.numpy() |
|
) |
|
|
|
|
|
discretized_actions = self.vocab_size - predicted_action_token_ids |
|
discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1) |
|
normalized_actions = self.bin_centers[discretized_actions] |
|
normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM) |
|
|
|
return normalized_actions, actions_hidden_states |
|
|
|
def predict_action( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
unnorm_key: Optional[str] = None, |
|
proprio=None, |
|
proprio_projector=None, |
|
action_head=None, |
|
vision_feature_layer=None, |
|
vision_feature_select_strategy=None, |
|
**kwargs: str, |
|
) -> np.ndarray: |
|
"""Predict actions from input sequence, with options for different prediction methods. |
|
|
|
Args: |
|
input_ids: Input token ids |
|
unnorm_key: Key for unnormalization statistics |
|
proprio: Proprioceptive features |
|
proprio_projector: Projector for proprioceptive features |
|
action_head: Optional head for L1 regression prediction |
|
**kwargs: Additional arguments including pixel_values and attention_mask |
|
|
|
Returns: |
|
Tuple of (unnormalized_actions, action_hidden_states) |
|
""" |
|
pixel_values = kwargs["pixel_values"] |
|
attention_mask = kwargs["attention_mask"] |
|
|
|
|
|
labels = input_ids.clone() |
|
labels[:] = self.ignore_idx |
|
|
|
|
|
input_ids, attention_mask = self._prepare_input_for_action_prediction(input_ids, attention_mask) |
|
|
|
|
|
labels = self._prepare_labels_for_action_prediction(labels, input_ids) |
|
|
|
|
|
input_embeddings = self.get_input_embeddings()(input_ids) |
|
all_actions_mask = self._process_action_masks(labels) |
|
|
|
|
|
if pixel_values is not None: |
|
vision_feature_layer = ( |
|
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer |
|
) |
|
vision_feature_select_strategy = ( |
|
vision_feature_select_strategy |
|
if vision_feature_select_strategy is not None |
|
else self.config.vision_feature_select_strategy |
|
) |
|
|
|
|
|
|
|
|
|
b, num_images, c, h, w = pixel_values.shape |
|
pixel_values = pixel_values.view(-1, c, h, w) |
|
image_embeds = self.get_image_features( |
|
pixel_values = pixel_values, |
|
vision_feature_layer = vision_feature_layer, |
|
vision_feature_select_strategy = vision_feature_select_strategy, |
|
) |
|
|
|
|
|
image_embeds = image_embeds.view(-1,image_embeds.shape[-1]) |
|
n_image_tokens = (input_ids == self.image_token_idx).sum().item() |
|
n_image_features = image_embeds.shape[0] |
|
if n_image_tokens != n_image_features: |
|
raise ValueError( |
|
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" |
|
) |
|
|
|
mask = input_ids == self.image_token_idx |
|
mask_unsqueezed = mask.unsqueeze(-1) |
|
mask_expanded = mask_unsqueezed.expand_as(input_embeddings) |
|
image_mask = mask_expanded.to(input_embeddings.device) |
|
|
|
image_embeds = image_embeds.to(input_embeddings.device, input_embeddings.dtype) |
|
input_embeddings = input_embeddings.masked_scatter(image_mask, image_embeds) |
|
|
|
|
|
use_proprio = proprio_projector is not None and proprio is not None |
|
if use_proprio: |
|
batch_size = input_ids.shape[0] if input_ids is not None else input_embeddings.shape[0] |
|
proprio = torch.Tensor(proprio).to(input_embeddings.device, dtype=input_embeddings.dtype) |
|
if proprio_projector is not None and proprio is not None: |
|
|
|
proprio = proprio.reshape(batch_size, -1) |
|
proprio_features = proprio_projector(proprio) |
|
proprio_features = proprio_features.unsqueeze(dim=1) |
|
|
|
proprio_features = proprio_features.view(-1, proprio_features.shape[-1]) |
|
n_proprio_tokens = (input_ids == self.proprio_pad_idx).sum().item() |
|
n_proprio_features = proprio_features.shape[0] |
|
if n_proprio_tokens != n_proprio_features: |
|
raise ValueError( |
|
f"Proprio features and proprio tokens do not match: tokens: {n_proprio_tokens}, features {n_proprio_features}" |
|
) |
|
|
|
mask = input_ids == self.proprio_pad_idx |
|
mask_unsqueezed = mask.unsqueeze(-1) |
|
mask_expanded = mask_unsqueezed.expand_as(input_embeddings) |
|
proprio_mask = mask_expanded.to(input_embeddings.device) |
|
|
|
proprio_features = proprio_features.to(input_embeddings.device, input_embeddings.dtype) |
|
input_embeddings = input_embeddings.masked_scatter(proprio_mask, proprio_features) |
|
|
|
|
|
normalized_actions, actions_hidden_states = self._regression_or_discrete_prediction( |
|
input_ids, |
|
input_embeddings, |
|
all_actions_mask, |
|
attention_mask, |
|
labels, |
|
action_head, |
|
pixel_values, |
|
) |
|
|
|
|
|
actions = self._unnormalize_actions(normalized_actions, unnorm_key) |
|
|
|
return actions, actions_hidden_states |
|
|
|
@staticmethod |
|
def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str: |
|
"""Validate and resolve the unnormalization key for action statistics""" |
|
if unnorm_key is None: |
|
assert len(norm_stats) == 1, ( |
|
f"Your model was trained on more than one dataset, " |
|
f"please pass a `unnorm_key` from the following options to choose the statistics " |
|
f"used for un-normalizing actions: {norm_stats.keys()}" |
|
) |
|
unnorm_key = next(iter(norm_stats.keys())) |
|
|
|
assert unnorm_key in norm_stats, ( |
|
f"The `unnorm_key` you chose is not in the set of available dataset statistics, " |
|
f"please choose from: {norm_stats.keys()}" |
|
) |
|
return unnorm_key |
|
|
|
def get_action_dim(self, unnorm_key: Optional[str] = None) -> int: |
|
"""Get the dimensionality of the policy's action space.""" |
|
unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key) |
|
return len(self.norm_stats[unnorm_key]["action"]["min"]) |
|
|
|
def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]: |
|
"""Get all the logged statistics for the given dataset.""" |
|
unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key) |
|
return self.norm_stats[unnorm_key]["action"] |
|
|