1package server
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "fmt"
8 "io"
9 "time"
10
11 "github.com/Azure/go-autorest/autorest/to"
12 "github.com/bluesky-social/indigo/api/atproto"
13 "github.com/bluesky-social/indigo/atproto/atdata"
14 "github.com/bluesky-social/indigo/atproto/syntax"
15 "github.com/bluesky-social/indigo/carstore"
16 "github.com/bluesky-social/indigo/events"
17 lexutil "github.com/bluesky-social/indigo/lex/util"
18 "github.com/bluesky-social/indigo/repo"
19 "github.com/haileyok/cocoon/internal/db"
20 "github.com/haileyok/cocoon/models"
21 "github.com/haileyok/cocoon/recording_blockstore"
22 blocks "github.com/ipfs/go-block-format"
23 "github.com/ipfs/go-cid"
24 cbor "github.com/ipfs/go-ipld-cbor"
25 "github.com/ipld/go-car"
26 "gorm.io/gorm/clause"
27)
28
29type RepoMan struct {
30 db *db.DB
31 s *Server
32 clock *syntax.TIDClock
33}
34
35func NewRepoMan(s *Server) *RepoMan {
36 clock := syntax.NewTIDClock(0)
37
38 return &RepoMan{
39 s: s,
40 db: s.db,
41 clock: &clock,
42 }
43}
44
45type OpType string
46
47var (
48 OpTypeCreate = OpType("com.atproto.repo.applyWrites#create")
49 OpTypeUpdate = OpType("com.atproto.repo.applyWrites#update")
50 OpTypeDelete = OpType("com.atproto.repo.applyWrites#delete")
51)
52
53func (ot OpType) String() string {
54 return string(ot)
55}
56
57type Op struct {
58 Type OpType `json:"$type"`
59 Collection string `json:"collection"`
60 Rkey *string `json:"rkey,omitempty"`
61 Validate *bool `json:"validate,omitempty"`
62 SwapRecord *string `json:"swapRecord,omitempty"`
63 Record *MarshalableMap `json:"record,omitempty"`
64}
65
66type MarshalableMap map[string]any
67
68type FirehoseOp struct {
69 Cid cid.Cid
70 Path string
71 Action string
72}
73
74func (mm *MarshalableMap) MarshalCBOR(w io.Writer) error {
75 data, err := atdata.MarshalCBOR(*mm)
76 if err != nil {
77 return err
78 }
79
80 w.Write(data)
81
82 return nil
83}
84
85type ApplyWriteResult struct {
86 Type *string `json:"$type,omitempty"`
87 Uri *string `json:"uri,omitempty"`
88 Cid *string `json:"cid,omitempty"`
89 Commit *RepoCommit `json:"commit,omitempty"`
90 ValidationStatus *string `json:"validationStatus,omitempty"`
91}
92
93type RepoCommit struct {
94 Cid string `json:"cid"`
95 Rev string `json:"rev"`
96}
97
98// TODO make use of swap commit
99func (rm *RepoMan) applyWrites(urepo models.Repo, writes []Op, swapCommit *string) ([]ApplyWriteResult, error) {
100 rootcid, err := cid.Cast(urepo.Root)
101 if err != nil {
102 return nil, err
103 }
104
105 dbs := rm.s.getBlockstore(urepo.Did)
106 bs := recording_blockstore.New(dbs)
107 r, err := repo.OpenRepo(context.TODO(), bs, rootcid)
108
109 entries := []models.Record{}
110 var results []ApplyWriteResult
111
112 for i, op := range writes {
113 if op.Type != OpTypeCreate && op.Rkey == nil {
114 return nil, fmt.Errorf("invalid rkey")
115 } else if op.Type == OpTypeCreate && op.Rkey != nil {
116 _, _, err := r.GetRecord(context.TODO(), op.Collection+"/"+*op.Rkey)
117 if err == nil {
118 op.Type = OpTypeUpdate
119 }
120 } else if op.Rkey == nil {
121 op.Rkey = to.StringPtr(rm.clock.Next().String())
122 writes[i].Rkey = op.Rkey
123 }
124
125 _, err := syntax.ParseRecordKey(*op.Rkey)
126 if err != nil {
127 return nil, err
128 }
129
130 switch op.Type {
131 case OpTypeCreate:
132 j, err := json.Marshal(*op.Record)
133 if err != nil {
134 return nil, err
135 }
136 out, err := atdata.UnmarshalJSON(j)
137 if err != nil {
138 return nil, err
139 }
140 mm := MarshalableMap(out)
141
142 // HACK: if a record doesn't contain a $type, we can manually set it here based on the op's collection
143 if mm["$type"] == "" {
144 mm["$type"] = op.Collection
145 }
146
147 nc, err := r.PutRecord(context.TODO(), op.Collection+"/"+*op.Rkey, &mm)
148 if err != nil {
149 return nil, err
150 }
151 d, err := atdata.MarshalCBOR(mm)
152 if err != nil {
153 return nil, err
154 }
155 entries = append(entries, models.Record{
156 Did: urepo.Did,
157 CreatedAt: rm.clock.Next().String(),
158 Nsid: op.Collection,
159 Rkey: *op.Rkey,
160 Cid: nc.String(),
161 Value: d,
162 })
163 results = append(results, ApplyWriteResult{
164 Type: to.StringPtr(OpTypeCreate.String()),
165 Uri: to.StringPtr("at://" + urepo.Did + "/" + op.Collection + "/" + *op.Rkey),
166 Cid: to.StringPtr(nc.String()),
167 ValidationStatus: to.StringPtr("valid"), // TODO: obviously this might not be true atm lol
168 })
169 case OpTypeDelete:
170 var old models.Record
171 if err := rm.db.Raw("SELECT value FROM records WHERE did = ? AND nsid = ? AND rkey = ?", nil, urepo.Did, op.Collection, op.Rkey).Scan(&old).Error; err != nil {
172 return nil, err
173 }
174 entries = append(entries, models.Record{
175 Did: urepo.Did,
176 Nsid: op.Collection,
177 Rkey: *op.Rkey,
178 Value: old.Value,
179 })
180 err := r.DeleteRecord(context.TODO(), op.Collection+"/"+*op.Rkey)
181 if err != nil {
182 return nil, err
183 }
184 results = append(results, ApplyWriteResult{
185 Type: to.StringPtr(OpTypeDelete.String()),
186 })
187 case OpTypeUpdate:
188 j, err := json.Marshal(*op.Record)
189 if err != nil {
190 return nil, err
191 }
192 out, err := atdata.UnmarshalJSON(j)
193 if err != nil {
194 return nil, err
195 }
196 mm := MarshalableMap(out)
197 nc, err := r.UpdateRecord(context.TODO(), op.Collection+"/"+*op.Rkey, &mm)
198 if err != nil {
199 return nil, err
200 }
201 d, err := atdata.MarshalCBOR(mm)
202 if err != nil {
203 return nil, err
204 }
205 entries = append(entries, models.Record{
206 Did: urepo.Did,
207 CreatedAt: rm.clock.Next().String(),
208 Nsid: op.Collection,
209 Rkey: *op.Rkey,
210 Cid: nc.String(),
211 Value: d,
212 })
213 results = append(results, ApplyWriteResult{
214 Type: to.StringPtr(OpTypeUpdate.String()),
215 Uri: to.StringPtr("at://" + urepo.Did + "/" + op.Collection + "/" + *op.Rkey),
216 Cid: to.StringPtr(nc.String()),
217 ValidationStatus: to.StringPtr("valid"), // TODO: obviously this might not be true atm lol
218 })
219 }
220 }
221
222 newroot, rev, err := r.Commit(context.TODO(), urepo.SignFor)
223 if err != nil {
224 return nil, err
225 }
226
227 buf := new(bytes.Buffer)
228
229 hb, err := cbor.DumpObject(&car.CarHeader{
230 Roots: []cid.Cid{newroot},
231 Version: 1,
232 })
233
234 if _, err := carstore.LdWrite(buf, hb); err != nil {
235 return nil, err
236 }
237
238 diffops, err := r.DiffSince(context.TODO(), rootcid)
239 if err != nil {
240 return nil, err
241 }
242
243 ops := make([]*atproto.SyncSubscribeRepos_RepoOp, 0, len(diffops))
244
245 for _, op := range diffops {
246 var c cid.Cid
247 switch op.Op {
248 case "add", "mut":
249 kind := "create"
250 if op.Op == "mut" {
251 kind = "update"
252 }
253
254 c = op.NewCid
255 ll := lexutil.LexLink(op.NewCid)
256 ops = append(ops, &atproto.SyncSubscribeRepos_RepoOp{
257 Action: kind,
258 Path: op.Rpath,
259 Cid: &ll,
260 })
261
262 case "del":
263 c = op.OldCid
264 ll := lexutil.LexLink(op.OldCid)
265 ops = append(ops, &atproto.SyncSubscribeRepos_RepoOp{
266 Action: "delete",
267 Path: op.Rpath,
268 Cid: nil,
269 Prev: &ll,
270 })
271 }
272
273 blk, err := dbs.Get(context.TODO(), c)
274 if err != nil {
275 return nil, err
276 }
277
278 if _, err := carstore.LdWrite(buf, blk.Cid().Bytes(), blk.RawData()); err != nil {
279 return nil, err
280 }
281 }
282
283 for _, op := range bs.GetWriteLog() {
284 if _, err := carstore.LdWrite(buf, op.Cid().Bytes(), op.RawData()); err != nil {
285 return nil, err
286 }
287 }
288
289 var blobs []lexutil.LexLink
290 for _, entry := range entries {
291 var cids []cid.Cid
292 if entry.Cid != "" {
293 if err := rm.s.db.Create(&entry, []clause.Expression{clause.OnConflict{
294 Columns: []clause.Column{{Name: "did"}, {Name: "nsid"}, {Name: "rkey"}},
295 UpdateAll: true,
296 }}).Error; err != nil {
297 return nil, err
298 }
299
300 cids, err = rm.incrementBlobRefs(urepo, entry.Value)
301 if err != nil {
302 return nil, err
303 }
304 } else {
305 if err := rm.s.db.Delete(&entry, nil).Error; err != nil {
306 return nil, err
307 }
308 cids, err = rm.decrementBlobRefs(urepo, entry.Value)
309 if err != nil {
310 return nil, err
311 }
312 }
313
314 for _, c := range cids {
315 blobs = append(blobs, lexutil.LexLink(c))
316 }
317 }
318
319 rm.s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{
320 RepoCommit: &atproto.SyncSubscribeRepos_Commit{
321 Repo: urepo.Did,
322 Blocks: buf.Bytes(),
323 Blobs: blobs,
324 Rev: rev,
325 Since: &urepo.Rev,
326 Commit: lexutil.LexLink(newroot),
327 Time: time.Now().Format(time.RFC3339Nano),
328 Ops: ops,
329 TooBig: false,
330 },
331 })
332
333 if err := rm.s.UpdateRepo(context.TODO(), urepo.Did, newroot, rev); err != nil {
334 return nil, err
335 }
336
337 for i := range results {
338 results[i].Type = to.StringPtr(*results[i].Type + "Result")
339 results[i].Commit = &RepoCommit{
340 Cid: newroot.String(),
341 Rev: rev,
342 }
343 }
344
345 return results, nil
346}
347
348func (rm *RepoMan) getRecordProof(urepo models.Repo, collection, rkey string) (cid.Cid, []blocks.Block, error) {
349 c, err := cid.Cast(urepo.Root)
350 if err != nil {
351 return cid.Undef, nil, err
352 }
353
354 dbs := rm.s.getBlockstore(urepo.Did)
355 bs := recording_blockstore.New(dbs)
356
357 r, err := repo.OpenRepo(context.TODO(), bs, c)
358 if err != nil {
359 return cid.Undef, nil, err
360 }
361
362 _, _, err = r.GetRecordBytes(context.TODO(), collection+"/"+rkey)
363 if err != nil {
364 return cid.Undef, nil, err
365 }
366
367 return c, bs.GetReadLog(), nil
368}
369
370func (rm *RepoMan) incrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) {
371 cids, err := getBlobCidsFromCbor(cbor)
372 if err != nil {
373 return nil, err
374 }
375
376 for _, c := range cids {
377 if err := rm.db.Exec("UPDATE blobs SET ref_count = ref_count + 1 WHERE did = ? AND cid = ?", nil, urepo.Did, c.Bytes()).Error; err != nil {
378 return nil, err
379 }
380 }
381
382 return cids, nil
383}
384
385func (rm *RepoMan) decrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) {
386 cids, err := getBlobCidsFromCbor(cbor)
387 if err != nil {
388 return nil, err
389 }
390
391 for _, c := range cids {
392 var res struct {
393 ID uint
394 Count int
395 }
396 if err := rm.db.Raw("UPDATE blobs SET ref_count = ref_count - 1 WHERE did = ? AND cid = ? RETURNING id, ref_count", nil, urepo.Did, c.Bytes()).Scan(&res).Error; err != nil {
397 return nil, err
398 }
399
400 if res.Count == 0 {
401 if err := rm.db.Exec("DELETE FROM blobs WHERE id = ?", nil, res.ID).Error; err != nil {
402 return nil, err
403 }
404 if err := rm.db.Exec("DELETE FROM blob_parts WHERE blob_id = ?", nil, res.ID).Error; err != nil {
405 return nil, err
406 }
407 }
408 }
409
410 return cids, nil
411}
412
413// to be honest, we could just store both the cbor and non-cbor in []entries above to avoid an additional
414// unmarshal here. this will work for now though
415func getBlobCidsFromCbor(cbor []byte) ([]cid.Cid, error) {
416 var cids []cid.Cid
417
418 decoded, err := atdata.UnmarshalCBOR(cbor)
419 if err != nil {
420 return nil, fmt.Errorf("error unmarshaling cbor: %w", err)
421 }
422
423 var deepiter func(any) error
424 deepiter = func(item any) error {
425 switch val := item.(type) {
426 case map[string]any:
427 if val["$type"] == "blob" {
428 if ref, ok := val["ref"].(string); ok {
429 c, err := cid.Parse(ref)
430 if err != nil {
431 return err
432 }
433 cids = append(cids, c)
434 }
435 for _, v := range val {
436 return deepiter(v)
437 }
438 }
439 case []any:
440 for _, v := range val {
441 deepiter(v)
442 }
443 }
444
445 return nil
446 }
447
448 if err := deepiter(decoded); err != nil {
449 return nil, err
450 }
451
452 return cids, nil
453}