1package server
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "fmt"
8 "io"
9 "time"
10
11 "github.com/Azure/go-autorest/autorest/to"
12 "github.com/bluesky-social/indigo/api/atproto"
13 "github.com/bluesky-social/indigo/atproto/data"
14 "github.com/bluesky-social/indigo/atproto/syntax"
15 "github.com/bluesky-social/indigo/carstore"
16 "github.com/bluesky-social/indigo/events"
17 lexutil "github.com/bluesky-social/indigo/lex/util"
18 "github.com/bluesky-social/indigo/repo"
19 "github.com/bluesky-social/indigo/util"
20 "github.com/haileyok/cocoon/blockstore"
21 "github.com/haileyok/cocoon/models"
22 blocks "github.com/ipfs/go-block-format"
23 "github.com/ipfs/go-cid"
24 cbor "github.com/ipfs/go-ipld-cbor"
25 "github.com/ipld/go-car"
26 "gorm.io/gorm"
27 "gorm.io/gorm/clause"
28)
29
30type RepoMan struct {
31 db *gorm.DB
32 s *Server
33 clock *syntax.TIDClock
34}
35
36func NewRepoMan(s *Server) *RepoMan {
37 clock := syntax.NewTIDClock(0)
38
39 return &RepoMan{
40 s: s,
41 db: s.db,
42 clock: &clock,
43 }
44}
45
46type OpType string
47
48var (
49 OpTypeCreate = OpType("com.atproto.repo.applyWrites#create")
50 OpTypeUpdate = OpType("com.atproto.repo.applyWrites#update")
51 OpTypeDelete = OpType("com.atproto.repo.applyWrites#delete")
52)
53
54func (ot OpType) String() string {
55 return string(ot)
56}
57
58type Op struct {
59 Type OpType `json:"$type"`
60 Collection string `json:"collection"`
61 Rkey *string `json:"rkey,omitempty"`
62 Validate *bool `json:"validate,omitempty"`
63 SwapRecord *string `json:"swapRecord,omitempty"`
64 Record *MarshalableMap `json:"record,omitempty"`
65}
66
67type MarshalableMap map[string]any
68
69type FirehoseOp struct {
70 Cid cid.Cid
71 Path string
72 Action string
73}
74
75func (mm *MarshalableMap) MarshalCBOR(w io.Writer) error {
76 data, err := data.MarshalCBOR(*mm)
77 if err != nil {
78 return err
79 }
80
81 w.Write(data)
82
83 return nil
84}
85
86type ApplyWriteResult struct {
87 Type *string `json:"$type,omitempty"`
88 Uri *string `json:"uri,omitempty"`
89 Cid *string `json:"cid,omitempty"`
90 Commit *RepoCommit `json:"commit,omitempty"`
91 ValidationStatus *string `json:"validationStatus,omitempty"`
92}
93
94type RepoCommit struct {
95 Cid string `json:"cid"`
96 Rev string `json:"rev"`
97}
98
99// TODO make use of swap commit
100func (rm *RepoMan) applyWrites(urepo models.Repo, writes []Op, swapCommit *string) ([]ApplyWriteResult, error) {
101 rootcid, err := cid.Cast(urepo.Root)
102 if err != nil {
103 return nil, err
104 }
105
106 dbs := blockstore.New(urepo.Did, rm.db)
107 r, err := repo.OpenRepo(context.TODO(), dbs, rootcid)
108
109 entries := []models.Record{}
110 var results []ApplyWriteResult
111
112 for i, op := range writes {
113 if op.Type != OpTypeCreate && op.Rkey == nil {
114 return nil, fmt.Errorf("invalid rkey")
115 } else if op.Rkey == nil {
116 op.Rkey = to.StringPtr(rm.clock.Next().String())
117 writes[i].Rkey = op.Rkey
118 }
119
120 _, err := syntax.ParseRecordKey(*op.Rkey)
121 if err != nil {
122 return nil, err
123 }
124
125 switch op.Type {
126 case OpTypeCreate:
127 j, err := json.Marshal(*op.Record)
128 if err != nil {
129 return nil, err
130 }
131 out, err := data.UnmarshalJSON(j)
132 if err != nil {
133 return nil, err
134 }
135 mm := MarshalableMap(out)
136 nc, err := r.PutRecord(context.TODO(), op.Collection+"/"+*op.Rkey, &mm)
137 if err != nil {
138 return nil, err
139 }
140 d, err := data.MarshalCBOR(mm)
141 if err != nil {
142 return nil, err
143 }
144 entries = append(entries, models.Record{
145 Did: urepo.Did,
146 CreatedAt: rm.clock.Next().String(),
147 Nsid: op.Collection,
148 Rkey: *op.Rkey,
149 Cid: nc.String(),
150 Value: d,
151 })
152 results = append(results, ApplyWriteResult{
153 Type: to.StringPtr(OpTypeCreate.String()),
154 Uri: to.StringPtr("at://" + urepo.Did + "/" + op.Collection + "/" + *op.Rkey),
155 Cid: to.StringPtr(nc.String()),
156 ValidationStatus: to.StringPtr("valid"), // TODO: obviously this might not be true atm lol
157 })
158 case OpTypeDelete:
159 var old models.Record
160 if err := rm.db.Raw("SELECT value FROM records WHERE did = ? AND nsid = ? AND rkey = ?", urepo.Did, op.Collection, op.Rkey).Scan(&old).Error; err != nil {
161 return nil, err
162 }
163 entries = append(entries, models.Record{
164 Did: urepo.Did,
165 Nsid: op.Collection,
166 Rkey: *op.Rkey,
167 Value: old.Value,
168 })
169 err := r.DeleteRecord(context.TODO(), op.Collection+"/"+*op.Rkey)
170 if err != nil {
171 return nil, err
172 }
173 results = append(results, ApplyWriteResult{
174 Type: to.StringPtr(OpTypeDelete.String()),
175 })
176 case OpTypeUpdate:
177 j, err := json.Marshal(*op.Record)
178 if err != nil {
179 return nil, err
180 }
181 out, err := data.UnmarshalJSON(j)
182 if err != nil {
183 return nil, err
184 }
185 mm := MarshalableMap(out)
186 nc, err := r.UpdateRecord(context.TODO(), op.Collection+"/"+*op.Rkey, &mm)
187 if err != nil {
188 return nil, err
189 }
190 d, err := data.MarshalCBOR(mm)
191 if err != nil {
192 return nil, err
193 }
194 entries = append(entries, models.Record{
195 Did: urepo.Did,
196 CreatedAt: rm.clock.Next().String(),
197 Nsid: op.Collection,
198 Rkey: *op.Rkey,
199 Cid: nc.String(),
200 Value: d,
201 })
202 results = append(results, ApplyWriteResult{
203 Type: to.StringPtr(OpTypeUpdate.String()),
204 Uri: to.StringPtr("at://" + urepo.Did + "/" + op.Collection + "/" + *op.Rkey),
205 Cid: to.StringPtr(nc.String()),
206 ValidationStatus: to.StringPtr("valid"), // TODO: obviously this might not be true atm lol
207 })
208 }
209 }
210
211 newroot, rev, err := r.Commit(context.TODO(), urepo.SignFor)
212 if err != nil {
213 return nil, err
214 }
215
216 buf := new(bytes.Buffer)
217
218 hb, err := cbor.DumpObject(&car.CarHeader{
219 Roots: []cid.Cid{newroot},
220 Version: 1,
221 })
222
223 if _, err := carstore.LdWrite(buf, hb); err != nil {
224 return nil, err
225 }
226
227 diffops, err := r.DiffSince(context.TODO(), rootcid)
228 if err != nil {
229 return nil, err
230 }
231
232 ops := make([]*atproto.SyncSubscribeRepos_RepoOp, 0, len(diffops))
233
234 for _, op := range diffops {
235 var c cid.Cid
236 switch op.Op {
237 case "add", "mut":
238 kind := "create"
239 if op.Op == "mut" {
240 kind = "update"
241 }
242
243 c = op.NewCid
244 ll := lexutil.LexLink(op.NewCid)
245 ops = append(ops, &atproto.SyncSubscribeRepos_RepoOp{
246 Action: kind,
247 Path: op.Rpath,
248 Cid: &ll,
249 })
250
251 case "del":
252 c = op.OldCid
253 ll := lexutil.LexLink(op.OldCid)
254 ops = append(ops, &atproto.SyncSubscribeRepos_RepoOp{
255 Action: "delete",
256 Path: op.Rpath,
257 Cid: nil,
258 Prev: &ll,
259 })
260 }
261
262 blk, err := dbs.Get(context.TODO(), c)
263 if err != nil {
264 return nil, err
265 }
266
267 if _, err := carstore.LdWrite(buf, blk.Cid().Bytes(), blk.RawData()); err != nil {
268 return nil, err
269 }
270 }
271
272 for _, op := range dbs.GetLog() {
273 if _, err := carstore.LdWrite(buf, op.Cid().Bytes(), op.RawData()); err != nil {
274 return nil, err
275 }
276 }
277
278 var blobs []lexutil.LexLink
279 for _, entry := range entries {
280 var cids []cid.Cid
281 if entry.Cid != "" {
282 if err := rm.s.db.Clauses(clause.OnConflict{
283 Columns: []clause.Column{{Name: "did"}, {Name: "nsid"}, {Name: "rkey"}},
284 UpdateAll: true,
285 }).Create(&entry).Error; err != nil {
286 return nil, err
287 }
288
289 cids, err = rm.incrementBlobRefs(urepo, entry.Value)
290 if err != nil {
291 return nil, err
292 }
293 } else {
294 if err := rm.s.db.Delete(&entry).Error; err != nil {
295 return nil, err
296 }
297 cids, err = rm.decrementBlobRefs(urepo, entry.Value)
298 if err != nil {
299 return nil, err
300 }
301 }
302
303 for _, c := range cids {
304 blobs = append(blobs, lexutil.LexLink(c))
305 }
306 }
307
308 rm.s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{
309 RepoCommit: &atproto.SyncSubscribeRepos_Commit{
310 Repo: urepo.Did,
311 Blocks: buf.Bytes(),
312 Blobs: blobs,
313 Rev: rev,
314 Since: &urepo.Rev,
315 Commit: lexutil.LexLink(newroot),
316 Time: time.Now().Format(util.ISO8601),
317 Ops: ops,
318 TooBig: false,
319 },
320 })
321
322 if err := dbs.UpdateRepo(context.TODO(), newroot, rev); err != nil {
323 return nil, err
324 }
325
326 for i := range results {
327 results[i].Type = to.StringPtr(*results[i].Type + "Result")
328 results[i].Commit = &RepoCommit{
329 Cid: newroot.String(),
330 Rev: rev,
331 }
332 }
333
334 return results, nil
335}
336
337func (rm *RepoMan) getRecordProof(urepo models.Repo, collection, rkey string) (cid.Cid, []blocks.Block, error) {
338 c, err := cid.Cast(urepo.Root)
339 if err != nil {
340 return cid.Undef, nil, err
341 }
342
343 dbs := blockstore.New(urepo.Did, rm.db)
344 bs := util.NewLoggingBstore(dbs)
345
346 r, err := repo.OpenRepo(context.TODO(), bs, c)
347 if err != nil {
348 return cid.Undef, nil, err
349 }
350
351 _, _, err = r.GetRecordBytes(context.TODO(), collection+"/"+rkey)
352 if err != nil {
353 return cid.Undef, nil, err
354 }
355
356 return c, bs.GetLoggedBlocks(), nil
357}
358
359func (rm *RepoMan) incrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) {
360 cids, err := getBlobCidsFromCbor(cbor)
361 if err != nil {
362 return nil, err
363 }
364
365 for _, c := range cids {
366 if err := rm.db.Exec("UPDATE blobs SET ref_count = ref_count + 1 WHERE did = ? AND cid = ?", urepo.Did, c.Bytes()).Error; err != nil {
367 return nil, err
368 }
369 }
370
371 return cids, nil
372}
373
374func (rm *RepoMan) decrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) {
375 cids, err := getBlobCidsFromCbor(cbor)
376 if err != nil {
377 return nil, err
378 }
379
380 for _, c := range cids {
381 var res struct {
382 ID uint
383 Count int
384 }
385 if err := rm.db.Raw("UPDATE blobs SET ref_count = ref_count - 1 WHERE did = ? AND cid = ? RETURNING id, ref_count", urepo.Did, c.Bytes()).Scan(&res).Error; err != nil {
386 return nil, err
387 }
388
389 if res.Count == 0 {
390 if err := rm.db.Exec("DELETE FROM blobs WHERE id = ?", res.ID).Error; err != nil {
391 return nil, err
392 }
393 if err := rm.db.Exec("DELETE FROM blob_parts WHERE blob_id = ?", res.ID).Error; err != nil {
394 return nil, err
395 }
396 }
397 }
398
399 return cids, nil
400}
401
402// to be honest, we could just store both the cbor and non-cbor in []entries above to avoid an additional
403// unmarshal here. this will work for now though
404func getBlobCidsFromCbor(cbor []byte) ([]cid.Cid, error) {
405 var cids []cid.Cid
406
407 decoded, err := data.UnmarshalCBOR(cbor)
408 if err != nil {
409 return nil, fmt.Errorf("error unmarshaling cbor: %w", err)
410 }
411
412 var deepiter func(interface{}) error
413 deepiter = func(item interface{}) error {
414 switch val := item.(type) {
415 case map[string]interface{}:
416 if val["$type"] == "blob" {
417 if ref, ok := val["ref"].(string); ok {
418 c, err := cid.Parse(ref)
419 if err != nil {
420 return err
421 }
422 cids = append(cids, c)
423 }
424 for _, v := range val {
425 return deepiter(v)
426 }
427 }
428 case []interface{}:
429 for _, v := range val {
430 deepiter(v)
431 }
432 }
433
434 return nil
435 }
436
437 if err := deepiter(decoded); err != nil {
438 return nil, err
439 }
440
441 return cids, nil
442}