An atproto PDS written in Go
at main 3.0 kB view raw
1package server 2 3import ( 4 "crypto/rand" 5 "crypto/sha256" 6 "encoding/base64" 7 "encoding/json" 8 "fmt" 9 "strings" 10 "time" 11 12 "github.com/Azure/go-autorest/autorest/to" 13 "github.com/google/uuid" 14 "github.com/haileyok/cocoon/internal/helpers" 15 "github.com/haileyok/cocoon/models" 16 "github.com/labstack/echo/v4" 17 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec" 18) 19 20type ServerGetServiceAuthRequest struct { 21 Aud string `query:"aud" validate:"required,atproto-did"` 22 // exp should be a float, as some clients will send a non-integer expiration 23 Exp float64 `query:"exp"` 24 Lxm string `query:"lxm"` 25} 26 27func (s *Server) handleServerGetServiceAuth(e echo.Context) error { 28 var req ServerGetServiceAuthRequest 29 if err := e.Bind(&req); err != nil { 30 s.logger.Error("could not bind service auth request", "error", err) 31 return helpers.ServerError(e, nil) 32 } 33 34 if err := e.Validate(req); err != nil { 35 return helpers.InputError(e, nil) 36 } 37 38 exp := int64(req.Exp) 39 now := time.Now().Unix() 40 if exp == 0 { 41 exp = now + 60 // default 42 } 43 44 if req.Lxm == "com.atproto.server.getServiceAuth" { 45 return helpers.InputError(e, to.StringPtr("may not generate auth tokens recursively")) 46 } 47 48 var maxExp int64 49 if req.Lxm != "" { 50 maxExp = now + (60 * 60) 51 } else { 52 maxExp = now + 60 53 } 54 if exp > maxExp { 55 return helpers.InputError(e, to.StringPtr("expiration too big. smoller please")) 56 } 57 58 repo := e.Get("repo").(*models.RepoActor) 59 60 header := map[string]string{ 61 "alg": "ES256K", 62 "crv": "secp256k1", 63 "typ": "JWT", 64 } 65 hj, err := json.Marshal(header) 66 if err != nil { 67 s.logger.Error("error marshaling header", "error", err) 68 return helpers.ServerError(e, nil) 69 } 70 71 encheader := strings.TrimRight(base64.RawURLEncoding.EncodeToString(hj), "=") 72 73 payload := map[string]any{ 74 "iss": repo.Repo.Did, 75 "aud": req.Aud, 76 "jti": uuid.NewString(), 77 "exp": exp, 78 "iat": now, 79 } 80 if req.Lxm != "" { 81 payload["lxm"] = req.Lxm 82 } 83 pj, err := json.Marshal(payload) 84 if err != nil { 85 s.logger.Error("error marashaling payload", "error", err) 86 return helpers.ServerError(e, nil) 87 } 88 89 encpayload := strings.TrimRight(base64.RawURLEncoding.EncodeToString(pj), "=") 90 91 input := fmt.Sprintf("%s.%s", encheader, encpayload) 92 hash := sha256.Sum256([]byte(input)) 93 94 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey) 95 if err != nil { 96 s.logger.Error("can't load private key", "error", err) 97 return err 98 } 99 100 R, S, _, err := sk.SignRaw(rand.Reader, hash[:]) 101 if err != nil { 102 s.logger.Error("error signing", "error", err) 103 return helpers.ServerError(e, nil) 104 } 105 106 rBytes := R.Bytes() 107 sBytes := S.Bytes() 108 109 rPadded := make([]byte, 32) 110 sPadded := make([]byte, 32) 111 copy(rPadded[32-len(rBytes):], rBytes) 112 copy(sPadded[32-len(sBytes):], sBytes) 113 114 rawsig := append(rPadded, sPadded...) 115 encsig := strings.TrimRight(base64.RawURLEncoding.EncodeToString(rawsig), "=") 116 token := fmt.Sprintf("%s.%s", input, encsig) 117 118 return e.JSON(200, map[string]string{ 119 "token": token, 120 }) 121}