1import functools
2import hashlib
3import json
4import multiprocessing as mp
5import re
6import shutil
7import subprocess
8import sys
9import tomllib
10from os.path import islink, realpath
11from pathlib import Path
12from typing import Any, TypedDict, cast
13from urllib.parse import unquote
14
15import requests
16from requests.adapters import HTTPAdapter, Retry
17
18eprint = functools.partial(print, file=sys.stderr)
19
20
21def load_toml(path: Path) -> dict[str, Any]:
22 with open(path, "rb") as f:
23 return tomllib.load(f)
24
25
26def get_lockfile_version(cargo_lock_toml: dict[str, Any]) -> int:
27 # lockfile v1 and v2 don't have the `version` key, so assume v2
28 version = cargo_lock_toml.get("version", 2)
29
30 # TODO: add logic for differentiating between v1 and v2
31
32 return version
33
34
35def create_http_session() -> requests.Session:
36 retries = Retry(
37 total=5,
38 backoff_factor=0.5,
39 status_forcelist=[500, 502, 503, 504]
40 )
41 session = requests.Session()
42 session.mount('http://', HTTPAdapter(max_retries=retries))
43 session.mount('https://', HTTPAdapter(max_retries=retries))
44 return session
45
46
47def download_file_with_checksum(session: requests.Session, url: str, destination_path: Path) -> str:
48 sha256_hash = hashlib.sha256()
49 with session.get(url, stream=True) as response:
50 if not response.ok:
51 raise Exception(f"Failed to fetch file from {url}. Status code: {response.status_code}")
52 with open(destination_path, "wb") as file:
53 for chunk in response.iter_content(1024): # Download in chunks
54 if chunk: # Filter out keep-alive chunks
55 file.write(chunk)
56 sha256_hash.update(chunk)
57
58 # Compute the final checksum
59 checksum = sha256_hash.hexdigest()
60 return checksum
61
62
63def get_download_url_for_tarball(pkg: dict[str, Any]) -> str:
64 # TODO: support other registries
65 # maybe fetch config.json from the registry root and get the dl key
66 # See: https://doc.rust-lang.org/cargo/reference/registry-index.html#index-configuration
67 if pkg["source"] != "registry+https://github.com/rust-lang/crates.io-index":
68 raise Exception("Only the default crates.io registry is supported.")
69
70 return f"https://crates.io/api/v1/crates/{pkg["name"]}/{pkg["version"]}/download"
71
72
73def download_tarball(session: requests.Session, pkg: dict[str, Any], out_dir: Path) -> None:
74
75 url = get_download_url_for_tarball(pkg)
76 filename = f"{pkg["name"]}-{pkg["version"]}.tar.gz"
77
78 # TODO: allow legacy checksum specification, see importCargoLock for example
79 # also, don't forget about the other usage of the checksum
80 expected_checksum = pkg["checksum"]
81
82 tarball_out_dir = out_dir / "tarballs" / filename
83 eprint(f"Fetching {url} -> tarballs/{filename}")
84
85 calculated_checksum = download_file_with_checksum(session, url, tarball_out_dir)
86
87 if calculated_checksum != expected_checksum:
88 raise Exception(f"Hash mismatch! File fetched from {url} had checksum {calculated_checksum}, expected {expected_checksum}.")
89
90
91def download_git_tree(url: str, git_sha_rev: str, out_dir: Path) -> None:
92
93 tree_out_dir = out_dir / "git" / git_sha_rev
94 eprint(f"Fetching {url}#{git_sha_rev} -> git/{git_sha_rev}")
95
96 cmd = ["nix-prefetch-git", "--builder", "--quiet", "--fetch-submodules", "--url", url, "--rev", git_sha_rev, "--out", str(tree_out_dir)]
97 subprocess.check_output(cmd)
98
99
100GIT_SOURCE_REGEX = re.compile("git\\+(?P<url>[^?]+)(\\?(?P<type>rev|tag|branch)=(?P<value>.*))?#(?P<git_sha_rev>.*)")
101
102
103class GitSourceInfo(TypedDict):
104 url: str
105 type: str | None
106 value: str | None
107 git_sha_rev: str
108
109
110def parse_git_source(source: str, lockfile_version: int) -> GitSourceInfo:
111 match = GIT_SOURCE_REGEX.match(source)
112 if match is None:
113 raise Exception(f"Unable to process git source: {source}.")
114
115 source_info = cast(GitSourceInfo, match.groupdict(default=None))
116
117 # the source URL is URL-encoded in lockfile_version >=4
118 # since we just used regex to parse it we have to manually decode the escaped branch/tag name
119 if lockfile_version >= 4 and source_info["value"] is not None:
120 source_info["value"] = unquote(source_info["value"])
121
122 return source_info
123
124
125def create_vendor_staging(lockfile_path: Path, out_dir: Path) -> None:
126 cargo_lock_toml = load_toml(lockfile_path)
127 lockfile_version = get_lockfile_version(cargo_lock_toml)
128
129 git_packages: list[dict[str, Any]] = []
130 registry_packages: list[dict[str, Any]] = []
131
132 for pkg in cargo_lock_toml["package"]:
133 # ignore local dependenices
134 if "source" not in pkg.keys():
135 eprint(f"Skipping local dependency: {pkg["name"]}")
136 continue
137 source = pkg["source"]
138
139 if source.startswith("git+"):
140 git_packages.append(pkg)
141 elif source.startswith("registry+"):
142 registry_packages.append(pkg)
143 else:
144 raise Exception(f"Can't process source: {source}.")
145
146 git_sha_rev_to_url: dict[str, str] = {}
147 for pkg in git_packages:
148 source_info = parse_git_source(pkg["source"], lockfile_version)
149 git_sha_rev_to_url[source_info["git_sha_rev"]] = source_info["url"]
150
151 out_dir.mkdir(exist_ok=True)
152 shutil.copy(lockfile_path, out_dir / "Cargo.lock")
153
154 # fetch git trees sequentially, since fetching concurrently leads to flaky behaviour
155 if len(git_packages) != 0:
156 (out_dir / "git").mkdir()
157 for git_sha_rev, url in git_sha_rev_to_url.items():
158 download_git_tree(url, git_sha_rev, out_dir)
159
160 # run tarball download jobs in parallel, with at most 5 concurrent download jobs
161 with mp.Pool(min(5, mp.cpu_count())) as pool:
162 if len(registry_packages) != 0:
163 (out_dir / "tarballs").mkdir()
164 session = create_http_session()
165 tarball_args_gen = ((session, pkg, out_dir) for pkg in registry_packages)
166 pool.starmap(download_tarball, tarball_args_gen)
167
168
169def get_manifest_metadata(manifest_path: Path) -> dict[str, Any]:
170 cmd = ["cargo", "metadata", "--format-version", "1", "--no-deps", "--manifest-path", str(manifest_path)]
171 output = subprocess.check_output(cmd)
172 return json.loads(output)
173
174
175def try_get_crate_manifest_path_from_mainfest_path(manifest_path: Path, crate_name: str) -> Path | None:
176 metadata = get_manifest_metadata(manifest_path)
177
178 for pkg in metadata["packages"]:
179 if pkg["name"] == crate_name:
180 return Path(pkg["manifest_path"])
181
182 return None
183
184
185def find_crate_manifest_in_tree(tree: Path, crate_name: str) -> Path:
186 # in some cases Cargo.toml is not located at the top level, so we also look at subdirectories
187 manifest_paths = tree.glob("**/Cargo.toml")
188
189 for manifest_path in manifest_paths:
190 res = try_get_crate_manifest_path_from_mainfest_path(manifest_path, crate_name)
191 if res is not None:
192 return res
193
194 raise Exception(f"Couldn't find manifest for crate {crate_name} inside {tree}.")
195
196
197def copy_and_patch_git_crate_subtree(git_tree: Path, crate_name: str, crate_out_dir: Path) -> None:
198
199 # This function will get called by copytree to decide which entries of a directory should be copied
200 # We'll copy everything except symlinks that are invalid
201 def ignore_func(dir_str: str, path_strs: list[str]) -> list[str]:
202 ignorelist: list[str] = []
203
204 dir = Path(realpath(dir_str, strict=True))
205
206 for path_str in path_strs:
207 path = dir / path_str
208 if not islink(path):
209 continue
210
211 # Filter out cyclic symlinks and symlinks pointing at nonexistant files
212 try:
213 target_path = Path(realpath(path, strict=True))
214 except OSError:
215 ignorelist.append(path_str)
216 eprint(f"Failed to resolve symlink, ignoring: {path}")
217 continue
218
219 # Filter out symlinks that point outside of the current crate's base git tree
220 # This can be useful if the nix build sandbox is turned off and there is a symlink to a common absolute path
221 if not target_path.is_relative_to(git_tree):
222 ignorelist.append(path_str)
223 eprint(f"Symlink points outside of the crate's base git tree, ignoring: {path} -> {target_path}")
224 continue
225
226 return ignorelist
227
228 crate_manifest_path = find_crate_manifest_in_tree(git_tree, crate_name)
229 crate_tree = crate_manifest_path.parent
230
231 eprint(f"Copying to {crate_out_dir}")
232 shutil.copytree(crate_tree, crate_out_dir, ignore=ignore_func)
233 crate_out_dir.chmod(0o755)
234
235 with open(crate_manifest_path, "r") as f:
236 manifest_data = f.read()
237
238 if "workspace" in manifest_data:
239 crate_manifest_metadata = get_manifest_metadata(crate_manifest_path)
240 workspace_root = Path(crate_manifest_metadata["workspace_root"])
241
242 root_manifest_path = workspace_root / "Cargo.toml"
243 manifest_path = crate_out_dir / "Cargo.toml"
244
245 manifest_path.chmod(0o644)
246 eprint(f"Patching {manifest_path}")
247
248 cmd = ["replace-workspace-values", str(manifest_path), str(root_manifest_path)]
249 subprocess.check_output(cmd)
250
251
252def extract_crate_tarball_contents(tarball_path: Path, crate_out_dir: Path) -> None:
253 eprint(f"Unpacking to {crate_out_dir}")
254 crate_out_dir.mkdir()
255 cmd = ["tar", "xf", str(tarball_path), "-C", str(crate_out_dir), "--strip-components=1"]
256 subprocess.check_output(cmd)
257
258
259def create_vendor(vendor_staging_dir: Path, out_dir: Path) -> None:
260 lockfile_path = vendor_staging_dir / "Cargo.lock"
261 out_dir.mkdir(exist_ok=True)
262 shutil.copy(lockfile_path, out_dir / "Cargo.lock")
263
264 cargo_lock_toml = load_toml(lockfile_path)
265 lockfile_version = get_lockfile_version(cargo_lock_toml)
266
267 config_lines = [
268 '[source.vendored-sources]',
269 'directory = "@vendor@"',
270 '[source.crates-io]',
271 'replace-with = "vendored-sources"',
272 ]
273
274 seen_source_keys = set()
275 for pkg in cargo_lock_toml["package"]:
276
277 # ignore local dependenices
278 if "source" not in pkg.keys():
279 continue
280
281 source: str = pkg["source"]
282
283 dir_name = f"{pkg["name"]}-{pkg["version"]}"
284 crate_out_dir = out_dir / dir_name
285
286 if source.startswith("git+"):
287
288 source_info = parse_git_source(pkg["source"], lockfile_version)
289
290 git_sha_rev = source_info["git_sha_rev"]
291 git_tree = vendor_staging_dir / "git" / git_sha_rev
292
293 copy_and_patch_git_crate_subtree(git_tree, pkg["name"], crate_out_dir)
294
295 # git based crates allow having no checksum information
296 with open(crate_out_dir / ".cargo-checksum.json", "w") as f:
297 json.dump({"files": {}}, f)
298
299 source_key = source[0:source.find("#")]
300
301 if source_key in seen_source_keys:
302 continue
303
304 seen_source_keys.add(source_key)
305
306 config_lines.append(f'[source."{source_key}"]')
307 config_lines.append(f'git = "{source_info["url"]}"')
308 if source_info["type"] is not None:
309 config_lines.append(f'{source_info["type"]} = "{source_info["value"]}"')
310 config_lines.append('replace-with = "vendored-sources"')
311
312 elif source.startswith("registry+"):
313
314 filename = f"{pkg["name"]}-{pkg["version"]}.tar.gz"
315 tarball_path = vendor_staging_dir / "tarballs" / filename
316
317 extract_crate_tarball_contents(tarball_path, crate_out_dir)
318
319 # non-git based crates need the package checksum at minimum
320 with open(crate_out_dir / ".cargo-checksum.json", "w") as f:
321 json.dump({"files": {}, "package": pkg["checksum"]}, f)
322
323 else:
324 raise Exception(f"Can't process source: {source}.")
325
326 (out_dir / ".cargo").mkdir()
327 with open(out_dir / ".cargo" / "config.toml", "w") as config_file:
328 config_file.writelines(line + "\n" for line in config_lines)
329
330
331def main() -> None:
332 subcommand = sys.argv[1]
333
334 subcommand_func_dict = {
335 "create-vendor-staging": lambda: create_vendor_staging(lockfile_path=Path(sys.argv[2]), out_dir=Path(sys.argv[3])),
336 "create-vendor": lambda: create_vendor(vendor_staging_dir=Path(sys.argv[2]), out_dir=Path(sys.argv[3]))
337 }
338
339 subcommand_func = subcommand_func_dict.get(subcommand)
340
341 if subcommand_func is None:
342 raise Exception(f"Unknown subcommand: '{subcommand}'. Must be one of {list(subcommand_func_dict.keys())}")
343
344 subcommand_func()
345
346
347if __name__ == "__main__":
348 main()