forked from tangled.org/core
this repo has no description
at knot-xrpc 4.1 kB view raw
1package session 2 3import ( 4 "context" 5 "encoding/json" 6 "fmt" 7 "time" 8 9 "tangled.sh/tangled.sh/core/appview/cache" 10) 11 12type OAuthSession struct { 13 Handle string 14 Did string 15 PdsUrl string 16 AccessJwt string 17 RefreshJwt string 18 AuthServerIss string 19 DpopPdsNonce string 20 DpopAuthserverNonce string 21 DpopPrivateJwk string 22 Expiry string 23} 24 25type OAuthRequest struct { 26 AuthserverIss string 27 Handle string 28 State string 29 Did string 30 PdsUrl string 31 PkceVerifier string 32 DpopAuthserverNonce string 33 DpopPrivateJwk string 34} 35 36type SessionStore struct { 37 cache *cache.Cache 38} 39 40const ( 41 stateKey = "oauthstate:%s" 42 requestKey = "oauthrequest:%s" 43 sessionKey = "oauthsession:%s" 44) 45 46func New(cache *cache.Cache) *SessionStore { 47 return &SessionStore{cache: cache} 48} 49 50func (s *SessionStore) SaveSession(ctx context.Context, session OAuthSession) error { 51 key := fmt.Sprintf(sessionKey, session.Did) 52 data, err := json.Marshal(session) 53 if err != nil { 54 return err 55 } 56 57 // set with ttl (7 days) 58 ttl := 7 * 24 * time.Hour 59 60 return s.cache.Set(ctx, key, data, ttl).Err() 61} 62 63// SaveRequest stores the OAuth request to be later fetched in the callback. Since 64// the fetching happens by comparing the state we get in the callback params, we 65// store an additional state->did mapping which then lets us fetch the whole OAuth request. 66func (s *SessionStore) SaveRequest(ctx context.Context, request OAuthRequest) error { 67 key := fmt.Sprintf(requestKey, request.Did) 68 data, err := json.Marshal(request) 69 if err != nil { 70 return err 71 } 72 73 // oauth flow must complete within 30 minutes 74 err = s.cache.Set(ctx, key, data, 30*time.Minute).Err() 75 if err != nil { 76 return fmt.Errorf("error saving request: %w", err) 77 } 78 79 stateKey := fmt.Sprintf(stateKey, request.State) 80 err = s.cache.Set(ctx, stateKey, request.Did, 30*time.Minute).Err() 81 if err != nil { 82 return fmt.Errorf("error saving state->did mapping: %w", err) 83 } 84 85 return nil 86} 87 88func (s *SessionStore) GetSession(ctx context.Context, did string) (*OAuthSession, error) { 89 key := fmt.Sprintf(sessionKey, did) 90 val, err := s.cache.Get(ctx, key).Result() 91 if err != nil { 92 return nil, err 93 } 94 95 var session OAuthSession 96 err = json.Unmarshal([]byte(val), &session) 97 if err != nil { 98 return nil, err 99 } 100 return &session, nil 101} 102 103func (s *SessionStore) GetRequestByState(ctx context.Context, state string) (*OAuthRequest, error) { 104 didKey, err := s.getRequestKeyFromState(ctx, state) 105 if err != nil { 106 return nil, err 107 } 108 109 val, err := s.cache.Get(ctx, didKey).Result() 110 if err != nil { 111 return nil, err 112 } 113 114 var request OAuthRequest 115 err = json.Unmarshal([]byte(val), &request) 116 if err != nil { 117 return nil, err 118 } 119 120 return &request, nil 121} 122 123func (s *SessionStore) DeleteSession(ctx context.Context, did string) error { 124 key := fmt.Sprintf(sessionKey, did) 125 return s.cache.Del(ctx, key).Err() 126} 127 128func (s *SessionStore) DeleteRequestByState(ctx context.Context, state string) error { 129 didKey, err := s.getRequestKeyFromState(ctx, state) 130 if err != nil { 131 return err 132 } 133 134 err = s.cache.Del(ctx, fmt.Sprintf(stateKey, state)).Err() 135 if err != nil { 136 return err 137 } 138 139 return s.cache.Del(ctx, didKey).Err() 140} 141 142func (s *SessionStore) RefreshSession(ctx context.Context, did, access, refresh, expiry string) error { 143 session, err := s.GetSession(ctx, did) 144 if err != nil { 145 return err 146 } 147 session.AccessJwt = access 148 session.RefreshJwt = refresh 149 session.Expiry = expiry 150 return s.SaveSession(ctx, *session) 151} 152 153func (s *SessionStore) UpdateNonce(ctx context.Context, did, nonce string) error { 154 session, err := s.GetSession(ctx, did) 155 if err != nil { 156 return err 157 } 158 session.DpopAuthserverNonce = nonce 159 return s.SaveSession(ctx, *session) 160} 161 162func (s *SessionStore) getRequestKeyFromState(ctx context.Context, state string) (string, error) { 163 key := fmt.Sprintf(stateKey, state) 164 did, err := s.cache.Get(ctx, key).Result() 165 if err != nil { 166 return "", err 167 } 168 169 didKey := fmt.Sprintf(requestKey, did) 170 return didKey, nil 171}