1package dpop
2
3import (
4 "crypto/hmac"
5 "crypto/sha256"
6 "encoding/base64"
7 "encoding/binary"
8 "sync"
9 "time"
10
11 "github.com/haileyok/cocoon/internal/helpers"
12 "github.com/haileyok/cocoon/oauth/constants"
13)
14
15type Nonce struct {
16 rotationInterval time.Duration
17 secret []byte
18
19 mu sync.RWMutex
20
21 counter int64
22 prev string
23 curr string
24 next string
25}
26
27type NonceArgs struct {
28 RotationInterval time.Duration
29 Secret []byte
30 OnSecretCreated func([]byte)
31}
32
33func NewNonce(args NonceArgs) *Nonce {
34 if args.RotationInterval == 0 {
35 args.RotationInterval = constants.NonceMaxRotationInterval / 3
36 }
37
38 if args.RotationInterval > constants.NonceMaxRotationInterval {
39 args.RotationInterval = constants.NonceMaxRotationInterval
40 }
41
42 if args.Secret == nil {
43 args.Secret = helpers.RandomBytes(constants.NonceSecretByteLength)
44 args.OnSecretCreated(args.Secret)
45 }
46
47 n := &Nonce{
48 rotationInterval: args.RotationInterval,
49 secret: args.Secret,
50 mu: sync.RWMutex{},
51 }
52
53 n.counter = n.currentCounter()
54 n.prev = n.compute(n.counter - 1)
55 n.curr = n.compute(n.counter)
56 n.next = n.compute(n.counter + 1)
57
58 return n
59}
60
61func (n *Nonce) currentCounter() int64 {
62 return time.Now().UnixNano() / int64(n.rotationInterval)
63}
64
65func (n *Nonce) compute(counter int64) string {
66 h := hmac.New(sha256.New, n.secret)
67 counterBytes := make([]byte, 8)
68 binary.BigEndian.PutUint64(counterBytes, uint64(counter))
69 h.Write(counterBytes)
70 return base64.RawURLEncoding.EncodeToString(h.Sum(nil))
71}
72
73func (n *Nonce) rotate() {
74 counter := n.currentCounter()
75 diff := counter - n.counter
76
77 switch diff {
78 case 0:
79 // counter == n.counter, do nothing
80 case 1:
81 n.prev = n.curr
82 n.curr = n.next
83 n.next = n.compute(counter + 1)
84 case 2:
85 n.prev = n.next
86 n.curr = n.compute(counter)
87 n.next = n.compute(counter + 1)
88 default:
89 n.prev = n.compute(counter - 1)
90 n.curr = n.compute(counter)
91 n.next = n.compute(counter + 1)
92 }
93
94 n.counter = counter
95}
96
97func (n *Nonce) NextNonce() string {
98 n.mu.Lock()
99 defer n.mu.Unlock()
100 n.rotate()
101 return n.next
102}
103
104func (n *Nonce) Check(nonce string) bool {
105 n.mu.Lock()
106 defer n.mu.Unlock()
107 n.rotate()
108 return nonce == n.prev || nonce == n.curr || nonce == n.next
109}