1package server
2
3import (
4 "bytes"
5 "context"
6 "fmt"
7 "io"
8 "time"
9
10 "github.com/Azure/go-autorest/autorest/to"
11 "github.com/bluesky-social/indigo/api/atproto"
12 "github.com/bluesky-social/indigo/atproto/data"
13 "github.com/bluesky-social/indigo/atproto/syntax"
14 "github.com/bluesky-social/indigo/carstore"
15 "github.com/bluesky-social/indigo/events"
16 lexutil "github.com/bluesky-social/indigo/lex/util"
17 "github.com/bluesky-social/indigo/repo"
18 "github.com/bluesky-social/indigo/util"
19 "github.com/haileyok/cocoon/blockstore"
20 "github.com/haileyok/cocoon/models"
21 blocks "github.com/ipfs/go-block-format"
22 "github.com/ipfs/go-cid"
23 cbor "github.com/ipfs/go-ipld-cbor"
24 "github.com/ipld/go-car"
25 "gorm.io/gorm"
26 "gorm.io/gorm/clause"
27)
28
29type RepoMan struct {
30 db *gorm.DB
31 s *Server
32 clock *syntax.TIDClock
33}
34
35func NewRepoMan(s *Server) *RepoMan {
36 clock := syntax.NewTIDClock(0)
37
38 return &RepoMan{
39 s: s,
40 db: s.db,
41 clock: &clock,
42 }
43}
44
45type OpType string
46
47var (
48 OpTypeCreate = OpType("com.atproto.repo.applyWrites#create")
49 OpTypeUpdate = OpType("com.atproto.repo.applyWrites#update")
50 OpTypeDelete = OpType("com.atproto.repo.applyWrites#delete")
51)
52
53func (ot OpType) String() string {
54 return string(ot)
55}
56
57type Op struct {
58 Type OpType `json:"$type"`
59 Collection string `json:"collection"`
60 Rkey *string `json:"rkey,omitempty"`
61 Validate *bool `json:"validate,omitempty"`
62 SwapRecord *string `json:"swapRecord,omitempty"`
63 Record *MarshalableMap `json:"record,omitempty"`
64}
65
66type MarshalableMap map[string]any
67
68type FirehoseOp struct {
69 Cid cid.Cid
70 Path string
71 Action string
72}
73
74func (mm *MarshalableMap) MarshalCBOR(w io.Writer) error {
75 data, err := data.MarshalCBOR(*mm)
76 if err != nil {
77 return err
78 }
79
80 w.Write(data)
81
82 return nil
83}
84
85type ApplyWriteResult struct {
86 Type *string `json:"$type,omitempty"`
87 Uri *string `json:"uri,omitempty"`
88 Cid *string `json:"cid,omitempty"`
89 Commit *RepoCommit `json:"commit,omitempty"`
90 ValidationStatus *string `json:"validationStatus,omitempty"`
91}
92
93type RepoCommit struct {
94 Cid string `json:"cid"`
95 Rev string `json:"rev"`
96}
97
98// TODO make use of swap commit
99func (rm *RepoMan) applyWrites(urepo models.Repo, writes []Op, swapCommit *string) ([]ApplyWriteResult, error) {
100 rootcid, err := cid.Cast(urepo.Root)
101 if err != nil {
102 return nil, err
103 }
104
105 dbs := blockstore.New(urepo.Did, rm.db)
106 r, err := repo.OpenRepo(context.TODO(), dbs, rootcid)
107
108 entries := []models.Record{}
109 var results []ApplyWriteResult
110
111 for i, op := range writes {
112 if op.Type != OpTypeCreate && op.Rkey == nil {
113 return nil, fmt.Errorf("invalid rkey")
114 } else if op.Rkey == nil {
115 op.Rkey = to.StringPtr(rm.clock.Next().String())
116 writes[i].Rkey = op.Rkey
117 }
118
119 _, err := syntax.ParseRecordKey(*op.Rkey)
120 if err != nil {
121 return nil, err
122 }
123
124 switch op.Type {
125 case OpTypeCreate:
126 nc, err := r.PutRecord(context.TODO(), op.Collection+"/"+*op.Rkey, op.Record)
127 if err != nil {
128 return nil, err
129 }
130
131 d, _ := data.MarshalCBOR(*op.Record)
132 entries = append(entries, models.Record{
133 Did: urepo.Did,
134 CreatedAt: rm.clock.Next().String(),
135 Nsid: op.Collection,
136 Rkey: *op.Rkey,
137 Cid: nc.String(),
138 Value: d,
139 })
140 results = append(results, ApplyWriteResult{
141 Type: to.StringPtr(OpTypeCreate.String()),
142 Uri: to.StringPtr("at://" + urepo.Did + "/" + op.Collection + "/" + *op.Rkey),
143 Cid: to.StringPtr(nc.String()),
144 ValidationStatus: to.StringPtr("valid"), // TODO: obviously this might not be true atm lol
145 })
146 case OpTypeDelete:
147 var old models.Record
148 if err := rm.db.Raw("SELECT value FROM records WHERE did = ? AND nsid = ? AND rkey = ?", urepo.Did, op.Collection, op.Rkey).Scan(&old).Error; err != nil {
149 return nil, err
150 }
151 entries = append(entries, models.Record{
152 Did: urepo.Did,
153 Nsid: op.Collection,
154 Rkey: *op.Rkey,
155 Value: old.Value,
156 })
157 err := r.DeleteRecord(context.TODO(), op.Collection+"/"+*op.Rkey)
158 if err != nil {
159 return nil, err
160 }
161 results = append(results, ApplyWriteResult{
162 Type: to.StringPtr(OpTypeDelete.String()),
163 })
164 case OpTypeUpdate:
165 nc, err := r.UpdateRecord(context.TODO(), op.Collection+"/"+*op.Rkey, op.Record)
166 if err != nil {
167 return nil, err
168 }
169
170 d, _ := data.MarshalCBOR(*op.Record)
171 entries = append(entries, models.Record{
172 Did: urepo.Did,
173 CreatedAt: rm.clock.Next().String(),
174 Nsid: op.Collection,
175 Rkey: *op.Rkey,
176 Cid: nc.String(),
177 Value: d,
178 })
179 results = append(results, ApplyWriteResult{
180 Type: to.StringPtr(OpTypeUpdate.String()),
181 Uri: to.StringPtr("at://" + urepo.Did + "/" + op.Collection + "/" + *op.Rkey),
182 Cid: to.StringPtr(nc.String()),
183 ValidationStatus: to.StringPtr("valid"), // TODO: obviously this might not be true atm lol
184 })
185 }
186 }
187
188 newroot, rev, err := r.Commit(context.TODO(), urepo.SignFor)
189 if err != nil {
190 return nil, err
191 }
192
193 buf := new(bytes.Buffer)
194
195 hb, err := cbor.DumpObject(&car.CarHeader{
196 Roots: []cid.Cid{newroot},
197 Version: 1,
198 })
199
200 if _, err := carstore.LdWrite(buf, hb); err != nil {
201 return nil, err
202 }
203
204 diffops, err := r.DiffSince(context.TODO(), rootcid)
205 if err != nil {
206 return nil, err
207 }
208
209 ops := make([]*atproto.SyncSubscribeRepos_RepoOp, 0, len(diffops))
210
211 for _, op := range diffops {
212 var c cid.Cid
213 switch op.Op {
214 case "add", "mut":
215 kind := "create"
216 if op.Op == "mut" {
217 kind = "update"
218 }
219
220 c = op.NewCid
221 ll := lexutil.LexLink(op.NewCid)
222 ops = append(ops, &atproto.SyncSubscribeRepos_RepoOp{
223 Action: kind,
224 Path: op.Rpath,
225 Cid: &ll,
226 })
227
228 case "del":
229 c = op.OldCid
230 ll := lexutil.LexLink(op.OldCid)
231 ops = append(ops, &atproto.SyncSubscribeRepos_RepoOp{
232 Action: "delete",
233 Path: op.Rpath,
234 Cid: nil,
235 Prev: &ll,
236 })
237 }
238
239 blk, err := dbs.Get(context.TODO(), c)
240 if err != nil {
241 return nil, err
242 }
243
244 if _, err := carstore.LdWrite(buf, blk.Cid().Bytes(), blk.RawData()); err != nil {
245 return nil, err
246 }
247 }
248
249 for _, op := range dbs.GetLog() {
250 if _, err := carstore.LdWrite(buf, op.Cid().Bytes(), op.RawData()); err != nil {
251 return nil, err
252 }
253 }
254
255 var blobs []lexutil.LexLink
256 for _, entry := range entries {
257 var cids []cid.Cid
258 if entry.Cid != "" {
259 if err := rm.s.db.Clauses(clause.OnConflict{
260 Columns: []clause.Column{{Name: "did"}, {Name: "nsid"}, {Name: "rkey"}},
261 UpdateAll: true,
262 }).Create(&entry).Error; err != nil {
263 return nil, err
264 }
265
266 cids, err = rm.incrementBlobRefs(urepo, entry.Value)
267 if err != nil {
268 return nil, err
269 }
270 } else {
271 if err := rm.s.db.Delete(&entry).Error; err != nil {
272 return nil, err
273 }
274 cids, err = rm.decrementBlobRefs(urepo, entry.Value)
275 if err != nil {
276 return nil, err
277 }
278 }
279
280 for _, c := range cids {
281 blobs = append(blobs, lexutil.LexLink(c))
282 }
283 }
284
285 rm.s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{
286 RepoCommit: &atproto.SyncSubscribeRepos_Commit{
287 Repo: urepo.Did,
288 Blocks: buf.Bytes(),
289 Blobs: blobs,
290 Rev: rev,
291 Since: &urepo.Rev,
292 Commit: lexutil.LexLink(newroot),
293 Time: time.Now().Format(util.ISO8601),
294 Ops: ops,
295 TooBig: false,
296 },
297 })
298
299 if err := dbs.UpdateRepo(context.TODO(), newroot, rev); err != nil {
300 return nil, err
301 }
302
303 for i := range results {
304 results[i].Type = to.StringPtr(*results[i].Type + "Result")
305 results[i].Commit = &RepoCommit{
306 Cid: newroot.String(),
307 Rev: rev,
308 }
309 }
310
311 return results, nil
312}
313
314func (rm *RepoMan) getRecordProof(urepo models.Repo, collection, rkey string) (cid.Cid, []blocks.Block, error) {
315 c, err := cid.Cast(urepo.Root)
316 if err != nil {
317 return cid.Undef, nil, err
318 }
319
320 dbs := blockstore.New(urepo.Did, rm.db)
321 bs := util.NewLoggingBstore(dbs)
322
323 r, err := repo.OpenRepo(context.TODO(), bs, c)
324 if err != nil {
325 return cid.Undef, nil, err
326 }
327
328 _, _, err = r.GetRecordBytes(context.TODO(), collection+"/"+rkey)
329 if err != nil {
330 return cid.Undef, nil, err
331 }
332
333 return c, bs.GetLoggedBlocks(), nil
334}
335
336func (rm *RepoMan) incrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) {
337 cids, err := getBlobCidsFromCbor(cbor)
338 if err != nil {
339 return nil, err
340 }
341
342 for _, c := range cids {
343 if err := rm.db.Exec("UPDATE blobs SET ref_count = ref_count + 1 WHERE did = ? AND cid = ?", urepo.Did, c.Bytes()).Error; err != nil {
344 return nil, err
345 }
346 }
347
348 return cids, nil
349}
350
351func (rm *RepoMan) decrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) {
352 cids, err := getBlobCidsFromCbor(cbor)
353 if err != nil {
354 return nil, err
355 }
356
357 for _, c := range cids {
358 var res struct {
359 ID uint
360 Count int
361 }
362 if err := rm.db.Raw("UPDATE blobs SET ref_count = ref_count - 1 WHERE did = ? AND cid = ? RETURNING id, ref_count", urepo.Did, c.Bytes()).Scan(&res).Error; err != nil {
363 return nil, err
364 }
365
366 if res.Count == 0 {
367 if err := rm.db.Exec("DELETE FROM blobs WHERE id = ?", res.ID).Error; err != nil {
368 return nil, err
369 }
370 if err := rm.db.Exec("DELETE FROM blob_parts WHERE blob_id = ?", res.ID).Error; err != nil {
371 return nil, err
372 }
373 }
374 }
375
376 return cids, nil
377}
378
379// to be honest, we could just store both the cbor and non-cbor in []entries above to avoid an additional
380// unmarshal here. this will work for now though
381func getBlobCidsFromCbor(cbor []byte) ([]cid.Cid, error) {
382 var cids []cid.Cid
383
384 decoded, err := data.UnmarshalCBOR(cbor)
385 if err != nil {
386 return nil, fmt.Errorf("error unmarshaling cbor: %w", err)
387 }
388
389 var deepiter func(interface{}) error
390 deepiter = func(item interface{}) error {
391 switch val := item.(type) {
392 case map[string]interface{}:
393 if val["$type"] == "blob" {
394 if ref, ok := val["ref"].(string); ok {
395 c, err := cid.Parse(ref)
396 if err != nil {
397 return err
398 }
399 cids = append(cids, c)
400 }
401 for _, v := range val {
402 return deepiter(v)
403 }
404 }
405 case []interface{}:
406 for _, v := range val {
407 deepiter(v)
408 }
409 }
410
411 return nil
412 }
413
414 if err := deepiter(decoded); err != nil {
415 return nil, err
416 }
417
418 return cids, nil
419}