1#!/usr/bin/env nix-shell
2#! nix-shell -i "python3 -I" -p "python3.withPackages(p: with p; [ rich structlog ])"
3
4from abc import ABC, abstractmethod
5from contextlib import contextmanager
6from pathlib import Path
7from structlog.contextvars import bound_contextvars as log_context
8from typing import ClassVar, List, Tuple
9
10import hashlib, logging, re, structlog
11
12
13logger = structlog.getLogger("sha-to-SRI")
14
15
16class Encoding(ABC):
17 alphabet: ClassVar[str]
18
19 @classmethod
20 @property
21 def name(cls) -> str:
22 return cls.__name__.lower()
23
24 def toSRI(self, s: str) -> str:
25 digest = self.decode(s)
26 assert len(digest) == self.n
27
28 from base64 import b64encode
29
30 return f"{self.hashName}-{b64encode(digest).decode()}"
31
32 @classmethod
33 def all(cls, h) -> "List[Encoding]":
34 return [c(h) for c in cls.__subclasses__()]
35
36 def __init__(self, h):
37 self.n = h.digest_size
38 self.hashName = h.name
39
40 @property
41 @abstractmethod
42 def length(self) -> int: ...
43
44 @property
45 def regex(self) -> str:
46 return f"[{self.alphabet}]{{{self.length}}}"
47
48 @abstractmethod
49 def decode(self, s: str) -> bytes: ...
50
51
52class Nix32(Encoding):
53 alphabet = "0123456789abcdfghijklmnpqrsvwxyz"
54 inverted = {c: i for i, c in enumerate(alphabet)}
55
56 @property
57 def length(self):
58 return 1 + (8 * self.n) // 5
59
60 def decode(self, s: str):
61 assert len(s) == self.length
62 out = bytearray(self.n)
63
64 for n, c in enumerate(reversed(s)):
65 digit = self.inverted[c]
66 i, j = divmod(5 * n, 8)
67 out[i] = out[i] | (digit << j) & 0xFF
68 rem = digit >> (8 - j)
69 if rem == 0:
70 continue
71 elif i < self.n:
72 out[i + 1] = rem
73 else:
74 raise ValueError(f"Invalid nix32 hash: '{s}'")
75
76 return bytes(out)
77
78
79class Hex(Encoding):
80 alphabet = "0-9A-Fa-f"
81
82 @property
83 def length(self):
84 return 2 * self.n
85
86 def decode(self, s: str):
87 from binascii import unhexlify
88
89 return unhexlify(s)
90
91
92class Base64(Encoding):
93 alphabet = "A-Za-z0-9+/"
94
95 @property
96 def format(self) -> Tuple[int, int]:
97 """Number of characters in data and padding."""
98 i, k = divmod(self.n, 3)
99 return 4 * i + (0 if k == 0 else k + 1), (3 - k) % 3
100
101 @property
102 def length(self):
103 return sum(self.format)
104
105 @property
106 def regex(self):
107 data, padding = self.format
108 return f"[{self.alphabet}]{{{data}}}={{{padding}}}"
109
110 def decode(self, s):
111 from base64 import b64decode
112
113 return b64decode(s, validate = True)
114
115
116_HASHES = (hashlib.new(n) for n in ("SHA-256", "SHA-512"))
117ENCODINGS = {h.name: Encoding.all(h) for h in _HASHES}
118
119RE = {
120 h: "|".join(
121 (f"({h}-)?" if e.name == "base64" else "") + f"(?P<{h}_{e.name}>{e.regex})"
122 for e in encodings
123 )
124 for h, encodings in ENCODINGS.items()
125}
126
127_DEF_RE = re.compile(
128 "|".join(
129 f"(?P<{h}>{h} = (?P<{h}_quote>['\"])({re})(?P={h}_quote);)"
130 for h, re in RE.items()
131 )
132)
133
134
135def defToSRI(s: str) -> str:
136 def f(m: re.Match[str]) -> str:
137 try:
138 for h, encodings in ENCODINGS.items():
139 if m.group(h) is None:
140 continue
141
142 for e in encodings:
143 s = m.group(f"{h}_{e.name}")
144 if s is not None:
145 return f'hash = "{e.toSRI(s)}";'
146
147 raise ValueError(f"Match with '{h}' but no subgroup")
148 raise ValueError("Match with no hash")
149
150 except ValueError as exn:
151 logger.error(
152 "Skipping",
153 exc_info = exn,
154 )
155 return m.group()
156
157 return _DEF_RE.sub(f, s)
158
159
160@contextmanager
161def atomicFileUpdate(target: Path):
162 """Atomically replace the contents of a file.
163
164 Guarantees that no temporary files are left behind, and `target` is either
165 left untouched, or overwritten with new content if no exception was raised.
166
167 Yields a pair `(original, new)` of open files.
168 `original` is the pre-existing file at `target`, open for reading;
169 `new` is an empty, temporary file in the same filder, open for writing.
170
171 Upon exiting the context, the files are closed; if no exception was
172 raised, `new` (atomically) replaces the `target`, otherwise it is deleted.
173 """
174 # That's mostly copied from noto-emoji.py, should DRY it out
175 from tempfile import NamedTemporaryFile
176
177 try:
178 with target.open() as original:
179 with NamedTemporaryFile(
180 dir = target.parent,
181 prefix = target.stem,
182 suffix = target.suffix,
183 delete = False,
184 mode="w", # otherwise the file would be opened in binary mode by default
185 ) as new:
186 tmpPath = Path(new.name)
187 yield (original, new)
188
189 tmpPath.replace(target)
190
191 except Exception:
192 tmpPath.unlink(missing_ok = True)
193 raise
194
195
196def fileToSRI(p: Path):
197 with atomicFileUpdate(p) as (og, new):
198 for i, line in enumerate(og):
199 with log_context(line = i):
200 new.write(defToSRI(line))
201
202
203_SKIP_RE = re.compile("(generated by)|(do not edit)", re.IGNORECASE)
204_IGNORE = frozenset({
205 "gemset.nix",
206 "yarn.nix",
207})
208
209if __name__ == "__main__":
210 from sys import argv
211
212 logger.info("Starting!")
213
214 def handleFile(p: Path, skipLevel = logging.INFO):
215 with log_context(file = str(p)):
216 try:
217 with p.open() as f:
218 for line in f:
219 if line.strip():
220 break
221
222 if _SKIP_RE.search(line):
223 logger.log(skipLevel, "File looks autogenerated, skipping!")
224 return
225
226 fileToSRI(p)
227
228 except Exception as exn:
229 logger.error(
230 "Unhandled exception, skipping file!",
231 exc_info = exn,
232 )
233 else:
234 logger.info("Finished processing file")
235
236 for arg in argv[1:]:
237 p = Path(arg)
238 with log_context(arg = arg):
239 if p.is_file():
240 handleFile(p, skipLevel = logging.WARNING)
241
242 elif p.is_dir():
243 logger.info("Recursing into directory")
244 for q in p.glob("**/*.nix"):
245 if q.is_file():
246 if q.name in _IGNORE or q.name.find("generated") != -1:
247 logger.info("File looks autogenerated, skipping!")
248 continue
249
250 handleFile(q)