1package labels
2
3import (
4 "context"
5 "database/sql"
6 "errors"
7 "fmt"
8 "log/slog"
9 "net/http"
10 "time"
11
12 comatproto "github.com/bluesky-social/indigo/api/atproto"
13 "github.com/bluesky-social/indigo/atproto/syntax"
14 lexutil "github.com/bluesky-social/indigo/lex/util"
15 "github.com/go-chi/chi/v5"
16
17 "tangled.org/core/api/tangled"
18 "tangled.org/core/appview/db"
19 "tangled.org/core/appview/middleware"
20 "tangled.org/core/appview/models"
21 "tangled.org/core/appview/oauth"
22 "tangled.org/core/appview/pages"
23 "tangled.org/core/appview/validator"
24 "tangled.org/core/appview/xrpcclient"
25 "tangled.org/core/log"
26 "tangled.org/core/tid"
27)
28
29type Labels struct {
30 oauth *oauth.OAuth
31 pages *pages.Pages
32 db *db.DB
33 logger *slog.Logger
34 validator *validator.Validator
35}
36
37func New(
38 oauth *oauth.OAuth,
39 pages *pages.Pages,
40 db *db.DB,
41 validator *validator.Validator,
42) *Labels {
43 logger := log.New("labels")
44
45 return &Labels{
46 oauth: oauth,
47 pages: pages,
48 db: db,
49 logger: logger,
50 validator: validator,
51 }
52}
53
54func (l *Labels) Router(mw *middleware.Middleware) http.Handler {
55 r := chi.NewRouter()
56
57 r.Use(middleware.AuthMiddleware(l.oauth))
58 r.Put("/perform", l.PerformLabelOp)
59
60 return r
61}
62
63// this is a tricky handler implementation:
64// - the user selects the new state of all the labels in the label panel and hits save
65// - this handler should calculate the diff in order to create the labelop record
66// - we need the diff in order to maintain a "history" of operations performed by users
67func (l *Labels) PerformLabelOp(w http.ResponseWriter, r *http.Request) {
68 user := l.oauth.GetUser(r)
69
70 noticeId := "add-label-error"
71
72 fail := func(msg string, err error) {
73 l.logger.Error("failed to add label", "err", err)
74 l.pages.Notice(w, noticeId, msg)
75 }
76
77 if err := r.ParseForm(); err != nil {
78 fail("Invalid form.", err)
79 return
80 }
81
82 did := user.Did
83 rkey := tid.TID()
84 performedAt := time.Now()
85 indexedAt := time.Now()
86 repoAt := r.Form.Get("repo")
87 subjectUri := r.Form.Get("subject")
88
89 // find all the labels that this repo subscribes to
90 repoLabels, err := db.GetRepoLabels(l.db, db.FilterEq("repo_at", repoAt))
91 if err != nil {
92 fail("Failed to get labels for this repository.", err)
93 return
94 }
95
96 var labelAts []string
97 for _, rl := range repoLabels {
98 labelAts = append(labelAts, rl.LabelAt.String())
99 }
100
101 actx, err := db.NewLabelApplicationCtx(l.db, db.FilterIn("at_uri", labelAts))
102 if err != nil {
103 fail("Invalid form data.", err)
104 return
105 }
106
107 // calculate the start state by applying already known labels
108 existingOps, err := db.GetLabelOps(l.db, db.FilterEq("subject", subjectUri))
109 if err != nil {
110 fail("Invalid form data.", err)
111 return
112 }
113
114 labelState := models.NewLabelState()
115 actx.ApplyLabelOps(labelState, existingOps)
116
117 var labelOps []models.LabelOp
118
119 // first delete all existing state
120 for key, vals := range labelState.Inner() {
121 for val := range vals {
122 labelOps = append(labelOps, models.LabelOp{
123 Did: did,
124 Rkey: rkey,
125 Subject: syntax.ATURI(subjectUri),
126 Operation: models.LabelOperationDel,
127 OperandKey: key,
128 OperandValue: val,
129 PerformedAt: performedAt,
130 IndexedAt: indexedAt,
131 })
132 }
133 }
134
135 // add all the new state the user specified
136 for key, vals := range r.Form {
137 if _, ok := actx.Defs[key]; !ok {
138 continue
139 }
140
141 for _, val := range vals {
142 labelOps = append(labelOps, models.LabelOp{
143 Did: did,
144 Rkey: rkey,
145 Subject: syntax.ATURI(subjectUri),
146 Operation: models.LabelOperationAdd,
147 OperandKey: key,
148 OperandValue: val,
149 PerformedAt: performedAt,
150 IndexedAt: indexedAt,
151 })
152 }
153 }
154
155 // reduce the opset
156 labelOps = models.ReduceLabelOps(labelOps)
157
158 for i := range labelOps {
159 def := actx.Defs[labelOps[i].OperandKey]
160 if err := l.validator.ValidateLabelOp(def, &labelOps[i]); err != nil {
161 fail(fmt.Sprintf("Invalid form data: %s", err), err)
162 return
163 }
164 }
165
166 // next, apply all ops introduced in this request and filter out ones that are no-ops
167 validLabelOps := labelOps[:0]
168 for _, op := range labelOps {
169 if err = actx.ApplyLabelOp(labelState, op); err != models.LabelNoOpError {
170 validLabelOps = append(validLabelOps, op)
171 }
172 }
173
174 // nothing to do
175 if len(validLabelOps) == 0 {
176 l.pages.HxRefresh(w)
177 return
178 }
179
180 // create an atproto record of valid ops
181 record := models.LabelOpsAsRecord(validLabelOps)
182
183 client, err := l.oauth.AuthorizedClient(r)
184 if err != nil {
185 fail("Failed to authorize user.", err)
186 return
187 }
188
189 resp, err := client.RepoPutRecord(r.Context(), &comatproto.RepoPutRecord_Input{
190 Collection: tangled.LabelOpNSID,
191 Repo: did,
192 Rkey: rkey,
193 Record: &lexutil.LexiconTypeDecoder{
194 Val: &record,
195 },
196 })
197 if err != nil {
198 fail("Failed to create record on PDS for user.", err)
199 return
200 }
201 atUri := resp.Uri
202
203 tx, err := l.db.BeginTx(r.Context(), nil)
204 if err != nil {
205 fail("Failed to update labels. Try again later.", err)
206 return
207 }
208
209 rollback := func() {
210 err1 := tx.Rollback()
211 err2 := rollbackRecord(context.Background(), atUri, client)
212
213 // ignore txn complete errors, this is okay
214 if errors.Is(err1, sql.ErrTxDone) {
215 err1 = nil
216 }
217
218 if errs := errors.Join(err1, err2); errs != nil {
219 return
220 }
221 }
222 defer rollback()
223
224 for _, o := range validLabelOps {
225 if _, err := db.AddLabelOp(l.db, &o); err != nil {
226 fail("Failed to update labels. Try again later.", err)
227 return
228 }
229 }
230
231 err = tx.Commit()
232 if err != nil {
233 return
234 }
235
236 // clear aturi when everything is successful
237 atUri = ""
238
239 l.pages.HxRefresh(w)
240}
241
242// this is used to rollback changes made to the PDS
243//
244// it is a no-op if the provided ATURI is empty
245func rollbackRecord(ctx context.Context, aturi string, xrpcc *xrpcclient.Client) error {
246 if aturi == "" {
247 return nil
248 }
249
250 parsed := syntax.ATURI(aturi)
251
252 collection := parsed.Collection().String()
253 repo := parsed.Authority().String()
254 rkey := parsed.RecordKey().String()
255
256 _, err := xrpcc.RepoDeleteRecord(ctx, &comatproto.RepoDeleteRecord_Input{
257 Collection: collection,
258 Repo: repo,
259 Rkey: rkey,
260 })
261 return err
262}