An atproto PDS written in Go
1package server 2 3import ( 4 "crypto/sha256" 5 "encoding/base64" 6 "fmt" 7 "strings" 8 "time" 9 10 "github.com/Azure/go-autorest/autorest/to" 11 "github.com/golang-jwt/jwt/v4" 12 "github.com/haileyok/cocoon/internal/helpers" 13 "github.com/haileyok/cocoon/models" 14 "github.com/haileyok/cocoon/oauth/provider" 15 "github.com/labstack/echo/v4" 16 "gitlab.com/yawning/secp256k1-voi" 17 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec" 18 "gorm.io/gorm" 19) 20 21func (s *Server) handleAdminMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 22 return func(e echo.Context) error { 23 username, password, ok := e.Request().BasicAuth() 24 if !ok || username != "admin" || password != s.config.AdminPassword { 25 return helpers.InputError(e, to.StringPtr("Unauthorized")) 26 } 27 28 if err := next(e); err != nil { 29 e.Error(err) 30 } 31 32 return nil 33 } 34} 35 36func (s *Server) handleLegacySessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 37 return func(e echo.Context) error { 38 authheader := e.Request().Header.Get("authorization") 39 if authheader == "" { 40 return e.JSON(401, map[string]string{"error": "Unauthorized"}) 41 } 42 43 pts := strings.Split(authheader, " ") 44 if len(pts) != 2 { 45 return helpers.ServerError(e, nil) 46 } 47 48 // move on to oauth session middleware if this is a dpop token 49 if pts[0] == "DPoP" { 50 return next(e) 51 } 52 53 tokenstr := pts[1] 54 token, _, err := new(jwt.Parser).ParseUnverified(tokenstr, jwt.MapClaims{}) 55 claims, ok := token.Claims.(jwt.MapClaims) 56 if !ok { 57 return helpers.InvalidTokenError(e) 58 } 59 60 var did string 61 var repo *models.RepoActor 62 63 // service auth tokens 64 lxm, hasLxm := claims["lxm"] 65 if hasLxm { 66 pts := strings.Split(e.Request().URL.String(), "/") 67 if lxm != pts[len(pts)-1] { 68 s.logger.Error("service auth lxm incorrect", "lxm", lxm, "expected", pts[len(pts)-1], "error", err) 69 return helpers.InputError(e, nil) 70 } 71 72 maybeDid, ok := claims["iss"].(string) 73 if !ok { 74 s.logger.Error("no iss in service auth token", "error", err) 75 return helpers.InputError(e, nil) 76 } 77 did = maybeDid 78 79 maybeRepo, err := s.getRepoActorByDid(did) 80 if err != nil { 81 s.logger.Error("error fetching repo", "error", err) 82 return helpers.ServerError(e, nil) 83 } 84 repo = maybeRepo 85 } 86 87 if token.Header["alg"] != "ES256K" { 88 token, err = new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) { 89 if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok { 90 return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"]) 91 } 92 return s.privateKey.Public(), nil 93 }) 94 if err != nil { 95 s.logger.Error("error parsing jwt", "error", err) 96 return helpers.ExpiredTokenError(e) 97 } 98 99 if !token.Valid { 100 return helpers.InvalidTokenError(e) 101 } 102 } else { 103 kpts := strings.Split(tokenstr, ".") 104 signingInput := kpts[0] + "." + kpts[1] 105 hash := sha256.Sum256([]byte(signingInput)) 106 sigBytes, err := base64.RawURLEncoding.DecodeString(kpts[2]) 107 if err != nil { 108 s.logger.Error("error decoding signature bytes", "error", err) 109 return helpers.ServerError(e, nil) 110 } 111 112 if len(sigBytes) != 64 { 113 s.logger.Error("incorrect sigbytes length", "length", len(sigBytes)) 114 return helpers.ServerError(e, nil) 115 } 116 117 rBytes := sigBytes[:32] 118 sBytes := sigBytes[32:] 119 rr, _ := secp256k1.NewScalarFromBytes((*[32]byte)(rBytes)) 120 ss, _ := secp256k1.NewScalarFromBytes((*[32]byte)(sBytes)) 121 122 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey) 123 if err != nil { 124 s.logger.Error("can't load private key", "error", err) 125 return err 126 } 127 128 pubKey, ok := sk.Public().(*secp256k1secec.PublicKey) 129 if !ok { 130 s.logger.Error("error getting public key from sk") 131 return helpers.ServerError(e, nil) 132 } 133 134 verified := pubKey.VerifyRaw(hash[:], rr, ss) 135 if !verified { 136 s.logger.Error("error verifying", "error", err) 137 return helpers.ServerError(e, nil) 138 } 139 } 140 141 isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession" 142 scope, _ := claims["scope"].(string) 143 144 if isRefresh && scope != "com.atproto.refresh" { 145 return helpers.InvalidTokenError(e) 146 } else if !hasLxm && !isRefresh && scope != "com.atproto.access" { 147 return helpers.InvalidTokenError(e) 148 } 149 150 table := "tokens" 151 if isRefresh { 152 table = "refresh_tokens" 153 } 154 155 if isRefresh { 156 type Result struct { 157 Found bool 158 } 159 var result Result 160 if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil { 161 if err == gorm.ErrRecordNotFound { 162 return helpers.InvalidTokenError(e) 163 } 164 165 s.logger.Error("error getting token from db", "error", err) 166 return helpers.ServerError(e, nil) 167 } 168 169 if !result.Found { 170 return helpers.InvalidTokenError(e) 171 } 172 } 173 174 exp, ok := claims["exp"].(float64) 175 if !ok { 176 s.logger.Error("error getting iat from token") 177 return helpers.ServerError(e, nil) 178 } 179 180 if exp < float64(time.Now().UTC().Unix()) { 181 return helpers.ExpiredTokenError(e) 182 } 183 184 if repo == nil { 185 maybeRepo, err := s.getRepoActorByDid(claims["sub"].(string)) 186 if err != nil { 187 s.logger.Error("error fetching repo", "error", err) 188 return helpers.ServerError(e, nil) 189 } 190 repo = maybeRepo 191 did = repo.Repo.Did 192 } 193 194 e.Set("repo", repo) 195 e.Set("did", did) 196 e.Set("token", tokenstr) 197 198 if err := next(e); err != nil { 199 return helpers.InvalidTokenError(e) 200 } 201 202 return nil 203 } 204} 205 206func (s *Server) handleOauthSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 207 return func(e echo.Context) error { 208 authheader := e.Request().Header.Get("authorization") 209 if authheader == "" { 210 return e.JSON(401, map[string]string{"error": "Unauthorized"}) 211 } 212 213 pts := strings.Split(authheader, " ") 214 if len(pts) != 2 { 215 return helpers.ServerError(e, nil) 216 } 217 218 if pts[0] != "DPoP" { 219 return next(e) 220 } 221 222 accessToken := pts[1] 223 224 nonce := s.oauthProvider.NextNonce() 225 if nonce != "" { 226 e.Response().Header().Set("DPoP-Nonce", nonce) 227 e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce") 228 } 229 230 proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, "https://"+s.config.Hostname+e.Request().URL.String(), e.Request().Header, to.StringPtr(accessToken)) 231 if err != nil { 232 s.logger.Error("invalid dpop proof", "error", err) 233 return helpers.InputError(e, to.StringPtr(err.Error())) 234 } 235 236 var oauthToken provider.OauthToken 237 if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE token = ?", nil, accessToken).Scan(&oauthToken).Error; err != nil { 238 s.logger.Error("error finding access token in db", "error", err) 239 return helpers.InputError(e, nil) 240 } 241 242 if oauthToken.Token == "" { 243 return helpers.InvalidTokenError(e) 244 } 245 246 if *oauthToken.Parameters.DpopJkt != proof.JKT { 247 s.logger.Error("jkt mismatch", "token", oauthToken.Parameters.DpopJkt, "proof", proof.JKT) 248 return helpers.InputError(e, to.StringPtr("dpop jkt mismatch")) 249 } 250 251 if time.Now().After(oauthToken.ExpiresAt) { 252 return helpers.ExpiredTokenError(e) 253 } 254 255 repo, err := s.getRepoActorByDid(oauthToken.Sub) 256 if err != nil { 257 s.logger.Error("could not find actor in db", "error", err) 258 return helpers.ServerError(e, nil) 259 } 260 261 e.Set("repo", repo) 262 e.Set("did", repo.Repo.Did) 263 e.Set("token", accessToken) 264 e.Set("scopes", strings.Split(oauthToken.Parameters.Scope, " ")) 265 266 return next(e) 267 } 268}