1import functools 2import hashlib 3import json 4import multiprocessing as mp 5import re 6import shutil 7import subprocess 8import sys 9import tomllib 10from os.path import islink, realpath 11from pathlib import Path 12from typing import Any, TypedDict, cast 13from urllib.parse import unquote 14 15import requests 16from requests.adapters import HTTPAdapter, Retry 17 18eprint = functools.partial(print, file=sys.stderr) 19 20 21def load_toml(path: Path) -> dict[str, Any]: 22 with open(path, "rb") as f: 23 return tomllib.load(f) 24 25 26def get_lockfile_version(cargo_lock_toml: dict[str, Any]) -> int: 27 # lockfile v1 and v2 don't have the `version` key, so assume v2 28 version = cargo_lock_toml.get("version", 2) 29 30 # TODO: add logic for differentiating between v1 and v2 31 32 return version 33 34 35def create_http_session() -> requests.Session: 36 retries = Retry( 37 total=5, 38 backoff_factor=0.5, 39 status_forcelist=[500, 502, 503, 504] 40 ) 41 session = requests.Session() 42 session.mount('http://', HTTPAdapter(max_retries=retries)) 43 session.mount('https://', HTTPAdapter(max_retries=retries)) 44 return session 45 46 47def download_file_with_checksum(session: requests.Session, url: str, destination_path: Path) -> str: 48 sha256_hash = hashlib.sha256() 49 with session.get(url, stream=True) as response: 50 if not response.ok: 51 raise Exception(f"Failed to fetch file from {url}. Status code: {response.status_code}") 52 with open(destination_path, "wb") as file: 53 for chunk in response.iter_content(1024): # Download in chunks 54 if chunk: # Filter out keep-alive chunks 55 file.write(chunk) 56 sha256_hash.update(chunk) 57 58 # Compute the final checksum 59 checksum = sha256_hash.hexdigest() 60 return checksum 61 62 63def get_download_url_for_tarball(pkg: dict[str, Any]) -> str: 64 # TODO: support other registries 65 # maybe fetch config.json from the registry root and get the dl key 66 # See: https://doc.rust-lang.org/cargo/reference/registry-index.html#index-configuration 67 if pkg["source"] != "registry+https://github.com/rust-lang/crates.io-index": 68 raise Exception("Only the default crates.io registry is supported.") 69 70 return f"https://crates.io/api/v1/crates/{pkg["name"]}/{pkg["version"]}/download" 71 72 73def download_tarball(session: requests.Session, pkg: dict[str, Any], out_dir: Path) -> None: 74 75 url = get_download_url_for_tarball(pkg) 76 filename = f"{pkg["name"]}-{pkg["version"]}.tar.gz" 77 78 # TODO: allow legacy checksum specification, see importCargoLock for example 79 # also, don't forget about the other usage of the checksum 80 expected_checksum = pkg["checksum"] 81 82 tarball_out_dir = out_dir / "tarballs" / filename 83 eprint(f"Fetching {url} -> tarballs/{filename}") 84 85 calculated_checksum = download_file_with_checksum(session, url, tarball_out_dir) 86 87 if calculated_checksum != expected_checksum: 88 raise Exception(f"Hash mismatch! File fetched from {url} had checksum {calculated_checksum}, expected {expected_checksum}.") 89 90 91def download_git_tree(url: str, git_sha_rev: str, out_dir: Path) -> None: 92 93 tree_out_dir = out_dir / "git" / git_sha_rev 94 eprint(f"Fetching {url}#{git_sha_rev} -> git/{git_sha_rev}") 95 96 cmd = ["nix-prefetch-git", "--builder", "--quiet", "--fetch-submodules", "--url", url, "--rev", git_sha_rev, "--out", str(tree_out_dir)] 97 subprocess.check_output(cmd) 98 99 100GIT_SOURCE_REGEX = re.compile("git\\+(?P<url>[^?]+)(\\?(?P<type>rev|tag|branch)=(?P<value>.*))?#(?P<git_sha_rev>.*)") 101 102 103class GitSourceInfo(TypedDict): 104 url: str 105 type: str | None 106 value: str | None 107 git_sha_rev: str 108 109 110def parse_git_source(source: str, lockfile_version: int) -> GitSourceInfo: 111 match = GIT_SOURCE_REGEX.match(source) 112 if match is None: 113 raise Exception(f"Unable to process git source: {source}.") 114 115 source_info = cast(GitSourceInfo, match.groupdict(default=None)) 116 117 # the source URL is URL-encoded in lockfile_version >=4 118 # since we just used regex to parse it we have to manually decode the escaped branch/tag name 119 if lockfile_version >= 4 and source_info["value"] is not None: 120 source_info["value"] = unquote(source_info["value"]) 121 122 return source_info 123 124 125def create_vendor_staging(lockfile_path: Path, out_dir: Path) -> None: 126 cargo_lock_toml = load_toml(lockfile_path) 127 lockfile_version = get_lockfile_version(cargo_lock_toml) 128 129 git_packages: list[dict[str, Any]] = [] 130 registry_packages: list[dict[str, Any]] = [] 131 132 for pkg in cargo_lock_toml["package"]: 133 # ignore local dependenices 134 if "source" not in pkg.keys(): 135 eprint(f"Skipping local dependency: {pkg["name"]}") 136 continue 137 source = pkg["source"] 138 139 if source.startswith("git+"): 140 git_packages.append(pkg) 141 elif source.startswith("registry+"): 142 registry_packages.append(pkg) 143 else: 144 raise Exception(f"Can't process source: {source}.") 145 146 git_sha_rev_to_url: dict[str, str] = {} 147 for pkg in git_packages: 148 source_info = parse_git_source(pkg["source"], lockfile_version) 149 git_sha_rev_to_url[source_info["git_sha_rev"]] = source_info["url"] 150 151 out_dir.mkdir(exist_ok=True) 152 shutil.copy(lockfile_path, out_dir / "Cargo.lock") 153 154 # fetch git trees sequentially, since fetching concurrently leads to flaky behaviour 155 if len(git_packages) != 0: 156 (out_dir / "git").mkdir() 157 for git_sha_rev, url in git_sha_rev_to_url.items(): 158 download_git_tree(url, git_sha_rev, out_dir) 159 160 # run tarball download jobs in parallel, with at most 5 concurrent download jobs 161 with mp.Pool(min(5, mp.cpu_count())) as pool: 162 if len(registry_packages) != 0: 163 (out_dir / "tarballs").mkdir() 164 session = create_http_session() 165 tarball_args_gen = ((session, pkg, out_dir) for pkg in registry_packages) 166 pool.starmap(download_tarball, tarball_args_gen) 167 168 169def get_manifest_metadata(manifest_path: Path) -> dict[str, Any]: 170 cmd = ["cargo", "metadata", "--format-version", "1", "--no-deps", "--manifest-path", str(manifest_path)] 171 output = subprocess.check_output(cmd) 172 return json.loads(output) 173 174 175def try_get_crate_manifest_path_from_mainfest_path(manifest_path: Path, crate_name: str) -> Path | None: 176 metadata = get_manifest_metadata(manifest_path) 177 178 for pkg in metadata["packages"]: 179 if pkg["name"] == crate_name: 180 return Path(pkg["manifest_path"]) 181 182 return None 183 184 185def find_crate_manifest_in_tree(tree: Path, crate_name: str) -> Path: 186 # in some cases Cargo.toml is not located at the top level, so we also look at subdirectories 187 manifest_paths = tree.glob("**/Cargo.toml") 188 189 for manifest_path in manifest_paths: 190 res = try_get_crate_manifest_path_from_mainfest_path(manifest_path, crate_name) 191 if res is not None: 192 return res 193 194 raise Exception(f"Couldn't find manifest for crate {crate_name} inside {tree}.") 195 196 197def copy_and_patch_git_crate_subtree(git_tree: Path, crate_name: str, crate_out_dir: Path) -> None: 198 199 # This function will get called by copytree to decide which entries of a directory should be copied 200 # We'll copy everything except symlinks that are invalid 201 def ignore_func(dir_str: str, path_strs: list[str]) -> list[str]: 202 ignorelist: list[str] = [] 203 204 dir = Path(realpath(dir_str, strict=True)) 205 206 for path_str in path_strs: 207 path = dir / path_str 208 if not islink(path): 209 continue 210 211 # Filter out cyclic symlinks and symlinks pointing at nonexistant files 212 try: 213 target_path = Path(realpath(path, strict=True)) 214 except OSError: 215 ignorelist.append(path_str) 216 eprint(f"Failed to resolve symlink, ignoring: {path}") 217 continue 218 219 # Filter out symlinks that point outside of the current crate's base git tree 220 # This can be useful if the nix build sandbox is turned off and there is a symlink to a common absolute path 221 if not target_path.is_relative_to(git_tree): 222 ignorelist.append(path_str) 223 eprint(f"Symlink points outside of the crate's base git tree, ignoring: {path} -> {target_path}") 224 continue 225 226 return ignorelist 227 228 crate_manifest_path = find_crate_manifest_in_tree(git_tree, crate_name) 229 crate_tree = crate_manifest_path.parent 230 231 eprint(f"Copying to {crate_out_dir}") 232 shutil.copytree(crate_tree, crate_out_dir, ignore=ignore_func) 233 crate_out_dir.chmod(0o755) 234 235 with open(crate_manifest_path, "r") as f: 236 manifest_data = f.read() 237 238 if "workspace" in manifest_data: 239 crate_manifest_metadata = get_manifest_metadata(crate_manifest_path) 240 workspace_root = Path(crate_manifest_metadata["workspace_root"]) 241 242 root_manifest_path = workspace_root / "Cargo.toml" 243 manifest_path = crate_out_dir / "Cargo.toml" 244 245 manifest_path.chmod(0o644) 246 eprint(f"Patching {manifest_path}") 247 248 cmd = ["replace-workspace-values", str(manifest_path), str(root_manifest_path)] 249 subprocess.check_output(cmd) 250 251 252def extract_crate_tarball_contents(tarball_path: Path, crate_out_dir: Path) -> None: 253 eprint(f"Unpacking to {crate_out_dir}") 254 crate_out_dir.mkdir() 255 cmd = ["tar", "xf", str(tarball_path), "-C", str(crate_out_dir), "--strip-components=1"] 256 subprocess.check_output(cmd) 257 258 259def create_vendor(vendor_staging_dir: Path, out_dir: Path) -> None: 260 lockfile_path = vendor_staging_dir / "Cargo.lock" 261 out_dir.mkdir(exist_ok=True) 262 shutil.copy(lockfile_path, out_dir / "Cargo.lock") 263 264 cargo_lock_toml = load_toml(lockfile_path) 265 lockfile_version = get_lockfile_version(cargo_lock_toml) 266 267 config_lines = [ 268 '[source.vendored-sources]', 269 'directory = "@vendor@"', 270 '[source.crates-io]', 271 'replace-with = "vendored-sources"', 272 ] 273 274 seen_source_keys = set() 275 for pkg in cargo_lock_toml["package"]: 276 277 # ignore local dependenices 278 if "source" not in pkg.keys(): 279 continue 280 281 source: str = pkg["source"] 282 283 dir_name = f"{pkg["name"]}-{pkg["version"]}" 284 crate_out_dir = out_dir / dir_name 285 286 if source.startswith("git+"): 287 288 source_info = parse_git_source(pkg["source"], lockfile_version) 289 290 git_sha_rev = source_info["git_sha_rev"] 291 git_tree = vendor_staging_dir / "git" / git_sha_rev 292 293 copy_and_patch_git_crate_subtree(git_tree, pkg["name"], crate_out_dir) 294 295 # git based crates allow having no checksum information 296 with open(crate_out_dir / ".cargo-checksum.json", "w") as f: 297 json.dump({"files": {}}, f) 298 299 source_key = source[0:source.find("#")] 300 301 if source_key in seen_source_keys: 302 continue 303 304 seen_source_keys.add(source_key) 305 306 config_lines.append(f'[source."{source_key}"]') 307 config_lines.append(f'git = "{source_info["url"]}"') 308 if source_info["type"] is not None: 309 config_lines.append(f'{source_info["type"]} = "{source_info["value"]}"') 310 config_lines.append('replace-with = "vendored-sources"') 311 312 elif source.startswith("registry+"): 313 314 filename = f"{pkg["name"]}-{pkg["version"]}.tar.gz" 315 tarball_path = vendor_staging_dir / "tarballs" / filename 316 317 extract_crate_tarball_contents(tarball_path, crate_out_dir) 318 319 # non-git based crates need the package checksum at minimum 320 with open(crate_out_dir / ".cargo-checksum.json", "w") as f: 321 json.dump({"files": {}, "package": pkg["checksum"]}, f) 322 323 else: 324 raise Exception(f"Can't process source: {source}.") 325 326 (out_dir / ".cargo").mkdir() 327 with open(out_dir / ".cargo" / "config.toml", "w") as config_file: 328 config_file.writelines(line + "\n" for line in config_lines) 329 330 331def main() -> None: 332 subcommand = sys.argv[1] 333 334 subcommand_func_dict = { 335 "create-vendor-staging": lambda: create_vendor_staging(lockfile_path=Path(sys.argv[2]), out_dir=Path(sys.argv[3])), 336 "create-vendor": lambda: create_vendor(vendor_staging_dir=Path(sys.argv[2]), out_dir=Path(sys.argv[3])) 337 } 338 339 subcommand_func = subcommand_func_dict.get(subcommand) 340 341 if subcommand_func is None: 342 raise Exception(f"Unknown subcommand: '{subcommand}'. Must be one of {list(subcommand_func_dict.keys())}") 343 344 subcommand_func() 345 346 347if __name__ == "__main__": 348 main()