An atproto PDS written in Go
1package server 2 3import ( 4 "crypto/sha256" 5 "encoding/base64" 6 "fmt" 7 "strings" 8 "time" 9 10 "github.com/Azure/go-autorest/autorest/to" 11 "github.com/golang-jwt/jwt/v4" 12 "github.com/haileyok/cocoon/internal/helpers" 13 "github.com/haileyok/cocoon/models" 14 "github.com/haileyok/cocoon/oauth/provider" 15 "github.com/labstack/echo/v4" 16 "gitlab.com/yawning/secp256k1-voi" 17 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec" 18 "gorm.io/gorm" 19) 20 21func (s *Server) handleAdminMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 22 return func(e echo.Context) error { 23 username, password, ok := e.Request().BasicAuth() 24 if !ok || username != "admin" || password != s.config.AdminPassword { 25 return helpers.InputError(e, to.StringPtr("Unauthorized")) 26 } 27 28 if err := next(e); err != nil { 29 e.Error(err) 30 } 31 32 return nil 33 } 34} 35 36func (s *Server) handleLegacySessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 37 return func(e echo.Context) error { 38 authheader := e.Request().Header.Get("authorization") 39 if authheader == "" { 40 return e.JSON(401, map[string]string{"error": "Unauthorized"}) 41 } 42 43 pts := strings.Split(authheader, " ") 44 if len(pts) != 2 { 45 return helpers.ServerError(e, nil) 46 } 47 48 // move on to oauth session middleware if this is a dpop token 49 if pts[0] == "DPoP" { 50 return next(e) 51 } 52 53 tokenstr := pts[1] 54 token, _, err := new(jwt.Parser).ParseUnverified(tokenstr, jwt.MapClaims{}) 55 claims, ok := token.Claims.(jwt.MapClaims) 56 if !ok { 57 return helpers.InputError(e, to.StringPtr("InvalidToken")) 58 } 59 60 var did string 61 var repo *models.RepoActor 62 63 // service auth tokens 64 lxm, hasLxm := claims["lxm"] 65 if hasLxm { 66 pts := strings.Split(e.Request().URL.String(), "/") 67 if lxm != pts[len(pts)-1] { 68 s.logger.Error("service auth lxm incorrect", "lxm", lxm, "expected", pts[len(pts)-1], "error", err) 69 return helpers.InputError(e, nil) 70 } 71 72 maybeDid, ok := claims["iss"].(string) 73 if !ok { 74 s.logger.Error("no iss in service auth token", "error", err) 75 return helpers.InputError(e, nil) 76 } 77 did = maybeDid 78 79 maybeRepo, err := s.getRepoActorByDid(did) 80 if err != nil { 81 s.logger.Error("error fetching repo", "error", err) 82 return helpers.ServerError(e, nil) 83 } 84 repo = maybeRepo 85 } 86 87 if token.Header["alg"] != "ES256K" { 88 token, err = new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) { 89 if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok { 90 return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"]) 91 } 92 return s.privateKey.Public(), nil 93 }) 94 if err != nil { 95 s.logger.Error("error parsing jwt", "error", err) 96 // NOTE: https://github.com/bluesky-social/atproto/discussions/3319 97 return e.JSON(400, map[string]string{"error": "ExpiredToken", "message": "token has expired"}) 98 } 99 100 if !token.Valid { 101 return helpers.InputError(e, to.StringPtr("InvalidToken")) 102 } 103 } else { 104 kpts := strings.Split(tokenstr, ".") 105 signingInput := kpts[0] + "." + kpts[1] 106 hash := sha256.Sum256([]byte(signingInput)) 107 sigBytes, err := base64.RawURLEncoding.DecodeString(kpts[2]) 108 if err != nil { 109 s.logger.Error("error decoding signature bytes", "error", err) 110 return helpers.ServerError(e, nil) 111 } 112 113 if len(sigBytes) != 64 { 114 s.logger.Error("incorrect sigbytes length", "length", len(sigBytes)) 115 return helpers.ServerError(e, nil) 116 } 117 118 rBytes := sigBytes[:32] 119 sBytes := sigBytes[32:] 120 rr, _ := secp256k1.NewScalarFromBytes((*[32]byte)(rBytes)) 121 ss, _ := secp256k1.NewScalarFromBytes((*[32]byte)(sBytes)) 122 123 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey) 124 if err != nil { 125 s.logger.Error("can't load private key", "error", err) 126 return err 127 } 128 129 pubKey, ok := sk.Public().(*secp256k1secec.PublicKey) 130 if !ok { 131 s.logger.Error("error getting public key from sk") 132 return helpers.ServerError(e, nil) 133 } 134 135 verified := pubKey.VerifyRaw(hash[:], rr, ss) 136 if !verified { 137 s.logger.Error("error verifying", "error", err) 138 return helpers.ServerError(e, nil) 139 } 140 } 141 142 isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession" 143 scope, _ := claims["scope"].(string) 144 145 if isRefresh && scope != "com.atproto.refresh" { 146 return helpers.InputError(e, to.StringPtr("InvalidToken")) 147 } else if !hasLxm && !isRefresh && scope != "com.atproto.access" { 148 return helpers.InputError(e, to.StringPtr("InvalidToken")) 149 } 150 151 table := "tokens" 152 if isRefresh { 153 table = "refresh_tokens" 154 } 155 156 if isRefresh { 157 type Result struct { 158 Found bool 159 } 160 var result Result 161 if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil { 162 if err == gorm.ErrRecordNotFound { 163 return helpers.InputError(e, to.StringPtr("InvalidToken")) 164 } 165 166 s.logger.Error("error getting token from db", "error", err) 167 return helpers.ServerError(e, nil) 168 } 169 170 if !result.Found { 171 return helpers.InputError(e, to.StringPtr("InvalidToken")) 172 } 173 } 174 175 exp, ok := claims["exp"].(float64) 176 if !ok { 177 s.logger.Error("error getting iat from token") 178 return helpers.ServerError(e, nil) 179 } 180 181 if exp < float64(time.Now().UTC().Unix()) { 182 return helpers.InputError(e, to.StringPtr("ExpiredToken")) 183 } 184 185 if repo == nil { 186 maybeRepo, err := s.getRepoActorByDid(claims["sub"].(string)) 187 if err != nil { 188 s.logger.Error("error fetching repo", "error", err) 189 return helpers.ServerError(e, nil) 190 } 191 repo = maybeRepo 192 did = repo.Repo.Did 193 } 194 195 e.Set("repo", repo) 196 e.Set("did", did) 197 e.Set("token", tokenstr) 198 199 if err := next(e); err != nil { 200 e.Error(err) 201 } 202 203 return nil 204 } 205} 206 207func (s *Server) handleOauthSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 208 return func(e echo.Context) error { 209 authheader := e.Request().Header.Get("authorization") 210 if authheader == "" { 211 return e.JSON(401, map[string]string{"error": "Unauthorized"}) 212 } 213 214 pts := strings.Split(authheader, " ") 215 if len(pts) != 2 { 216 return helpers.ServerError(e, nil) 217 } 218 219 if pts[0] != "DPoP" { 220 return next(e) 221 } 222 223 accessToken := pts[1] 224 225 nonce := s.oauthProvider.NextNonce() 226 if nonce != "" { 227 e.Response().Header().Set("DPoP-Nonce", nonce) 228 e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce") 229 } 230 231 proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, "https://"+s.config.Hostname+e.Request().URL.String(), e.Request().Header, to.StringPtr(accessToken)) 232 if err != nil { 233 s.logger.Error("invalid dpop proof", "error", err) 234 return helpers.InputError(e, to.StringPtr(err.Error())) 235 } 236 237 var oauthToken provider.OauthToken 238 if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE token = ?", nil, accessToken).Scan(&oauthToken).Error; err != nil { 239 s.logger.Error("error finding access token in db", "error", err) 240 return helpers.InputError(e, nil) 241 } 242 243 if oauthToken.Token == "" { 244 return helpers.InputError(e, to.StringPtr("InvalidToken")) 245 } 246 247 if *oauthToken.Parameters.DpopJkt != proof.JKT { 248 s.logger.Error("jkt mismatch", "token", oauthToken.Parameters.DpopJkt, "proof", proof.JKT) 249 return helpers.InputError(e, to.StringPtr("dpop jkt mismatch")) 250 } 251 252 if time.Now().After(oauthToken.ExpiresAt) { 253 return e.JSON(400, map[string]string{"error": "ExpiredToken", "message": "token has expired"}) 254 } 255 256 repo, err := s.getRepoActorByDid(oauthToken.Sub) 257 if err != nil { 258 s.logger.Error("could not find actor in db", "error", err) 259 return helpers.ServerError(e, nil) 260 } 261 262 e.Set("repo", repo) 263 e.Set("did", repo.Repo.Did) 264 e.Set("token", accessToken) 265 e.Set("scopes", strings.Split(oauthToken.Parameters.Scope, " ")) 266 267 return next(e) 268 } 269}