An atproto PDS written in Go
at main 8.8 kB view raw
1package server 2 3import ( 4 "context" 5 "errors" 6 "fmt" 7 "strings" 8 "time" 9 10 "github.com/Azure/go-autorest/autorest/to" 11 "github.com/bluesky-social/indigo/api/atproto" 12 "github.com/bluesky-social/indigo/atproto/atcrypto" 13 "github.com/bluesky-social/indigo/events" 14 "github.com/bluesky-social/indigo/repo" 15 "github.com/bluesky-social/indigo/util" 16 "github.com/haileyok/cocoon/internal/helpers" 17 "github.com/haileyok/cocoon/models" 18 "github.com/labstack/echo/v4" 19 "golang.org/x/crypto/bcrypt" 20 "gorm.io/gorm" 21) 22 23type ComAtprotoServerCreateAccountRequest struct { 24 Email string `json:"email" validate:"required,email"` 25 Handle string `json:"handle" validate:"required,atproto-handle"` 26 Did *string `json:"did" validate:"atproto-did"` 27 Password string `json:"password" validate:"required"` 28 InviteCode string `json:"inviteCode" validate:"omitempty"` 29} 30 31type ComAtprotoServerCreateAccountResponse struct { 32 AccessJwt string `json:"accessJwt"` 33 RefreshJwt string `json:"refreshJwt"` 34 Handle string `json:"handle"` 35 Did string `json:"did"` 36} 37 38func (s *Server) handleCreateAccount(e echo.Context) error { 39 var request ComAtprotoServerCreateAccountRequest 40 41 if err := e.Bind(&request); err != nil { 42 s.logger.Error("error receiving request", "endpoint", "com.atproto.server.createAccount", "error", err) 43 return helpers.ServerError(e, nil) 44 } 45 46 request.Handle = strings.ToLower(request.Handle) 47 48 if err := e.Validate(request); err != nil { 49 s.logger.Error("error validating request", "endpoint", "com.atproto.server.createAccount", "error", err) 50 51 var verr ValidationError 52 if errors.As(err, &verr) { 53 if verr.Field == "Email" { 54 // TODO: what is this supposed to be? `InvalidEmail` isn't listed in doc 55 return helpers.InputError(e, to.StringPtr("InvalidEmail")) 56 } 57 58 if verr.Field == "Handle" { 59 return helpers.InputError(e, to.StringPtr("InvalidHandle")) 60 } 61 62 if verr.Field == "Password" { 63 return helpers.InputError(e, to.StringPtr("InvalidPassword")) 64 } 65 66 if verr.Field == "InviteCode" { 67 return helpers.InputError(e, to.StringPtr("InvalidInviteCode")) 68 } 69 } 70 } 71 72 var signupDid string 73 if request.Did != nil { 74 signupDid = *request.Did; 75 76 token := strings.TrimSpace(strings.Replace(e.Request().Header.Get("authorization"), "Bearer ", "", 1)) 77 if token == "" { 78 return helpers.UnauthorizedError(e, to.StringPtr("must authenticate to use an existing did")) 79 } 80 authDid, err := s.validateServiceAuth(e.Request().Context(), token, "com.atproto.server.createAccount") 81 82 if err != nil { 83 s.logger.Warn("error validating authorization token", "endpoint", "com.atproto.server.createAccount", "error", err) 84 return helpers.UnauthorizedError(e, to.StringPtr("invalid authorization token")) 85 } 86 87 if authDid != signupDid { 88 return helpers.ForbiddenError(e, to.StringPtr("auth did did not match signup did")) 89 } 90 } 91 92 // see if the handle is already taken 93 actor, err := s.getActorByHandle(request.Handle) 94 if err != nil && err != gorm.ErrRecordNotFound { 95 s.logger.Error("error looking up handle in db", "endpoint", "com.atproto.server.createAccount", "error", err) 96 return helpers.ServerError(e, nil) 97 } 98 if err == nil && actor.Did != signupDid { 99 return helpers.InputError(e, to.StringPtr("HandleNotAvailable")) 100 } 101 102 if did, err := s.passport.ResolveHandle(e.Request().Context(), request.Handle); err == nil && did != signupDid { 103 return helpers.InputError(e, to.StringPtr("HandleNotAvailable")) 104 } 105 106 var ic models.InviteCode 107 if s.config.RequireInvite { 108 if strings.TrimSpace(request.InviteCode) == "" { 109 return helpers.InputError(e, to.StringPtr("InvalidInviteCode")) 110 } 111 112 if err := s.db.Raw("SELECT * FROM invite_codes WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil { 113 if err == gorm.ErrRecordNotFound { 114 return helpers.InputError(e, to.StringPtr("InvalidInviteCode")) 115 } 116 s.logger.Error("error getting invite code from db", "error", err) 117 return helpers.ServerError(e, nil) 118 } 119 120 if ic.RemainingUseCount < 1 { 121 return helpers.InputError(e, to.StringPtr("InvalidInviteCode")) 122 } 123 } 124 125 // see if the email is already taken 126 existingRepo, err := s.getRepoByEmail(request.Email) 127 if err != nil && err != gorm.ErrRecordNotFound { 128 s.logger.Error("error looking up email in db", "endpoint", "com.atproto.server.createAccount", "error", err) 129 return helpers.ServerError(e, nil) 130 } 131 if err == nil && existingRepo.Did != signupDid { 132 return helpers.InputError(e, to.StringPtr("EmailNotAvailable")) 133 } 134 135 // TODO: unsupported domains 136 137 var k *atcrypto.PrivateKeyK256 138 139 if signupDid != "" { 140 reservedKey, err := s.getReservedKey(signupDid) 141 if err != nil { 142 s.logger.Error("error looking up reserved key", "error", err) 143 } 144 if reservedKey != nil { 145 k, err = atcrypto.ParsePrivateBytesK256(reservedKey.PrivateKey) 146 if err != nil { 147 s.logger.Error("error parsing reserved key", "error", err) 148 k = nil 149 } else { 150 defer func() { 151 if delErr := s.deleteReservedKey(reservedKey.KeyDid, reservedKey.Did); delErr != nil { 152 s.logger.Error("error deleting reserved key", "error", delErr) 153 } 154 }() 155 } 156 } 157 } 158 159 if k == nil { 160 k, err = atcrypto.GeneratePrivateKeyK256() 161 if err != nil { 162 s.logger.Error("error creating signing key", "endpoint", "com.atproto.server.createAccount", "error", err) 163 return helpers.ServerError(e, nil) 164 } 165 } 166 167 if signupDid == "" { 168 did, op, err := s.plcClient.CreateDID(k, "", request.Handle) 169 if err != nil { 170 s.logger.Error("error creating operation", "endpoint", "com.atproto.server.createAccount", "error", err) 171 return helpers.ServerError(e, nil) 172 } 173 174 if err := s.plcClient.SendOperation(e.Request().Context(), did, op); err != nil { 175 s.logger.Error("error sending plc op", "endpoint", "com.atproto.server.createAccount", "error", err) 176 return helpers.ServerError(e, nil) 177 } 178 signupDid = did 179 } 180 181 hashed, err := bcrypt.GenerateFromPassword([]byte(request.Password), 10) 182 if err != nil { 183 s.logger.Error("error hashing password", "error", err) 184 return helpers.ServerError(e, nil) 185 } 186 187 urepo := models.Repo{ 188 Did: signupDid, 189 CreatedAt: time.Now(), 190 Email: request.Email, 191 EmailVerificationCode: to.StringPtr(fmt.Sprintf("%s-%s", helpers.RandomVarchar(6), helpers.RandomVarchar(6))), 192 Password: string(hashed), 193 SigningKey: k.Bytes(), 194 } 195 196 if actor == nil { 197 actor = &models.Actor{ 198 Did: signupDid, 199 Handle: request.Handle, 200 } 201 202 if err := s.db.Create(&urepo, nil).Error; err != nil { 203 s.logger.Error("error inserting new repo", "error", err) 204 return helpers.ServerError(e, nil) 205 } 206 207 if err := s.db.Create(&actor, nil).Error; err != nil { 208 s.logger.Error("error inserting new actor", "error", err) 209 return helpers.ServerError(e, nil) 210 } 211 } else { 212 if err := s.db.Save(&actor, nil).Error; err != nil { 213 s.logger.Error("error inserting new actor", "error", err) 214 return helpers.ServerError(e, nil) 215 } 216 } 217 218 if request.Did == nil || *request.Did == "" { 219 bs := s.getBlockstore(signupDid) 220 r := repo.NewRepo(context.TODO(), signupDid, bs) 221 222 root, rev, err := r.Commit(context.TODO(), urepo.SignFor) 223 if err != nil { 224 s.logger.Error("error committing", "error", err) 225 return helpers.ServerError(e, nil) 226 } 227 228 if err := s.UpdateRepo(context.TODO(), urepo.Did, root, rev); err != nil { 229 s.logger.Error("error updating repo after commit", "error", err) 230 return helpers.ServerError(e, nil) 231 } 232 233 s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{ 234 RepoIdentity: &atproto.SyncSubscribeRepos_Identity{ 235 Did: urepo.Did, 236 Handle: to.StringPtr(request.Handle), 237 Seq: time.Now().UnixMicro(), // TODO: no 238 Time: time.Now().Format(util.ISO8601), 239 }, 240 }) 241 } 242 243 if s.config.RequireInvite { 244 if err := s.db.Raw("UPDATE invite_codes SET remaining_use_count = remaining_use_count - 1 WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil { 245 s.logger.Error("error decrementing use count", "error", err) 246 return helpers.ServerError(e, nil) 247 } 248 } 249 250 sess, err := s.createSession(&urepo) 251 if err != nil { 252 s.logger.Error("error creating new session", "error", err) 253 return helpers.ServerError(e, nil) 254 } 255 256 go func() { 257 if err := s.sendEmailVerification(urepo.Email, actor.Handle, *urepo.EmailVerificationCode); err != nil { 258 s.logger.Error("error sending email verification email", "error", err) 259 } 260 if err := s.sendWelcomeMail(urepo.Email, actor.Handle); err != nil { 261 s.logger.Error("error sending welcome email", "error", err) 262 } 263 }() 264 265 return e.JSON(200, ComAtprotoServerCreateAccountResponse{ 266 AccessJwt: sess.AccessToken, 267 RefreshJwt: sess.RefreshToken, 268 Handle: request.Handle, 269 Did: signupDid, 270 }) 271}