1package server
2
3import (
4 "crypto/rand"
5 "crypto/sha256"
6 "encoding/base64"
7 "encoding/json"
8 "fmt"
9 "strings"
10 "time"
11
12 "github.com/Azure/go-autorest/autorest/to"
13 "github.com/google/uuid"
14 "github.com/haileyok/cocoon/internal/helpers"
15 "github.com/haileyok/cocoon/models"
16 "github.com/labstack/echo/v4"
17 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec"
18)
19
20type ServerGetServiceAuthRequest struct {
21 Aud string `query:"aud" validate:"required,atproto-did"`
22 // exp should be a float, as some clients will send a non-integer expiration
23 Exp float64 `query:"exp"`
24 Lxm string `query:"lxm"`
25}
26
27func (s *Server) handleServerGetServiceAuth(e echo.Context) error {
28 var req ServerGetServiceAuthRequest
29 if err := e.Bind(&req); err != nil {
30 s.logger.Error("could not bind service auth request", "error", err)
31 return helpers.ServerError(e, nil)
32 }
33
34 if err := e.Validate(req); err != nil {
35 return helpers.InputError(e, nil)
36 }
37
38 exp := int64(req.Exp)
39 now := time.Now().Unix()
40 if exp == 0 {
41 exp = now + 60 // default
42 }
43
44 if req.Lxm == "com.atproto.server.getServiceAuth" {
45 return helpers.InputError(e, to.StringPtr("may not generate auth tokens recursively"))
46 }
47
48 var maxExp int64
49 if req.Lxm != "" {
50 maxExp = now + (60 * 60)
51 } else {
52 maxExp = now + 60
53 }
54 if exp > maxExp {
55 return helpers.InputError(e, to.StringPtr("expiration too big. smoller please"))
56 }
57
58 repo := e.Get("repo").(*models.RepoActor)
59
60 header := map[string]string{
61 "alg": "ES256K",
62 "crv": "secp256k1",
63 "typ": "JWT",
64 }
65 hj, err := json.Marshal(header)
66 if err != nil {
67 s.logger.Error("error marshaling header", "error", err)
68 return helpers.ServerError(e, nil)
69 }
70
71 encheader := strings.TrimRight(base64.RawURLEncoding.EncodeToString(hj), "=")
72
73 payload := map[string]any{
74 "iss": repo.Repo.Did,
75 "aud": req.Aud,
76 "jti": uuid.NewString(),
77 "exp": exp,
78 "iat": now,
79 }
80 if req.Lxm != "" {
81 payload["lxm"] = req.Lxm
82 }
83 pj, err := json.Marshal(payload)
84 if err != nil {
85 s.logger.Error("error marashaling payload", "error", err)
86 return helpers.ServerError(e, nil)
87 }
88
89 encpayload := strings.TrimRight(base64.RawURLEncoding.EncodeToString(pj), "=")
90
91 input := fmt.Sprintf("%s.%s", encheader, encpayload)
92 hash := sha256.Sum256([]byte(input))
93
94 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey)
95 if err != nil {
96 s.logger.Error("can't load private key", "error", err)
97 return err
98 }
99
100 R, S, _, err := sk.SignRaw(rand.Reader, hash[:])
101 if err != nil {
102 s.logger.Error("error signing", "error", err)
103 return helpers.ServerError(e, nil)
104 }
105
106 rBytes := R.Bytes()
107 sBytes := S.Bytes()
108
109 rPadded := make([]byte, 32)
110 sPadded := make([]byte, 32)
111 copy(rPadded[32-len(rBytes):], rBytes)
112 copy(sPadded[32-len(sBytes):], sBytes)
113
114 rawsig := append(rPadded, sPadded...)
115 encsig := strings.TrimRight(base64.RawURLEncoding.EncodeToString(rawsig), "=")
116 token := fmt.Sprintf("%s.%s", input, encsig)
117
118 return e.JSON(200, map[string]string{
119 "token": token,
120 })
121}