this repo has no description
1package main
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "io"
8 "log/slog"
9 "net"
10 "net/http"
11 "net/url"
12 "os"
13 "strings"
14
15 "github.com/bluesky-social/indigo/atproto/syntax"
16 oauth "github.com/haileyok/atproto-oauth-golang"
17 _ "github.com/joho/godotenv/autoload"
18 "github.com/labstack/echo/v4"
19 "github.com/lestrrat-go/jwx/v2/jwk"
20 slogecho "github.com/samber/slog-echo"
21 "github.com/urfave/cli/v2"
22 "gorm.io/driver/sqlite"
23 "gorm.io/gorm"
24)
25
26var (
27 ctx = context.Background()
28 serverAddr = os.Getenv("OAUTH_TEST_SERVER_ADDR")
29 serverUrlRoot = os.Getenv("OAUTH_TEST_SERVER_URL_ROOT")
30 staticFilePath = os.Getenv("OAUTH_TEST_SERVER_STATIC_PATH")
31 serverMetadataUrl = fmt.Sprintf("%s/oauth/client-metadata.json", serverUrlRoot)
32 serverCallbackUrl = fmt.Sprintf("%s/callback", serverUrlRoot)
33 pdsUrl = os.Getenv("OAUTH_TEST_PDS_URL")
34 scope = "atproto transition:generic"
35)
36
37func main() {
38 app := &cli.App{
39 Name: "atproto-oauth-golang-tester",
40 Action: run,
41 }
42
43 if serverUrlRoot == "" {
44 panic(fmt.Errorf("no server url root set in env file"))
45 }
46
47 app.RunAndExitOnError()
48}
49
50type TestServer struct {
51 httpd *http.Server
52 e *echo.Echo
53 db *gorm.DB
54 oauthClient *oauth.OauthClient
55 jwksResponse *oauth.JwksResponseObject
56}
57
58func run(cmd *cli.Context) error {
59 s, err := NewServer()
60 if err != nil {
61 panic(err)
62 }
63
64 s.run()
65
66 return nil
67}
68
69func NewServer() (*TestServer, error) {
70 e := echo.New()
71
72 e.Use(slogecho.New(slog.Default()))
73
74 fmt.Println("atproto oauth golang tester server")
75
76 b, err := os.ReadFile("./jwks.json")
77 if err != nil {
78 if os.IsNotExist(err) {
79 return nil, fmt.Errorf(
80 "could not find jwks.json. does it exist? hint: run `go run ./cmd/cmd generate-jwks --prefix demo` to create one.",
81 )
82 }
83 return nil, err
84 }
85
86 k, err := jwk.ParseKey(b)
87 if err != nil {
88 return nil, err
89 }
90
91 pubKey, err := k.PublicKey()
92 if err != nil {
93 return nil, err
94 }
95
96 c, err := oauth.NewOauthClient(oauth.OauthClientArgs{
97 ClientJwk: k,
98 ClientId: serverMetadataUrl,
99 RedirectUri: serverCallbackUrl,
100 })
101 if err != nil {
102 return nil, err
103 }
104
105 httpd := &http.Server{
106 Addr: serverAddr,
107 Handler: e,
108 }
109
110 db, err := gorm.Open(sqlite.Open("oauth.db"), &gorm.Config{})
111 if err != nil {
112 return nil, err
113 }
114
115 db.AutoMigrate(&OauthRequest{})
116
117 return &TestServer{
118 httpd: httpd,
119 e: e,
120 db: db,
121 oauthClient: c,
122 jwksResponse: oauth.CreateJwksResponseObject(pubKey),
123 }, nil
124}
125
126func (s *TestServer) run() error {
127 s.e.File("/", s.getFilePath("index.html"))
128 s.e.File("/login", s.getFilePath("login.html"))
129 s.e.POST("/login", s.handleLoginSubmit)
130 s.e.GET("/oauth/client-metadata.json", s.handleClientMetadata)
131 s.e.GET("/oauth/jwks.json", s.handleJwks)
132
133 if err := s.httpd.ListenAndServe(); err != nil {
134 return err
135 }
136
137 return nil
138}
139
140func (s *TestServer) handleClientMetadata(e echo.Context) error {
141 metadata := map[string]any{
142 "client_id": serverMetadataUrl,
143 "client_name": "Atproto Oauth Golang Tester",
144 "client_uri": serverUrlRoot,
145 "logo_uri": fmt.Sprintf("%s/logo.png", serverUrlRoot),
146 "tos_uri": fmt.Sprintf("%s/tos", serverUrlRoot),
147 "policy_url": fmt.Sprintf("%s/policy", serverUrlRoot),
148 "redirect_uris": []string{serverCallbackUrl},
149 "grant_types": []string{"authorization_code", "refresh_token"},
150 "response_types": []string{"code"},
151 "application_type": "web",
152 "dpop_bound_access_tokens": true,
153 "jwks_uri": fmt.Sprintf("%s/oauth/jwks.json", serverUrlRoot),
154 "scope": "atproto transition:generic",
155 "token_endpoint_auth_method": "private_key_jwt",
156 "token_endpoint_auth_signing_alg": "ES256",
157 }
158
159 return e.JSON(200, metadata)
160}
161
162func (s *TestServer) handleJwks(e echo.Context) error {
163 return e.JSON(200, s.jwksResponse)
164}
165
166func (s *TestServer) handleLoginSubmit(e echo.Context) error {
167 handle := e.FormValue("handle")
168 if handle == "" {
169 return e.Redirect(302, "/login?e=handle-empty")
170 }
171
172 _, herr := syntax.ParseHandle(handle)
173 _, derr := syntax.ParseDID(handle)
174
175 if herr != nil && derr != nil {
176 return e.Redirect(302, "/login?e=handle-invalid")
177 }
178
179 var did string
180
181 if derr == nil {
182 did = handle
183 } else {
184 maybeDid, err := resolveHandle(e.Request().Context(), handle)
185 if err != nil {
186 return err
187 }
188
189 did = maybeDid
190 }
191
192 service, err := resolveService(ctx, did)
193 if err != nil {
194 return err
195 }
196
197 authserver, err := s.oauthClient.ResolvePDSAuthServer(ctx, service)
198 if err != nil {
199 return err
200 }
201
202 meta, err := s.oauthClient.FetchAuthServerMetadata(ctx, authserver)
203 if err != nil {
204 return err
205 }
206
207 dpopPrivateKey, err := oauth.GenerateKey(nil)
208 if err != nil {
209 return err
210 }
211
212 dpopPrivateKeyJson, err := json.Marshal(dpopPrivateKey)
213 if err != nil {
214 return err
215 }
216
217 parResp, err := s.oauthClient.SendParAuthRequest(
218 ctx,
219 authserver,
220 meta,
221 "",
222 scope,
223 dpopPrivateKey,
224 )
225
226 oauthRequest := OauthRequest{
227 State: "",
228 AuthserverIss: meta.Issuer,
229 Did: did,
230 PdsUrl: service,
231 PkceVerifier: parResp.PkceVerifier,
232 DpopAuthserverNonce: parResp.DpopAuthserverNonce,
233 DpopPrivateJwk: string(dpopPrivateKeyJson),
234 }
235
236 if err := s.db.Create(&oauthRequest).Error; err != nil {
237 return err
238 }
239
240 u, _ := url.Parse(meta.AuthorizationEndpoint)
241 u.RawQuery = fmt.Sprintf(
242 "client_id=%s&request_uri=%s",
243 url.QueryEscape(serverMetadataUrl),
244 parResp.Resp["request_uri"].(string),
245 )
246
247 return e.Redirect(302, u.String())
248}
249
250func resolveHandle(ctx context.Context, handle string) (string, error) {
251 var did string
252
253 _, err := syntax.ParseHandle(handle)
254 if err != nil {
255 return "", err
256 }
257
258 recs, err := net.LookupTXT(fmt.Sprintf("_atproto.%s", handle))
259 if err != nil {
260 return "", err
261 }
262
263 for _, rec := range recs {
264 if strings.HasPrefix(rec, "did=") {
265 did = strings.Split(rec, "did=")[1]
266 break
267 }
268 }
269
270 if did == "" {
271 req, err := http.NewRequestWithContext(
272 ctx,
273 "GET",
274 fmt.Sprintf("https://%s/.well-known/atproto-did", handle),
275 nil,
276 )
277 if err != nil {
278 return "", err
279 }
280
281 resp, err := http.DefaultClient.Do(req)
282 if err != nil {
283 return "", err
284 }
285 defer resp.Body.Close()
286
287 if resp.StatusCode != http.StatusOK {
288 io.Copy(io.Discard, resp.Body)
289 return "", fmt.Errorf("unable to resolve handle")
290 }
291
292 b, err := io.ReadAll(resp.Body)
293 if err != nil {
294 return "", err
295 }
296
297 maybeDid := string(b)
298
299 if _, err := syntax.ParseDID(maybeDid); err != nil {
300 return "", fmt.Errorf("unable to resolve handle")
301 }
302
303 did = maybeDid
304 }
305
306 // TODO: we can also support did:web here
307
308 if did == "" {
309 return "", fmt.Errorf("unable to resolve handle")
310 }
311
312 return did, nil
313}
314
315func resolveService(ctx context.Context, did string) (string, error) {
316 type Identity struct {
317 Service []struct {
318 ID string `json:"id"`
319 Type string `json:"type"`
320 ServiceEndpoint string `json:"serviceEndpoint"`
321 } `json:"service"`
322 }
323
324 if strings.HasPrefix(did, "did:plc:") {
325 req, err := http.NewRequestWithContext(
326 ctx,
327 "GET",
328 fmt.Sprintf("https://plc.directory/%s", did),
329 nil,
330 )
331 if err != nil {
332 return "", err
333 }
334
335 resp, err := http.DefaultClient.Do(req)
336 if err != nil {
337 return "", err
338 }
339 defer resp.Body.Close()
340
341 if resp.StatusCode != 200 {
342 io.Copy(io.Discard, resp.Body)
343 return "", fmt.Errorf("could not find identity in plc registry")
344 }
345
346 var identity Identity
347 if err := json.NewDecoder(resp.Body).Decode(&identity); err != nil {
348 return "", err
349 }
350
351 var service string
352 for _, svc := range identity.Service {
353 if svc.ID == "#atproto_pds" {
354 service = svc.ServiceEndpoint
355 }
356 }
357
358 if service == "" {
359 return "", fmt.Errorf("could not find atproto_pds service in identity services")
360 }
361
362 return service, nil
363 } else if strings.HasPrefix(did, "did:web:") {
364 // TODO: needs more work
365 req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("https://%s/.well-known/did.json", did), nil)
366 if err != nil {
367 return "", err
368 }
369
370 resp, err := http.DefaultClient.Do(req)
371 if err != nil {
372 return "", err
373 }
374 defer resp.Body.Close()
375
376 if resp.StatusCode != 200 {
377 io.Copy(io.Discard, resp.Body)
378 return "", fmt.Errorf("could not find identity in plc registry")
379 }
380
381 var identity Identity
382 if err := json.NewDecoder(resp.Body).Decode(&identity); err != nil {
383 return "", err
384 }
385
386 var service string
387 for _, svc := range identity.Service {
388 if svc.ID == "#atproto_pds" {
389 service = svc.ServiceEndpoint
390 }
391 }
392
393 if service == "" {
394 return "", fmt.Errorf("could not find atproto_pds service in identity services")
395 }
396
397 return service, nil
398 } else {
399 return "", fmt.Errorf("did was not a supported did type")
400 }
401}
402
403func (s *TestServer) getFilePath(file string) string {
404 return fmt.Sprintf("%s/%s", staticFilePath, file)
405}