|
| 1 | +# Copied from https://github.com/apple/ARKitScenes/blob/main/download_data.py |
| 2 | +# Licensing information: https://github.com/apple/ARKitScenes/blob/main/LICENSE |
| 3 | +import math |
| 4 | +import os |
| 5 | +import subprocess |
| 6 | +from pathlib import Path |
| 7 | +from typing import Final, List, Optional |
| 8 | + |
| 9 | +import pandas as pd |
| 10 | + |
| 11 | +ARkitscense_url = "https://docs-assets.developer.apple.com/ml-research/datasets/arkitscenes/v1" |
| 12 | +TRAINING: Final = "Training" |
| 13 | +VALIDATION: Final = "Validation" |
| 14 | +HIGRES_DEPTH_ASSET_NAME: Final = "highres_depth" |
| 15 | +POINT_CLOUDS_FOLDER: Final = "laser_scanner_point_clouds" |
| 16 | + |
| 17 | +AVAILABLE_RECORDINGS: Final = ["48458663", "42444949", "41069046", "41125722", "41125763", "42446167"] |
| 18 | +DATASET_DIR: Final = Path(os.path.dirname(__file__)) / "dataset" |
| 19 | + |
| 20 | +default_raw_dataset_assets = [ |
| 21 | + "mov", |
| 22 | + "annotation", |
| 23 | + "mesh", |
| 24 | + "confidence", |
| 25 | + "highres_depth", |
| 26 | + "lowres_depth", |
| 27 | + "lowres_wide.traj", |
| 28 | + "lowres_wide", |
| 29 | + "lowres_wide_intrinsics", |
| 30 | + "ultrawide", |
| 31 | + "ultrawide_intrinsics", |
| 32 | + "vga_wide", |
| 33 | + "vga_wide_intrinsics", |
| 34 | +] |
| 35 | + |
| 36 | +missing_3dod_assets_video_ids = [ |
| 37 | + "47334522", |
| 38 | + "47334523", |
| 39 | + "42897421", |
| 40 | + "45261582", |
| 41 | + "47333152", |
| 42 | + "47333155", |
| 43 | + "48458535", |
| 44 | + "48018733", |
| 45 | + "47429677", |
| 46 | + "48458541", |
| 47 | + "42897848", |
| 48 | + "47895482", |
| 49 | + "47333960", |
| 50 | + "47430089", |
| 51 | + "42899148", |
| 52 | + "42897612", |
| 53 | + "42899153", |
| 54 | + "42446164", |
| 55 | + "48018149", |
| 56 | + "47332198", |
| 57 | + "47334515", |
| 58 | + "45663223", |
| 59 | + "45663226", |
| 60 | + "45663227", |
| 61 | +] |
| 62 | + |
| 63 | + |
| 64 | +def raw_files(video_id: str, assets: List[str], metadata: pd.DataFrame) -> List[str]: |
| 65 | + file_names = [] |
| 66 | + for asset in assets: |
| 67 | + if HIGRES_DEPTH_ASSET_NAME == asset: |
| 68 | + in_upsampling = metadata.loc[metadata["video_id"] == float(video_id), ["is_in_upsampling"]].iat[0, 0] |
| 69 | + if not in_upsampling: |
| 70 | + print(f"Skipping asset {asset} for video_id {video_id} - Video not in upsampling dataset") |
| 71 | + continue # highres_depth asset only available for video ids from upsampling dataset |
| 72 | + |
| 73 | + if asset in [ |
| 74 | + "confidence", |
| 75 | + "highres_depth", |
| 76 | + "lowres_depth", |
| 77 | + "lowres_wide", |
| 78 | + "lowres_wide_intrinsics", |
| 79 | + "ultrawide", |
| 80 | + "ultrawide_intrinsics", |
| 81 | + "wide", |
| 82 | + "wide_intrinsics", |
| 83 | + "vga_wide", |
| 84 | + "vga_wide_intrinsics", |
| 85 | + ]: |
| 86 | + file_names.append(asset + ".zip") |
| 87 | + elif asset == "mov": |
| 88 | + file_names.append(f"{video_id}.mov") |
| 89 | + elif asset == "mesh": |
| 90 | + if video_id not in missing_3dod_assets_video_ids: |
| 91 | + file_names.append(f"{video_id}_3dod_mesh.ply") |
| 92 | + elif asset == "annotation": |
| 93 | + if video_id not in missing_3dod_assets_video_ids: |
| 94 | + file_names.append(f"{video_id}_3dod_annotation.json") |
| 95 | + elif asset == "lowres_wide.traj": |
| 96 | + if video_id not in missing_3dod_assets_video_ids: |
| 97 | + file_names.append("lowres_wide.traj") |
| 98 | + else: |
| 99 | + raise Exception(f"No asset = {asset} in raw dataset") |
| 100 | + return file_names |
| 101 | + |
| 102 | + |
| 103 | +def download_file(url: str, file_name: str, dst: Path) -> bool: |
| 104 | + os.makedirs(dst, exist_ok=True) |
| 105 | + filepath = os.path.join(dst, file_name) |
| 106 | + |
| 107 | + if not os.path.isfile(filepath): |
| 108 | + command = f"curl {url} -o {file_name}.tmp --fail" |
| 109 | + print(f"Downloading file {filepath}") |
| 110 | + try: |
| 111 | + subprocess.check_call(command, shell=True, cwd=dst) |
| 112 | + except Exception as error: |
| 113 | + print(f"Error downloading {url}, error: {error}") |
| 114 | + return False |
| 115 | + os.rename(filepath + ".tmp", filepath) |
| 116 | + else: |
| 117 | + print(f"WARNING: skipping download of existing file: {filepath}") |
| 118 | + return True |
| 119 | + |
| 120 | + |
| 121 | +def unzip_file(file_name: str, dst: Path, keep_zip: bool = True) -> bool: |
| 122 | + filepath = os.path.join(dst, file_name) |
| 123 | + print(f"Unzipping zip file {filepath}") |
| 124 | + command = f"unzip -oq {filepath} -d {dst}" |
| 125 | + try: |
| 126 | + subprocess.check_call(command, shell=True) |
| 127 | + except Exception as error: |
| 128 | + print(f"Error unzipping {filepath}, error: {error}") |
| 129 | + return False |
| 130 | + if not keep_zip: |
| 131 | + os.remove(filepath) |
| 132 | + return True |
| 133 | + |
| 134 | + |
| 135 | +def download_laser_scanner_point_clouds_for_video(video_id: str, metadata: pd.DataFrame, download_dir: Path) -> None: |
| 136 | + video_metadata = metadata.loc[metadata["video_id"] == float(video_id)] |
| 137 | + visit_id = video_metadata["visit_id"].iat[0] |
| 138 | + has_laser_scanner_point_clouds = video_metadata["has_laser_scanner_point_clouds"].iat[0] |
| 139 | + |
| 140 | + if not has_laser_scanner_point_clouds: |
| 141 | + print(f"Warning: Laser scanner point clouds for video {video_id} are not available") |
| 142 | + return |
| 143 | + |
| 144 | + if math.isnan(visit_id) or not visit_id.is_integer(): |
| 145 | + print(f"Warning: Downloading laser scanner point clouds for video {video_id} failed - Bad visit id {visit_id}") |
| 146 | + return |
| 147 | + |
| 148 | + visit_id = int(visit_id) # Expecting an 8 digit integer |
| 149 | + laser_scanner_point_clouds_ids = laser_scanner_point_clouds_for_visit_id(visit_id, download_dir) |
| 150 | + |
| 151 | + for point_cloud_id in laser_scanner_point_clouds_ids: |
| 152 | + download_laser_scanner_point_clouds(point_cloud_id, visit_id, download_dir) |
| 153 | + |
| 154 | + |
| 155 | +def laser_scanner_point_clouds_for_visit_id(visit_id: int, download_dir: Path) -> List[str]: |
| 156 | + point_cloud_to_visit_id_mapping_filename = "laser_scanner_point_clouds_mapping.csv" |
| 157 | + if not os.path.exists(point_cloud_to_visit_id_mapping_filename): |
| 158 | + point_cloud_to_visit_id_mapping_url = ( |
| 159 | + f"{ARkitscense_url}/raw/laser_scanner_point_clouds/{point_cloud_to_visit_id_mapping_filename}" |
| 160 | + ) |
| 161 | + if not download_file( |
| 162 | + point_cloud_to_visit_id_mapping_url, |
| 163 | + point_cloud_to_visit_id_mapping_filename, |
| 164 | + download_dir, |
| 165 | + ): |
| 166 | + print( |
| 167 | + f"Error downloading point cloud for visit_id {visit_id} at location " |
| 168 | + f"{point_cloud_to_visit_id_mapping_url}" |
| 169 | + ) |
| 170 | + return [] |
| 171 | + |
| 172 | + point_cloud_to_visit_id_mapping_filepath = os.path.join(download_dir, point_cloud_to_visit_id_mapping_filename) |
| 173 | + point_cloud_to_visit_id_mapping = pd.read_csv(point_cloud_to_visit_id_mapping_filepath) |
| 174 | + point_cloud_ids = point_cloud_to_visit_id_mapping.loc[ |
| 175 | + point_cloud_to_visit_id_mapping["visit_id"] == visit_id, |
| 176 | + ["laser_scanner_point_clouds_id"], |
| 177 | + ] |
| 178 | + point_cloud_ids_list = [scan_id[0] for scan_id in point_cloud_ids.values] |
| 179 | + |
| 180 | + return point_cloud_ids_list |
| 181 | + |
| 182 | + |
| 183 | +def download_laser_scanner_point_clouds(laser_scanner_point_cloud_id: str, visit_id: int, download_dir: Path) -> None: |
| 184 | + laser_scanner_point_clouds_folder_path = download_dir / POINT_CLOUDS_FOLDER / str(visit_id) |
| 185 | + os.makedirs(laser_scanner_point_clouds_folder_path, exist_ok=True) |
| 186 | + |
| 187 | + for extension in [".ply", "_pose.txt"]: |
| 188 | + filename = f"{laser_scanner_point_cloud_id}{extension}" |
| 189 | + filepath = os.path.join(laser_scanner_point_clouds_folder_path, filename) |
| 190 | + if os.path.exists(filepath): |
| 191 | + return |
| 192 | + file_url = f"{ARkitscense_url}/raw/laser_scanner_point_clouds/{visit_id}/{filename}" |
| 193 | + download_file(file_url, filename, laser_scanner_point_clouds_folder_path) |
| 194 | + |
| 195 | + |
| 196 | +def get_metadata(dataset: str, download_dir: Path) -> pd.DataFrame: |
| 197 | + filename = "metadata.csv" |
| 198 | + url = f"{ARkitscense_url}/threedod/{filename}" if "3dod" == dataset else f"{ARkitscense_url}/{dataset}/{filename}" |
| 199 | + dst_folder = download_dir / dataset |
| 200 | + dst_file = dst_folder / filename |
| 201 | + |
| 202 | + if not download_file(url, filename, dst_folder): |
| 203 | + return |
| 204 | + |
| 205 | + metadata = pd.read_csv(dst_file) |
| 206 | + return metadata |
| 207 | + |
| 208 | + |
| 209 | +def download_data( |
| 210 | + dataset: str, |
| 211 | + video_ids: List[str], |
| 212 | + dataset_splits: List[str], |
| 213 | + download_dir: Path, |
| 214 | + keep_zip: bool, |
| 215 | + raw_dataset_assets: Optional[List[str]] = None, |
| 216 | + should_download_laser_scanner_point_cloud: bool = False, |
| 217 | +) -> None: |
| 218 | + """ |
| 219 | + Downloads data from the specified dataset and video IDs to the given download directory. |
| 220 | +
|
| 221 | + Args: |
| 222 | + ---- |
| 223 | + dataset: the name of the dataset to download from (raw, 3dod, or upsampling) |
| 224 | + video_ids: the list of video IDs to download data for |
| 225 | + dataset_splits: the list of splits for each video ID (train, validation, or test) |
| 226 | + download_dir: the directory to download data to |
| 227 | + keep_zip: whether to keep the downloaded zip files after extracting them |
| 228 | + raw_dataset_assets: a list of asset types to download from the raw dataset, if dataset is "raw" |
| 229 | + should_download_laser_scanner_point_cloud: whether to download the laser scanner point cloud data, if available |
| 230 | +
|
| 231 | + Returns: None |
| 232 | + """ |
| 233 | + metadata = get_metadata(dataset, download_dir) |
| 234 | + if None is metadata: |
| 235 | + print(f"Error retrieving metadata for dataset {dataset}") |
| 236 | + return |
| 237 | + |
| 238 | + for video_id in sorted(set(video_ids)): |
| 239 | + split = dataset_splits[video_ids.index(video_id)] |
| 240 | + dst_dir = download_dir / dataset / split |
| 241 | + if dataset == "raw": |
| 242 | + url_prefix = "" |
| 243 | + file_names = [] |
| 244 | + if not raw_dataset_assets: |
| 245 | + print(f"Warning: No raw assets given for video id {video_id}") |
| 246 | + else: |
| 247 | + dst_dir = dst_dir / str(video_id) |
| 248 | + url_prefix = f"{ARkitscense_url}/raw/{split}/{video_id}" + "/{}" |
| 249 | + file_names = raw_files(video_id, raw_dataset_assets, metadata) |
| 250 | + elif dataset == "3dod": |
| 251 | + url_prefix = f"{ARkitscense_url}/threedod/{split}" + "/{}" |
| 252 | + file_names = [ |
| 253 | + f"{video_id}.zip", |
| 254 | + ] |
| 255 | + elif dataset == "upsampling": |
| 256 | + url_prefix = f"{ARkitscense_url}/upsampling/{split}" + "/{}" |
| 257 | + file_names = [ |
| 258 | + f"{video_id}.zip", |
| 259 | + ] |
| 260 | + else: |
| 261 | + raise Exception(f"No such dataset = {dataset}") |
| 262 | + |
| 263 | + if should_download_laser_scanner_point_cloud and dataset == "raw": |
| 264 | + # Point clouds only available for the raw dataset |
| 265 | + download_laser_scanner_point_clouds_for_video(video_id, metadata, download_dir) |
| 266 | + |
| 267 | + for file_name in file_names: |
| 268 | + dst_path = os.path.join(dst_dir, file_name) |
| 269 | + url = url_prefix.format(file_name) |
| 270 | + |
| 271 | + if not file_name.endswith(".zip") or not os.path.isdir(dst_path[: -len(".zip")]): |
| 272 | + download_file(url, dst_path, dst_dir) |
| 273 | + else: |
| 274 | + print(f"WARNING: skipping download of existing zip file: {dst_path}") |
| 275 | + if file_name.endswith(".zip") and os.path.isfile(dst_path): |
| 276 | + unzip_file(file_name, dst_dir, keep_zip) |
| 277 | + |
| 278 | + |
| 279 | +def ensure_recording_downloaded(video_id: str, include_highres: bool) -> Path: |
| 280 | + """Only downloads from validation set.""" |
| 281 | + data_path = DATASET_DIR / "raw" / "Validation" / video_id |
| 282 | + assets_to_download = [ |
| 283 | + "lowres_wide", |
| 284 | + "lowres_depth", |
| 285 | + "lowres_wide_intrinsics", |
| 286 | + "lowres_wide.traj", |
| 287 | + "annotation", |
| 288 | + "mesh", |
| 289 | + ] |
| 290 | + if include_highres: |
| 291 | + assets_to_download.extend(["highres_depth", "wide", "wide_intrinsics"]) |
| 292 | + download_data( |
| 293 | + dataset="raw", |
| 294 | + video_ids=[video_id], |
| 295 | + dataset_splits=[VALIDATION], |
| 296 | + download_dir=DATASET_DIR, |
| 297 | + keep_zip=False, |
| 298 | + raw_dataset_assets=assets_to_download, |
| 299 | + should_download_laser_scanner_point_cloud=False, |
| 300 | + ) |
| 301 | + return data_path |
| 302 | + |
| 303 | + |
| 304 | +def ensure_recording_available(video_id: str, include_highres: bool) -> Path: |
| 305 | + """ |
| 306 | + Returns the path to the recording for a given video_id. |
| 307 | +
|
| 308 | + Args: |
| 309 | + video_id (str): Identifier for the recording. |
| 310 | +
|
| 311 | + Returns |
| 312 | + ------- |
| 313 | + Path: Path object representing the path to the recording. |
| 314 | +
|
| 315 | + Raises |
| 316 | + ------ |
| 317 | + AssertionError: If the recording path does not exist. |
| 318 | + """ |
| 319 | + recording_path = ensure_recording_downloaded(video_id, include_highres) |
| 320 | + assert recording_path.exists(), f"Recording path {recording_path} does not exist." |
| 321 | + return recording_path # Return the path to the recording |
0 commit comments