1package server
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "fmt"
8 "io"
9 "time"
10
11 "github.com/Azure/go-autorest/autorest/to"
12 "github.com/bluesky-social/indigo/api/atproto"
13 "github.com/bluesky-social/indigo/atproto/data"
14 "github.com/bluesky-social/indigo/atproto/syntax"
15 "github.com/bluesky-social/indigo/carstore"
16 "github.com/bluesky-social/indigo/events"
17 lexutil "github.com/bluesky-social/indigo/lex/util"
18 "github.com/bluesky-social/indigo/repo"
19 "github.com/haileyok/cocoon/internal/db"
20 "github.com/haileyok/cocoon/models"
21 "github.com/haileyok/cocoon/recording_blockstore"
22 blocks "github.com/ipfs/go-block-format"
23 "github.com/ipfs/go-cid"
24 cbor "github.com/ipfs/go-ipld-cbor"
25 "github.com/ipld/go-car"
26 "gorm.io/gorm/clause"
27)
28
29type RepoMan struct {
30 db *db.DB
31 s *Server
32 clock *syntax.TIDClock
33}
34
35func NewRepoMan(s *Server) *RepoMan {
36 clock := syntax.NewTIDClock(0)
37
38 return &RepoMan{
39 s: s,
40 db: s.db,
41 clock: &clock,
42 }
43}
44
45type OpType string
46
47var (
48 OpTypeCreate = OpType("com.atproto.repo.applyWrites#create")
49 OpTypeUpdate = OpType("com.atproto.repo.applyWrites#update")
50 OpTypeDelete = OpType("com.atproto.repo.applyWrites#delete")
51)
52
53func (ot OpType) String() string {
54 return string(ot)
55}
56
57type Op struct {
58 Type OpType `json:"$type"`
59 Collection string `json:"collection"`
60 Rkey *string `json:"rkey,omitempty"`
61 Validate *bool `json:"validate,omitempty"`
62 SwapRecord *string `json:"swapRecord,omitempty"`
63 Record *MarshalableMap `json:"record,omitempty"`
64}
65
66type MarshalableMap map[string]any
67
68type FirehoseOp struct {
69 Cid cid.Cid
70 Path string
71 Action string
72}
73
74func (mm *MarshalableMap) MarshalCBOR(w io.Writer) error {
75 data, err := data.MarshalCBOR(*mm)
76 if err != nil {
77 return err
78 }
79
80 w.Write(data)
81
82 return nil
83}
84
85type ApplyWriteResult struct {
86 Type *string `json:"$type,omitempty"`
87 Uri *string `json:"uri,omitempty"`
88 Cid *string `json:"cid,omitempty"`
89 Commit *RepoCommit `json:"commit,omitempty"`
90 ValidationStatus *string `json:"validationStatus,omitempty"`
91}
92
93type RepoCommit struct {
94 Cid string `json:"cid"`
95 Rev string `json:"rev"`
96}
97
98// TODO make use of swap commit
99func (rm *RepoMan) applyWrites(urepo models.Repo, writes []Op, swapCommit *string) ([]ApplyWriteResult, error) {
100 rootcid, err := cid.Cast(urepo.Root)
101 if err != nil {
102 return nil, err
103 }
104
105 dbs := rm.s.getBlockstore(urepo.Did)
106 bs := recording_blockstore.New(dbs)
107 r, err := repo.OpenRepo(context.TODO(), dbs, rootcid)
108
109 entries := []models.Record{}
110 var results []ApplyWriteResult
111
112 for i, op := range writes {
113 if op.Type != OpTypeCreate && op.Rkey == nil {
114 return nil, fmt.Errorf("invalid rkey")
115 } else if op.Type == OpTypeCreate && op.Rkey != nil {
116 _, _, err := r.GetRecord(context.TODO(), op.Collection+"/"+*op.Rkey)
117 if err == nil {
118 op.Type = OpTypeUpdate
119 }
120 } else if op.Rkey == nil {
121 op.Rkey = to.StringPtr(rm.clock.Next().String())
122 writes[i].Rkey = op.Rkey
123 }
124
125 _, err := syntax.ParseRecordKey(*op.Rkey)
126 if err != nil {
127 return nil, err
128 }
129
130 switch op.Type {
131 case OpTypeCreate:
132 j, err := json.Marshal(*op.Record)
133 if err != nil {
134 return nil, err
135 }
136 out, err := data.UnmarshalJSON(j)
137 if err != nil {
138 return nil, err
139 }
140 mm := MarshalableMap(out)
141 nc, err := r.PutRecord(context.TODO(), op.Collection+"/"+*op.Rkey, &mm)
142 if err != nil {
143 return nil, err
144 }
145 d, err := data.MarshalCBOR(mm)
146 if err != nil {
147 return nil, err
148 }
149 entries = append(entries, models.Record{
150 Did: urepo.Did,
151 CreatedAt: rm.clock.Next().String(),
152 Nsid: op.Collection,
153 Rkey: *op.Rkey,
154 Cid: nc.String(),
155 Value: d,
156 })
157 results = append(results, ApplyWriteResult{
158 Type: to.StringPtr(OpTypeCreate.String()),
159 Uri: to.StringPtr("at://" + urepo.Did + "/" + op.Collection + "/" + *op.Rkey),
160 Cid: to.StringPtr(nc.String()),
161 ValidationStatus: to.StringPtr("valid"), // TODO: obviously this might not be true atm lol
162 })
163 case OpTypeDelete:
164 var old models.Record
165 if err := rm.db.Raw("SELECT value FROM records WHERE did = ? AND nsid = ? AND rkey = ?", nil, urepo.Did, op.Collection, op.Rkey).Scan(&old).Error; err != nil {
166 return nil, err
167 }
168 entries = append(entries, models.Record{
169 Did: urepo.Did,
170 Nsid: op.Collection,
171 Rkey: *op.Rkey,
172 Value: old.Value,
173 })
174 err := r.DeleteRecord(context.TODO(), op.Collection+"/"+*op.Rkey)
175 if err != nil {
176 return nil, err
177 }
178 results = append(results, ApplyWriteResult{
179 Type: to.StringPtr(OpTypeDelete.String()),
180 })
181 case OpTypeUpdate:
182 j, err := json.Marshal(*op.Record)
183 if err != nil {
184 return nil, err
185 }
186 out, err := data.UnmarshalJSON(j)
187 if err != nil {
188 return nil, err
189 }
190 mm := MarshalableMap(out)
191 nc, err := r.UpdateRecord(context.TODO(), op.Collection+"/"+*op.Rkey, &mm)
192 if err != nil {
193 return nil, err
194 }
195 d, err := data.MarshalCBOR(mm)
196 if err != nil {
197 return nil, err
198 }
199 entries = append(entries, models.Record{
200 Did: urepo.Did,
201 CreatedAt: rm.clock.Next().String(),
202 Nsid: op.Collection,
203 Rkey: *op.Rkey,
204 Cid: nc.String(),
205 Value: d,
206 })
207 results = append(results, ApplyWriteResult{
208 Type: to.StringPtr(OpTypeUpdate.String()),
209 Uri: to.StringPtr("at://" + urepo.Did + "/" + op.Collection + "/" + *op.Rkey),
210 Cid: to.StringPtr(nc.String()),
211 ValidationStatus: to.StringPtr("valid"), // TODO: obviously this might not be true atm lol
212 })
213 }
214 }
215
216 newroot, rev, err := r.Commit(context.TODO(), urepo.SignFor)
217 if err != nil {
218 return nil, err
219 }
220
221 buf := new(bytes.Buffer)
222
223 hb, err := cbor.DumpObject(&car.CarHeader{
224 Roots: []cid.Cid{newroot},
225 Version: 1,
226 })
227
228 if _, err := carstore.LdWrite(buf, hb); err != nil {
229 return nil, err
230 }
231
232 diffops, err := r.DiffSince(context.TODO(), rootcid)
233 if err != nil {
234 return nil, err
235 }
236
237 ops := make([]*atproto.SyncSubscribeRepos_RepoOp, 0, len(diffops))
238
239 for _, op := range diffops {
240 var c cid.Cid
241 switch op.Op {
242 case "add", "mut":
243 kind := "create"
244 if op.Op == "mut" {
245 kind = "update"
246 }
247
248 c = op.NewCid
249 ll := lexutil.LexLink(op.NewCid)
250 ops = append(ops, &atproto.SyncSubscribeRepos_RepoOp{
251 Action: kind,
252 Path: op.Rpath,
253 Cid: &ll,
254 })
255
256 case "del":
257 c = op.OldCid
258 ll := lexutil.LexLink(op.OldCid)
259 ops = append(ops, &atproto.SyncSubscribeRepos_RepoOp{
260 Action: "delete",
261 Path: op.Rpath,
262 Cid: nil,
263 Prev: &ll,
264 })
265 }
266
267 blk, err := dbs.Get(context.TODO(), c)
268 if err != nil {
269 return nil, err
270 }
271
272 if _, err := carstore.LdWrite(buf, blk.Cid().Bytes(), blk.RawData()); err != nil {
273 return nil, err
274 }
275 }
276
277 for _, op := range bs.GetWriteLog() {
278 if _, err := carstore.LdWrite(buf, op.Cid().Bytes(), op.RawData()); err != nil {
279 return nil, err
280 }
281 }
282
283 var blobs []lexutil.LexLink
284 for _, entry := range entries {
285 var cids []cid.Cid
286 if entry.Cid != "" {
287 if err := rm.s.db.Create(&entry, []clause.Expression{clause.OnConflict{
288 Columns: []clause.Column{{Name: "did"}, {Name: "nsid"}, {Name: "rkey"}},
289 UpdateAll: true,
290 }}).Error; err != nil {
291 return nil, err
292 }
293
294 cids, err = rm.incrementBlobRefs(urepo, entry.Value)
295 if err != nil {
296 return nil, err
297 }
298 } else {
299 if err := rm.s.db.Delete(&entry, nil).Error; err != nil {
300 return nil, err
301 }
302 cids, err = rm.decrementBlobRefs(urepo, entry.Value)
303 if err != nil {
304 return nil, err
305 }
306 }
307
308 for _, c := range cids {
309 blobs = append(blobs, lexutil.LexLink(c))
310 }
311 }
312
313 rm.s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{
314 RepoCommit: &atproto.SyncSubscribeRepos_Commit{
315 Repo: urepo.Did,
316 Blocks: buf.Bytes(),
317 Blobs: blobs,
318 Rev: rev,
319 Since: &urepo.Rev,
320 Commit: lexutil.LexLink(newroot),
321 Time: time.Now().Format(time.RFC3339Nano),
322 Ops: ops,
323 TooBig: false,
324 },
325 })
326
327 if err := rm.s.UpdateRepo(context.TODO(), urepo.Did, newroot, rev); err != nil {
328 return nil, err
329 }
330
331 for i := range results {
332 results[i].Type = to.StringPtr(*results[i].Type + "Result")
333 results[i].Commit = &RepoCommit{
334 Cid: newroot.String(),
335 Rev: rev,
336 }
337 }
338
339 return results, nil
340}
341
342func (rm *RepoMan) getRecordProof(urepo models.Repo, collection, rkey string) (cid.Cid, []blocks.Block, error) {
343 c, err := cid.Cast(urepo.Root)
344 if err != nil {
345 return cid.Undef, nil, err
346 }
347
348 dbs := rm.s.getBlockstore(urepo.Did)
349 bs := recording_blockstore.New(dbs)
350
351 r, err := repo.OpenRepo(context.TODO(), bs, c)
352 if err != nil {
353 return cid.Undef, nil, err
354 }
355
356 _, _, err = r.GetRecordBytes(context.TODO(), collection+"/"+rkey)
357 if err != nil {
358 return cid.Undef, nil, err
359 }
360
361 return c, bs.GetReadLog(), nil
362}
363
364func (rm *RepoMan) incrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) {
365 cids, err := getBlobCidsFromCbor(cbor)
366 if err != nil {
367 return nil, err
368 }
369
370 for _, c := range cids {
371 if err := rm.db.Exec("UPDATE blobs SET ref_count = ref_count + 1 WHERE did = ? AND cid = ?", nil, urepo.Did, c.Bytes()).Error; err != nil {
372 return nil, err
373 }
374 }
375
376 return cids, nil
377}
378
379func (rm *RepoMan) decrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) {
380 cids, err := getBlobCidsFromCbor(cbor)
381 if err != nil {
382 return nil, err
383 }
384
385 for _, c := range cids {
386 var res struct {
387 ID uint
388 Count int
389 }
390 if err := rm.db.Raw("UPDATE blobs SET ref_count = ref_count - 1 WHERE did = ? AND cid = ? RETURNING id, ref_count", nil, urepo.Did, c.Bytes()).Scan(&res).Error; err != nil {
391 return nil, err
392 }
393
394 if res.Count == 0 {
395 if err := rm.db.Exec("DELETE FROM blobs WHERE id = ?", nil, res.ID).Error; err != nil {
396 return nil, err
397 }
398 if err := rm.db.Exec("DELETE FROM blob_parts WHERE blob_id = ?", nil, res.ID).Error; err != nil {
399 return nil, err
400 }
401 }
402 }
403
404 return cids, nil
405}
406
407// to be honest, we could just store both the cbor and non-cbor in []entries above to avoid an additional
408// unmarshal here. this will work for now though
409func getBlobCidsFromCbor(cbor []byte) ([]cid.Cid, error) {
410 var cids []cid.Cid
411
412 decoded, err := data.UnmarshalCBOR(cbor)
413 if err != nil {
414 return nil, fmt.Errorf("error unmarshaling cbor: %w", err)
415 }
416
417 var deepiter func(any) error
418 deepiter = func(item any) error {
419 switch val := item.(type) {
420 case map[string]any:
421 if val["$type"] == "blob" {
422 if ref, ok := val["ref"].(string); ok {
423 c, err := cid.Parse(ref)
424 if err != nil {
425 return err
426 }
427 cids = append(cids, c)
428 }
429 for _, v := range val {
430 return deepiter(v)
431 }
432 }
433 case []any:
434 for _, v := range val {
435 deepiter(v)
436 }
437 }
438
439 return nil
440 }
441
442 if err := deepiter(decoded); err != nil {
443 return nil, err
444 }
445
446 return cids, nil
447}