1package server
2
3import (
4 "crypto/sha256"
5 "encoding/base64"
6 "fmt"
7 "strings"
8 "time"
9
10 "github.com/Azure/go-autorest/autorest/to"
11 "github.com/golang-jwt/jwt/v4"
12 "github.com/haileyok/cocoon/internal/helpers"
13 "github.com/haileyok/cocoon/models"
14 "github.com/haileyok/cocoon/oauth/provider"
15 "github.com/labstack/echo/v4"
16 "gitlab.com/yawning/secp256k1-voi"
17 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec"
18 "gorm.io/gorm"
19)
20
21func (s *Server) handleAdminMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
22 return func(e echo.Context) error {
23 username, password, ok := e.Request().BasicAuth()
24 if !ok || username != "admin" || password != s.config.AdminPassword {
25 return helpers.InputError(e, to.StringPtr("Unauthorized"))
26 }
27
28 if err := next(e); err != nil {
29 e.Error(err)
30 }
31
32 return nil
33 }
34}
35
36func (s *Server) handleLegacySessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
37 return func(e echo.Context) error {
38 authheader := e.Request().Header.Get("authorization")
39 if authheader == "" {
40 return e.JSON(401, map[string]string{"error": "Unauthorized"})
41 }
42
43 pts := strings.Split(authheader, " ")
44 if len(pts) != 2 {
45 return helpers.ServerError(e, nil)
46 }
47
48 // move on to oauth session middleware if this is a dpop token
49 if pts[0] == "DPoP" {
50 return next(e)
51 }
52
53 tokenstr := pts[1]
54 token, _, err := new(jwt.Parser).ParseUnverified(tokenstr, jwt.MapClaims{})
55 claims, ok := token.Claims.(jwt.MapClaims)
56 if !ok {
57 return helpers.InputError(e, to.StringPtr("InvalidToken"))
58 }
59
60 var did string
61 var repo *models.RepoActor
62
63 // service auth tokens
64 lxm, hasLxm := claims["lxm"]
65 if hasLxm {
66 pts := strings.Split(e.Request().URL.String(), "/")
67 if lxm != pts[len(pts)-1] {
68 s.logger.Error("service auth lxm incorrect", "lxm", lxm, "expected", pts[len(pts)-1], "error", err)
69 return helpers.InputError(e, nil)
70 }
71
72 maybeDid, ok := claims["iss"].(string)
73 if !ok {
74 s.logger.Error("no iss in service auth token", "error", err)
75 return helpers.InputError(e, nil)
76 }
77 did = maybeDid
78
79 maybeRepo, err := s.getRepoActorByDid(did)
80 if err != nil {
81 s.logger.Error("error fetching repo", "error", err)
82 return helpers.ServerError(e, nil)
83 }
84 repo = maybeRepo
85 }
86
87 if token.Header["alg"] != "ES256K" {
88 token, err = new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) {
89 if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok {
90 return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"])
91 }
92 return s.privateKey.Public(), nil
93 })
94 if err != nil {
95 s.logger.Error("error parsing jwt", "error", err)
96 // NOTE: https://github.com/bluesky-social/atproto/discussions/3319
97 return e.JSON(400, map[string]string{"error": "ExpiredToken", "message": "token has expired"})
98 }
99
100 if !token.Valid {
101 return helpers.InputError(e, to.StringPtr("InvalidToken"))
102 }
103 } else {
104 kpts := strings.Split(tokenstr, ".")
105 signingInput := kpts[0] + "." + kpts[1]
106 hash := sha256.Sum256([]byte(signingInput))
107 sigBytes, err := base64.RawURLEncoding.DecodeString(kpts[2])
108 if err != nil {
109 s.logger.Error("error decoding signature bytes", "error", err)
110 return helpers.ServerError(e, nil)
111 }
112
113 if len(sigBytes) != 64 {
114 s.logger.Error("incorrect sigbytes length", "length", len(sigBytes))
115 return helpers.ServerError(e, nil)
116 }
117
118 rBytes := sigBytes[:32]
119 sBytes := sigBytes[32:]
120 rr, _ := secp256k1.NewScalarFromBytes((*[32]byte)(rBytes))
121 ss, _ := secp256k1.NewScalarFromBytes((*[32]byte)(sBytes))
122
123 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey)
124 if err != nil {
125 s.logger.Error("can't load private key", "error", err)
126 return err
127 }
128
129 pubKey, ok := sk.Public().(*secp256k1secec.PublicKey)
130 if !ok {
131 s.logger.Error("error getting public key from sk")
132 return helpers.ServerError(e, nil)
133 }
134
135 verified := pubKey.VerifyRaw(hash[:], rr, ss)
136 if !verified {
137 s.logger.Error("error verifying", "error", err)
138 return helpers.ServerError(e, nil)
139 }
140 }
141
142 isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession"
143 scope, _ := claims["scope"].(string)
144
145 if isRefresh && scope != "com.atproto.refresh" {
146 return helpers.InputError(e, to.StringPtr("InvalidToken"))
147 } else if !hasLxm && !isRefresh && scope != "com.atproto.access" {
148 return helpers.InputError(e, to.StringPtr("InvalidToken"))
149 }
150
151 table := "tokens"
152 if isRefresh {
153 table = "refresh_tokens"
154 }
155
156 if isRefresh {
157 type Result struct {
158 Found bool
159 }
160 var result Result
161 if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil {
162 if err == gorm.ErrRecordNotFound {
163 return helpers.InputError(e, to.StringPtr("InvalidToken"))
164 }
165
166 s.logger.Error("error getting token from db", "error", err)
167 return helpers.ServerError(e, nil)
168 }
169
170 if !result.Found {
171 return helpers.InputError(e, to.StringPtr("InvalidToken"))
172 }
173 }
174
175 exp, ok := claims["exp"].(float64)
176 if !ok {
177 s.logger.Error("error getting iat from token")
178 return helpers.ServerError(e, nil)
179 }
180
181 if exp < float64(time.Now().UTC().Unix()) {
182 return helpers.InputError(e, to.StringPtr("ExpiredToken"))
183 }
184
185 if repo == nil {
186 maybeRepo, err := s.getRepoActorByDid(claims["sub"].(string))
187 if err != nil {
188 s.logger.Error("error fetching repo", "error", err)
189 return helpers.ServerError(e, nil)
190 }
191 repo = maybeRepo
192 did = repo.Repo.Did
193 }
194
195 e.Set("repo", repo)
196 e.Set("did", did)
197 e.Set("token", tokenstr)
198
199 if err := next(e); err != nil {
200 e.Error(err)
201 }
202
203 return nil
204 }
205}
206
207func (s *Server) handleOauthSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
208 return func(e echo.Context) error {
209 authheader := e.Request().Header.Get("authorization")
210 if authheader == "" {
211 return e.JSON(401, map[string]string{"error": "Unauthorized"})
212 }
213
214 pts := strings.Split(authheader, " ")
215 if len(pts) != 2 {
216 return helpers.ServerError(e, nil)
217 }
218
219 if pts[0] != "DPoP" {
220 return next(e)
221 }
222
223 accessToken := pts[1]
224
225 nonce := s.oauthProvider.NextNonce()
226 if nonce != "" {
227 e.Response().Header().Set("DPoP-Nonce", nonce)
228 e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce")
229 }
230
231 proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, "https://"+s.config.Hostname+e.Request().URL.String(), e.Request().Header, to.StringPtr(accessToken))
232 if err != nil {
233 s.logger.Error("invalid dpop proof", "error", err)
234 return helpers.InputError(e, to.StringPtr(err.Error()))
235 }
236
237 var oauthToken provider.OauthToken
238 if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE token = ?", nil, accessToken).Scan(&oauthToken).Error; err != nil {
239 s.logger.Error("error finding access token in db", "error", err)
240 return helpers.InputError(e, nil)
241 }
242
243 if oauthToken.Token == "" {
244 return helpers.InputError(e, to.StringPtr("InvalidToken"))
245 }
246
247 if *oauthToken.Parameters.DpopJkt != proof.JKT {
248 s.logger.Error("jkt mismatch", "token", oauthToken.Parameters.DpopJkt, "proof", proof.JKT)
249 return helpers.InputError(e, to.StringPtr("dpop jkt mismatch"))
250 }
251
252 if time.Now().After(oauthToken.ExpiresAt) {
253 return e.JSON(400, map[string]string{"error": "ExpiredToken", "message": "token has expired"})
254 }
255
256 repo, err := s.getRepoActorByDid(oauthToken.Sub)
257 if err != nil {
258 s.logger.Error("could not find actor in db", "error", err)
259 return helpers.ServerError(e, nil)
260 }
261
262 e.Set("repo", repo)
263 e.Set("did", repo.Repo.Did)
264 e.Set("token", accessToken)
265 e.Set("scopes", strings.Split(oauthToken.Parameters.Scope, " "))
266
267 return next(e)
268 }
269}