An atproto PDS written in Go
at v0.5.1 2.2 kB view raw
1package dpop 2 3import ( 4 "crypto/hmac" 5 "crypto/sha256" 6 "encoding/base64" 7 "encoding/binary" 8 "sync" 9 "time" 10 11 "github.com/haileyok/cocoon/internal/helpers" 12 "github.com/haileyok/cocoon/oauth/constants" 13) 14 15type Nonce struct { 16 rotationInterval time.Duration 17 secret []byte 18 19 mu sync.RWMutex 20 21 counter int64 22 prev string 23 curr string 24 next string 25} 26 27type NonceArgs struct { 28 RotationInterval time.Duration 29 Secret []byte 30 OnSecretCreated func([]byte) 31} 32 33func NewNonce(args NonceArgs) *Nonce { 34 if args.RotationInterval == 0 { 35 args.RotationInterval = constants.NonceMaxRotationInterval / 3 36 } 37 38 if args.RotationInterval > constants.NonceMaxRotationInterval { 39 args.RotationInterval = constants.NonceMaxRotationInterval 40 } 41 42 if args.Secret == nil { 43 args.Secret = helpers.RandomBytes(constants.NonceSecretByteLength) 44 args.OnSecretCreated(args.Secret) 45 } 46 47 n := &Nonce{ 48 rotationInterval: args.RotationInterval, 49 secret: args.Secret, 50 mu: sync.RWMutex{}, 51 } 52 53 n.counter = n.currentCounter() 54 n.prev = n.compute(n.counter - 1) 55 n.curr = n.compute(n.counter) 56 n.next = n.compute(n.counter + 1) 57 58 return n 59} 60 61func (n *Nonce) currentCounter() int64 { 62 return time.Now().UnixNano() / int64(n.rotationInterval) 63} 64 65func (n *Nonce) compute(counter int64) string { 66 h := hmac.New(sha256.New, n.secret) 67 counterBytes := make([]byte, 8) 68 binary.BigEndian.PutUint64(counterBytes, uint64(counter)) 69 h.Write(counterBytes) 70 return base64.RawURLEncoding.EncodeToString(h.Sum(nil)) 71} 72 73func (n *Nonce) rotate() { 74 counter := n.currentCounter() 75 diff := counter - n.counter 76 77 switch diff { 78 case 0: 79 // counter == n.counter, do nothing 80 case 1: 81 n.prev = n.curr 82 n.curr = n.next 83 n.next = n.compute(counter + 1) 84 case 2: 85 n.prev = n.next 86 n.curr = n.compute(counter) 87 n.next = n.compute(counter + 1) 88 default: 89 n.prev = n.compute(counter - 1) 90 n.curr = n.compute(counter) 91 n.next = n.compute(counter + 1) 92 } 93 94 n.counter = counter 95} 96 97func (n *Nonce) NextNonce() string { 98 n.mu.Lock() 99 defer n.mu.Unlock() 100 n.rotate() 101 return n.next 102} 103 104func (n *Nonce) Check(nonce string) bool { 105 n.mu.RLock() 106 defer n.mu.RUnlock() 107 return nonce == n.prev || nonce == n.curr || nonce == n.next 108}