1package server
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "strings"
8 "time"
9
10 "github.com/Azure/go-autorest/autorest/to"
11 "github.com/bluesky-social/indigo/api/atproto"
12 "github.com/bluesky-social/indigo/atproto/atcrypto"
13 "github.com/bluesky-social/indigo/events"
14 "github.com/bluesky-social/indigo/repo"
15 "github.com/bluesky-social/indigo/util"
16 "github.com/haileyok/cocoon/internal/helpers"
17 "github.com/haileyok/cocoon/models"
18 "github.com/labstack/echo/v4"
19 "golang.org/x/crypto/bcrypt"
20 "gorm.io/gorm"
21)
22
23type ComAtprotoServerCreateAccountRequest struct {
24 Email string `json:"email" validate:"required,email"`
25 Handle string `json:"handle" validate:"required,atproto-handle"`
26 Did *string `json:"did" validate:"atproto-did"`
27 Password string `json:"password" validate:"required"`
28 InviteCode string `json:"inviteCode" validate:"omitempty"`
29}
30
31type ComAtprotoServerCreateAccountResponse struct {
32 AccessJwt string `json:"accessJwt"`
33 RefreshJwt string `json:"refreshJwt"`
34 Handle string `json:"handle"`
35 Did string `json:"did"`
36}
37
38func (s *Server) handleCreateAccount(e echo.Context) error {
39 var request ComAtprotoServerCreateAccountRequest
40
41 if err := e.Bind(&request); err != nil {
42 s.logger.Error("error receiving request", "endpoint", "com.atproto.server.createAccount", "error", err)
43 return helpers.ServerError(e, nil)
44 }
45
46 request.Handle = strings.ToLower(request.Handle)
47
48 if err := e.Validate(request); err != nil {
49 s.logger.Error("error validating request", "endpoint", "com.atproto.server.createAccount", "error", err)
50
51 var verr ValidationError
52 if errors.As(err, &verr) {
53 if verr.Field == "Email" {
54 // TODO: what is this supposed to be? `InvalidEmail` isn't listed in doc
55 return helpers.InputError(e, to.StringPtr("InvalidEmail"))
56 }
57
58 if verr.Field == "Handle" {
59 return helpers.InputError(e, to.StringPtr("InvalidHandle"))
60 }
61
62 if verr.Field == "Password" {
63 return helpers.InputError(e, to.StringPtr("InvalidPassword"))
64 }
65
66 if verr.Field == "InviteCode" {
67 return helpers.InputError(e, to.StringPtr("InvalidInviteCode"))
68 }
69 }
70 }
71
72 var signupDid string
73 if request.Did != nil {
74 signupDid = *request.Did;
75
76 token := strings.TrimSpace(strings.Replace(e.Request().Header.Get("authorization"), "Bearer ", "", 1))
77 if token == "" {
78 return helpers.UnauthorizedError(e, to.StringPtr("must authenticate to use an existing did"))
79 }
80 authDid, err := s.validateServiceAuth(e.Request().Context(), token, "com.atproto.server.createAccount")
81
82 if err != nil {
83 s.logger.Warn("error validating authorization token", "endpoint", "com.atproto.server.createAccount", "error", err)
84 return helpers.UnauthorizedError(e, to.StringPtr("invalid authorization token"))
85 }
86
87 if authDid != signupDid {
88 return helpers.ForbiddenError(e, to.StringPtr("auth did did not match signup did"))
89 }
90 }
91
92 // see if the handle is already taken
93 actor, err := s.getActorByHandle(request.Handle)
94 if err != nil && err != gorm.ErrRecordNotFound {
95 s.logger.Error("error looking up handle in db", "endpoint", "com.atproto.server.createAccount", "error", err)
96 return helpers.ServerError(e, nil)
97 }
98 if err == nil && actor.Did != signupDid {
99 return helpers.InputError(e, to.StringPtr("HandleNotAvailable"))
100 }
101
102 if did, err := s.passport.ResolveHandle(e.Request().Context(), request.Handle); err == nil && did != signupDid {
103 return helpers.InputError(e, to.StringPtr("HandleNotAvailable"))
104 }
105
106 var ic models.InviteCode
107 if s.config.RequireInvite {
108 if strings.TrimSpace(request.InviteCode) == "" {
109 return helpers.InputError(e, to.StringPtr("InvalidInviteCode"))
110 }
111
112 if err := s.db.Raw("SELECT * FROM invite_codes WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil {
113 if err == gorm.ErrRecordNotFound {
114 return helpers.InputError(e, to.StringPtr("InvalidInviteCode"))
115 }
116 s.logger.Error("error getting invite code from db", "error", err)
117 return helpers.ServerError(e, nil)
118 }
119
120 if ic.RemainingUseCount < 1 {
121 return helpers.InputError(e, to.StringPtr("InvalidInviteCode"))
122 }
123 }
124
125 // see if the email is already taken
126 existingRepo, err := s.getRepoByEmail(request.Email)
127 if err != nil && err != gorm.ErrRecordNotFound {
128 s.logger.Error("error looking up email in db", "endpoint", "com.atproto.server.createAccount", "error", err)
129 return helpers.ServerError(e, nil)
130 }
131 if err == nil && existingRepo.Did != signupDid {
132 return helpers.InputError(e, to.StringPtr("EmailNotAvailable"))
133 }
134
135 // TODO: unsupported domains
136
137 var k *atcrypto.PrivateKeyK256
138
139 if signupDid != "" {
140 reservedKey, err := s.getReservedKey(signupDid)
141 if err != nil {
142 s.logger.Error("error looking up reserved key", "error", err)
143 }
144 if reservedKey != nil {
145 k, err = atcrypto.ParsePrivateBytesK256(reservedKey.PrivateKey)
146 if err != nil {
147 s.logger.Error("error parsing reserved key", "error", err)
148 k = nil
149 } else {
150 defer func() {
151 if delErr := s.deleteReservedKey(reservedKey.KeyDid, reservedKey.Did); delErr != nil {
152 s.logger.Error("error deleting reserved key", "error", delErr)
153 }
154 }()
155 }
156 }
157 }
158
159 if k == nil {
160 k, err = atcrypto.GeneratePrivateKeyK256()
161 if err != nil {
162 s.logger.Error("error creating signing key", "endpoint", "com.atproto.server.createAccount", "error", err)
163 return helpers.ServerError(e, nil)
164 }
165 }
166
167 if signupDid == "" {
168 did, op, err := s.plcClient.CreateDID(k, "", request.Handle)
169 if err != nil {
170 s.logger.Error("error creating operation", "endpoint", "com.atproto.server.createAccount", "error", err)
171 return helpers.ServerError(e, nil)
172 }
173
174 if err := s.plcClient.SendOperation(e.Request().Context(), did, op); err != nil {
175 s.logger.Error("error sending plc op", "endpoint", "com.atproto.server.createAccount", "error", err)
176 return helpers.ServerError(e, nil)
177 }
178 signupDid = did
179 }
180
181 hashed, err := bcrypt.GenerateFromPassword([]byte(request.Password), 10)
182 if err != nil {
183 s.logger.Error("error hashing password", "error", err)
184 return helpers.ServerError(e, nil)
185 }
186
187 urepo := models.Repo{
188 Did: signupDid,
189 CreatedAt: time.Now(),
190 Email: request.Email,
191 EmailVerificationCode: to.StringPtr(fmt.Sprintf("%s-%s", helpers.RandomVarchar(6), helpers.RandomVarchar(6))),
192 Password: string(hashed),
193 SigningKey: k.Bytes(),
194 }
195
196 if actor == nil {
197 actor = &models.Actor{
198 Did: signupDid,
199 Handle: request.Handle,
200 }
201
202 if err := s.db.Create(&urepo, nil).Error; err != nil {
203 s.logger.Error("error inserting new repo", "error", err)
204 return helpers.ServerError(e, nil)
205 }
206
207 if err := s.db.Create(&actor, nil).Error; err != nil {
208 s.logger.Error("error inserting new actor", "error", err)
209 return helpers.ServerError(e, nil)
210 }
211 } else {
212 if err := s.db.Save(&actor, nil).Error; err != nil {
213 s.logger.Error("error inserting new actor", "error", err)
214 return helpers.ServerError(e, nil)
215 }
216 }
217
218 if request.Did == nil || *request.Did == "" {
219 bs := s.getBlockstore(signupDid)
220 r := repo.NewRepo(context.TODO(), signupDid, bs)
221
222 root, rev, err := r.Commit(context.TODO(), urepo.SignFor)
223 if err != nil {
224 s.logger.Error("error committing", "error", err)
225 return helpers.ServerError(e, nil)
226 }
227
228 if err := s.UpdateRepo(context.TODO(), urepo.Did, root, rev); err != nil {
229 s.logger.Error("error updating repo after commit", "error", err)
230 return helpers.ServerError(e, nil)
231 }
232
233 s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{
234 RepoIdentity: &atproto.SyncSubscribeRepos_Identity{
235 Did: urepo.Did,
236 Handle: to.StringPtr(request.Handle),
237 Seq: time.Now().UnixMicro(), // TODO: no
238 Time: time.Now().Format(util.ISO8601),
239 },
240 })
241 }
242
243 if s.config.RequireInvite {
244 if err := s.db.Raw("UPDATE invite_codes SET remaining_use_count = remaining_use_count - 1 WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil {
245 s.logger.Error("error decrementing use count", "error", err)
246 return helpers.ServerError(e, nil)
247 }
248 }
249
250 sess, err := s.createSession(&urepo)
251 if err != nil {
252 s.logger.Error("error creating new session", "error", err)
253 return helpers.ServerError(e, nil)
254 }
255
256 go func() {
257 if err := s.sendEmailVerification(urepo.Email, actor.Handle, *urepo.EmailVerificationCode); err != nil {
258 s.logger.Error("error sending email verification email", "error", err)
259 }
260 if err := s.sendWelcomeMail(urepo.Email, actor.Handle); err != nil {
261 s.logger.Error("error sending welcome email", "error", err)
262 }
263 }()
264
265 return e.JSON(200, ComAtprotoServerCreateAccountResponse{
266 AccessJwt: sess.AccessToken,
267 RefreshJwt: sess.RefreshToken,
268 Handle: request.Handle,
269 Did: signupDid,
270 })
271}