nixos server configurations
1"""
2This HTTP service implements the /tls-check endpoint used by Caddy's
3on_demand_tls.ask configuration. Behavior:
4
5- If the query parameter `domain` is ALLOWED -> return HTTP 200 OK with body "OK".
6- Otherwise proxy the request (including the query string) to the real PDS
7 tls-check endpoint at http://127.0.0.1:<PDS_PORT>/tls-check and forward
8 the upstream status and body.
9"""
10
11import logging
12import os
13import socket
14import sys
15import urllib.error
16import urllib.request
17from http.server import BaseHTTPRequestHandler, HTTPServer
18from urllib.parse import parse_qs, urlparse
19
20# Configuration (can be overridden with env vars)
21PDS_PORT = int(os.environ.get("PDS_PORT", "3000"))
22LISTEN_HOST = os.environ.get("LISTEN_HOST", "127.0.0.1")
23LISTEN_PORT = int(os.environ.get("LISTEN_PORT", "8081"))
24TIMEOUT = float(os.environ.get("TIMEOUT", "5.0"))
25
26# Allowed domain values (lowercase)
27ALLOWED = {"pds", "knot", "spindle"}
28
29# Configure logging to stderr (systemd/journal-friendly)
30logging.basicConfig(
31 level=logging.INFO,
32 format="%(asctime)s %(levelname)s %(message)s",
33 stream=sys.stderr,
34)
35
36
37def filter_response_headers(headers):
38 """
39 Given an iterable of (header, value) pairs or a mapping-like object,
40 return a dict with hop-by-hop headers removed. This avoids sending
41 problematic headers to the client (Caddy).
42 """
43 hop_by_hop = {
44 "connection",
45 "keep-alive",
46 "proxy-authenticate",
47 "proxy-authorization",
48 "te",
49 "trailers",
50 "transfer-encoding",
51 "upgrade",
52 }
53 result = {}
54 if hasattr(headers, "items"):
55 iterator = headers.items()
56 else:
57 iterator = headers
58 for k, v in iterator:
59 if k.lower() not in hop_by_hop:
60 result[k] = v
61 return result
62
63
64class TLSCheckHandler(BaseHTTPRequestHandler):
65 # Reduce console noise from BaseHTTPRequestHandler
66 def log_message(self, format, *args):
67 # route to logging module at INFO level
68 logging.info("%s - %s", self.client_address[0], format % args)
69
70 def _send(self, status, body=b"", headers=None):
71 # Send status, headers, and body to the client
72 self.send_response(status)
73 if headers:
74 for k, v in headers.items():
75 # BaseHTTPRequestHandler will fold multiple headers set via send_header
76 # if necessary; we assume simple string values here.
77 try:
78 self.send_header(k, v)
79 except Exception:
80 # Ignore any header-setting errors; continue to send response
81 logging.debug("skipping header %r due to error", k)
82 else:
83 self.send_header("Content-Type", "text/plain; charset=utf-8")
84 self.end_headers()
85 if body:
86 if isinstance(body, str):
87 body = body.encode("utf-8")
88 try:
89 self.wfile.write(body)
90 except BrokenPipeError:
91 # Client disconnected early; nothing to do
92 pass
93
94 def _proxy_to_pds(self, path_with_query):
95 """
96 Proxy a request to http://127.0.0.1:<PDS_PORT><path_with_query>.
97 Returns (status, body_bytes, headers_dict).
98 """
99 target = f"http://127.0.0.1:{PDS_PORT}{path_with_query}"
100 logging.debug("proxying to upstream: %s", target)
101 req = urllib.request.Request(
102 target, headers={"User-Agent": "ondemand-tls-helper/1.0"}
103 )
104 try:
105 with urllib.request.urlopen(req, timeout=TIMEOUT) as resp:
106 data = resp.read()
107 headers = filter_response_headers(resp.getheaders())
108 return resp.status, data, headers
109 except urllib.error.HTTPError as e:
110 # Upstream returned an HTTP error; return its body and status
111 try:
112 data = e.read()
113 except Exception:
114 data = b""
115 status = getattr(e, "code", 502)
116 headers = {}
117 logging.info("upstream returned HTTPError %s for %s", status, target)
118 return status, data, headers
119 except Exception as e:
120 logging.exception("error proxying to upstream %s: %s", target, e)
121 # Return 502 Bad Gateway
122 return 502, f"upstream error: {e}".encode("utf-8"), {}
123
124 def _get_domain_param(self):
125 parsed = urlparse(self.path)
126 qs = parse_qs(parsed.query)
127 domain_vals = qs.get("domain") or []
128 if not domain_vals:
129 return ""
130 return domain_vals[0].strip().lower()
131
132 def do_GET(self):
133 parsed = urlparse(self.path)
134 path_with_query = parsed.path + ("?" + parsed.query if parsed.query else "")
135 domain = self._get_domain_param()
136
137 if domain in ALLOWED:
138 logging.debug("allowed domain %r -> returning 200", domain)
139 return self._send(200, "OK")
140
141 status, body, headers = self._proxy_to_pds(path_with_query)
142 return self._send(status, body, headers=headers)
143
144 def do_HEAD(self):
145 parsed = urlparse(self.path)
146 path_with_query = parsed.path + ("?" + parsed.query if parsed.query else "")
147 domain = self._get_domain_param()
148
149 if domain in ALLOWED:
150 logging.debug("allowed domain (HEAD) %r -> returning 200", domain)
151 return self._send(200, b"")
152
153 status, _, headers = self._proxy_to_pds(path_with_query)
154 return self._send(status, b"", headers=headers)
155
156
157def run():
158 server_address = (LISTEN_HOST, LISTEN_PORT)
159 try:
160 httpd = HTTPServer(server_address, TLSCheckHandler)
161 except OSError as e:
162 logging.error("cannot bind to %s:%s: %s", LISTEN_HOST, LISTEN_PORT, e)
163 sys.exit(1)
164
165 sa = httpd.socket.getsockname()
166 logging.info("ondemand TLS helper listening on %s:%s", sa[0], sa[1])
167 try:
168 httpd.serve_forever()
169 except KeyboardInterrupt:
170 logging.info("shutting down on keyboard interrupt")
171 except Exception:
172 logging.exception("server crashed")
173 finally:
174 try:
175 httpd.server_close()
176 except Exception:
177 pass
178
179
180if __name__ == "__main__":
181 # Allow override of config via arguments if invoked with simple flags:
182 # --pds-port N --listen-host HOST --listen-port N --timeout S
183 # to keep the script flexible for local testing.
184 args = sys.argv[1:]
185 it = iter(args)
186 while True:
187 try:
188 a = next(it)
189 except StopIteration:
190 break
191 if a in ("--pds-port",):
192 try:
193 PDS_PORT = int(next(it))
194 except StopIteration:
195 break
196 elif a in ("--listen-host",):
197 try:
198 LISTEN_HOST = next(it)
199 except StopIteration:
200 break
201 elif a in ("--listen-port",):
202 try:
203 LISTEN_PORT = int(next(it))
204 except StopIteration:
205 break
206 elif a in ("--timeout",):
207 try:
208 TIMEOUT = float(next(it))
209 except StopIteration:
210 break
211 else:
212 # ignore unknown args
213 continue
214
215 run()