1package session
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "time"
8
9 "tangled.sh/tangled.sh/core/appview/cache"
10)
11
12type OAuthSession struct {
13 Handle string
14 Did string
15 PdsUrl string
16 AccessJwt string
17 RefreshJwt string
18 AuthServerIss string
19 DpopPdsNonce string
20 DpopAuthserverNonce string
21 DpopPrivateJwk string
22 Expiry string
23}
24
25type OAuthRequest struct {
26 AuthserverIss string
27 Handle string
28 State string
29 Did string
30 PdsUrl string
31 PkceVerifier string
32 DpopAuthserverNonce string
33 DpopPrivateJwk string
34}
35
36type SessionStore struct {
37 cache *cache.Cache
38}
39
40const (
41 stateKey = "oauthstate:%s"
42 requestKey = "oauthrequest:%s"
43 sessionKey = "oauthsession:%s"
44)
45
46func New(cache *cache.Cache) *SessionStore {
47 return &SessionStore{cache: cache}
48}
49
50func (s *SessionStore) SaveSession(ctx context.Context, session OAuthSession) error {
51 key := fmt.Sprintf(sessionKey, session.Did)
52 data, err := json.Marshal(session)
53 if err != nil {
54 return err
55 }
56
57 // set with ttl (7 days)
58 ttl := 7 * 24 * time.Hour
59
60 return s.cache.Set(ctx, key, data, ttl).Err()
61}
62
63// SaveRequest stores the OAuth request to be later fetched in the callback. Since
64// the fetching happens by comparing the state we get in the callback params, we
65// store an additional state->did mapping which then lets us fetch the whole OAuth request.
66func (s *SessionStore) SaveRequest(ctx context.Context, request OAuthRequest) error {
67 key := fmt.Sprintf(requestKey, request.Did)
68 data, err := json.Marshal(request)
69 if err != nil {
70 return err
71 }
72
73 // oauth flow must complete within 30 minutes
74 err = s.cache.Set(ctx, key, data, 30*time.Minute).Err()
75 if err != nil {
76 return fmt.Errorf("error saving request: %w", err)
77 }
78
79 stateKey := fmt.Sprintf(stateKey, request.State)
80 err = s.cache.Set(ctx, stateKey, request.Did, 30*time.Minute).Err()
81 if err != nil {
82 return fmt.Errorf("error saving state->did mapping: %w", err)
83 }
84
85 return nil
86}
87
88func (s *SessionStore) GetSession(ctx context.Context, did string) (*OAuthSession, error) {
89 key := fmt.Sprintf(sessionKey, did)
90 val, err := s.cache.Get(ctx, key).Result()
91 if err != nil {
92 return nil, err
93 }
94
95 var session OAuthSession
96 err = json.Unmarshal([]byte(val), &session)
97 if err != nil {
98 return nil, err
99 }
100 return &session, nil
101}
102
103func (s *SessionStore) GetRequestByState(ctx context.Context, state string) (*OAuthRequest, error) {
104 didKey, err := s.getRequestKeyFromState(ctx, state)
105 if err != nil {
106 return nil, err
107 }
108
109 val, err := s.cache.Get(ctx, didKey).Result()
110 if err != nil {
111 return nil, err
112 }
113
114 var request OAuthRequest
115 err = json.Unmarshal([]byte(val), &request)
116 if err != nil {
117 return nil, err
118 }
119
120 return &request, nil
121}
122
123func (s *SessionStore) DeleteSession(ctx context.Context, did string) error {
124 key := fmt.Sprintf(sessionKey, did)
125 return s.cache.Del(ctx, key).Err()
126}
127
128func (s *SessionStore) DeleteRequestByState(ctx context.Context, state string) error {
129 didKey, err := s.getRequestKeyFromState(ctx, state)
130 if err != nil {
131 return err
132 }
133
134 err = s.cache.Del(ctx, fmt.Sprintf(stateKey, state)).Err()
135 if err != nil {
136 return err
137 }
138
139 return s.cache.Del(ctx, didKey).Err()
140}
141
142func (s *SessionStore) RefreshSession(ctx context.Context, did, access, refresh, expiry string) error {
143 session, err := s.GetSession(ctx, did)
144 if err != nil {
145 return err
146 }
147 session.AccessJwt = access
148 session.RefreshJwt = refresh
149 session.Expiry = expiry
150 return s.SaveSession(ctx, *session)
151}
152
153func (s *SessionStore) UpdateNonce(ctx context.Context, did, nonce string) error {
154 session, err := s.GetSession(ctx, did)
155 if err != nil {
156 return err
157 }
158 session.DpopAuthserverNonce = nonce
159 return s.SaveSession(ctx, *session)
160}
161
162func (s *SessionStore) getRequestKeyFromState(ctx context.Context, state string) (string, error) {
163 key := fmt.Sprintf(stateKey, state)
164 did, err := s.cache.Get(ctx, key).Result()
165 if err != nil {
166 return "", err
167 }
168
169 didKey := fmt.Sprintf(requestKey, did)
170 return didKey, nil
171}