nixos server configurations
at main 7.2 kB view raw
1""" 2This HTTP service implements the /tls-check endpoint used by Caddy's 3on_demand_tls.ask configuration. Behavior: 4 5- If the query parameter `domain` is ALLOWED -> return HTTP 200 OK with body "OK". 6- Otherwise proxy the request (including the query string) to the real PDS 7 tls-check endpoint at http://127.0.0.1:<PDS_PORT>/tls-check and forward 8 the upstream status and body. 9""" 10 11import logging 12import os 13import socket 14import sys 15import urllib.error 16import urllib.request 17from http.server import BaseHTTPRequestHandler, HTTPServer 18from urllib.parse import parse_qs, urlparse 19 20# Configuration (can be overridden with env vars) 21PDS_PORT = int(os.environ.get("PDS_PORT", "3000")) 22LISTEN_HOST = os.environ.get("LISTEN_HOST", "127.0.0.1") 23LISTEN_PORT = int(os.environ.get("LISTEN_PORT", "8081")) 24TIMEOUT = float(os.environ.get("TIMEOUT", "5.0")) 25 26# Allowed domain values (lowercase) 27ALLOWED = {"pds", "knot", "spindle"} 28 29# Configure logging to stderr (systemd/journal-friendly) 30logging.basicConfig( 31 level=logging.INFO, 32 format="%(asctime)s %(levelname)s %(message)s", 33 stream=sys.stderr, 34) 35 36 37def filter_response_headers(headers): 38 """ 39 Given an iterable of (header, value) pairs or a mapping-like object, 40 return a dict with hop-by-hop headers removed. This avoids sending 41 problematic headers to the client (Caddy). 42 """ 43 hop_by_hop = { 44 "connection", 45 "keep-alive", 46 "proxy-authenticate", 47 "proxy-authorization", 48 "te", 49 "trailers", 50 "transfer-encoding", 51 "upgrade", 52 } 53 result = {} 54 if hasattr(headers, "items"): 55 iterator = headers.items() 56 else: 57 iterator = headers 58 for k, v in iterator: 59 if k.lower() not in hop_by_hop: 60 result[k] = v 61 return result 62 63 64class TLSCheckHandler(BaseHTTPRequestHandler): 65 # Reduce console noise from BaseHTTPRequestHandler 66 def log_message(self, format, *args): 67 # route to logging module at INFO level 68 logging.info("%s - %s", self.client_address[0], format % args) 69 70 def _send(self, status, body=b"", headers=None): 71 # Send status, headers, and body to the client 72 self.send_response(status) 73 if headers: 74 for k, v in headers.items(): 75 # BaseHTTPRequestHandler will fold multiple headers set via send_header 76 # if necessary; we assume simple string values here. 77 try: 78 self.send_header(k, v) 79 except Exception: 80 # Ignore any header-setting errors; continue to send response 81 logging.debug("skipping header %r due to error", k) 82 else: 83 self.send_header("Content-Type", "text/plain; charset=utf-8") 84 self.end_headers() 85 if body: 86 if isinstance(body, str): 87 body = body.encode("utf-8") 88 try: 89 self.wfile.write(body) 90 except BrokenPipeError: 91 # Client disconnected early; nothing to do 92 pass 93 94 def _proxy_to_pds(self, path_with_query): 95 """ 96 Proxy a request to http://127.0.0.1:<PDS_PORT><path_with_query>. 97 Returns (status, body_bytes, headers_dict). 98 """ 99 target = f"http://127.0.0.1:{PDS_PORT}{path_with_query}" 100 logging.debug("proxying to upstream: %s", target) 101 req = urllib.request.Request( 102 target, headers={"User-Agent": "ondemand-tls-helper/1.0"} 103 ) 104 try: 105 with urllib.request.urlopen(req, timeout=TIMEOUT) as resp: 106 data = resp.read() 107 headers = filter_response_headers(resp.getheaders()) 108 return resp.status, data, headers 109 except urllib.error.HTTPError as e: 110 # Upstream returned an HTTP error; return its body and status 111 try: 112 data = e.read() 113 except Exception: 114 data = b"" 115 status = getattr(e, "code", 502) 116 headers = {} 117 logging.info("upstream returned HTTPError %s for %s", status, target) 118 return status, data, headers 119 except Exception as e: 120 logging.exception("error proxying to upstream %s: %s", target, e) 121 # Return 502 Bad Gateway 122 return 502, f"upstream error: {e}".encode("utf-8"), {} 123 124 def _get_domain_param(self): 125 parsed = urlparse(self.path) 126 qs = parse_qs(parsed.query) 127 domain_vals = qs.get("domain") or [] 128 if not domain_vals: 129 return "" 130 return domain_vals[0].strip().lower() 131 132 def do_GET(self): 133 parsed = urlparse(self.path) 134 path_with_query = parsed.path + ("?" + parsed.query if parsed.query else "") 135 domain = self._get_domain_param() 136 137 if domain in ALLOWED: 138 logging.debug("allowed domain %r -> returning 200", domain) 139 return self._send(200, "OK") 140 141 status, body, headers = self._proxy_to_pds(path_with_query) 142 return self._send(status, body, headers=headers) 143 144 def do_HEAD(self): 145 parsed = urlparse(self.path) 146 path_with_query = parsed.path + ("?" + parsed.query if parsed.query else "") 147 domain = self._get_domain_param() 148 149 if domain in ALLOWED: 150 logging.debug("allowed domain (HEAD) %r -> returning 200", domain) 151 return self._send(200, b"") 152 153 status, _, headers = self._proxy_to_pds(path_with_query) 154 return self._send(status, b"", headers=headers) 155 156 157def run(): 158 server_address = (LISTEN_HOST, LISTEN_PORT) 159 try: 160 httpd = HTTPServer(server_address, TLSCheckHandler) 161 except OSError as e: 162 logging.error("cannot bind to %s:%s: %s", LISTEN_HOST, LISTEN_PORT, e) 163 sys.exit(1) 164 165 sa = httpd.socket.getsockname() 166 logging.info("ondemand TLS helper listening on %s:%s", sa[0], sa[1]) 167 try: 168 httpd.serve_forever() 169 except KeyboardInterrupt: 170 logging.info("shutting down on keyboard interrupt") 171 except Exception: 172 logging.exception("server crashed") 173 finally: 174 try: 175 httpd.server_close() 176 except Exception: 177 pass 178 179 180if __name__ == "__main__": 181 # Allow override of config via arguments if invoked with simple flags: 182 # --pds-port N --listen-host HOST --listen-port N --timeout S 183 # to keep the script flexible for local testing. 184 args = sys.argv[1:] 185 it = iter(args) 186 while True: 187 try: 188 a = next(it) 189 except StopIteration: 190 break 191 if a in ("--pds-port",): 192 try: 193 PDS_PORT = int(next(it)) 194 except StopIteration: 195 break 196 elif a in ("--listen-host",): 197 try: 198 LISTEN_HOST = next(it) 199 except StopIteration: 200 break 201 elif a in ("--listen-port",): 202 try: 203 LISTEN_PORT = int(next(it)) 204 except StopIteration: 205 break 206 elif a in ("--timeout",): 207 try: 208 TIMEOUT = float(next(it)) 209 except StopIteration: 210 break 211 else: 212 # ignore unknown args 213 continue 214 215 run()