A URL shortener service that uses ATProto to allow self hosting and ensuring the user owns their data
1package database
2
3import (
4 "context"
5 "database/sql"
6 "encoding/json"
7 "fmt"
8 "log/slog"
9
10 "github.com/bluesky-social/indigo/atproto/auth/oauth"
11 "github.com/bluesky-social/indigo/atproto/syntax"
12)
13
14func createOauthRequestsTable(db *sql.DB) error {
15 createOauthRequestsTableSQL := `CREATE TABLE IF NOT EXISTS oauthrequests (
16 "id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
17 "state" TEXT,
18 "authServerURL" TEXT,
19 "accountDID" TEXT,
20 "scope" TEXT,
21 "requestURI" TEXT,
22 "authServerTokenEndpoint" TEXT,
23 "pkceVerifier" TEXT,
24 "dpopAuthserverNonce" TEXT,
25 "dpopPrivateKeyMultibase" TEXT,
26 UNIQUE(state)
27 );`
28
29 slog.Info("Create oauthrequests table...")
30 statement, err := db.Prepare(createOauthRequestsTableSQL)
31 if err != nil {
32 return fmt.Errorf("prepare DB statement to create oauthrequests table: %w", err)
33 }
34 _, err = statement.Exec()
35 if err != nil {
36 return fmt.Errorf("exec sql statement to create oauthrequests table: %w", err)
37 }
38 slog.Info("oauthrequests table created")
39
40 return nil
41}
42
43func (d *DB) SaveAuthRequestInfo(ctx context.Context, info oauth.AuthRequestData) error {
44 did := ""
45 if info.AccountDID != nil {
46 did = info.AccountDID.String()
47 }
48
49 scopes, err := json.Marshal(info.Scopes)
50 if err != nil {
51 return fmt.Errorf("encoding scopes to JSON: %w", err)
52 }
53
54 sql := `INSERT INTO oauthrequests (state, authServerURL, accountDID, scope, requestURI, authServerTokenEndpoint, pkceVerifier, dpopAuthserverNonce, dpopPrivateKeyMultibase) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(state) DO NOTHING;`
55 _, err = d.db.Exec(sql, info.State, info.AuthServerURL, did, string(scopes), info.RequestURI, info.AuthServerTokenEndpoint, info.PKCEVerifier, info.DPoPAuthServerNonce, info.DPoPPrivateKeyMultibase)
56 if err != nil {
57 slog.Error("saving auth request info", "error", err)
58 return fmt.Errorf("exec insert oauth request: %w", err)
59 }
60
61 return nil
62}
63
64func (d *DB) GetAuthRequestInfo(ctx context.Context, state string) (*oauth.AuthRequestData, error) {
65 var oauthRequest oauth.AuthRequestData
66 sql := "SELECT state, authServerURL, accountDID, scope, requestURI, authServerTokenEndpoint, pkceVerifier, dpopAuthserverNonce, dpopPrivateKeyMultibase FROM oauthrequests where state = ?;"
67 rows, err := d.db.Query(sql, state)
68 if err != nil {
69 return nil, fmt.Errorf("run query to get oauth request: %w", err)
70 }
71 defer rows.Close()
72
73 var did string
74 var scopesStr string
75
76 for rows.Next() {
77 if err := rows.Scan(&oauthRequest.State, &oauthRequest.AuthServerURL, &did, &scopesStr, &oauthRequest.RequestURI, &oauthRequest.AuthServerTokenEndpoint, &oauthRequest.PKCEVerifier, &oauthRequest.DPoPAuthServerNonce, &oauthRequest.DPoPPrivateKeyMultibase); err != nil {
78 return nil, fmt.Errorf("scan row: %w", err)
79 }
80
81 if did != "" {
82 parsedDID, err := syntax.ParseDID(did)
83 if err != nil {
84 return nil, fmt.Errorf("invalid DID stored in record: %w", err)
85 }
86 oauthRequest.AccountDID = &parsedDID
87 }
88
89 if scopesStr != "" {
90 var scopes []string
91 err = json.Unmarshal([]byte(scopesStr), &scopes)
92 if err != nil {
93 return nil, fmt.Errorf("decode scopes in record: %w", err)
94 }
95 oauthRequest.Scopes = scopes
96 }
97
98 return &oauthRequest, nil
99 }
100 return nil, fmt.Errorf("not found")
101}
102
103func (d *DB) DeleteAuthRequestInfo(ctx context.Context, state string) error {
104 sql := "DELETE FROM oauthrequests WHERE state = ?;"
105 _, err := d.db.Exec(sql, state)
106 if err != nil {
107 return fmt.Errorf("exec delete oauth request: %w", err)
108 }
109 return nil
110}