1package server
2
3import (
4 "crypto/rand"
5 "crypto/sha256"
6 "encoding/base64"
7 "encoding/json"
8 "fmt"
9 "strings"
10 "time"
11
12 "github.com/Azure/go-autorest/autorest/to"
13 "github.com/google/uuid"
14 "github.com/haileyok/cocoon/internal/helpers"
15 "github.com/haileyok/cocoon/models"
16 "github.com/labstack/echo/v4"
17 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec"
18)
19
20type ServerGetServiceAuthRequest struct {
21 Aud string `query:"aud" validate:"required,atproto-did"`
22 Exp int64 `query:"exp"`
23 Lxm string `query:"lxm" validate:"required,atproto-nsid"`
24}
25
26func (s *Server) handleServerGetServiceAuth(e echo.Context) error {
27 var req ServerGetServiceAuthRequest
28 if err := e.Bind(&req); err != nil {
29 s.logger.Error("could not bind service auth request", "error", err)
30 return helpers.ServerError(e, nil)
31 }
32
33 if err := e.Validate(req); err != nil {
34 return helpers.InputError(e, nil)
35 }
36
37 now := time.Now().Unix()
38 if req.Exp == 0 {
39 req.Exp = now + 60 // default
40 }
41
42 if req.Lxm == "com.atproto.server.getServiceAuth" {
43 return helpers.InputError(e, to.StringPtr("may not generate auth tokens recursively"))
44 }
45
46 maxExp := now + (60 * 30)
47 if req.Exp > maxExp {
48 return helpers.InputError(e, to.StringPtr("expiration too big. smoller please"))
49 }
50
51 repo := e.Get("repo").(*models.RepoActor)
52
53 header := map[string]string{
54 "alg": "ES256K",
55 "crv": "secp256k1",
56 "typ": "JWT",
57 }
58 hj, err := json.Marshal(header)
59 if err != nil {
60 s.logger.Error("error marshaling header", "error", err)
61 return helpers.ServerError(e, nil)
62 }
63
64 encheader := strings.TrimRight(base64.RawURLEncoding.EncodeToString(hj), "=")
65
66 payload := map[string]any{
67 "iss": repo.Repo.Did,
68 "aud": req.Aud,
69 "lxm": req.Lxm,
70 "jti": uuid.NewString(),
71 "exp": req.Exp,
72 "iat": now,
73 }
74 pj, err := json.Marshal(payload)
75 if err != nil {
76 s.logger.Error("error marashaling payload", "error", err)
77 return helpers.ServerError(e, nil)
78 }
79
80 encpayload := strings.TrimRight(base64.RawURLEncoding.EncodeToString(pj), "=")
81
82 input := fmt.Sprintf("%s.%s", encheader, encpayload)
83 hash := sha256.Sum256([]byte(input))
84
85 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey)
86 if err != nil {
87 s.logger.Error("can't load private key", "error", err)
88 return err
89 }
90
91 R, S, _, err := sk.SignRaw(rand.Reader, hash[:])
92 if err != nil {
93 s.logger.Error("error signing", "error", err)
94 return helpers.ServerError(e, nil)
95 }
96
97 rBytes := R.Bytes()
98 sBytes := S.Bytes()
99
100 rPadded := make([]byte, 32)
101 sPadded := make([]byte, 32)
102 copy(rPadded[32-len(rBytes):], rBytes)
103 copy(sPadded[32-len(sBytes):], sBytes)
104
105 rawsig := append(rPadded, sPadded...)
106 encsig := strings.TrimRight(base64.RawURLEncoding.EncodeToString(rawsig), "=")
107 token := fmt.Sprintf("%s.%s", input, encsig)
108
109 return e.JSON(200, map[string]string{
110 "token": token,
111 })
112}