1package middleware
2
3import (
4 "context"
5 "log"
6 "net/http"
7 "strconv"
8 "time"
9
10 comatproto "github.com/bluesky-social/indigo/api/atproto"
11 "github.com/bluesky-social/indigo/xrpc"
12 "tangled.sh/tangled.sh/core/appview"
13 "tangled.sh/tangled.sh/core/appview/auth"
14 "tangled.sh/tangled.sh/core/appview/pagination"
15)
16
17type Middleware func(http.Handler) http.Handler
18
19func AuthMiddleware(a *auth.Auth) Middleware {
20 return func(next http.Handler) http.Handler {
21 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
22 redirectFunc := func(w http.ResponseWriter, r *http.Request) {
23 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
24 }
25 if r.Header.Get("HX-Request") == "true" {
26 redirectFunc = func(w http.ResponseWriter, _ *http.Request) {
27 w.Header().Set("HX-Redirect", "/login")
28 w.WriteHeader(http.StatusOK)
29 }
30 }
31
32 session, err := a.GetSession(r)
33 if session.IsNew || err != nil {
34 log.Printf("not logged in, redirecting")
35 redirectFunc(w, r)
36 return
37 }
38
39 authorized, ok := session.Values[appview.SessionAuthenticated].(bool)
40 if !ok || !authorized {
41 log.Printf("not logged in, redirecting")
42 redirectFunc(w, r)
43 return
44 }
45
46 // refresh if nearing expiry
47 // TODO: dedup with /login
48 expiryStr := session.Values[appview.SessionExpiry].(string)
49 expiry, err := time.Parse(time.RFC3339, expiryStr)
50 if err != nil {
51 log.Println("invalid expiry time", err)
52 redirectFunc(w, r)
53 return
54 }
55 pdsUrl, ok1 := session.Values[appview.SessionPds].(string)
56 did, ok2 := session.Values[appview.SessionDid].(string)
57 refreshJwt, ok3 := session.Values[appview.SessionRefreshJwt].(string)
58
59 if !ok1 || !ok2 || !ok3 {
60 log.Println("invalid expiry time", err)
61 redirectFunc(w, r)
62 return
63 }
64
65 if time.Now().After(expiry) {
66 log.Println("token expired, refreshing ...")
67
68 client := xrpc.Client{
69 Host: pdsUrl,
70 Auth: &xrpc.AuthInfo{
71 Did: did,
72 AccessJwt: refreshJwt,
73 RefreshJwt: refreshJwt,
74 },
75 }
76 atSession, err := comatproto.ServerRefreshSession(r.Context(), &client)
77 if err != nil {
78 log.Println("failed to refresh session", err)
79 redirectFunc(w, r)
80 return
81 }
82
83 sessionish := auth.RefreshSessionWrapper{atSession}
84
85 err = a.StoreSession(r, w, &sessionish, pdsUrl)
86 if err != nil {
87 log.Printf("failed to store session for did: %s\n: %s", atSession.Did, err)
88 return
89 }
90
91 log.Println("successfully refreshed token")
92 }
93
94 next.ServeHTTP(w, r)
95 })
96 }
97}
98
99func Paginate(next http.Handler) http.Handler {
100 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
101 page := pagination.FirstPage()
102
103 offsetVal := r.URL.Query().Get("offset")
104 if offsetVal != "" {
105 offset, err := strconv.Atoi(offsetVal)
106 if err != nil {
107 log.Println("invalid offset")
108 } else {
109 page.Offset = offset
110 }
111 }
112
113 limitVal := r.URL.Query().Get("limit")
114 if limitVal != "" {
115 limit, err := strconv.Atoi(limitVal)
116 if err != nil {
117 log.Println("invalid limit")
118 } else {
119 page.Limit = limit
120 }
121 }
122
123 ctx := context.WithValue(r.Context(), "page", page)
124 next.ServeHTTP(w, r.WithContext(ctx))
125 })
126}