Bi-JROS: Bi-level Learning of Task-Specific Decoders for Joint Registration and One-Shot Medical Image Segmentation
Xin Fan1, Xiaolin Wang1, Jiaxin Gao1, Jia Wang1, Zhongxuan Luo1, Risheng Liu1
1School of Software Technology, Dalian University of Technology, Dalian, China
- [2025/04/28]: ✨Adapt Different Encoders (eg. sam, synthseg) to Our Framework (updating).
- [2025/04/27]: ✨We release the model weight of Bi-JROS in the Step 1:Pretrain the shared encder 🤗 Huggingface
- [2024/04/23]: ✨We release the train and inference code.
- [2024/02/27]: ✨This paper was accepted by CVPR 2024!
This code requires the following:
- Python==3.8
- PyTorch==1.12.1
- Torchvision==0.13.1
- Torchaudio==0.12.1
- Numpy==1.24.3
- Scipy==1.10.1
- Scikit-image==0.21.0
- Nibabel==5.2.0
The datasets used in the paper, ABIDE, ANDI, PPMI, and OASIS, are publicly available for download.
For example, ADNI can be applied for and downloaded through the following link: https://adni.loni.usc.edu/data-samples/adni-data/#AccessData.
The download process for ABIDE is described at https://fcon_1000.projects.nitrc.org/indi/abide/databases.html.
Preprocessed ABIDE data can be accessed at http://preprocessed-connectomes-project.org/abide/index.html.
Clone the repo:
git clone https://github.com/Coradlut/Bi-JROS.git
python train.py
Before executing the code, it may be necessary to configure certain parameters in accordance with specific requirements.
To test the performance:
python infer.py
In this section, we demonstrate how we adapt different encoders to our framework. Specifically, we focus on integrating four encoders: SAM, SynthSeg, and two of our own proposed methods. We will showcase the results of applying these encoders and provide a brief introduction to each of the methods.

SAM introduces prompt-based guidance to enable fast segmentation of arbitrary targets within an image. Prompts can take various forms, such as points, boxes, masks, or text descriptions. Based on these prompts, the model generates valid segmentation masks. The encoder of SAM consists of two parts: the image encoder and the prompt encoder. The image encoder generates a one-time embedding that captures the overall representation of the input image. The prompt encoder encodes points, boxes, text, or masks into embedding vectors in real time, which are then combined with the image embedding to guide the segmentation process.
- One-Time Embedding: Generates a single global image representation with a pre-trained Vision Transformer.
- Prompt Encoding: Dynamically embeds points, boxes, masks, and text to guide segmentation.
Reference: SAM Paper (Kirillov et al., 2023)
SynthSeg trains the network with on-the-fly synthesized images using a Bayesian generative model and domain randomization, enabling it to learn domain-agnostic features and perform direct segmentation on real images without retraining. SynthSeg uses a 3D U-Net encoder to extract domain-agnostic features from on-the-fly synthesized images. By processing inputs with randomized contrast, resolution, and artifacts, the encoder learns robust structural representations for accurate segmentation across diverse domains.
- Domain Randomization: Trained on images with randomized contrast, resolution, noise, and artifacts to ensure robustness.
- 3D U-Net Architecture: Uses deep hierarchical features and skip connections to capture both local and global structure.
- On-the-Fly Data Augmentation: Continuously exposed to new synthetic inputs during training, enhancing generalization without retraining.
Reference: SynthSeg Paper (Billot et al., 2021)
Bi-JROS introduces a novel bi-level learning framework for one-shot medical image segmentation, using a pretrained fixed shared encoder to stabilize training and enhance adaptability. It treats registration as the major objective and segmentation as a learnable constraint, while leveraging appearance conformity to generate style-consistent pseudo-labels for data augmentation. The pretrained and fixed shared encoder extracts stable, domain-adaptive features from medical images, providing a common feature space for both the registration and segmentation decoders.
- Pretrained on diverse unlabeled data: Extracts stable and domain-adaptive features for one-shot medical segmentation.
- Fixed parameters after pretraining: Prevents feature drift and ensures stable optimization during downstream tasks.
- Shared feature space for tasks: Provides unified features for both registration and segmentation decoders.
Reference: Bi-JROS Paper (Fan et al., 2024)
RRL-MedSAM adapts SAM for one-shot 3D medical image segmentation by introducing a dual-stage knowledge distillation (DSKD) strategy and a mutual-EMA mechanism to train lightweight general and medical-specific encoders. The General Encoder distilled from SAM for domain-agnostic feature learning and the Medical Encoder specialized for fine-grained 3D medical image segmentation, jointly optimized via mutual-EMA collaboration.
- Dual encoders: General and Medical encoders collaboratively learn domain-agnostic and medical-specific features.
- Dual-stage knowledge distillation: Two-step distillation with mutual-EMA ensures robust and harmonized feature representation.
- Auto-prompting decoder: Automatically generates prompts from coarse masks to enhance segmentation without manual interaction.
Reference: Reference: RRL-SAM Paper (Fan et al., 2024)[on hold]
In this section, we present the results of applying the different encoders to our framework. The Dice coefficient is used as the evaluation metric to compare the segmentation performance of each method.
| Encoder Method | Dice Coefficient (%) [OASIS] | Dice Coefficient [Dataset 2] | Dice Coefficient [Dataset 3] |
|---|---|---|---|
| SAM | 71.58 | -- | -- |
| SynthSeg | -- | -- | -- |
| Bi-JROS | 81.4 | -- | -- |
| RRL-SAM | 82.4 | -- | -- |
- Downloading the pretrained weghts for train in huggingface.
Set hyperparameters ‘enc’(eg. sam, sythseg, bi-jros. and rrl-sam) to select which necoder to adapt to our framework.
python train_arbi_enc4dec.py -enc bi-jros
