1package session
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "time"
8
9 "tangled.org/core/appview/cache"
10)
11
12type OAuthSession struct {
13 Handle string
14 Did string
15 PdsUrl string
16 AccessJwt string
17 RefreshJwt string
18 AuthServerIss string
19 DpopPdsNonce string
20 DpopAuthserverNonce string
21 DpopPrivateJwk string
22 Expiry string
23}
24
25type OAuthRequest struct {
26 AuthserverIss string
27 Handle string
28 State string
29 Did string
30 PdsUrl string
31 PkceVerifier string
32 DpopAuthserverNonce string
33 DpopPrivateJwk string
34 ReturnUrl string
35}
36
37type SessionStore struct {
38 cache *cache.Cache
39}
40
41const (
42 stateKey = "oauthstate:%s"
43 requestKey = "oauthrequest:%s"
44 sessionKey = "oauthsession:%s"
45)
46
47func New(cache *cache.Cache) *SessionStore {
48 return &SessionStore{cache: cache}
49}
50
51func (s *SessionStore) SaveSession(ctx context.Context, session OAuthSession) error {
52 key := fmt.Sprintf(sessionKey, session.Did)
53 data, err := json.Marshal(session)
54 if err != nil {
55 return err
56 }
57
58 // set with ttl (7 days)
59 ttl := 7 * 24 * time.Hour
60
61 return s.cache.Set(ctx, key, data, ttl).Err()
62}
63
64// SaveRequest stores the OAuth request to be later fetched in the callback. Since
65// the fetching happens by comparing the state we get in the callback params, we
66// store an additional state->did mapping which then lets us fetch the whole OAuth request.
67func (s *SessionStore) SaveRequest(ctx context.Context, request OAuthRequest) error {
68 key := fmt.Sprintf(requestKey, request.Did)
69 data, err := json.Marshal(request)
70 if err != nil {
71 return err
72 }
73
74 // oauth flow must complete within 30 minutes
75 err = s.cache.Set(ctx, key, data, 30*time.Minute).Err()
76 if err != nil {
77 return fmt.Errorf("error saving request: %w", err)
78 }
79
80 stateKey := fmt.Sprintf(stateKey, request.State)
81 err = s.cache.Set(ctx, stateKey, request.Did, 30*time.Minute).Err()
82 if err != nil {
83 return fmt.Errorf("error saving state->did mapping: %w", err)
84 }
85
86 return nil
87}
88
89func (s *SessionStore) GetSession(ctx context.Context, did string) (*OAuthSession, error) {
90 key := fmt.Sprintf(sessionKey, did)
91 val, err := s.cache.Get(ctx, key).Result()
92 if err != nil {
93 return nil, err
94 }
95
96 var session OAuthSession
97 err = json.Unmarshal([]byte(val), &session)
98 if err != nil {
99 return nil, err
100 }
101 return &session, nil
102}
103
104func (s *SessionStore) GetRequestByState(ctx context.Context, state string) (*OAuthRequest, error) {
105 didKey, err := s.getRequestKeyFromState(ctx, state)
106 if err != nil {
107 return nil, err
108 }
109
110 val, err := s.cache.Get(ctx, didKey).Result()
111 if err != nil {
112 return nil, err
113 }
114
115 var request OAuthRequest
116 err = json.Unmarshal([]byte(val), &request)
117 if err != nil {
118 return nil, err
119 }
120
121 return &request, nil
122}
123
124func (s *SessionStore) DeleteSession(ctx context.Context, did string) error {
125 key := fmt.Sprintf(sessionKey, did)
126 return s.cache.Del(ctx, key).Err()
127}
128
129func (s *SessionStore) DeleteRequestByState(ctx context.Context, state string) error {
130 didKey, err := s.getRequestKeyFromState(ctx, state)
131 if err != nil {
132 return err
133 }
134
135 err = s.cache.Del(ctx, fmt.Sprintf(stateKey, state)).Err()
136 if err != nil {
137 return err
138 }
139
140 return s.cache.Del(ctx, didKey).Err()
141}
142
143func (s *SessionStore) RefreshSession(ctx context.Context, did, access, refresh, expiry string) error {
144 session, err := s.GetSession(ctx, did)
145 if err != nil {
146 return err
147 }
148 session.AccessJwt = access
149 session.RefreshJwt = refresh
150 session.Expiry = expiry
151 return s.SaveSession(ctx, *session)
152}
153
154func (s *SessionStore) UpdateNonce(ctx context.Context, did, nonce string) error {
155 session, err := s.GetSession(ctx, did)
156 if err != nil {
157 return err
158 }
159 session.DpopAuthserverNonce = nonce
160 return s.SaveSession(ctx, *session)
161}
162
163func (s *SessionStore) getRequestKeyFromState(ctx context.Context, state string) (string, error) {
164 key := fmt.Sprintf(stateKey, state)
165 did, err := s.cache.Get(ctx, key).Result()
166 if err != nil {
167 return "", err
168 }
169
170 didKey := fmt.Sprintf(requestKey, did)
171 return didKey, nil
172}