1package server
2
3import (
4 "crypto/sha256"
5 "encoding/base64"
6 "errors"
7 "fmt"
8 "strings"
9 "time"
10
11 "github.com/Azure/go-autorest/autorest/to"
12 "github.com/golang-jwt/jwt/v4"
13 "github.com/haileyok/cocoon/internal/helpers"
14 "github.com/haileyok/cocoon/models"
15 "github.com/haileyok/cocoon/oauth/dpop"
16 "github.com/haileyok/cocoon/oauth/provider"
17 "github.com/labstack/echo/v4"
18 "gitlab.com/yawning/secp256k1-voi"
19 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec"
20 "gorm.io/gorm"
21)
22
23func (s *Server) handleAdminMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
24 return func(e echo.Context) error {
25 username, password, ok := e.Request().BasicAuth()
26 if !ok || username != "admin" || password != s.config.AdminPassword {
27 return helpers.InputError(e, to.StringPtr("Unauthorized"))
28 }
29
30 if err := next(e); err != nil {
31 e.Error(err)
32 }
33
34 return nil
35 }
36}
37
38func (s *Server) handleLegacySessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
39 return func(e echo.Context) error {
40 authheader := e.Request().Header.Get("authorization")
41 if authheader == "" {
42 return e.JSON(401, map[string]string{"error": "Unauthorized"})
43 }
44
45 pts := strings.Split(authheader, " ")
46 if len(pts) != 2 {
47 return helpers.ServerError(e, nil)
48 }
49
50 // move on to oauth session middleware if this is a dpop token
51 if pts[0] == "DPoP" {
52 return next(e)
53 }
54
55 tokenstr := pts[1]
56 token, _, err := new(jwt.Parser).ParseUnverified(tokenstr, jwt.MapClaims{})
57 claims, ok := token.Claims.(jwt.MapClaims)
58 if !ok {
59 return helpers.InvalidTokenError(e)
60 }
61
62 var did string
63 var repo *models.RepoActor
64
65 // service auth tokens
66 lxm, hasLxm := claims["lxm"]
67 if hasLxm {
68 pts := strings.Split(e.Request().URL.String(), "/")
69 if lxm != pts[len(pts)-1] {
70 s.logger.Error("service auth lxm incorrect", "lxm", lxm, "expected", pts[len(pts)-1], "error", err)
71 return helpers.InputError(e, nil)
72 }
73
74 maybeDid, ok := claims["iss"].(string)
75 if !ok {
76 s.logger.Error("no iss in service auth token", "error", err)
77 return helpers.InputError(e, nil)
78 }
79 did = maybeDid
80
81 maybeRepo, err := s.getRepoActorByDid(did)
82 if err != nil {
83 s.logger.Error("error fetching repo", "error", err)
84 return helpers.ServerError(e, nil)
85 }
86 repo = maybeRepo
87 }
88
89 if token.Header["alg"] != "ES256K" {
90 token, err = new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) {
91 if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok {
92 return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"])
93 }
94 return s.privateKey.Public(), nil
95 })
96 if err != nil {
97 s.logger.Error("error parsing jwt", "error", err)
98 return helpers.ExpiredTokenError(e)
99 }
100
101 if !token.Valid {
102 return helpers.InvalidTokenError(e)
103 }
104 } else {
105 kpts := strings.Split(tokenstr, ".")
106 signingInput := kpts[0] + "." + kpts[1]
107 hash := sha256.Sum256([]byte(signingInput))
108 sigBytes, err := base64.RawURLEncoding.DecodeString(kpts[2])
109 if err != nil {
110 s.logger.Error("error decoding signature bytes", "error", err)
111 return helpers.ServerError(e, nil)
112 }
113
114 if len(sigBytes) != 64 {
115 s.logger.Error("incorrect sigbytes length", "length", len(sigBytes))
116 return helpers.ServerError(e, nil)
117 }
118
119 rBytes := sigBytes[:32]
120 sBytes := sigBytes[32:]
121 rr, _ := secp256k1.NewScalarFromBytes((*[32]byte)(rBytes))
122 ss, _ := secp256k1.NewScalarFromBytes((*[32]byte)(sBytes))
123
124 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey)
125 if err != nil {
126 s.logger.Error("can't load private key", "error", err)
127 return err
128 }
129
130 pubKey, ok := sk.Public().(*secp256k1secec.PublicKey)
131 if !ok {
132 s.logger.Error("error getting public key from sk")
133 return helpers.ServerError(e, nil)
134 }
135
136 verified := pubKey.VerifyRaw(hash[:], rr, ss)
137 if !verified {
138 s.logger.Error("error verifying", "error", err)
139 return helpers.ServerError(e, nil)
140 }
141 }
142
143 isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession"
144 scope, _ := claims["scope"].(string)
145
146 if isRefresh && scope != "com.atproto.refresh" {
147 return helpers.InvalidTokenError(e)
148 } else if !hasLxm && !isRefresh && scope != "com.atproto.access" {
149 return helpers.InvalidTokenError(e)
150 }
151
152 table := "tokens"
153 if isRefresh {
154 table = "refresh_tokens"
155 }
156
157 if isRefresh {
158 type Result struct {
159 Found bool
160 }
161 var result Result
162 if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil {
163 if err == gorm.ErrRecordNotFound {
164 return helpers.InvalidTokenError(e)
165 }
166
167 s.logger.Error("error getting token from db", "error", err)
168 return helpers.ServerError(e, nil)
169 }
170
171 if !result.Found {
172 return helpers.InvalidTokenError(e)
173 }
174 }
175
176 exp, ok := claims["exp"].(float64)
177 if !ok {
178 s.logger.Error("error getting iat from token")
179 return helpers.ServerError(e, nil)
180 }
181
182 if exp < float64(time.Now().UTC().Unix()) {
183 return helpers.ExpiredTokenError(e)
184 }
185
186 if repo == nil {
187 maybeRepo, err := s.getRepoActorByDid(claims["sub"].(string))
188 if err != nil {
189 s.logger.Error("error fetching repo", "error", err)
190 return helpers.ServerError(e, nil)
191 }
192 repo = maybeRepo
193 did = repo.Repo.Did
194 }
195
196 e.Set("repo", repo)
197 e.Set("did", did)
198 e.Set("token", tokenstr)
199
200 if err := next(e); err != nil {
201 return helpers.InvalidTokenError(e)
202 }
203
204 return nil
205 }
206}
207
208func (s *Server) handleOauthSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
209 return func(e echo.Context) error {
210 authheader := e.Request().Header.Get("authorization")
211 if authheader == "" {
212 return e.JSON(401, map[string]string{"error": "Unauthorized"})
213 }
214
215 pts := strings.Split(authheader, " ")
216 if len(pts) != 2 {
217 return helpers.ServerError(e, nil)
218 }
219
220 if pts[0] != "DPoP" {
221 return next(e)
222 }
223
224 accessToken := pts[1]
225
226 nonce := s.oauthProvider.NextNonce()
227 if nonce != "" {
228 e.Response().Header().Set("DPoP-Nonce", nonce)
229 e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce")
230 }
231
232 proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, "https://"+s.config.Hostname+e.Request().URL.String(), e.Request().Header, to.StringPtr(accessToken))
233 if err != nil {
234 if errors.Is(err, dpop.ErrUseDpopNonce) {
235 e.Response().Header().Set("WWW-Authenticate", `DPoP error="use_dpop_nonce"`)
236 e.Response().Header().Add("access-control-expose-headers", "WWW-Authenticate")
237 return e.JSON(401, map[string]string{
238 "error": "use_dpop_nonce",
239 })
240 }
241 s.logger.Error("invalid dpop proof", "error", err)
242 return helpers.InputError(e, nil)
243 }
244
245 var oauthToken provider.OauthToken
246 if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE token = ?", nil, accessToken).Scan(&oauthToken).Error; err != nil {
247 s.logger.Error("error finding access token in db", "error", err)
248 return helpers.InputError(e, nil)
249 }
250
251 if oauthToken.Token == "" {
252 return helpers.InvalidTokenError(e)
253 }
254
255 if *oauthToken.Parameters.DpopJkt != proof.JKT {
256 s.logger.Error("jkt mismatch", "token", oauthToken.Parameters.DpopJkt, "proof", proof.JKT)
257 return helpers.InputError(e, to.StringPtr("dpop jkt mismatch"))
258 }
259
260 if time.Now().After(oauthToken.ExpiresAt) {
261 e.Response().Header().Set("WWW-Authenticate", `DPoP error="invalid_token", error_description="Token expired"`)
262 e.Response().Header().Add("access-control-expose-headers", "WWW-Authenticate")
263 return e.JSON(401, map[string]string{
264 "error": "invalid_token",
265 "error_description": "Token expired",
266 })
267 }
268
269 repo, err := s.getRepoActorByDid(oauthToken.Sub)
270 if err != nil {
271 s.logger.Error("could not find actor in db", "error", err)
272 return helpers.ServerError(e, nil)
273 }
274
275 e.Set("repo", repo)
276 e.Set("did", repo.Repo.Did)
277 e.Set("token", accessToken)
278 e.Set("scopes", strings.Split(oauthToken.Parameters.Scope, " "))
279
280 return next(e)
281 }
282}