1package labels
2
3import (
4 "context"
5 "database/sql"
6 "errors"
7 "fmt"
8 "log/slog"
9 "net/http"
10 "time"
11
12 comatproto "github.com/bluesky-social/indigo/api/atproto"
13 "github.com/bluesky-social/indigo/atproto/syntax"
14 lexutil "github.com/bluesky-social/indigo/lex/util"
15 "github.com/go-chi/chi/v5"
16
17 "tangled.sh/tangled.sh/core/api/tangled"
18 "tangled.sh/tangled.sh/core/appview/db"
19 "tangled.sh/tangled.sh/core/appview/middleware"
20 "tangled.sh/tangled.sh/core/appview/oauth"
21 "tangled.sh/tangled.sh/core/appview/pages"
22 "tangled.sh/tangled.sh/core/appview/validator"
23 "tangled.sh/tangled.sh/core/appview/xrpcclient"
24 "tangled.sh/tangled.sh/core/log"
25 "tangled.sh/tangled.sh/core/tid"
26)
27
28type Labels struct {
29 oauth *oauth.OAuth
30 pages *pages.Pages
31 db *db.DB
32 logger *slog.Logger
33 validator *validator.Validator
34}
35
36func New(
37 oauth *oauth.OAuth,
38 pages *pages.Pages,
39 db *db.DB,
40 validator *validator.Validator,
41) *Labels {
42 logger := log.New("labels")
43
44 return &Labels{
45 oauth: oauth,
46 pages: pages,
47 db: db,
48 logger: logger,
49 validator: validator,
50 }
51}
52
53func (l *Labels) Router(mw *middleware.Middleware) http.Handler {
54 r := chi.NewRouter()
55
56 r.With(middleware.AuthMiddleware(l.oauth)).Put("/perform", l.PerformLabelOp)
57
58 return r
59}
60
61func (l *Labels) PerformLabelOp(w http.ResponseWriter, r *http.Request) {
62 user := l.oauth.GetUser(r)
63
64 if err := r.ParseForm(); err != nil {
65 l.logger.Error("failed to parse form data", "error", err)
66 http.Error(w, "Invalid form data", http.StatusBadRequest)
67 return
68 }
69
70 did := user.Did
71 rkey := tid.TID()
72 performedAt := time.Now()
73 indexedAt := time.Now()
74 repoAt := r.Form.Get("repo")
75 subjectUri := r.Form.Get("subject")
76 keys := r.Form["operand-key"]
77 vals := r.Form["operand-val"]
78
79 var labelOps []db.LabelOp
80 for i := range len(keys) {
81 op := r.FormValue(fmt.Sprintf("op-%d", i))
82 if op == "" {
83 op = string(db.LabelOperationDel)
84 }
85 key := keys[i]
86 val := vals[i]
87
88 labelOps = append(labelOps, db.LabelOp{
89 Did: did,
90 Rkey: rkey,
91 Subject: syntax.ATURI(subjectUri),
92 Operation: db.LabelOperation(op),
93 OperandKey: key,
94 OperandValue: val,
95 PerformedAt: performedAt,
96 IndexedAt: indexedAt,
97 })
98 }
99
100 // find all the labels that this repo subscribes to
101 repoLabels, err := db.GetRepoLabels(l.db, db.FilterEq("repo_at", repoAt))
102 if err != nil {
103 http.Error(w, "Invalid form data", http.StatusBadRequest)
104 return
105 }
106
107 var labelAts []string
108 for _, rl := range repoLabels {
109 labelAts = append(labelAts, rl.LabelAt.String())
110 }
111
112 actx, err := db.NewLabelApplicationCtx(l.db, db.FilterIn("at_uri", labelAts))
113 if err != nil {
114 http.Error(w, "Invalid form data", http.StatusBadRequest)
115 return
116 }
117
118 for i := range labelOps {
119 def := actx.Defs[labelOps[i].OperandKey]
120 if err := l.validator.ValidateLabelOp(def, &labelOps[i]); err != nil {
121 l.logger.Error("form failed to validate", "err", err)
122 http.Error(w, "Invalid form data", http.StatusBadRequest)
123 return
124 }
125
126 l.logger.Info("value changed to: ", "v", labelOps[i].OperandValue)
127 }
128
129 // calculate the start state by applying already known labels
130 existingOps, err := db.GetLabelOps(l.db, db.FilterEq("subject", subjectUri))
131 if err != nil {
132 http.Error(w, "Invalid form data", http.StatusBadRequest)
133 return
134 }
135
136 labelState := db.NewLabelState()
137 actx.ApplyLabelOps(labelState, existingOps)
138
139 l.logger.Info("state", "state", labelState)
140
141 // next, apply all ops introduced in this request and filter out ones that are no-ops
142 validLabelOps := labelOps[:0]
143 for _, op := range labelOps {
144 if err = actx.ApplyLabelOp(labelState, op); err != db.LabelNoOpError {
145 validLabelOps = append(validLabelOps, op)
146 }
147 }
148
149 // nothing to do
150 if len(validLabelOps) == 0 {
151 l.pages.HxRefresh(w)
152 return
153 }
154
155 // create an atproto record of valid ops
156 record := db.LabelOpsAsRecord(validLabelOps)
157
158 client, err := l.oauth.AuthorizedClient(r)
159 if err != nil {
160 l.logger.Error("failed to create client", "error", err)
161 http.Error(w, "Invalid form data", http.StatusBadRequest)
162 return
163 }
164
165 resp, err := client.RepoPutRecord(r.Context(), &comatproto.RepoPutRecord_Input{
166 Collection: tangled.LabelOpNSID,
167 Repo: did,
168 Rkey: rkey,
169 Record: &lexutil.LexiconTypeDecoder{
170 Val: &record,
171 },
172 })
173 if err != nil {
174 l.logger.Error("failed to write to PDS", "error", err)
175 http.Error(w, "failed to write to PDS", http.StatusInternalServerError)
176 return
177 }
178 atUri := resp.Uri
179
180 tx, err := l.db.BeginTx(r.Context(), nil)
181 if err != nil {
182 l.logger.Error("failed to start tx", "error", err)
183 return
184 }
185
186 rollback := func() {
187 err1 := tx.Rollback()
188 err2 := rollbackRecord(context.Background(), atUri, client)
189
190 // ignore txn complete errors, this is okay
191 if errors.Is(err1, sql.ErrTxDone) {
192 err1 = nil
193 }
194
195 if errs := errors.Join(err1, err2); errs != nil {
196 return
197 }
198 }
199 defer rollback()
200
201 for _, o := range validLabelOps {
202 if _, err := db.AddLabelOp(l.db, &o); err != nil {
203 l.logger.Error("failed to add op", "err", err)
204 return
205 }
206
207 l.logger.Info("performed label op", "did", o.Did, "rkey", o.Rkey, "kind", o.Operation, "subjcet", o.Subject, "key", o.OperandKey)
208 }
209
210 err = tx.Commit()
211 if err != nil {
212 return
213 }
214
215 // clear aturi when everything is successful
216 atUri = ""
217
218 l.pages.HxRefresh(w)
219}
220
221// this is used to rollback changes made to the PDS
222//
223// it is a no-op if the provided ATURI is empty
224func rollbackRecord(ctx context.Context, aturi string, xrpcc *xrpcclient.Client) error {
225 if aturi == "" {
226 return nil
227 }
228
229 parsed := syntax.ATURI(aturi)
230
231 collection := parsed.Collection().String()
232 repo := parsed.Authority().String()
233 rkey := parsed.RecordKey().String()
234
235 _, err := xrpcc.RepoDeleteRecord(ctx, &comatproto.RepoDeleteRecord_Input{
236 Collection: collection,
237 Repo: repo,
238 Rkey: rkey,
239 })
240 return err
241}