at master 6.8 kB view raw
1#!/usr/bin/env nix-shell 2#! nix-shell -i "python3 -I" -p "python3.withPackages(p: with p; [ rich structlog ])" 3 4from abc import ABC, abstractmethod 5from contextlib import contextmanager 6from pathlib import Path 7from structlog.contextvars import bound_contextvars as log_context 8from typing import ClassVar, List, Tuple 9 10import hashlib, logging, re, structlog 11 12 13logger = structlog.getLogger("sha-to-SRI") 14 15 16class Encoding(ABC): 17 alphabet: ClassVar[str] 18 19 @classmethod 20 @property 21 def name(cls) -> str: 22 return cls.__name__.lower() 23 24 def toSRI(self, s: str) -> str: 25 digest = self.decode(s) 26 assert len(digest) == self.n 27 28 from base64 import b64encode 29 30 return f"{self.hashName}-{b64encode(digest).decode()}" 31 32 @classmethod 33 def all(cls, h) -> "List[Encoding]": 34 return [c(h) for c in cls.__subclasses__()] 35 36 def __init__(self, h): 37 self.n = h.digest_size 38 self.hashName = h.name 39 40 @property 41 @abstractmethod 42 def length(self) -> int: ... 43 44 @property 45 def regex(self) -> str: 46 return f"[{self.alphabet}]{{{self.length}}}" 47 48 @abstractmethod 49 def decode(self, s: str) -> bytes: ... 50 51 52class Nix32(Encoding): 53 alphabet = "0123456789abcdfghijklmnpqrsvwxyz" 54 inverted = {c: i for i, c in enumerate(alphabet)} 55 56 @property 57 def length(self): 58 return 1 + (8 * self.n) // 5 59 60 def decode(self, s: str): 61 assert len(s) == self.length 62 out = bytearray(self.n) 63 64 for n, c in enumerate(reversed(s)): 65 digit = self.inverted[c] 66 i, j = divmod(5 * n, 8) 67 out[i] = out[i] | (digit << j) & 0xFF 68 rem = digit >> (8 - j) 69 if rem == 0: 70 continue 71 elif i < self.n: 72 out[i + 1] = rem 73 else: 74 raise ValueError(f"Invalid nix32 hash: '{s}'") 75 76 return bytes(out) 77 78 79class Hex(Encoding): 80 alphabet = "0-9A-Fa-f" 81 82 @property 83 def length(self): 84 return 2 * self.n 85 86 def decode(self, s: str): 87 from binascii import unhexlify 88 89 return unhexlify(s) 90 91 92class Base64(Encoding): 93 alphabet = "A-Za-z0-9+/" 94 95 @property 96 def format(self) -> Tuple[int, int]: 97 """Number of characters in data and padding.""" 98 i, k = divmod(self.n, 3) 99 return 4 * i + (0 if k == 0 else k + 1), (3 - k) % 3 100 101 @property 102 def length(self): 103 return sum(self.format) 104 105 @property 106 def regex(self): 107 data, padding = self.format 108 return f"[{self.alphabet}]{{{data}}}={{{padding}}}" 109 110 def decode(self, s): 111 from base64 import b64decode 112 113 return b64decode(s, validate = True) 114 115 116_HASHES = (hashlib.new(n) for n in ("SHA-256", "SHA-512")) 117ENCODINGS = {h.name: Encoding.all(h) for h in _HASHES} 118 119RE = { 120 h: "|".join( 121 (f"({h}-)?" if e.name == "base64" else "") + f"(?P<{h}_{e.name}>{e.regex})" 122 for e in encodings 123 ) 124 for h, encodings in ENCODINGS.items() 125} 126 127_DEF_RE = re.compile( 128 "|".join( 129 f"(?P<{h}>{h} = (?P<{h}_quote>['\"])({re})(?P={h}_quote);)" 130 for h, re in RE.items() 131 ) 132) 133 134 135def defToSRI(s: str) -> str: 136 def f(m: re.Match[str]) -> str: 137 try: 138 for h, encodings in ENCODINGS.items(): 139 if m.group(h) is None: 140 continue 141 142 for e in encodings: 143 s = m.group(f"{h}_{e.name}") 144 if s is not None: 145 return f'hash = "{e.toSRI(s)}";' 146 147 raise ValueError(f"Match with '{h}' but no subgroup") 148 raise ValueError("Match with no hash") 149 150 except ValueError as exn: 151 logger.error( 152 "Skipping", 153 exc_info = exn, 154 ) 155 return m.group() 156 157 return _DEF_RE.sub(f, s) 158 159 160@contextmanager 161def atomicFileUpdate(target: Path): 162 """Atomically replace the contents of a file. 163 164 Guarantees that no temporary files are left behind, and `target` is either 165 left untouched, or overwritten with new content if no exception was raised. 166 167 Yields a pair `(original, new)` of open files. 168 `original` is the pre-existing file at `target`, open for reading; 169 `new` is an empty, temporary file in the same filder, open for writing. 170 171 Upon exiting the context, the files are closed; if no exception was 172 raised, `new` (atomically) replaces the `target`, otherwise it is deleted. 173 """ 174 # That's mostly copied from noto-emoji.py, should DRY it out 175 from tempfile import NamedTemporaryFile 176 177 try: 178 with target.open() as original: 179 with NamedTemporaryFile( 180 dir = target.parent, 181 prefix = target.stem, 182 suffix = target.suffix, 183 delete = False, 184 mode="w", # otherwise the file would be opened in binary mode by default 185 ) as new: 186 tmpPath = Path(new.name) 187 yield (original, new) 188 189 tmpPath.replace(target) 190 191 except Exception: 192 tmpPath.unlink(missing_ok = True) 193 raise 194 195 196def fileToSRI(p: Path): 197 with atomicFileUpdate(p) as (og, new): 198 for i, line in enumerate(og): 199 with log_context(line = i): 200 new.write(defToSRI(line)) 201 202 203_SKIP_RE = re.compile("(generated by)|(do not edit)", re.IGNORECASE) 204_IGNORE = frozenset({ 205 "gemset.nix", 206 "yarn.nix", 207}) 208 209if __name__ == "__main__": 210 from sys import argv 211 212 logger.info("Starting!") 213 214 def handleFile(p: Path, skipLevel = logging.INFO): 215 with log_context(file = str(p)): 216 try: 217 with p.open() as f: 218 for line in f: 219 if line.strip(): 220 break 221 222 if _SKIP_RE.search(line): 223 logger.log(skipLevel, "File looks autogenerated, skipping!") 224 return 225 226 fileToSRI(p) 227 228 except Exception as exn: 229 logger.error( 230 "Unhandled exception, skipping file!", 231 exc_info = exn, 232 ) 233 else: 234 logger.info("Finished processing file") 235 236 for arg in argv[1:]: 237 p = Path(arg) 238 with log_context(arg = arg): 239 if p.is_file(): 240 handleFile(p, skipLevel = logging.WARNING) 241 242 elif p.is_dir(): 243 logger.info("Recursing into directory") 244 for q in p.glob("**/*.nix"): 245 if q.is_file(): 246 if q.name in _IGNORE or q.name.find("generated") != -1: 247 logger.info("File looks autogenerated, skipping!") 248 continue 249 250 handleFile(q)