Skip to content

Commit a479c0c

Browse files
pablovela5620nikolausWestemilk
authored
Add new ARKitScenes example (#1538)
Co-authored-by: Nikolaus West <nikolaus.west@me.com> Co-authored-by: Emil Ernerfeldt <emil.ernerfeldt@gmail.com>
1 parent bf27c8e commit a479c0c

4 files changed

Lines changed: 779 additions & 0 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
dataset/**
Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
# Copied from https://github.com/apple/ARKitScenes/blob/main/download_data.py
2+
# Licensing information: https://github.com/apple/ARKitScenes/blob/main/LICENSE
3+
import math
4+
import os
5+
import subprocess
6+
from pathlib import Path
7+
from typing import Final, List, Optional
8+
9+
import pandas as pd
10+
11+
ARkitscense_url = "https://docs-assets.developer.apple.com/ml-research/datasets/arkitscenes/v1"
12+
TRAINING: Final = "Training"
13+
VALIDATION: Final = "Validation"
14+
HIGRES_DEPTH_ASSET_NAME: Final = "highres_depth"
15+
POINT_CLOUDS_FOLDER: Final = "laser_scanner_point_clouds"
16+
17+
AVAILABLE_RECORDINGS: Final = ["48458663", "42444949", "41069046", "41125722", "41125763", "42446167"]
18+
DATASET_DIR: Final = Path(os.path.dirname(__file__)) / "dataset"
19+
20+
default_raw_dataset_assets = [
21+
"mov",
22+
"annotation",
23+
"mesh",
24+
"confidence",
25+
"highres_depth",
26+
"lowres_depth",
27+
"lowres_wide.traj",
28+
"lowres_wide",
29+
"lowres_wide_intrinsics",
30+
"ultrawide",
31+
"ultrawide_intrinsics",
32+
"vga_wide",
33+
"vga_wide_intrinsics",
34+
]
35+
36+
missing_3dod_assets_video_ids = [
37+
"47334522",
38+
"47334523",
39+
"42897421",
40+
"45261582",
41+
"47333152",
42+
"47333155",
43+
"48458535",
44+
"48018733",
45+
"47429677",
46+
"48458541",
47+
"42897848",
48+
"47895482",
49+
"47333960",
50+
"47430089",
51+
"42899148",
52+
"42897612",
53+
"42899153",
54+
"42446164",
55+
"48018149",
56+
"47332198",
57+
"47334515",
58+
"45663223",
59+
"45663226",
60+
"45663227",
61+
]
62+
63+
64+
def raw_files(video_id: str, assets: List[str], metadata: pd.DataFrame) -> List[str]:
65+
file_names = []
66+
for asset in assets:
67+
if HIGRES_DEPTH_ASSET_NAME == asset:
68+
in_upsampling = metadata.loc[metadata["video_id"] == float(video_id), ["is_in_upsampling"]].iat[0, 0]
69+
if not in_upsampling:
70+
print(f"Skipping asset {asset} for video_id {video_id} - Video not in upsampling dataset")
71+
continue # highres_depth asset only available for video ids from upsampling dataset
72+
73+
if asset in [
74+
"confidence",
75+
"highres_depth",
76+
"lowres_depth",
77+
"lowres_wide",
78+
"lowres_wide_intrinsics",
79+
"ultrawide",
80+
"ultrawide_intrinsics",
81+
"wide",
82+
"wide_intrinsics",
83+
"vga_wide",
84+
"vga_wide_intrinsics",
85+
]:
86+
file_names.append(asset + ".zip")
87+
elif asset == "mov":
88+
file_names.append(f"{video_id}.mov")
89+
elif asset == "mesh":
90+
if video_id not in missing_3dod_assets_video_ids:
91+
file_names.append(f"{video_id}_3dod_mesh.ply")
92+
elif asset == "annotation":
93+
if video_id not in missing_3dod_assets_video_ids:
94+
file_names.append(f"{video_id}_3dod_annotation.json")
95+
elif asset == "lowres_wide.traj":
96+
if video_id not in missing_3dod_assets_video_ids:
97+
file_names.append("lowres_wide.traj")
98+
else:
99+
raise Exception(f"No asset = {asset} in raw dataset")
100+
return file_names
101+
102+
103+
def download_file(url: str, file_name: str, dst: Path) -> bool:
104+
os.makedirs(dst, exist_ok=True)
105+
filepath = os.path.join(dst, file_name)
106+
107+
if not os.path.isfile(filepath):
108+
command = f"curl {url} -o {file_name}.tmp --fail"
109+
print(f"Downloading file {filepath}")
110+
try:
111+
subprocess.check_call(command, shell=True, cwd=dst)
112+
except Exception as error:
113+
print(f"Error downloading {url}, error: {error}")
114+
return False
115+
os.rename(filepath + ".tmp", filepath)
116+
else:
117+
print(f"WARNING: skipping download of existing file: {filepath}")
118+
return True
119+
120+
121+
def unzip_file(file_name: str, dst: Path, keep_zip: bool = True) -> bool:
122+
filepath = os.path.join(dst, file_name)
123+
print(f"Unzipping zip file {filepath}")
124+
command = f"unzip -oq {filepath} -d {dst}"
125+
try:
126+
subprocess.check_call(command, shell=True)
127+
except Exception as error:
128+
print(f"Error unzipping {filepath}, error: {error}")
129+
return False
130+
if not keep_zip:
131+
os.remove(filepath)
132+
return True
133+
134+
135+
def download_laser_scanner_point_clouds_for_video(video_id: str, metadata: pd.DataFrame, download_dir: Path) -> None:
136+
video_metadata = metadata.loc[metadata["video_id"] == float(video_id)]
137+
visit_id = video_metadata["visit_id"].iat[0]
138+
has_laser_scanner_point_clouds = video_metadata["has_laser_scanner_point_clouds"].iat[0]
139+
140+
if not has_laser_scanner_point_clouds:
141+
print(f"Warning: Laser scanner point clouds for video {video_id} are not available")
142+
return
143+
144+
if math.isnan(visit_id) or not visit_id.is_integer():
145+
print(f"Warning: Downloading laser scanner point clouds for video {video_id} failed - Bad visit id {visit_id}")
146+
return
147+
148+
visit_id = int(visit_id) # Expecting an 8 digit integer
149+
laser_scanner_point_clouds_ids = laser_scanner_point_clouds_for_visit_id(visit_id, download_dir)
150+
151+
for point_cloud_id in laser_scanner_point_clouds_ids:
152+
download_laser_scanner_point_clouds(point_cloud_id, visit_id, download_dir)
153+
154+
155+
def laser_scanner_point_clouds_for_visit_id(visit_id: int, download_dir: Path) -> List[str]:
156+
point_cloud_to_visit_id_mapping_filename = "laser_scanner_point_clouds_mapping.csv"
157+
if not os.path.exists(point_cloud_to_visit_id_mapping_filename):
158+
point_cloud_to_visit_id_mapping_url = (
159+
f"{ARkitscense_url}/raw/laser_scanner_point_clouds/{point_cloud_to_visit_id_mapping_filename}"
160+
)
161+
if not download_file(
162+
point_cloud_to_visit_id_mapping_url,
163+
point_cloud_to_visit_id_mapping_filename,
164+
download_dir,
165+
):
166+
print(
167+
f"Error downloading point cloud for visit_id {visit_id} at location "
168+
f"{point_cloud_to_visit_id_mapping_url}"
169+
)
170+
return []
171+
172+
point_cloud_to_visit_id_mapping_filepath = os.path.join(download_dir, point_cloud_to_visit_id_mapping_filename)
173+
point_cloud_to_visit_id_mapping = pd.read_csv(point_cloud_to_visit_id_mapping_filepath)
174+
point_cloud_ids = point_cloud_to_visit_id_mapping.loc[
175+
point_cloud_to_visit_id_mapping["visit_id"] == visit_id,
176+
["laser_scanner_point_clouds_id"],
177+
]
178+
point_cloud_ids_list = [scan_id[0] for scan_id in point_cloud_ids.values]
179+
180+
return point_cloud_ids_list
181+
182+
183+
def download_laser_scanner_point_clouds(laser_scanner_point_cloud_id: str, visit_id: int, download_dir: Path) -> None:
184+
laser_scanner_point_clouds_folder_path = download_dir / POINT_CLOUDS_FOLDER / str(visit_id)
185+
os.makedirs(laser_scanner_point_clouds_folder_path, exist_ok=True)
186+
187+
for extension in [".ply", "_pose.txt"]:
188+
filename = f"{laser_scanner_point_cloud_id}{extension}"
189+
filepath = os.path.join(laser_scanner_point_clouds_folder_path, filename)
190+
if os.path.exists(filepath):
191+
return
192+
file_url = f"{ARkitscense_url}/raw/laser_scanner_point_clouds/{visit_id}/{filename}"
193+
download_file(file_url, filename, laser_scanner_point_clouds_folder_path)
194+
195+
196+
def get_metadata(dataset: str, download_dir: Path) -> pd.DataFrame:
197+
filename = "metadata.csv"
198+
url = f"{ARkitscense_url}/threedod/{filename}" if "3dod" == dataset else f"{ARkitscense_url}/{dataset}/{filename}"
199+
dst_folder = download_dir / dataset
200+
dst_file = dst_folder / filename
201+
202+
if not download_file(url, filename, dst_folder):
203+
return
204+
205+
metadata = pd.read_csv(dst_file)
206+
return metadata
207+
208+
209+
def download_data(
210+
dataset: str,
211+
video_ids: List[str],
212+
dataset_splits: List[str],
213+
download_dir: Path,
214+
keep_zip: bool,
215+
raw_dataset_assets: Optional[List[str]] = None,
216+
should_download_laser_scanner_point_cloud: bool = False,
217+
) -> None:
218+
"""
219+
Downloads data from the specified dataset and video IDs to the given download directory.
220+
221+
Args:
222+
----
223+
dataset: the name of the dataset to download from (raw, 3dod, or upsampling)
224+
video_ids: the list of video IDs to download data for
225+
dataset_splits: the list of splits for each video ID (train, validation, or test)
226+
download_dir: the directory to download data to
227+
keep_zip: whether to keep the downloaded zip files after extracting them
228+
raw_dataset_assets: a list of asset types to download from the raw dataset, if dataset is "raw"
229+
should_download_laser_scanner_point_cloud: whether to download the laser scanner point cloud data, if available
230+
231+
Returns: None
232+
"""
233+
metadata = get_metadata(dataset, download_dir)
234+
if None is metadata:
235+
print(f"Error retrieving metadata for dataset {dataset}")
236+
return
237+
238+
for video_id in sorted(set(video_ids)):
239+
split = dataset_splits[video_ids.index(video_id)]
240+
dst_dir = download_dir / dataset / split
241+
if dataset == "raw":
242+
url_prefix = ""
243+
file_names = []
244+
if not raw_dataset_assets:
245+
print(f"Warning: No raw assets given for video id {video_id}")
246+
else:
247+
dst_dir = dst_dir / str(video_id)
248+
url_prefix = f"{ARkitscense_url}/raw/{split}/{video_id}" + "/{}"
249+
file_names = raw_files(video_id, raw_dataset_assets, metadata)
250+
elif dataset == "3dod":
251+
url_prefix = f"{ARkitscense_url}/threedod/{split}" + "/{}"
252+
file_names = [
253+
f"{video_id}.zip",
254+
]
255+
elif dataset == "upsampling":
256+
url_prefix = f"{ARkitscense_url}/upsampling/{split}" + "/{}"
257+
file_names = [
258+
f"{video_id}.zip",
259+
]
260+
else:
261+
raise Exception(f"No such dataset = {dataset}")
262+
263+
if should_download_laser_scanner_point_cloud and dataset == "raw":
264+
# Point clouds only available for the raw dataset
265+
download_laser_scanner_point_clouds_for_video(video_id, metadata, download_dir)
266+
267+
for file_name in file_names:
268+
dst_path = os.path.join(dst_dir, file_name)
269+
url = url_prefix.format(file_name)
270+
271+
if not file_name.endswith(".zip") or not os.path.isdir(dst_path[: -len(".zip")]):
272+
download_file(url, dst_path, dst_dir)
273+
else:
274+
print(f"WARNING: skipping download of existing zip file: {dst_path}")
275+
if file_name.endswith(".zip") and os.path.isfile(dst_path):
276+
unzip_file(file_name, dst_dir, keep_zip)
277+
278+
279+
def ensure_recording_downloaded(video_id: str, include_highres: bool) -> Path:
280+
"""Only downloads from validation set."""
281+
data_path = DATASET_DIR / "raw" / "Validation" / video_id
282+
assets_to_download = [
283+
"lowres_wide",
284+
"lowres_depth",
285+
"lowres_wide_intrinsics",
286+
"lowres_wide.traj",
287+
"annotation",
288+
"mesh",
289+
]
290+
if include_highres:
291+
assets_to_download.extend(["highres_depth", "wide", "wide_intrinsics"])
292+
download_data(
293+
dataset="raw",
294+
video_ids=[video_id],
295+
dataset_splits=[VALIDATION],
296+
download_dir=DATASET_DIR,
297+
keep_zip=False,
298+
raw_dataset_assets=assets_to_download,
299+
should_download_laser_scanner_point_cloud=False,
300+
)
301+
return data_path
302+
303+
304+
def ensure_recording_available(video_id: str, include_highres: bool) -> Path:
305+
"""
306+
Returns the path to the recording for a given video_id.
307+
308+
Args:
309+
video_id (str): Identifier for the recording.
310+
311+
Returns
312+
-------
313+
Path: Path object representing the path to the recording.
314+
315+
Raises
316+
------
317+
AssertionError: If the recording path does not exist.
318+
"""
319+
recording_path = ensure_recording_downloaded(video_id, include_highres)
320+
assert recording_path.exists(), f"Recording path {recording_path} does not exist."
321+
return recording_path # Return the path to the recording

0 commit comments

Comments
 (0)