This code works for calculating norm_stats for openVLA. This currently supports datasets from open-X embodiment and custom rlds datasets.
might have to register the dataset if using custom dataset in prismatic/vla/datasets/rlds/oxe/configs.py and /prismatic/vla/datasets/rlds/oxe/transforms.py
"""requires tensorflow 2.15
works with all open-x datasets and custom RLDS datasets.
if you are using new datasets, you may need to update the get_oxe_dataset_kwargs_and_weights function.
example usage:
# For Open-X datasets:
python compute_save_norm_stats.py --dataset_path open-x --output_dir norm_stats_ucsd --dataset_name ucsd_pick_and_place_dataset_converted_externally_to_rlds
# For custom RLDS datasets:
python compute_save_norm_stats.py --dataset_path /path/to/custom_dataset --output_dir norm_stats_custom --dataset_name my_custom_dataset --custom_dataset
if using custom datasets, ensure they are in the RLDS format.
and update the dataset_name accordingly.
Usage Example:
For Open-X datasets:
python compute_save_norm_stats.py \\
--dataset_path /path/to/open-x \\
--output_dir /norm_stats_ucsd \\
--dataset_name ucsd_pick_and_place_dataset_converted_externally_to_rlds
For custom RLDS datasets:
# Inspect dataset structure first
python compute_save_norm_stats.py \\
--dataset_path /path/to/my_custom_dataset \\
--dataset_name my_custom_dataset \\
--custom_dataset \\
--inspect_only
# Compute statistics
python compute_save_norm_stats.py \\
--dataset_path /path/to/my_custom_dataset \\
--output_dir /norm_stats_custom \\
--dataset_name my_custom_dataset \\
--custom_dataset
"""
import argparse
import json
import os
from pathlib import Path
import tensorflow as tf
from prismatic.vla.datasets import RLDSDataset, RLDSBatchTransform
from prismatic.vla.datasets.rlds.utils.data_utils import save_dataset_statistics
from prismatic.models.backbones.llm.prompting import PurePromptBuilder
from prismatic.vla.action_tokenizer import ActionTokenizer
from prismatic.vla.datasets.rlds.oxe.configs import OXE_DATASET_CONFIGS
from transformers import AutoTokenizer
import numpy as np
from PIL import Image
import torch
def get_custom_dataset_config(dataset_name: str) -> dict:
"""
Returns a default configuration for custom RLDS datasets.
Users can modify this function to match their dataset structure.
"""
# Default configuration - modify as needed for your dataset
return {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["eef_state", None, "gripper_state"],
"state_encoding": "pos_euler", # or "pos_quat", "joint", "none"
"action_encoding": "eef_pos", # or "joint_pos"
}
def inspect_dataset_structure(dataset_path: str, dataset_name: str):
"""
Inspects the structure of a custom RLDS dataset to help users configure it properly.
"""
import tensorflow_datasets as tfds
print(f"\\n=== Inspecting Dataset Structure: {dataset_name} ===")
try:
builder = tfds.builder(dataset_name, data_dir=dataset_path)
print(f"Dataset info: {builder.info}")
# Get a sample episode to inspect structure
ds = builder.as_dataset(split='train', shuffle_files=False).take(1)
for episode in ds:
print(f"\\nEpisode keys: {list(episode.keys())}")
if 'steps' in episode:
for step in episode['steps'].take(1):
print(f"Step keys: {list(step.keys())}")
if 'observation' in step:
obs = step['observation']
print(f"Observation keys: {list(obs.keys())}")
# Check image keys
image_keys = [k for k in obs.keys() if 'image' in k.lower() or 'rgb' in k.lower()]
if image_keys:
print(f"Image observation keys: {image_keys}")
# Check state/proprioception keys
state_keys = [k for k in obs.keys() if any(x in k.lower() for x in ['state', 'pose', 'joint', 'eef', 'gripper'])]
if state_keys:
print(f"State observation keys: {state_keys}")
if 'action' in step:
action = step['action']
print(f"Action shape: {action.shape}")
print(f"Action dtype: {action.dtype}")
# Check for language instructions
lang_keys = [k for k in step.keys() if 'language' in k.lower() or 'instruction' in k.lower() or 'task' in k.lower()]
if lang_keys:
print(f"Language keys: {lang_keys}")
break
except Exception as e:
print(f"Error inspecting dataset: {e}")
print("Make sure the dataset is in proper RLDS format.")
def compute_and_save_stats(dataset_path, output_dir, dataset_name="ucsd_pick_and_place_dataset_converted_externally_to_rlds", custom_dataset=False, inspect_only=False):
"""
Compute dataset statistics using OpenVLA's existing functions and save to specified directory.
Args:
dataset_path: Path to your RLDS dataset directory
output_dir: Directory where to save dataset_statistics.json
dataset_name: Name identifier for the dataset (should match OXE config key or custom dataset name)
custom_dataset: Whether this is a custom dataset (not in OXE)
inspect_only: If True, only inspect the dataset structure without computing stats
"""
print(f"Starting computation with dataset_path={dataset_path}, output_dir={output_dir}, dataset_name={dataset_name}")
print(f"Custom dataset: {custom_dataset}")
# If inspect_only, just show dataset structure
if inspect_only:
inspect_dataset_structure(dataset_path, dataset_name)
return
# Configure TensorFlow (same as OpenVLA)
tf.config.set_visible_devices([], "GPU")
print("TensorFlow configured")
# Check if dataset is registered in OXE configs
if not custom_dataset and dataset_name not in OXE_DATASET_CONFIGS:
print(f"Warning: {dataset_name} not found in OXE_DATASET_CONFIGS")
print(f"Available OXE datasets: {list(OXE_DATASET_CONFIGS.keys())}")
print("Consider using --custom_dataset flag for non-OXE datasets")
# Auto-detect if it should be treated as custom
custom_dataset = True
print("Automatically treating as custom dataset...")
# Create minimal components needed for RLDSDataset
print("Loading tokenizer...")
base_tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium", padding_side="right")
base_tokenizer.add_special_tokens({"pad_token": "<pad>"})
action_tokenizer = ActionTokenizer(base_tokenizer)
print("Tokenizer loaded")
# Simple image transform (placeholder)
def dummy_image_transform(image):
return torch.zeros(3, 224, 224)
# Create batch transform (needed by RLDSDataset but not used for statistics)
print("Creating batch transform...")
batch_transform = RLDSBatchTransform(
action_tokenizer=action_tokenizer,
base_tokenizer=base_tokenizer,
image_transform=dummy_image_transform,
prompt_builder_fn=PurePromptBuilder,
)
print("Batch transform created")
# For custom datasets, we'll use direct TFDS approach since they're not in OXE registry
if custom_dataset:
print("Processing custom RLDS dataset...")
# First, inspect the dataset structure to help users
inspect_dataset_structure(dataset_path, dataset_name)
# Use direct TFDS approach for custom datasets
import tensorflow_datasets as tfds
# Set TFDS data dir
os.environ['TFDS_DATA_DIR'] = str(dataset_path)
try:
print(f"Building TFDS dataset for {dataset_name}...")
builder = tfds.builder(dataset_name, data_dir=dataset_path)
ds = builder.as_dataset(split='train', shuffle_files=False)
# Compute basic statistics manually
print("Computing statistics manually...")
print("Processing all episodes in the dataset...")
actions = []
episode_count = 0
total_transitions = 0
# Process ALL episodes
for episode in ds:
episode_count += 1
episode_transitions = 0
for step in episode['steps']:
action = step['action'].numpy()
actions.append(action)
episode_transitions += 1
total_transitions += 1
# Print progress every 100 episodes
if episode_count % 100 == 0:
print(f"Processed {episode_count} episodes, {total_transitions} transitions")
print(f"Finished processing {episode_count} episodes, {total_transitions} total transitions")
actions = np.array(actions)
action_dim = actions.shape[1]
print(f"Action dimensions: {action_dim}")
# For custom datasets, use a general approach for action mask
# Users can modify this based on their action space
if action_dim == 7:
# Standard 7D: pos (3) + ori (3) + gripper (1)
action_mask = [True, True, True, True, True, True, False]
elif action_dim == 4:
# Common 4D: pos (3) + gripper (1)
action_mask = [True, True, True, False]
else:
# Default: assume last dimension is gripper/discrete action
action_mask = [True] * (action_dim - 1) + [False]
print(f"Using action mask: {action_mask}")
# Compute statistics
dataset_statistics = {
dataset_name: {
"action": {
"mean": actions.mean(axis=0).tolist(),
"std": actions.std(axis=0).tolist(),
"min": actions.min(axis=0).tolist(),
"max": actions.max(axis=0).tolist(),
"q01": np.percentile(actions, 1, axis=0).tolist(),
"q99": np.percentile(actions, 99, axis=0).tolist(),
"mask": action_mask
},
"proprio": {
"mean": [0.0] * action_dim, # Match action dimensionality
"std": [0.0] * action_dim,
"min": [0.0] * action_dim,
"max": [0.0] * action_dim,
"q01": [0.0] * action_dim,
"q99": [0.0] * action_dim,
},
"num_transitions": total_transitions,
"num_trajectories": episode_count,
}
}
print("Successfully computed statistics for custom dataset")
except Exception as e:
print(f"Error processing custom dataset: {e}")
raise
else:
# Use OpenVLA's standard pipeline for OXE datasets
print("Processing Open-X dataset...")
# Check if dataset exists in the expected structure
print(f"Checking for dataset at {dataset_path}...")
dataset_full_path = Path(dataset_path) / dataset_name
if not dataset_full_path.exists():
print(f"Dataset not found at {dataset_full_path}")
print(f"Available datasets in {dataset_path}:")
if Path(dataset_path).exists():
for item in Path(dataset_path).iterdir():
if item.is_dir():
print(f" - {item.name}")
# Try to find the dataset with a different name
possible_names = [
"ucsd_pick_and_place_dataset_converted_externally_to_rlds",
]
found_dataset = None
for name in possible_names:
test_path = Path(dataset_path) / name
if test_path.exists():
found_dataset = name
dataset_name = name
break
if not found_dataset:
raise FileNotFoundError(f"Could not find dataset in {dataset_path}")
else:
print(f"Found dataset: {found_dataset}")
else:
print(f"Dataset found at {dataset_full_path}")
# Create RLDSDataset - this will compute statistics during initialization
try:
print(f"Creating RLDSDataset with data_root_dir={dataset_path}, data_mix={dataset_name}")
dataset = RLDSDataset(
data_root_dir=Path(dataset_path),
data_mix=dataset_name,
batch_transform=batch_transform,
resize_resolution=(224, 224),
shuffle_buffer_size=1000,
train=True,
image_aug=False,
)
# Extract dataset statistics
dataset_statistics = dataset.dataset_statistics
print("Successfully created dataset and extracted statistics")
except Exception as e:
print(f"Error creating RLDSDataset: {e}")
print("Trying alternative approach...")
# Alternative: try direct dataset creation if the above fails
from prismatic.vla.datasets.rlds.oxe import get_oxe_dataset_kwargs_and_weights
from prismatic.vla.datasets.rlds.dataset import make_interleaved_dataset
from prismatic.vla.datasets.rlds.utils.data_utils import NormalizationType
# Create mixture spec for single dataset
mixture_spec = [(dataset_name, 1.0)]
try:
print("Trying get_oxe_dataset_kwargs_and_weights...")
per_dataset_kwargs, weights = get_oxe_dataset_kwargs_and_weights(
Path(dataset_path),
mixture_spec,
load_camera_views=("primary",),
load_depth=False,
load_proprio=False,
load_language=True,
action_proprio_normalization_type=NormalizationType.BOUNDS_Q99,
)
rlds_config = dict(
traj_transform_kwargs=dict(
window_size=1,
future_action_window_size=0,
skip_unlabeled=True,
goal_relabeling_strategy="uniform",
),
frame_transform_kwargs=dict(
resize_size=(224, 224),
num_parallel_calls=1,
),
dataset_kwargs_list=per_dataset_kwargs,
shuffle_buffer_size=1000,
sample_weights=weights,
balance_weights=True,
traj_transform_threads=1,
traj_read_threads=1,
train=True,
)
print("Creating interleaved dataset...")
dataset, dataset_length, dataset_statistics = make_interleaved_dataset(**rlds_config)
print("Successfully created dataset using alternative approach")
except Exception as e2:
print(f"Alternative approach also failed: {e2}")
print("Falling back to direct TFDS approach...")
# Fallback to the same approach as custom datasets
import tensorflow_datasets as tfds
# Set TFDS data dir
os.environ['TFDS_DATA_DIR'] = str(dataset_path)
try:
print(f"Building TFDS dataset for {dataset_name}...")
builder = tfds.builder(dataset_name, data_dir=dataset_path)
ds = builder.as_dataset(split='train', shuffle_files=False)
# Compute basic statistics manually
print("Computing statistics manually...")
print("Processing all episodes in the dataset...")
actions = []
episode_count = 0
total_transitions = 0
# Process ALL episodes
for episode in ds:
episode_count += 1
episode_transitions = 0
for step in episode['steps']:
action = step['action'].numpy()
actions.append(action)
episode_transitions += 1
total_transitions += 1
# Print progress every 100 episodes
if episode_count % 100 == 0:
print(f"Processed {episode_count} episodes, {total_transitions} transitions")
print(f"Finished processing {episode_count} episodes, {total_transitions} total transitions")
actions = np.array(actions)
# Determine action dimensionality and create appropriate mask
action_dim = actions.shape[1]
print(f"Original action dimensions: {action_dim}")
# For UCSD pick and place, we need to pad to 7 dimensions if it's only 4
if action_dim == 4:
print("Padding 4D actions to 7D (3D pos + 3D ori + gripper)")
# Current: [x, y, z, gripper]
# Target: [x, y, z, rx, ry, rz, gripper]
# Pad with zeros for orientation (dimensions 3, 4, 5)
padded_actions = np.zeros((actions.shape[0], 7))
padded_actions[:, 0:3] = actions[:, 0:3] # position
padded_actions[:, 6] = actions[:, 3] # gripper
# Dimensions 3, 4, 5 remain zero (orientation)
actions = padded_actions
action_dim = 7
# Standard 7D mask: pos (3) + ori (3) + gripper (1)
action_mask = [True, True, True, True, True, True, False]
elif action_dim == 7:
# Already 7D
action_mask = [True, True, True, True, True, True, False]
else:
# Default: assume last dimension is gripper
action_mask = [True] * (action_dim - 1) + [False]
print(f"Final action dimensions: {action_dim}")
# Compute statistics
dataset_statistics = {
dataset_name: {
"action": {
"mean": actions.mean(axis=0).tolist(),
"std": actions.std(axis=0).tolist(),
"min": actions.min(axis=0).tolist(),
"max": actions.max(axis=0).tolist(),
"q01": np.percentile(actions, 1, axis=0).tolist(),
"q99": np.percentile(actions, 99, axis=0).tolist(),
"mask": action_mask
},
"proprio": {
"mean": [0.0] * action_dim, # Match action dimensionality
"std": [0.0] * action_dim,
"min": [0.0] * action_dim,
"max": [0.0] * action_dim,
"q01": [0.0] * action_dim,
"q99": [0.0] * action_dim,
},
"num_transitions": total_transitions,
"num_trajectories": episode_count,
}
}
print("Successfully computed statistics manually")
except Exception as e3:
print(f"All approaches failed. Final error: {e3}")
raise
# Save statistics
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# Save using OpenVLA's save function
save_dataset_statistics(dataset_statistics, output_path)
print(f"Statistics saved to {output_path / 'dataset_statistics.json'}")
# Print summary for the specific dataset
if dataset_name in dataset_statistics:
stats = dataset_statistics[dataset_name]
action_stats = stats["action"]
print(f"\\nAction Statistics Summary for {dataset_name}:")
print(f" Dimensions: {len(action_stats['mean'])}")
print(f" Mean: {action_stats['mean']}")
print(f" Std: {action_stats['std']}")
print(f" Min: {action_stats['min']}")
print(f" Max: {action_stats['max']}")
print(f" Q01: {action_stats['q01']}")
print(f" Q99: {action_stats['q99']}")
if "mask" in action_stats:
print(f" Mask: {action_stats['mask']}")
if "num_trajectories" in stats:
print(f" Number of trajectories: {stats['num_trajectories']}")
if "num_transitions" in stats:
print(f" Number of transitions: {stats['num_transitions']}")
else:
print(f"Available datasets in statistics: {list(dataset_statistics.keys())}")
return dataset_statistics
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Compute dataset normalization statistics for OpenVLA")
parser.add_argument("--dataset_path", type=str, required=True,
help="Path to root directory containing datasets (for OXE) or path to custom RLDS dataset")
parser.add_argument("--output_dir", type=str, required=True,
help="Output directory for dataset_statistics.json")
parser.add_argument("--dataset_name", type=str,
default="ucsd_pick_and_place_dataset_converted_externally_to_rlds",
help="Dataset name (should match directory name for OXE or dataset name for custom RLDS)")
parser.add_argument("--custom_dataset", action="store_true",
help="Use this flag for custom RLDS datasets (not in Open-X)")
parser.add_argument("--inspect_only", action="store_true",
help="Only inspect dataset structure without computing statistics")
args = parser.parse_args()
print("Script started")
compute_and_save_stats(
dataset_path=args.dataset_path,
output_dir=args.output_dir,
dataset_name=args.dataset_name,
custom_dataset=args.custom_dataset,
inspect_only=args.inspect_only
)
print("Script completed")