1package server
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "fmt"
8 "io"
9 "time"
10
11 "github.com/Azure/go-autorest/autorest/to"
12 "github.com/bluesky-social/indigo/api/atproto"
13 "github.com/bluesky-social/indigo/atproto/atdata"
14 "github.com/bluesky-social/indigo/atproto/syntax"
15 "github.com/bluesky-social/indigo/carstore"
16 "github.com/bluesky-social/indigo/events"
17 lexutil "github.com/bluesky-social/indigo/lex/util"
18 "github.com/bluesky-social/indigo/repo"
19 "github.com/haileyok/cocoon/internal/db"
20 "github.com/haileyok/cocoon/models"
21 "github.com/haileyok/cocoon/recording_blockstore"
22 blocks "github.com/ipfs/go-block-format"
23 "github.com/ipfs/go-cid"
24 cbor "github.com/ipfs/go-ipld-cbor"
25 "github.com/ipld/go-car"
26 "gorm.io/gorm/clause"
27)
28
29type RepoMan struct {
30 db *db.DB
31 s *Server
32 clock *syntax.TIDClock
33}
34
35func NewRepoMan(s *Server) *RepoMan {
36 clock := syntax.NewTIDClock(0)
37
38 return &RepoMan{
39 s: s,
40 db: s.db,
41 clock: &clock,
42 }
43}
44
45type OpType string
46
47var (
48 OpTypeCreate = OpType("com.atproto.repo.applyWrites#create")
49 OpTypeUpdate = OpType("com.atproto.repo.applyWrites#update")
50 OpTypeDelete = OpType("com.atproto.repo.applyWrites#delete")
51)
52
53func (ot OpType) String() string {
54 return string(ot)
55}
56
57type Op struct {
58 Type OpType `json:"$type"`
59 Collection string `json:"collection"`
60 Rkey *string `json:"rkey,omitempty"`
61 Validate *bool `json:"validate,omitempty"`
62 SwapRecord *string `json:"swapRecord,omitempty"`
63 Record *MarshalableMap `json:"record,omitempty"`
64}
65
66type MarshalableMap map[string]any
67
68type FirehoseOp struct {
69 Cid cid.Cid
70 Path string
71 Action string
72}
73
74func (mm *MarshalableMap) MarshalCBOR(w io.Writer) error {
75 data, err := atdata.MarshalCBOR(*mm)
76 if err != nil {
77 return err
78 }
79
80 w.Write(data)
81
82 return nil
83}
84
85type ApplyWriteResult struct {
86 Type *string `json:"$type,omitempty"`
87 Uri *string `json:"uri,omitempty"`
88 Cid *string `json:"cid,omitempty"`
89 Commit *RepoCommit `json:"commit,omitempty"`
90 ValidationStatus *string `json:"validationStatus,omitempty"`
91}
92
93type RepoCommit struct {
94 Cid string `json:"cid"`
95 Rev string `json:"rev"`
96}
97
98// TODO make use of swap commit
99func (rm *RepoMan) applyWrites(ctx context.Context, urepo models.Repo, writes []Op, swapCommit *string) ([]ApplyWriteResult, error) {
100 rootcid, err := cid.Cast(urepo.Root)
101 if err != nil {
102 return nil, err
103 }
104
105 dbs := rm.s.getBlockstore(urepo.Did)
106 bs := recording_blockstore.New(dbs)
107 r, err := repo.OpenRepo(ctx, bs, rootcid)
108
109 entries := make([]models.Record, 0, len(writes))
110 results := make([]ApplyWriteResult, 0, len(writes))
111
112 for i, op := range writes {
113 if op.Type != OpTypeCreate && op.Rkey == nil {
114 return nil, fmt.Errorf("invalid rkey")
115 } else if op.Type == OpTypeCreate && op.Rkey != nil {
116 _, _, err := r.GetRecord(ctx, op.Collection+"/"+*op.Rkey)
117 if err == nil {
118 op.Type = OpTypeUpdate
119 }
120 } else if op.Rkey == nil {
121 op.Rkey = to.StringPtr(rm.clock.Next().String())
122 writes[i].Rkey = op.Rkey
123 }
124
125 _, err := syntax.ParseRecordKey(*op.Rkey)
126 if err != nil {
127 return nil, err
128 }
129
130 switch op.Type {
131 case OpTypeCreate:
132 j, err := json.Marshal(*op.Record)
133 if err != nil {
134 return nil, err
135 }
136 out, err := atdata.UnmarshalJSON(j)
137 if err != nil {
138 return nil, err
139 }
140 mm := MarshalableMap(out)
141
142 // HACK: if a record doesn't contain a $type, we can manually set it here based on the op's collection
143 if mm["$type"] == "" {
144 mm["$type"] = op.Collection
145 }
146
147 nc, err := r.PutRecord(ctx, op.Collection+"/"+*op.Rkey, &mm)
148 if err != nil {
149 return nil, err
150 }
151 d, err := atdata.MarshalCBOR(mm)
152 if err != nil {
153 return nil, err
154 }
155 entries = append(entries, models.Record{
156 Did: urepo.Did,
157 CreatedAt: rm.clock.Next().String(),
158 Nsid: op.Collection,
159 Rkey: *op.Rkey,
160 Cid: nc.String(),
161 Value: d,
162 })
163 results = append(results, ApplyWriteResult{
164 Type: to.StringPtr(OpTypeCreate.String()),
165 Uri: to.StringPtr("at://" + urepo.Did + "/" + op.Collection + "/" + *op.Rkey),
166 Cid: to.StringPtr(nc.String()),
167 ValidationStatus: to.StringPtr("valid"), // TODO: obviously this might not be true atm lol
168 })
169 case OpTypeDelete:
170 var old models.Record
171 if err := rm.db.Raw("SELECT value FROM records WHERE did = ? AND nsid = ? AND rkey = ?", nil, urepo.Did, op.Collection, op.Rkey).Scan(&old).Error; err != nil {
172 return nil, err
173 }
174 entries = append(entries, models.Record{
175 Did: urepo.Did,
176 Nsid: op.Collection,
177 Rkey: *op.Rkey,
178 Value: old.Value,
179 })
180 err := r.DeleteRecord(ctx, op.Collection+"/"+*op.Rkey)
181 if err != nil {
182 return nil, err
183 }
184 results = append(results, ApplyWriteResult{
185 Type: to.StringPtr(OpTypeDelete.String()),
186 })
187 case OpTypeUpdate:
188 j, err := json.Marshal(*op.Record)
189 if err != nil {
190 return nil, err
191 }
192 out, err := atdata.UnmarshalJSON(j)
193 if err != nil {
194 return nil, err
195 }
196 mm := MarshalableMap(out)
197 nc, err := r.UpdateRecord(ctx, op.Collection+"/"+*op.Rkey, &mm)
198 if err != nil {
199 return nil, err
200 }
201 d, err := atdata.MarshalCBOR(mm)
202 if err != nil {
203 return nil, err
204 }
205 entries = append(entries, models.Record{
206 Did: urepo.Did,
207 CreatedAt: rm.clock.Next().String(),
208 Nsid: op.Collection,
209 Rkey: *op.Rkey,
210 Cid: nc.String(),
211 Value: d,
212 })
213 results = append(results, ApplyWriteResult{
214 Type: to.StringPtr(OpTypeUpdate.String()),
215 Uri: to.StringPtr("at://" + urepo.Did + "/" + op.Collection + "/" + *op.Rkey),
216 Cid: to.StringPtr(nc.String()),
217 ValidationStatus: to.StringPtr("valid"), // TODO: obviously this might not be true atm lol
218 })
219 }
220 }
221
222 newroot, rev, err := r.Commit(ctx, urepo.SignFor)
223 if err != nil {
224 return nil, err
225 }
226
227 buf := new(bytes.Buffer)
228
229 hb, err := cbor.DumpObject(&car.CarHeader{
230 Roots: []cid.Cid{newroot},
231 Version: 1,
232 })
233 if err != nil {
234 return nil, err
235 }
236
237 if _, err := carstore.LdWrite(buf, hb); err != nil {
238 return nil, err
239 }
240
241 diffops, err := r.DiffSince(ctx, rootcid)
242 if err != nil {
243 return nil, err
244 }
245
246 ops := make([]*atproto.SyncSubscribeRepos_RepoOp, 0, len(diffops))
247
248 for _, op := range diffops {
249 var c cid.Cid
250 switch op.Op {
251 case "add", "mut":
252 kind := "create"
253 if op.Op == "mut" {
254 kind = "update"
255 }
256
257 c = op.NewCid
258 ll := lexutil.LexLink(op.NewCid)
259 ops = append(ops, &atproto.SyncSubscribeRepos_RepoOp{
260 Action: kind,
261 Path: op.Rpath,
262 Cid: &ll,
263 })
264
265 case "del":
266 c = op.OldCid
267 ll := lexutil.LexLink(op.OldCid)
268 ops = append(ops, &atproto.SyncSubscribeRepos_RepoOp{
269 Action: "delete",
270 Path: op.Rpath,
271 Cid: nil,
272 Prev: &ll,
273 })
274 }
275
276 blk, err := dbs.Get(ctx, c)
277 if err != nil {
278 return nil, err
279 }
280
281 if _, err := carstore.LdWrite(buf, blk.Cid().Bytes(), blk.RawData()); err != nil {
282 return nil, err
283 }
284 }
285
286 for _, op := range bs.GetWriteLog() {
287 if _, err := carstore.LdWrite(buf, op.Cid().Bytes(), op.RawData()); err != nil {
288 return nil, err
289 }
290 }
291
292 var blobs []lexutil.LexLink
293 for _, entry := range entries {
294 var cids []cid.Cid
295 if entry.Cid != "" {
296 if err := rm.s.db.Create(&entry, []clause.Expression{clause.OnConflict{
297 Columns: []clause.Column{{Name: "did"}, {Name: "nsid"}, {Name: "rkey"}},
298 UpdateAll: true,
299 }}).Error; err != nil {
300 return nil, err
301 }
302
303 cids, err = rm.incrementBlobRefs(urepo, entry.Value)
304 if err != nil {
305 return nil, err
306 }
307 } else {
308 if err := rm.s.db.Delete(&entry, nil).Error; err != nil {
309 return nil, err
310 }
311 cids, err = rm.decrementBlobRefs(urepo, entry.Value)
312 if err != nil {
313 return nil, err
314 }
315 }
316
317 for _, c := range cids {
318 blobs = append(blobs, lexutil.LexLink(c))
319 }
320 }
321
322 rm.s.evtman.AddEvent(ctx, &events.XRPCStreamEvent{
323 RepoCommit: &atproto.SyncSubscribeRepos_Commit{
324 Repo: urepo.Did,
325 Blocks: buf.Bytes(),
326 Blobs: blobs,
327 Rev: rev,
328 Since: &urepo.Rev,
329 Commit: lexutil.LexLink(newroot),
330 Time: time.Now().Format(time.RFC3339Nano),
331 Ops: ops,
332 TooBig: false,
333 },
334 })
335
336 if err := rm.s.UpdateRepo(ctx, urepo.Did, newroot, rev); err != nil {
337 return nil, err
338 }
339
340 for i := range results {
341 results[i].Type = to.StringPtr(*results[i].Type + "Result")
342 results[i].Commit = &RepoCommit{
343 Cid: newroot.String(),
344 Rev: rev,
345 }
346 }
347
348 return results, nil
349}
350
351func (rm *RepoMan) getRecordProof(ctx context.Context, urepo models.Repo, collection, rkey string) (cid.Cid, []blocks.Block, error) {
352 c, err := cid.Cast(urepo.Root)
353 if err != nil {
354 return cid.Undef, nil, err
355 }
356
357 dbs := rm.s.getBlockstore(urepo.Did)
358 bs := recording_blockstore.New(dbs)
359
360 r, err := repo.OpenRepo(ctx, bs, c)
361 if err != nil {
362 return cid.Undef, nil, err
363 }
364
365 _, _, err = r.GetRecordBytes(ctx, collection+"/"+rkey)
366 if err != nil {
367 return cid.Undef, nil, err
368 }
369
370 return c, bs.GetReadLog(), nil
371}
372
373func (rm *RepoMan) incrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) {
374 cids, err := getBlobCidsFromCbor(cbor)
375 if err != nil {
376 return nil, err
377 }
378
379 for _, c := range cids {
380 if err := rm.db.Exec("UPDATE blobs SET ref_count = ref_count + 1 WHERE did = ? AND cid = ?", nil, urepo.Did, c.Bytes()).Error; err != nil {
381 return nil, err
382 }
383 }
384
385 return cids, nil
386}
387
388func (rm *RepoMan) decrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) {
389 cids, err := getBlobCidsFromCbor(cbor)
390 if err != nil {
391 return nil, err
392 }
393
394 for _, c := range cids {
395 var res struct {
396 ID uint
397 Count int
398 }
399 if err := rm.db.Raw("UPDATE blobs SET ref_count = ref_count - 1 WHERE did = ? AND cid = ? RETURNING id, ref_count", nil, urepo.Did, c.Bytes()).Scan(&res).Error; err != nil {
400 return nil, err
401 }
402
403 if res.Count == 0 {
404 if err := rm.db.Exec("DELETE FROM blobs WHERE id = ?", nil, res.ID).Error; err != nil {
405 return nil, err
406 }
407 if err := rm.db.Exec("DELETE FROM blob_parts WHERE blob_id = ?", nil, res.ID).Error; err != nil {
408 return nil, err
409 }
410 }
411 }
412
413 return cids, nil
414}
415
416// to be honest, we could just store both the cbor and non-cbor in []entries above to avoid an additional
417// unmarshal here. this will work for now though
418func getBlobCidsFromCbor(cbor []byte) ([]cid.Cid, error) {
419 var cids []cid.Cid
420
421 decoded, err := atdata.UnmarshalCBOR(cbor)
422 if err != nil {
423 return nil, fmt.Errorf("error unmarshaling cbor: %w", err)
424 }
425
426 var deepiter func(any) error
427 deepiter = func(item any) error {
428 switch val := item.(type) {
429 case map[string]any:
430 if val["$type"] == "blob" {
431 if ref, ok := val["ref"].(string); ok {
432 c, err := cid.Parse(ref)
433 if err != nil {
434 return err
435 }
436 cids = append(cids, c)
437 }
438 for _, v := range val {
439 return deepiter(v)
440 }
441 }
442 case []any:
443 for _, v := range val {
444 deepiter(v)
445 }
446 }
447
448 return nil
449 }
450
451 if err := deepiter(decoded); err != nil {
452 return nil, err
453 }
454
455 return cids, nil
456}