1package server
2
3import (
4 "errors"
5 "strings"
6
7 "github.com/bluesky-social/indigo/atproto/syntax"
8 "github.com/gorilla/sessions"
9 "github.com/haileyok/cocoon/internal/helpers"
10 "github.com/haileyok/cocoon/models"
11 "github.com/labstack/echo-contrib/session"
12 "github.com/labstack/echo/v4"
13 "golang.org/x/crypto/bcrypt"
14 "gorm.io/gorm"
15)
16
17type OauthSigninRequest struct {
18 Username string `form:"username"`
19 Password string `form:"password"`
20 QueryParams string `form:"query_params"`
21}
22
23func (s *Server) getSessionRepoOrErr(e echo.Context) (*models.RepoActor, *sessions.Session, error) {
24 sess, err := session.Get("session", e)
25 if err != nil {
26 return nil, nil, err
27 }
28
29 did, ok := sess.Values["did"].(string)
30 if !ok {
31 return nil, sess, errors.New("did was not set in session")
32 }
33
34 repo, err := s.getRepoActorByDid(did)
35 if err != nil {
36 return nil, sess, err
37 }
38
39 return repo, sess, nil
40}
41
42func getFlashesFromSession(e echo.Context, sess *sessions.Session) map[string]any {
43 defer sess.Save(e.Request(), e.Response())
44 return map[string]any{
45 "errors": sess.Flashes("error"),
46 "successes": sess.Flashes("success"),
47 }
48}
49
50func (s *Server) handleAccountSigninGet(e echo.Context) error {
51 _, sess, err := s.getSessionRepoOrErr(e)
52 if err == nil {
53 return e.Redirect(303, "/account")
54 }
55
56 return e.Render(200, "signin.html", map[string]any{
57 "flashes": getFlashesFromSession(e, sess),
58 "QueryParams": e.QueryParams().Encode(),
59 })
60}
61
62func (s *Server) handleAccountSigninPost(e echo.Context) error {
63 var req OauthSigninRequest
64 if err := e.Bind(&req); err != nil {
65 s.logger.Error("error binding sign in req", "error", err)
66 return helpers.ServerError(e, nil)
67 }
68
69 sess, _ := session.Get("session", e)
70
71 req.Username = strings.ToLower(req.Username)
72 var idtype string
73 if _, err := syntax.ParseDID(req.Username); err == nil {
74 idtype = "did"
75 } else if _, err := syntax.ParseHandle(req.Username); err == nil {
76 idtype = "handle"
77 } else {
78 idtype = "email"
79 }
80
81 // TODO: we should make this a helper since we do it for the base create_session as well
82 var repo models.RepoActor
83 var err error
84 switch idtype {
85 case "did":
86 err = s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Username).Scan(&repo).Error
87 case "handle":
88 err = s.db.Raw("SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Username).Scan(&repo).Error
89 case "email":
90 err = s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Username).Scan(&repo).Error
91 }
92 if err != nil {
93 if err == gorm.ErrRecordNotFound {
94 sess.AddFlash("Handle or password is incorrect", "error")
95 } else {
96 sess.AddFlash("Something went wrong!", "error")
97 }
98 sess.Save(e.Request(), e.Response())
99 return e.Redirect(303, "/account/signin")
100 }
101
102 if err := bcrypt.CompareHashAndPassword([]byte(repo.Password), []byte(req.Password)); err != nil {
103 if err != bcrypt.ErrMismatchedHashAndPassword {
104 sess.AddFlash("Handle or password is incorrect", "error")
105 } else {
106 sess.AddFlash("Something went wrong!", "error")
107 }
108 sess.Save(e.Request(), e.Response())
109 return e.Redirect(303, "/account/signin")
110 }
111
112 sess.Options = &sessions.Options{
113 Path: "/",
114 MaxAge: int(AccountSessionMaxAge.Seconds()),
115 HttpOnly: true,
116 }
117
118 sess.Values = map[any]any{}
119 sess.Values["did"] = repo.Repo.Did
120
121 if err := sess.Save(e.Request(), e.Response()); err != nil {
122 return err
123 }
124
125 if req.QueryParams != "" {
126 return e.Redirect(303, "/oauth/authorize?"+req.QueryParams)
127 } else {
128 return e.Redirect(303, "/account")
129 }
130}