1package server
2
3import (
4 "crypto/rand"
5 "crypto/sha256"
6 "encoding/base64"
7 "encoding/json"
8 "fmt"
9 "strings"
10 "time"
11
12 "github.com/Azure/go-autorest/autorest/to"
13 "github.com/google/uuid"
14 "github.com/haileyok/cocoon/internal/helpers"
15 "github.com/haileyok/cocoon/models"
16 "github.com/labstack/echo/v4"
17 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec"
18)
19
20type ServerGetServiceAuthRequest struct {
21 Aud string `query:"aud" validate:"required,atproto-did"`
22 // exp should be a float, as some clients will send a non-integer expiration
23 Exp float64 `query:"exp"`
24 Lxm string `query:"lxm" validate:"required,atproto-nsid"`
25}
26
27func (s *Server) handleServerGetServiceAuth(e echo.Context) error {
28 var req ServerGetServiceAuthRequest
29 if err := e.Bind(&req); err != nil {
30 s.logger.Error("could not bind service auth request", "error", err)
31 return helpers.ServerError(e, nil)
32 }
33
34 if err := e.Validate(req); err != nil {
35 return helpers.InputError(e, nil)
36 }
37
38 exp := int64(req.Exp)
39 now := time.Now().Unix()
40 if exp == 0 {
41 exp = now + 60 // default
42 }
43
44 if req.Lxm == "com.atproto.server.getServiceAuth" {
45 return helpers.InputError(e, to.StringPtr("may not generate auth tokens recursively"))
46 }
47
48 maxExp := now + (60 * 30)
49 if exp > maxExp {
50 return helpers.InputError(e, to.StringPtr("expiration too big. smoller please"))
51 }
52
53 repo := e.Get("repo").(*models.RepoActor)
54
55 header := map[string]string{
56 "alg": "ES256K",
57 "crv": "secp256k1",
58 "typ": "JWT",
59 }
60 hj, err := json.Marshal(header)
61 if err != nil {
62 s.logger.Error("error marshaling header", "error", err)
63 return helpers.ServerError(e, nil)
64 }
65
66 encheader := strings.TrimRight(base64.RawURLEncoding.EncodeToString(hj), "=")
67
68 payload := map[string]any{
69 "iss": repo.Repo.Did,
70 "aud": req.Aud,
71 "lxm": req.Lxm,
72 "jti": uuid.NewString(),
73 "exp": exp,
74 "iat": now,
75 }
76 pj, err := json.Marshal(payload)
77 if err != nil {
78 s.logger.Error("error marashaling payload", "error", err)
79 return helpers.ServerError(e, nil)
80 }
81
82 encpayload := strings.TrimRight(base64.RawURLEncoding.EncodeToString(pj), "=")
83
84 input := fmt.Sprintf("%s.%s", encheader, encpayload)
85 hash := sha256.Sum256([]byte(input))
86
87 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey)
88 if err != nil {
89 s.logger.Error("can't load private key", "error", err)
90 return err
91 }
92
93 R, S, _, err := sk.SignRaw(rand.Reader, hash[:])
94 if err != nil {
95 s.logger.Error("error signing", "error", err)
96 return helpers.ServerError(e, nil)
97 }
98
99 rBytes := R.Bytes()
100 sBytes := S.Bytes()
101
102 rPadded := make([]byte, 32)
103 sPadded := make([]byte, 32)
104 copy(rPadded[32-len(rBytes):], rBytes)
105 copy(sPadded[32-len(sBytes):], sBytes)
106
107 rawsig := append(rPadded, sPadded...)
108 encsig := strings.TrimRight(base64.RawURLEncoding.EncodeToString(rawsig), "=")
109 token := fmt.Sprintf("%s.%s", input, encsig)
110
111 return e.JSON(200, map[string]string{
112 "token": token,
113 })
114}