Logo

Wavelet Policy

Imitation Policy Learning in the Frequency Domain with Wavelet Transforms
Quick-Start Demo Β· Paper on arXiv Β· Presentation Video

# Wavelet Policy

> πŸ” **Abstract:** > Most imitation learning policies treat the problem as a time-series prediction task, directly mapping high-dimensional observationsβ€”such as visual input and proprioceptionβ€”into action space. While time-series methods focus on spatial-domain modeling, they often overlook inherent temporal patterns in action sequences. To address this, we recast imitation learning policies in the frequency domain and propose **Wavelet Policy**. Our approach applies discrete wavelet transforms (WT) for feature preprocessing, then uses a Single-Encoder-Multiple-Decoder (SE2MD) architecture to extract multi-scale frequency-domain features. To further enrich feature mapping and boost capacity, we insert a Learnable Frequency-Domain Filter (LFDF) after each frequency decoder, improving robustness under varying visual conditions. Experiments show that Wavelet Policy outperforms state-of-the-art end-to-end methods by over 10 % across four challenging robotic-arm tasks while keeping model size comparable. In long-horizon settings, its performance degrades more gracefully as task complexity increases. The code will be released publicly. --- ## πŸ“ Striving for a Simple and Efficient Embodied Intelligence Model

--- ## πŸ’» System Requirements | Component | Requirement | | ---------------- | ---------------------------------------- | | Operating System | Ubuntu 20.04 | | GPU Memory | Training: β‰₯ 4 GB; Inference: β‰₯ 2 GB | | Disk Space | 100–200 GB (datasets) | | GPU Support | NVIDIA GPU with CUDA support recommended | > πŸ“Œ For detailed hardware specs, see the β€œExperimental Setup” section in the paper. --- ## βš™οΈ Environment Configuration We recommend using a `conda` environment. Quick install steps: ```bash git clone https://github.com/lurenjia384/Wavelet_Policy cd Wavelet_Policy # Create conda environment conda create -n Wavelet_Policy python=3.7.16 -y conda activate Wavelet_Policy # Install dependencies pip install -r requirements.txt ``` --- ## πŸ“ Project Structure ```bash Wavelet_Policy/ β”œβ”€β”€ images # Images for GitHub display β”œβ”€β”€ pre_model # Pretrained models β”œβ”€β”€ log # Log files β”œβ”€β”€ assets # Robot fixtures or CAD assets β”œβ”€β”€ pytorch_wavelets # Wavelet transform utilities β”œβ”€β”€ vid_path # Saved inference videos β”œβ”€β”€ infer.py # Inference script β”œβ”€β”€ model.py # Network definitions β”œβ”€β”€ utils.py # Configuration and helper code β”œβ”€β”€ requirements.txt # Python dependencies β”œβ”€β”€ LICENSE └── README.md ``` --- ## πŸ”— Pre-trained Model Download Pre-trained weights and configurations are available on Hugging Face: πŸ‘‰ [WaveletPolicy-base](https://huggingface.co/lurenjia384/wavelet_policy_model) | Cameras | Dataset | Params (M) | Download (closed) | | :-----: | :----------------: | :--------: | :------------------------------------------------------------------------------------------------------------------------------- | | One | Transfer Cube | 17.22 | [best\_model\_11.pt](https://huggingface.co/lurenjia384/wavelet_policy_model/resolve/main/task_1/best_model_11.pt?download=true) | | Two | Transfer Cube | β€” | coming soon | | One | Bimanual Insertion | 17.22 | coming soon | | Two | Bimanual Insertion | β€” | coming soon | | One | Transfer Plus | 17.22 | [best\_model\_13.pt](https://huggingface.co/lurenjia384/wavelet_policy_model/resolve/main/task_3/best_model_13.pt?download=true) | | Two | Transfer Plus | β€” | coming soon | | One | Stack Two Blocks | 17.22 | coming soon | | Two | Stack Two Blocks | β€” | coming soon | After downloading, place the model files under: ``` Wavelet_Policy/ β”œβ”€β”€ pre_model ``` You can also load models directly with `huggingface_hub` (no need for `--netdir` or `--stats_path`): ```python from huggingface_hub import hf_hub_download model_path = hf_hub_download( repo_id="lurenjia384/wavelet_policy_model", filename="task_3/best_model_13.pt" ) stats_path = hf_hub_download( repo_id="lurenjia384/wavelet_policy_model", filename="task_3/task_3.pkl" ) ``` --- ## πŸš€ Quick Start **Run inference:** ```bash python infer.py --task_name {task_name} \ --stats_path {data_pkl_path.pkl} \ --netdir {model_path.pt} \ --no_visualization {0|1} ``` **Example:** Transfer Plus task, with visualization enabled: ```bash python infer.py --task_name sim_transfer_cube_scripted_plus --no_visualization 0 ``` If there are still difficulties, you can refer to the demonstration video[Quick-Start Demo](https://youtu.be/WnUJzu8MQBk). Other valid values for `--task_name` are: * `sim_transfer_cube_scripted` β€” Transfer Cube * `sim_insertion_scripted` β€” Bimanual Insertion * `Put` β€” Stack Two Blocks **Train the model:** Coming soon… --- ## πŸ“Š Experimental Results **Table:** Success rates (%) of Wavelet Policy vs. five baselines across four tasks and three stages | Model | TC-1 | TC-2 | TC-3 | BI-1 | BI-2 | BI-3 | TP-1 | TP-2 | TP-3 | ST-1 | ST-2 | ST-3 | | --------- | -----: | -----: | -----: | -----: | -----: | -----: | -----: | -----: | -----: | -----: | -----: | -----: | | DP (DDIM) | 9 | 6 | 4 | 4 | 3 | 1 | 2 | 1 | 1 | 1 | 1 | 1 | | ACT | 98 | 96 | 94 | 81 | 73 | 68 | 66 | 57 | 57 | 85 | 67 | 50 | | NL-ACT | 94 | 91 | 90 | 83 | 74 | 70 | 62 | 55 | 55 | 82 | 65 | 48 | | HACT-Vq | 98 | 98 | 97 | 87 | 82 | 76 | **79** | 68 | 68 | 90 | 76 | 55 | | InterACT | **98** | 88 | 82 | **88** | 78 | 44 | β€” | β€” | β€” | β€” | β€” | β€” | | Ours | 98 | **98** | **97** | 87 | **82** | **78** | 78 | **70** | **70** | **96** | **79** | **59** | > **Note:** Bold entries denote the best performance in each column. > `β€”` indicates no experiment for that method. > `TC`: Transfer Cube; `BI`: Bimanual Insertion; `TP`: Transfer Plus; `ST`: Stack Two Blocks. > The number after each task name indicates the stage. > See the paper for full experimental details.