An atproto PDS written in Go
1package server 2 3import ( 4 "crypto/rand" 5 "crypto/sha256" 6 "encoding/base64" 7 "encoding/json" 8 "fmt" 9 "strings" 10 "time" 11 12 "github.com/Azure/go-autorest/autorest/to" 13 "github.com/google/uuid" 14 "github.com/haileyok/cocoon/internal/helpers" 15 "github.com/haileyok/cocoon/models" 16 "github.com/labstack/echo/v4" 17 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec" 18) 19 20type ServerGetServiceAuthRequest struct { 21 Aud string `query:"aud" validate:"required,atproto-did"` 22 Exp int64 `query:"exp"` 23 Lxm string `query:"lxm" validate:"required,atproto-nsid"` 24} 25 26func (s *Server) handleServerGetServiceAuth(e echo.Context) error { 27 var req ServerGetServiceAuthRequest 28 if err := e.Bind(&req); err != nil { 29 s.logger.Error("could not bind service auth request", "error", err) 30 return helpers.ServerError(e, nil) 31 } 32 33 if err := e.Validate(req); err != nil { 34 return helpers.InputError(e, nil) 35 } 36 37 now := time.Now().Unix() 38 if req.Exp == 0 { 39 req.Exp = now + 60 // default 40 } 41 42 if req.Lxm == "com.atproto.server.getServiceAuth" { 43 return helpers.InputError(e, to.StringPtr("may not generate auth tokens recursively")) 44 } 45 46 maxExp := now + (60 * 30) 47 if req.Exp > maxExp { 48 return helpers.InputError(e, to.StringPtr("expiration too big. smoller please")) 49 } 50 51 repo := e.Get("repo").(*models.RepoActor) 52 53 header := map[string]string{ 54 "alg": "ES256K", 55 "crv": "secp256k1", 56 "typ": "JWT", 57 } 58 hj, err := json.Marshal(header) 59 if err != nil { 60 s.logger.Error("error marshaling header", "error", err) 61 return helpers.ServerError(e, nil) 62 } 63 64 encheader := strings.TrimRight(base64.RawURLEncoding.EncodeToString(hj), "=") 65 66 payload := map[string]any{ 67 "iss": repo.Repo.Did, 68 "aud": req.Aud, 69 "lxm": req.Lxm, 70 "jti": uuid.NewString(), 71 "exp": req.Exp, 72 "iat": now, 73 } 74 pj, err := json.Marshal(payload) 75 if err != nil { 76 s.logger.Error("error marashaling payload", "error", err) 77 return helpers.ServerError(e, nil) 78 } 79 80 encpayload := strings.TrimRight(base64.RawURLEncoding.EncodeToString(pj), "=") 81 82 input := fmt.Sprintf("%s.%s", encheader, encpayload) 83 hash := sha256.Sum256([]byte(input)) 84 85 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey) 86 if err != nil { 87 s.logger.Error("can't load private key", "error", err) 88 return err 89 } 90 91 R, S, _, err := sk.SignRaw(rand.Reader, hash[:]) 92 if err != nil { 93 s.logger.Error("error signing", "error", err) 94 return helpers.ServerError(e, nil) 95 } 96 97 rBytes := R.Bytes() 98 sBytes := S.Bytes() 99 100 rPadded := make([]byte, 32) 101 sPadded := make([]byte, 32) 102 copy(rPadded[32-len(rBytes):], rBytes) 103 copy(sPadded[32-len(sBytes):], sBytes) 104 105 rawsig := append(rPadded, sPadded...) 106 encsig := strings.TrimRight(base64.RawURLEncoding.EncodeToString(rawsig), "=") 107 token := fmt.Sprintf("%s.%s", input, encsig) 108 109 return e.JSON(200, map[string]string{ 110 "token": token, 111 }) 112}