An atproto PDS written in Go
at v0.5.1 2.9 kB view raw
1package server 2 3import ( 4 "crypto/rand" 5 "crypto/sha256" 6 "encoding/base64" 7 "encoding/json" 8 "fmt" 9 "strings" 10 "time" 11 12 "github.com/Azure/go-autorest/autorest/to" 13 "github.com/google/uuid" 14 "github.com/haileyok/cocoon/internal/helpers" 15 "github.com/haileyok/cocoon/models" 16 "github.com/labstack/echo/v4" 17 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec" 18) 19 20type ServerGetServiceAuthRequest struct { 21 Aud string `query:"aud" validate:"required,atproto-did"` 22 // exp should be a float, as some clients will send a non-integer expiration 23 Exp float64 `query:"exp"` 24 Lxm string `query:"lxm" validate:"required,atproto-nsid"` 25} 26 27func (s *Server) handleServerGetServiceAuth(e echo.Context) error { 28 var req ServerGetServiceAuthRequest 29 if err := e.Bind(&req); err != nil { 30 s.logger.Error("could not bind service auth request", "error", err) 31 return helpers.ServerError(e, nil) 32 } 33 34 if err := e.Validate(req); err != nil { 35 return helpers.InputError(e, nil) 36 } 37 38 exp := int64(req.Exp) 39 now := time.Now().Unix() 40 if exp == 0 { 41 exp = now + 60 // default 42 } 43 44 if req.Lxm == "com.atproto.server.getServiceAuth" { 45 return helpers.InputError(e, to.StringPtr("may not generate auth tokens recursively")) 46 } 47 48 maxExp := now + (60 * 30) 49 if exp > maxExp { 50 return helpers.InputError(e, to.StringPtr("expiration too big. smoller please")) 51 } 52 53 repo := e.Get("repo").(*models.RepoActor) 54 55 header := map[string]string{ 56 "alg": "ES256K", 57 "crv": "secp256k1", 58 "typ": "JWT", 59 } 60 hj, err := json.Marshal(header) 61 if err != nil { 62 s.logger.Error("error marshaling header", "error", err) 63 return helpers.ServerError(e, nil) 64 } 65 66 encheader := strings.TrimRight(base64.RawURLEncoding.EncodeToString(hj), "=") 67 68 payload := map[string]any{ 69 "iss": repo.Repo.Did, 70 "aud": req.Aud, 71 "lxm": req.Lxm, 72 "jti": uuid.NewString(), 73 "exp": exp, 74 "iat": now, 75 } 76 pj, err := json.Marshal(payload) 77 if err != nil { 78 s.logger.Error("error marashaling payload", "error", err) 79 return helpers.ServerError(e, nil) 80 } 81 82 encpayload := strings.TrimRight(base64.RawURLEncoding.EncodeToString(pj), "=") 83 84 input := fmt.Sprintf("%s.%s", encheader, encpayload) 85 hash := sha256.Sum256([]byte(input)) 86 87 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey) 88 if err != nil { 89 s.logger.Error("can't load private key", "error", err) 90 return err 91 } 92 93 R, S, _, err := sk.SignRaw(rand.Reader, hash[:]) 94 if err != nil { 95 s.logger.Error("error signing", "error", err) 96 return helpers.ServerError(e, nil) 97 } 98 99 rBytes := R.Bytes() 100 sBytes := S.Bytes() 101 102 rPadded := make([]byte, 32) 103 sPadded := make([]byte, 32) 104 copy(rPadded[32-len(rBytes):], rBytes) 105 copy(sPadded[32-len(sBytes):], sBytes) 106 107 rawsig := append(rPadded, sPadded...) 108 encsig := strings.TrimRight(base64.RawURLEncoding.EncodeToString(rawsig), "=") 109 token := fmt.Sprintf("%s.%s", input, encsig) 110 111 return e.JSON(200, map[string]string{ 112 "token": token, 113 }) 114}