1package middleware
2
3import (
4 "log"
5 "net/http"
6 "time"
7
8 comatproto "github.com/bluesky-social/indigo/api/atproto"
9 "github.com/bluesky-social/indigo/xrpc"
10 "tangled.sh/tangled.sh/core/appview"
11 "tangled.sh/tangled.sh/core/appview/auth"
12)
13
14type Middleware func(http.Handler) http.Handler
15
16func AuthMiddleware(a *auth.Auth) Middleware {
17 return func(next http.Handler) http.Handler {
18 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
19 redirectFunc := func(w http.ResponseWriter, r *http.Request) {
20 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
21 }
22 if r.Header.Get("HX-Request") == "true" {
23 redirectFunc = func(w http.ResponseWriter, _ *http.Request) {
24 w.Header().Set("HX-Redirect", "/login")
25 w.WriteHeader(http.StatusOK)
26 }
27 }
28
29 session, err := a.GetSession(r)
30 if session.IsNew || err != nil {
31 log.Printf("not logged in, redirecting")
32 redirectFunc(w, r)
33 return
34 }
35
36 authorized, ok := session.Values[appview.SessionAuthenticated].(bool)
37 if !ok || !authorized {
38 log.Printf("not logged in, redirecting")
39 redirectFunc(w, r)
40 return
41 }
42
43 // refresh if nearing expiry
44 // TODO: dedup with /login
45 expiryStr := session.Values[appview.SessionExpiry].(string)
46 expiry, err := time.Parse(time.RFC3339, expiryStr)
47 if err != nil {
48 log.Println("invalid expiry time", err)
49 redirectFunc(w, r)
50 return
51 }
52 pdsUrl, ok1 := session.Values[appview.SessionPds].(string)
53 did, ok2 := session.Values[appview.SessionDid].(string)
54 refreshJwt, ok3 := session.Values[appview.SessionRefreshJwt].(string)
55
56 if !ok1 || !ok2 || !ok3 {
57 log.Println("invalid expiry time", err)
58 redirectFunc(w, r)
59 return
60 }
61
62 if time.Now().After(expiry) {
63 log.Println("token expired, refreshing ...")
64
65 client := xrpc.Client{
66 Host: pdsUrl,
67 Auth: &xrpc.AuthInfo{
68 Did: did,
69 AccessJwt: refreshJwt,
70 RefreshJwt: refreshJwt,
71 },
72 }
73 atSession, err := comatproto.ServerRefreshSession(r.Context(), &client)
74 if err != nil {
75 log.Println("failed to refresh session", err)
76 redirectFunc(w, r)
77 return
78 }
79
80 sessionish := auth.RefreshSessionWrapper{atSession}
81
82 err = a.StoreSession(r, w, &sessionish, pdsUrl)
83 if err != nil {
84 log.Printf("failed to store session for did: %s\n: %s", atSession.Did, err)
85 return
86 }
87
88 log.Println("successfully refreshed token")
89 }
90
91 next.ServeHTTP(w, r)
92 })
93 }
94}