1package server
2
3import (
4 "errors"
5 "strings"
6
7 "github.com/Azure/go-autorest/autorest/to"
8 "github.com/bluesky-social/indigo/atproto/syntax"
9 "github.com/haileyok/cocoon/internal/helpers"
10 "github.com/haileyok/cocoon/models"
11 "github.com/labstack/echo/v4"
12 "golang.org/x/crypto/bcrypt"
13 "gorm.io/gorm"
14)
15
16type ComAtprotoServerCreateSessionRequest struct {
17 Identifier string `json:"identifier" validate:"required"`
18 Password string `json:"password" validate:"required"`
19 AuthFactorToken *string `json:"authFactorToken,omitempty"`
20}
21
22type ComAtprotoServerCreateSessionResponse struct {
23 AccessJwt string `json:"accessJwt"`
24 RefreshJwt string `json:"refreshJwt"`
25 Handle string `json:"handle"`
26 Did string `json:"did"`
27 Email string `json:"email"`
28 EmailConfirmed bool `json:"emailConfirmed"`
29 EmailAuthFactor bool `json:"emailAuthFactor"`
30 Active bool `json:"active"`
31 Status *string `json:"status,omitempty"`
32}
33
34func (s *Server) handleCreateSession(e echo.Context) error {
35 var req ComAtprotoServerCreateSessionRequest
36 if err := e.Bind(&req); err != nil {
37 s.logger.Error("error binding request", "endpoint", "com.atproto.server.serverCreateSession", "error", err)
38 return helpers.ServerError(e, nil)
39 }
40
41 if err := e.Validate(req); err != nil {
42 var verr ValidationError
43 if errors.As(err, &verr) {
44 if verr.Field == "Identifier" {
45 return helpers.InputError(e, to.StringPtr("InvalidRequest"))
46 }
47
48 if verr.Field == "Password" {
49 return helpers.InputError(e, to.StringPtr("InvalidRequest"))
50 }
51 }
52 }
53
54 req.Identifier = strings.ToLower(req.Identifier)
55 var idtype string
56 if _, err := syntax.ParseDID(req.Identifier); err == nil {
57 idtype = "did"
58 } else if _, err := syntax.ParseHandle(req.Identifier); err == nil {
59 idtype = "handle"
60 } else {
61 idtype = "email"
62 }
63
64 var repo models.RepoActor
65 var err error
66 switch idtype {
67 case "did":
68 err = s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Identifier).Scan(&repo).Error
69 case "handle":
70 err = s.db.Raw("SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Identifier).Scan(&repo).Error
71 case "email":
72 err = s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Identifier).Scan(&repo).Error
73 }
74
75 if err != nil {
76 if err == gorm.ErrRecordNotFound {
77 return helpers.InputError(e, to.StringPtr("InvalidRequest"))
78 }
79
80 s.logger.Error("erorr looking up repo", "endpoint", "com.atproto.server.createSession", "error", err)
81 return helpers.ServerError(e, nil)
82 }
83
84 if err := bcrypt.CompareHashAndPassword([]byte(repo.Password), []byte(req.Password)); err != nil {
85 if err != bcrypt.ErrMismatchedHashAndPassword {
86 s.logger.Error("erorr comparing hash and password", "error", err)
87 }
88 return helpers.InputError(e, to.StringPtr("InvalidRequest"))
89 }
90
91 sess, err := s.createSession(&repo.Repo)
92 if err != nil {
93 s.logger.Error("error creating session", "error", err)
94 return helpers.ServerError(e, nil)
95 }
96
97 return e.JSON(200, ComAtprotoServerCreateSessionResponse{
98 AccessJwt: sess.AccessToken,
99 RefreshJwt: sess.RefreshToken,
100 Handle: repo.Handle,
101 Did: repo.Repo.Did,
102 Email: repo.Email,
103 EmailConfirmed: repo.EmailConfirmedAt != nil,
104 EmailAuthFactor: false,
105 Active: repo.Active(),
106 Status: repo.Status(),
107 })
108}