1package server
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "strings"
8 "time"
9
10 "github.com/Azure/go-autorest/autorest/to"
11 "github.com/bluesky-social/indigo/api/atproto"
12 "github.com/bluesky-social/indigo/atproto/crypto"
13 "github.com/bluesky-social/indigo/atproto/syntax"
14 "github.com/bluesky-social/indigo/events"
15 "github.com/bluesky-social/indigo/repo"
16 "github.com/bluesky-social/indigo/util"
17 "github.com/haileyok/cocoon/internal/helpers"
18 "github.com/haileyok/cocoon/models"
19 "github.com/labstack/echo/v4"
20 "golang.org/x/crypto/bcrypt"
21 "gorm.io/gorm"
22)
23
24type ComAtprotoServerCreateAccountRequest struct {
25 Email string `json:"email" validate:"required,email"`
26 Handle string `json:"handle" validate:"required,atproto-handle"`
27 Did *string `json:"did" validate:"atproto-did"`
28 Password string `json:"password" validate:"required"`
29 InviteCode string `json:"inviteCode" validate:"required"`
30}
31
32type ComAtprotoServerCreateAccountResponse struct {
33 AccessJwt string `json:"accessJwt"`
34 RefreshJwt string `json:"refreshJwt"`
35 Handle string `json:"handle"`
36 Did string `json:"did"`
37}
38
39func (s *Server) handleCreateAccount(e echo.Context) error {
40 var request ComAtprotoServerCreateAccountRequest
41
42 var signupDid string
43 customDidHeader := e.Request().Header.Get("authorization")
44 if customDidHeader != "" {
45 pts := strings.Split(customDidHeader, " ")
46 if len(pts) != 2 {
47 return helpers.InputError(e, to.StringPtr("InvalidDid"))
48 }
49
50 _, err := syntax.ParseDID(pts[1])
51 if err != nil {
52 return helpers.InputError(e, to.StringPtr("InvalidDid"))
53 }
54
55 signupDid = pts[1]
56 }
57
58 if err := e.Bind(&request); err != nil {
59 s.logger.Error("error receiving request", "endpoint", "com.atproto.server.createAccount", "error", err)
60 return helpers.ServerError(e, nil)
61 }
62
63 request.Handle = strings.ToLower(request.Handle)
64
65 if err := e.Validate(request); err != nil {
66 s.logger.Error("error validating request", "endpoint", "com.atproto.server.createAccount", "error", err)
67
68 var verr ValidationError
69 if errors.As(err, &verr) {
70 if verr.Field == "Email" {
71 // TODO: what is this supposed to be? `InvalidEmail` isn't listed in doc
72 return helpers.InputError(e, to.StringPtr("InvalidEmail"))
73 }
74
75 if verr.Field == "Handle" {
76 return helpers.InputError(e, to.StringPtr("InvalidHandle"))
77 }
78
79 if verr.Field == "Password" {
80 return helpers.InputError(e, to.StringPtr("InvalidPassword"))
81 }
82
83 if verr.Field == "InviteCode" {
84 return helpers.InputError(e, to.StringPtr("InvalidInviteCode"))
85 }
86 }
87 }
88
89 // see if the handle is already taken
90 _, err := s.getActorByHandle(request.Handle)
91 if err != nil && err != gorm.ErrRecordNotFound {
92 s.logger.Error("error looking up handle in db", "endpoint", "com.atproto.server.createAccount", "error", err)
93 return helpers.ServerError(e, nil)
94 }
95 if err == nil {
96 return helpers.InputError(e, to.StringPtr("HandleNotAvailable"))
97 }
98
99 if did, err := s.passport.ResolveHandle(e.Request().Context(), request.Handle); err == nil && did != "" {
100 return helpers.InputError(e, to.StringPtr("HandleNotAvailable"))
101 }
102
103 var ic models.InviteCode
104 if err := s.db.Raw("SELECT * FROM invite_codes WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil {
105 if err == gorm.ErrRecordNotFound {
106 return helpers.InputError(e, to.StringPtr("InvalidInviteCode"))
107 }
108 s.logger.Error("error getting invite code from db", "error", err)
109 return helpers.ServerError(e, nil)
110 }
111
112 if ic.RemainingUseCount < 1 {
113 return helpers.InputError(e, to.StringPtr("InvalidInviteCode"))
114 }
115
116 // see if the email is already taken
117 _, err = s.getRepoByEmail(request.Email)
118 if err != nil && err != gorm.ErrRecordNotFound {
119 s.logger.Error("error looking up email in db", "endpoint", "com.atproto.server.createAccount", "error", err)
120 return helpers.ServerError(e, nil)
121 }
122 if err == nil {
123 return helpers.InputError(e, to.StringPtr("EmailNotAvailable"))
124 }
125
126 // TODO: unsupported domains
127
128 k, err := crypto.GeneratePrivateKeyK256()
129 if err != nil {
130 s.logger.Error("error creating signing key", "endpoint", "com.atproto.server.createAccount", "error", err)
131 return helpers.ServerError(e, nil)
132 }
133
134 if signupDid == "" {
135 did, op, err := s.plcClient.CreateDID(k, "", request.Handle)
136 if err != nil {
137 s.logger.Error("error creating operation", "endpoint", "com.atproto.server.createAccount", "error", err)
138 return helpers.ServerError(e, nil)
139 }
140
141 if err := s.plcClient.SendOperation(e.Request().Context(), did, op); err != nil {
142 s.logger.Error("error sending plc op", "endpoint", "com.atproto.server.createAccount", "error", err)
143 return helpers.ServerError(e, nil)
144 }
145 signupDid = did
146 }
147
148 hashed, err := bcrypt.GenerateFromPassword([]byte(request.Password), 10)
149 if err != nil {
150 s.logger.Error("error hashing password", "error", err)
151 return helpers.ServerError(e, nil)
152 }
153
154 urepo := models.Repo{
155 Did: signupDid,
156 CreatedAt: time.Now(),
157 Email: request.Email,
158 EmailVerificationCode: to.StringPtr(fmt.Sprintf("%s-%s", helpers.RandomVarchar(6), helpers.RandomVarchar(6))),
159 Password: string(hashed),
160 SigningKey: k.Bytes(),
161 }
162
163 actor := models.Actor{
164 Did: signupDid,
165 Handle: request.Handle,
166 }
167
168 if err := s.db.Create(&urepo, nil).Error; err != nil {
169 s.logger.Error("error inserting new repo", "error", err)
170 return helpers.ServerError(e, nil)
171 }
172
173 if err := s.db.Create(&actor, nil).Error; err != nil {
174 s.logger.Error("error inserting new actor", "error", err)
175 return helpers.ServerError(e, nil)
176 }
177
178 if customDidHeader == "" {
179 bs := s.getBlockstore(signupDid)
180 r := repo.NewRepo(context.TODO(), signupDid, bs)
181
182 root, rev, err := r.Commit(context.TODO(), urepo.SignFor)
183 if err != nil {
184 s.logger.Error("error committing", "error", err)
185 return helpers.ServerError(e, nil)
186 }
187
188 if err := s.UpdateRepo(context.TODO(), urepo.Did, root, rev); err != nil {
189 s.logger.Error("error updating repo after commit", "error", err)
190 return helpers.ServerError(e, nil)
191 }
192
193 s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{
194 RepoHandle: &atproto.SyncSubscribeRepos_Handle{
195 Did: urepo.Did,
196 Handle: request.Handle,
197 Seq: time.Now().UnixMicro(), // TODO: no
198 Time: time.Now().Format(util.ISO8601),
199 },
200 })
201
202 s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{
203 RepoIdentity: &atproto.SyncSubscribeRepos_Identity{
204 Did: urepo.Did,
205 Handle: to.StringPtr(request.Handle),
206 Seq: time.Now().UnixMicro(), // TODO: no
207 Time: time.Now().Format(util.ISO8601),
208 },
209 })
210 }
211
212 if err := s.db.Raw("UPDATE invite_codes SET remaining_use_count = remaining_use_count - 1 WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil {
213 s.logger.Error("error decrementing use count", "error", err)
214 return helpers.ServerError(e, nil)
215 }
216
217 sess, err := s.createSession(&urepo)
218 if err != nil {
219 s.logger.Error("error creating new session", "error", err)
220 return helpers.ServerError(e, nil)
221 }
222
223 go func() {
224 if err := s.sendEmailVerification(urepo.Email, actor.Handle, *urepo.EmailVerificationCode); err != nil {
225 s.logger.Error("error sending email verification email", "error", err)
226 }
227 if err := s.sendWelcomeMail(urepo.Email, actor.Handle); err != nil {
228 s.logger.Error("error sending welcome email", "error", err)
229 }
230 }()
231
232 return e.JSON(200, ComAtprotoServerCreateAccountResponse{
233 AccessJwt: sess.AccessToken,
234 RefreshJwt: sess.RefreshToken,
235 Handle: request.Handle,
236 Did: signupDid,
237 })
238}