An atproto PDS written in Go
at v0.5.1 7.5 kB view raw
1package server 2 3import ( 4 "crypto/sha256" 5 "encoding/base64" 6 "errors" 7 "fmt" 8 "strings" 9 "time" 10 11 "github.com/Azure/go-autorest/autorest/to" 12 "github.com/golang-jwt/jwt/v4" 13 "github.com/haileyok/cocoon/internal/helpers" 14 "github.com/haileyok/cocoon/models" 15 "github.com/haileyok/cocoon/oauth/dpop" 16 "github.com/haileyok/cocoon/oauth/provider" 17 "github.com/labstack/echo/v4" 18 "gitlab.com/yawning/secp256k1-voi" 19 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec" 20 "gorm.io/gorm" 21) 22 23func (s *Server) handleAdminMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 24 return func(e echo.Context) error { 25 username, password, ok := e.Request().BasicAuth() 26 if !ok || username != "admin" || password != s.config.AdminPassword { 27 return helpers.InputError(e, to.StringPtr("Unauthorized")) 28 } 29 30 if err := next(e); err != nil { 31 e.Error(err) 32 } 33 34 return nil 35 } 36} 37 38func (s *Server) handleLegacySessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 39 return func(e echo.Context) error { 40 authheader := e.Request().Header.Get("authorization") 41 if authheader == "" { 42 return e.JSON(401, map[string]string{"error": "Unauthorized"}) 43 } 44 45 pts := strings.Split(authheader, " ") 46 if len(pts) != 2 { 47 return helpers.ServerError(e, nil) 48 } 49 50 // move on to oauth session middleware if this is a dpop token 51 if pts[0] == "DPoP" { 52 return next(e) 53 } 54 55 tokenstr := pts[1] 56 token, _, err := new(jwt.Parser).ParseUnverified(tokenstr, jwt.MapClaims{}) 57 claims, ok := token.Claims.(jwt.MapClaims) 58 if !ok { 59 return helpers.InvalidTokenError(e) 60 } 61 62 var did string 63 var repo *models.RepoActor 64 65 // service auth tokens 66 lxm, hasLxm := claims["lxm"] 67 if hasLxm { 68 pts := strings.Split(e.Request().URL.String(), "/") 69 if lxm != pts[len(pts)-1] { 70 s.logger.Error("service auth lxm incorrect", "lxm", lxm, "expected", pts[len(pts)-1], "error", err) 71 return helpers.InputError(e, nil) 72 } 73 74 maybeDid, ok := claims["iss"].(string) 75 if !ok { 76 s.logger.Error("no iss in service auth token", "error", err) 77 return helpers.InputError(e, nil) 78 } 79 did = maybeDid 80 81 maybeRepo, err := s.getRepoActorByDid(did) 82 if err != nil { 83 s.logger.Error("error fetching repo", "error", err) 84 return helpers.ServerError(e, nil) 85 } 86 repo = maybeRepo 87 } 88 89 if token.Header["alg"] != "ES256K" { 90 token, err = new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) { 91 if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok { 92 return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"]) 93 } 94 return s.privateKey.Public(), nil 95 }) 96 if err != nil { 97 s.logger.Error("error parsing jwt", "error", err) 98 return helpers.ExpiredTokenError(e) 99 } 100 101 if !token.Valid { 102 return helpers.InvalidTokenError(e) 103 } 104 } else { 105 kpts := strings.Split(tokenstr, ".") 106 signingInput := kpts[0] + "." + kpts[1] 107 hash := sha256.Sum256([]byte(signingInput)) 108 sigBytes, err := base64.RawURLEncoding.DecodeString(kpts[2]) 109 if err != nil { 110 s.logger.Error("error decoding signature bytes", "error", err) 111 return helpers.ServerError(e, nil) 112 } 113 114 if len(sigBytes) != 64 { 115 s.logger.Error("incorrect sigbytes length", "length", len(sigBytes)) 116 return helpers.ServerError(e, nil) 117 } 118 119 rBytes := sigBytes[:32] 120 sBytes := sigBytes[32:] 121 rr, _ := secp256k1.NewScalarFromBytes((*[32]byte)(rBytes)) 122 ss, _ := secp256k1.NewScalarFromBytes((*[32]byte)(sBytes)) 123 124 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey) 125 if err != nil { 126 s.logger.Error("can't load private key", "error", err) 127 return err 128 } 129 130 pubKey, ok := sk.Public().(*secp256k1secec.PublicKey) 131 if !ok { 132 s.logger.Error("error getting public key from sk") 133 return helpers.ServerError(e, nil) 134 } 135 136 verified := pubKey.VerifyRaw(hash[:], rr, ss) 137 if !verified { 138 s.logger.Error("error verifying", "error", err) 139 return helpers.ServerError(e, nil) 140 } 141 } 142 143 isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession" 144 scope, _ := claims["scope"].(string) 145 146 if isRefresh && scope != "com.atproto.refresh" { 147 return helpers.InvalidTokenError(e) 148 } else if !hasLxm && !isRefresh && scope != "com.atproto.access" { 149 return helpers.InvalidTokenError(e) 150 } 151 152 table := "tokens" 153 if isRefresh { 154 table = "refresh_tokens" 155 } 156 157 if isRefresh { 158 type Result struct { 159 Found bool 160 } 161 var result Result 162 if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil { 163 if err == gorm.ErrRecordNotFound { 164 return helpers.InvalidTokenError(e) 165 } 166 167 s.logger.Error("error getting token from db", "error", err) 168 return helpers.ServerError(e, nil) 169 } 170 171 if !result.Found { 172 return helpers.InvalidTokenError(e) 173 } 174 } 175 176 exp, ok := claims["exp"].(float64) 177 if !ok { 178 s.logger.Error("error getting iat from token") 179 return helpers.ServerError(e, nil) 180 } 181 182 if exp < float64(time.Now().UTC().Unix()) { 183 return helpers.ExpiredTokenError(e) 184 } 185 186 if repo == nil { 187 maybeRepo, err := s.getRepoActorByDid(claims["sub"].(string)) 188 if err != nil { 189 s.logger.Error("error fetching repo", "error", err) 190 return helpers.ServerError(e, nil) 191 } 192 repo = maybeRepo 193 did = repo.Repo.Did 194 } 195 196 e.Set("repo", repo) 197 e.Set("did", did) 198 e.Set("token", tokenstr) 199 200 if err := next(e); err != nil { 201 return helpers.InvalidTokenError(e) 202 } 203 204 return nil 205 } 206} 207 208func (s *Server) handleOauthSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 209 return func(e echo.Context) error { 210 authheader := e.Request().Header.Get("authorization") 211 if authheader == "" { 212 return e.JSON(401, map[string]string{"error": "Unauthorized"}) 213 } 214 215 pts := strings.Split(authheader, " ") 216 if len(pts) != 2 { 217 return helpers.ServerError(e, nil) 218 } 219 220 if pts[0] != "DPoP" { 221 return next(e) 222 } 223 224 accessToken := pts[1] 225 226 nonce := s.oauthProvider.NextNonce() 227 if nonce != "" { 228 e.Response().Header().Set("DPoP-Nonce", nonce) 229 e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce") 230 } 231 232 proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, "https://"+s.config.Hostname+e.Request().URL.String(), e.Request().Header, to.StringPtr(accessToken)) 233 if err != nil { 234 if errors.Is(err, dpop.ErrUseDpopNonce) { 235 return e.JSON(400, map[string]string{ 236 "error": "use_dpop_nonce", 237 }) 238 } 239 s.logger.Error("invalid dpop proof", "error", err) 240 return helpers.InputError(e, nil) 241 } 242 243 var oauthToken provider.OauthToken 244 if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE token = ?", nil, accessToken).Scan(&oauthToken).Error; err != nil { 245 s.logger.Error("error finding access token in db", "error", err) 246 return helpers.InputError(e, nil) 247 } 248 249 if oauthToken.Token == "" { 250 return helpers.InvalidTokenError(e) 251 } 252 253 if *oauthToken.Parameters.DpopJkt != proof.JKT { 254 s.logger.Error("jkt mismatch", "token", oauthToken.Parameters.DpopJkt, "proof", proof.JKT) 255 return helpers.InputError(e, to.StringPtr("dpop jkt mismatch")) 256 } 257 258 if time.Now().After(oauthToken.ExpiresAt) { 259 return helpers.ExpiredTokenError(e) 260 } 261 262 repo, err := s.getRepoActorByDid(oauthToken.Sub) 263 if err != nil { 264 s.logger.Error("could not find actor in db", "error", err) 265 return helpers.ServerError(e, nil) 266 } 267 268 e.Set("repo", repo) 269 e.Set("did", repo.Repo.Did) 270 e.Set("token", accessToken) 271 e.Set("scopes", strings.Split(oauthToken.Parameters.Scope, " ")) 272 273 return next(e) 274 } 275}