SANSA: Unleashing the Hidden Semantics in SAM2 for Few-Shot Segmentation
Claudia Cuttano* Β· Gabriele Trivigno* Β· Giuseppe Averta Β· Carlo Masone
β¨ NeurIPS 2025 Spotlight β¨
SANSA unlocks the hidden semantics of Segment Anything 2, turning it into a powerful few-shot segmenter for both objects and parts.
π No fine-tuning of SAM2 weights.
π§ οΈ Fully promptable: points Β· boxes Β· scribbles Β· masks, making it ideal for real-world labeling.
π State-of-the-art on few-shot object & part segmentation benchmarks.
β‘ Lightweight: 3β5Γ faster, 4β5Γ smaller!
SANSA_demo.mp4
To get started, create a Conda environment and install the required dependencies. SANSA is compatible with any PyTorch β₯ 2.0. The experiments in the paper were run with PyTorch 2.7.1 (with CUDA 12.6), which we provide as a reference configuration. To set up the environment using Conda, run:
conda create --name sansa python=3.10 -y
conda activate sansa
pip install -r requirements.txt
In this repository, you will find:
1. SANSA Universal Model: a single model, fully promptable (points Β· boxes Β· scribbles Β· masks), for both objects & parts.
Β Β Β Β· We release this model on TorchHub, and include an interactive demo to try it on your own data.
Β Β Β Β· Note: this is not the model used for the paper benchmarks.
2. Paper Results & Training: strict few-shot and in-context benchmarks, with results and training scripts for reproducibility.
Run on your own data (objects & parts, promptable with points Β· boxes Β· scribbles Β· masks).
Quick Links: π₯ Download Weights Β· π§βπ» Interactive Notebook Β· π¦ TorchHub
Curious about SANSA? The Notebook lets you try it out. Mark an object or part in one image (point, box, scribble, or mask), and SANSA will segment the same class in the following images.
π‘ Example: draw a quick box around a person, and SANSA finds all the people in the next images.
Below is a minimal example showing how to load SANSA from TorchHub and run inference. Use point, box, or mask prompts depending on your application.
Expand for 'def format_prompt' function
def format_prompt(n_shots: int, prompt_input, prompt_type: str, device: torch.device = torch.device('cuda')):
"""
Format prompt to be fed to the SANSA model. Alternatively, import as 'from util.demo_sansa import format_prompt'
"""
assert prompt_type in ['mask', 'point', 'box']
prompt_dict = {0: {}, 'shots': n_shots}
prompt_d = prompt_input
if prompt_type in ['point']:
pts = torch.as_tensor(prompt_input, dtype=torch.float32, device=device).view(-1, 2)
prompt_d = {'point_coords': pts.view(1, -1, 2),
'point_labels': torch.ones(1, pts.shape[0], dtype=torch.int32, device=device)}
elif prompt_type == 'box':
b = torch.as_tensor(prompt_input, dtype=torch.float32, device=device).view(-1, 4)
x0y0 = torch.minimum(b[:, :2], b[:, 2:])
x1y1 = torch.maximum(b[:, :2], b[:, 2:])
point_coords = torch.stack([x0y0, x1y1], dim=1).view(1, -1, 2)
n = point_coords.shape[1] // 2
point_labels = torch.tensor([2, 3], dtype=torch.int32, device=device).repeat(1, n)
prompt_d = {'point_coords': point_coords, 'point_labels': point_labels}
prompt_dict[0][0] = {'prompt_type': prompt_type, 'prompt': prompt_d}
return prompt_dictimport torch
from torchvision import transforms
import torchvision.transforms.functional as TF
from PIL import Image
import numpy as np
sup_img_path='assets/demo/images_demo/image011.jpg'
q_img_path='assets/demo/images_demo/image005.jpg'
sup_mask_path = 'assets/demo/masks_demo/image011_dog.png'
device = torch.device('cuda')
_transform = transforms.Compose([transforms.Resize(size=(640, 640)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
sup_t, q_t = _transform(Image.open(sup_img_path)), _transform(Image.open(q_img_path))
video = torch.stack([sup_t, q_t], dim=0)[None].to(device) # [1, 2, 3, H, W]
model = torch.hub.load('ClaudiaCuttano/SANSA', 'sansa', pretrained=True, trust_repo=True, device=device)
point = np.array([[320, 330]], dtype=np.float32) # points
#box = np.array([300, 560, 440, 320], dtype=np.float32) # box
#mask = TF.to_tensor(Image.open(sup_mask_path).convert("L").resize((img_size, img_size), Image.NEAREST))[None].to(device) # mask
point_prompt = format_prompt(n_shots=1, prompt_input=point, prompt_type='point', device=device)
#box_prompt = format_prompt(n_shots=1, prompt_input=box, prompt_type='box', device=device)
#mask_prompt = format_prompt(n_shots=1, prompt_input=mask, prompt_type='mask', device=device)
with torch.no_grad():
out = model(video, point_prompt) # choose one between [point_prompt, box_prompt, mask_prompt]
pred_mask = out["pred_masks"][1].sigmoid() > 0.5Reproduce benchmarks (strict few-shot & in-context segmentation) and training.
To train and reproduce our results, set up your dataset: please refer to data.md for detailed data preparation.
Once organized, the directory structure should look like this:
SANSA/
βββ data/
β βββ COCO2014/
β βββ FSS-1000/
β βββ ...
βββ datasets/
βββ models/
β βββ sam2/
β βββ sansa/
β βββ ...
...
Β· Purpose. Exact checkpoints and commands to match the paper numbers.
Β· Tracks. (1) Strict few-shot segmentation Β· (2) Generalist in-context segmentation.
Β· Note. Models in this section supports masks prompts-only, to ensure fair comparison with prior works.
Β· Tip. If you just want one versatile and promptable model for your own data, use SANSA Universal Model above.
Standard novel-class protocol with disjoint partitions: LVIS-92i (10 folds) and COCO-20i (4 folds); FSS-1000 has a single fixed split.
We release one adapter per fold and report per-fold and mean IoU. Choose shots at eval with --shot {1|5}.
Reference objects are given as masks.
| Dataset | Pretrained adapters |
Fold 0 |
Fold 1 |
Fold 2 |
Fold 3 |
Fold 4 |
Fold 5 |
Fold 6 |
Fold 7 |
Fold 8 |
Fold 9 |
Mean IoU |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| LVIS-92i | π₯ LVIS (10) | 48.4 | 48.3 | 51.5 | 50.7 | 44.8 | 50.1 | 51.1 | 50.5 | 45.9 | 46.3 | 48.8 |
| COCO-20i | π₯ COCO (4) | 58.9 | 62.6 | 61.5 | 58.0 | 60.2 | ||||||
| FSS-1000 | π₯ FSS-1000 | 91.4 | 91.4 |
Command to replicate the results:
python inference_fss.py \
--dataset_file {coco|lvis|fss} \
--fold {FOLD} \ # omit for FSS
--resume /path/to/adapter_{ds}_fold{FOLD}.pth \
--name_exp eval_{coco|lvis|fss} \
--shot {1|5} \
--adaptformer_stages 2,3 \
--prompt mask
Optionally, add --visualize to visualize the results.
Single generalist adapter trained on COCO+ADE20K+LVIS+PACO for in-context few-shot segmentation: one model across datasets and tasks (object + part segmentation). Reference objects are given as masks.
Note: if you want a single generalist promptable model, please refer to SANSA Universal Model.
| Pretrained adapters | Segmentation | Segmentation | Segmentation | Part | Part |
|---|---|---|---|---|---|
| LVIS-92i | COCO-20i | FSS-1000 | Pascal-Part | PACO-Part | |
| π₯ In-Context Generalist | 50.3 | 75.6 | 90.0 | 49.1 | 43.0 |
Command to replicate the results:
python inference_fss.py \
--dataset_file {coco|lvis|fss|pascal_part|paco_part} \
--fold {FOLD} \ # LVIS: 0β9, COCO: 0β3, FSS: omit/0, Pascal/PACO: 0β3
--resume pretrain/adapter_generalist.pth \
--name_exp eval_generalist_fss_{coco|lvis|fss|pascal_part|paco_part} \
--shot {1|5} \
--channel_factor 0.8 \
--adaptformer_stages 2,3 \
--prompt mask
To train SANSA on strict few shot segmentation, use the generic command below and adjust the flags as needed:
python main.py \
--batch_size 32 \ # global batch size (tune to your GPU memory)
--name_exp train_{ds}_f{FOLD} \ # run name
--dataset_file {coco|lvis|fss} \ # choose the benchmark
--fold {FOLD} \ # fold to EVALUATE on; training uses the REMAINING folds
--adaptformer_stages 2 3 \ # adapters in the last two Hiera encoder stages
--prompt mask
Notes:
- Strict few-shot protocol: passing
--fold Fmeans evaluate on fold F and train on the other folds. - Folds: COCO-20i
F β {0,1,2,3}Β· LVIS-92iF β {0,β¦,9}Β· FSS-1000: fixed split: omit--fold. - Use
--prompt multifor promptable strict few shot segmentation: trains by sampling amongmask/scribble/box/pointeach episode. - Frozen SAM2-Large: backbone/decoder remain frozen; only the adapter is trained.
Example:
# COCO-20i, fold 0 (strict few-shot)
python main.py --batch_size 32 --name_exp train_coco_f0 --dataset_file coco --fold 0 --adaptformer_stages 2 3 --prompt mask
Train one adapter jointly on multiple datasets:
python main.py \
--batch_size 32 \
--name_exp train_generalist \
--multi_train \
--dataset_file lvis, coco, ade20k, paco_part \
--ds_weight 0.4, 0.45, 0.1, 0.05 \
--fold -1 \
--adaptformer_stages 2 3 \
--channel_factor 0.8 \
--prompt mask
Notes:
--fold -1disables strict fold splitting: for multi-dataset training we donβt use disjoint train/test folds (as we do in strict FSS, where the goal is to evaluate generalization on unseen categories).--ds_weightsets per-dataset sampling proportions (same order as--dataset_file).- To replicate our SANSA Universal Model, simply add
--prompt multi.
If you find this work useful in your research, please cite it using the BibTeX entry below:
@misc{cuttano2025sansa,
title = {SANSA: Unleashing the Hidden Semantics in SAM2 for Few-Shot Segmentation},
author = {Claudia Cuttano and Gabriele Trivigno and Giuseppe Averta and Carlo Masone},
year = {2025},
eprint = {2505.21795},
url = {https://arxiv.org/abs/2505.21795},
}
This project builds upon code from the following libraries and repositories:
