1package server
2
3import (
4 "bytes"
5 "context"
6 "fmt"
7 "io"
8 "time"
9
10 "github.com/Azure/go-autorest/autorest/to"
11 "github.com/bluesky-social/indigo/api/atproto"
12 "github.com/bluesky-social/indigo/atproto/data"
13 "github.com/bluesky-social/indigo/atproto/syntax"
14 "github.com/bluesky-social/indigo/carstore"
15 "github.com/bluesky-social/indigo/events"
16 lexutil "github.com/bluesky-social/indigo/lex/util"
17 "github.com/bluesky-social/indigo/repo"
18 "github.com/bluesky-social/indigo/util"
19 "github.com/haileyok/cocoon/blockstore"
20 "github.com/haileyok/cocoon/models"
21 blocks "github.com/ipfs/go-block-format"
22 "github.com/ipfs/go-cid"
23 cbor "github.com/ipfs/go-ipld-cbor"
24 "github.com/ipld/go-car"
25 "gorm.io/gorm"
26 "gorm.io/gorm/clause"
27)
28
29type RepoMan struct {
30 db *gorm.DB
31 s *Server
32 clock *syntax.TIDClock
33}
34
35func NewRepoMan(s *Server) *RepoMan {
36 clock := syntax.NewTIDClock(0)
37
38 return &RepoMan{
39 s: s,
40 db: s.db,
41 clock: &clock,
42 }
43}
44
45type OpType string
46
47var (
48 OpTypeCreate = OpType("com.atproto.repo.applyWrites#create")
49 OpTypeUpdate = OpType("com.atproto.repo.applyWrites#update")
50 OpTypeDelete = OpType("com.atproto.repo.applyWrites#delete")
51)
52
53func (ot OpType) String() string {
54 return string(ot)
55}
56
57type Op struct {
58 Type OpType `json:"$type"`
59 Collection string `json:"collection"`
60 Rkey *string `json:"rkey,omitempty"`
61 Validate *bool `json:"validate,omitempty"`
62 SwapRecord *string `json:"swapRecord,omitempty"`
63 Record *MarshalableMap `json:"record,omitempty"`
64}
65
66type MarshalableMap map[string]any
67
68type FirehoseOp struct {
69 Cid cid.Cid
70 Path string
71 Action string
72}
73
74func (mm *MarshalableMap) MarshalCBOR(w io.Writer) error {
75 data, err := data.MarshalCBOR(*mm)
76 if err != nil {
77 return err
78 }
79
80 w.Write(data)
81
82 return nil
83}
84
85type ApplyWriteResult struct {
86 Type *string `json:"$type,omitempty"`
87 Uri *string `json:"uri,omitempty"`
88 Cid *string `json:"cid,omitempty"`
89 Commit *RepoCommit `json:"commit,omitempty"`
90 ValidationStatus *string `json:"validationStatus,omitempty"`
91}
92
93type RepoCommit struct {
94 Cid string `json:"cid"`
95 Rev string `json:"rev"`
96}
97
98// TODO make use of swap commit
99func (rm *RepoMan) applyWrites(urepo models.Repo, writes []Op, swapCommit *string) ([]ApplyWriteResult, error) {
100 rootcid, err := cid.Cast(urepo.Root)
101 if err != nil {
102 return nil, err
103 }
104
105 dbs := blockstore.New(urepo.Did, rm.db)
106 r, err := repo.OpenRepo(context.TODO(), dbs, rootcid)
107
108 entries := []models.Record{}
109 var results []ApplyWriteResult
110
111 for i, op := range writes {
112 if op.Type != OpTypeCreate && op.Rkey == nil {
113 return nil, fmt.Errorf("invalid rkey")
114 } else if op.Rkey == nil {
115 op.Rkey = to.StringPtr(rm.clock.Next().String())
116 writes[i].Rkey = op.Rkey
117 }
118
119 _, err := syntax.ParseRecordKey(*op.Rkey)
120 if err != nil {
121 return nil, err
122 }
123
124 switch op.Type {
125 case OpTypeCreate:
126 nc, err := r.PutRecord(context.TODO(), op.Collection+"/"+*op.Rkey, op.Record)
127 if err != nil {
128 return nil, err
129 }
130
131 d, _ := data.MarshalCBOR(*op.Record)
132 entries = append(entries, models.Record{
133 Did: urepo.Did,
134 CreatedAt: rm.clock.Next().String(),
135 Nsid: op.Collection,
136 Rkey: *op.Rkey,
137 Cid: nc.String(),
138 Value: d,
139 })
140 results = append(results, ApplyWriteResult{
141 Type: to.StringPtr(OpTypeCreate.String()),
142 Uri: to.StringPtr("at://" + urepo.Did + "/" + op.Collection + "/" + *op.Rkey),
143 Cid: to.StringPtr(nc.String()),
144 ValidationStatus: to.StringPtr("valid"), // TODO: obviously this might not be true atm lol
145 })
146 case OpTypeDelete:
147 var old models.Record
148 if err := rm.db.Raw("SELECT value FROM records WHERE did = ? AND nsid = ? AND rkey = ?", urepo.Did, op.Collection, op.Rkey).Scan(&old).Error; err != nil {
149 return nil, err
150 }
151 entries = append(entries, models.Record{
152 Did: urepo.Did,
153 Nsid: op.Collection,
154 Rkey: *op.Rkey,
155 Value: old.Value,
156 })
157 err := r.DeleteRecord(context.TODO(), op.Collection+"/"+*op.Rkey)
158 if err != nil {
159 return nil, err
160 }
161 results = append(results, ApplyWriteResult{
162 Type: to.StringPtr(OpTypeDelete.String()),
163 })
164 case OpTypeUpdate:
165 nc, err := r.UpdateRecord(context.TODO(), op.Collection+"/"+*op.Rkey, op.Record)
166 if err != nil {
167 return nil, err
168 }
169
170 d, _ := data.MarshalCBOR(*op.Record)
171 entries = append(entries, models.Record{
172 Did: urepo.Did,
173 CreatedAt: rm.clock.Next().String(),
174 Nsid: op.Collection,
175 Rkey: *op.Rkey,
176 Cid: nc.String(),
177 Value: d,
178 })
179 results = append(results, ApplyWriteResult{
180 Type: to.StringPtr(OpTypeUpdate.String()),
181 Uri: to.StringPtr("at://" + urepo.Did + "/" + op.Collection + "/" + *op.Rkey),
182 Cid: to.StringPtr(nc.String()),
183 ValidationStatus: to.StringPtr("valid"), // TODO: obviously this might not be true atm lol
184 })
185 }
186 }
187
188 newroot, rev, err := r.Commit(context.TODO(), urepo.SignFor)
189 if err != nil {
190 return nil, err
191 }
192
193 buf := new(bytes.Buffer)
194
195 hb, err := cbor.DumpObject(&car.CarHeader{
196 Roots: []cid.Cid{newroot},
197 Version: 1,
198 })
199
200 if _, err := carstore.LdWrite(buf, hb); err != nil {
201 return nil, err
202 }
203
204 diffops, err := r.DiffSince(context.TODO(), rootcid)
205 if err != nil {
206 return nil, err
207 }
208
209 ops := make([]*atproto.SyncSubscribeRepos_RepoOp, 0, len(diffops))
210
211 for _, op := range diffops {
212 switch op.Op {
213 case "add", "mut":
214 kind := "create"
215 if op.Op == "mut" {
216 kind = "update"
217 }
218
219 ll := lexutil.LexLink(op.NewCid)
220 ops = append(ops, &atproto.SyncSubscribeRepos_RepoOp{
221 Action: kind,
222 Path: op.Rpath,
223 Cid: &ll,
224 })
225
226 case "del":
227 ll := lexutil.LexLink(op.OldCid)
228 ops = append(ops, &atproto.SyncSubscribeRepos_RepoOp{
229 Action: "delete",
230 Path: op.Rpath,
231 Cid: nil,
232 Prev: &ll,
233 })
234 }
235
236 blk, err := dbs.Get(context.TODO(), op.NewCid)
237 if err != nil {
238 return nil, err
239 }
240
241 if _, err := carstore.LdWrite(buf, blk.Cid().Bytes(), blk.RawData()); err != nil {
242 return nil, err
243 }
244 }
245
246 for _, op := range dbs.GetLog() {
247 if _, err := carstore.LdWrite(buf, op.Cid().Bytes(), op.RawData()); err != nil {
248 return nil, err
249 }
250 }
251
252 var blobs []lexutil.LexLink
253 for _, entry := range entries {
254 var cids []cid.Cid
255 if entry.Cid != "" {
256 if err := rm.s.db.Clauses(clause.OnConflict{
257 Columns: []clause.Column{{Name: "did"}, {Name: "nsid"}, {Name: "rkey"}},
258 UpdateAll: true,
259 }).Create(&entry).Error; err != nil {
260 return nil, err
261 }
262
263 cids, err = rm.incrementBlobRefs(urepo, entry.Value)
264 if err != nil {
265 return nil, err
266 }
267 } else {
268 if err := rm.s.db.Delete(&entry).Error; err != nil {
269 return nil, err
270 }
271 cids, err = rm.decrementBlobRefs(urepo, entry.Value)
272 if err != nil {
273 return nil, err
274 }
275 }
276
277 for _, c := range cids {
278 blobs = append(blobs, lexutil.LexLink(c))
279 }
280 }
281
282 rm.s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{
283 RepoCommit: &atproto.SyncSubscribeRepos_Commit{
284 Repo: urepo.Did,
285 Blocks: buf.Bytes(),
286 Blobs: blobs,
287 Rev: rev,
288 Since: &urepo.Rev,
289 Commit: lexutil.LexLink(newroot),
290 Time: time.Now().Format(util.ISO8601),
291 Ops: ops,
292 TooBig: false,
293 },
294 })
295
296 if err := dbs.UpdateRepo(context.TODO(), newroot, rev); err != nil {
297 return nil, err
298 }
299
300 for i := range results {
301 results[i].Type = to.StringPtr(*results[i].Type + "Result")
302 results[i].Commit = &RepoCommit{
303 Cid: newroot.String(),
304 Rev: rev,
305 }
306 }
307
308 return results, nil
309}
310
311func (rm *RepoMan) getRecordProof(urepo models.Repo, collection, rkey string) (cid.Cid, []blocks.Block, error) {
312 c, err := cid.Cast(urepo.Root)
313 if err != nil {
314 return cid.Undef, nil, err
315 }
316
317 dbs := blockstore.New(urepo.Did, rm.db)
318 bs := util.NewLoggingBstore(dbs)
319
320 r, err := repo.OpenRepo(context.TODO(), bs, c)
321 if err != nil {
322 return cid.Undef, nil, err
323 }
324
325 _, _, err = r.GetRecordBytes(context.TODO(), collection+"/"+rkey)
326 if err != nil {
327 return cid.Undef, nil, err
328 }
329
330 return c, bs.GetLoggedBlocks(), nil
331}
332
333func (rm *RepoMan) incrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) {
334 cids, err := getBlobCidsFromCbor(cbor)
335 if err != nil {
336 return nil, err
337 }
338
339 for _, c := range cids {
340 if err := rm.db.Exec("UPDATE blobs SET ref_count = ref_count + 1 WHERE did = ? AND cid = ?", urepo.Did, c.Bytes()).Error; err != nil {
341 return nil, err
342 }
343 }
344
345 return cids, nil
346}
347
348func (rm *RepoMan) decrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) {
349 cids, err := getBlobCidsFromCbor(cbor)
350 if err != nil {
351 return nil, err
352 }
353
354 for _, c := range cids {
355 var res struct {
356 ID uint
357 Count int
358 }
359 if err := rm.db.Raw("UPDATE blobs SET ref_count = ref_count - 1 WHERE did = ? AND cid = ? RETURNING id, ref_count", urepo.Did, c.Bytes()).Scan(&res).Error; err != nil {
360 return nil, err
361 }
362
363 if res.Count == 0 {
364 if err := rm.db.Exec("DELETE FROM blobs WHERE id = ?", res.ID).Error; err != nil {
365 return nil, err
366 }
367 if err := rm.db.Exec("DELETE FROM blob_parts WHERE blob_id = ?", res.ID).Error; err != nil {
368 return nil, err
369 }
370 }
371 }
372
373 return cids, nil
374}
375
376// to be honest, we could just store both the cbor and non-cbor in []entries above to avoid an additional
377// unmarshal here. this will work for now though
378func getBlobCidsFromCbor(cbor []byte) ([]cid.Cid, error) {
379 var cids []cid.Cid
380
381 decoded, err := data.UnmarshalCBOR(cbor)
382 if err != nil {
383 return nil, fmt.Errorf("error unmarshaling cbor: %w", err)
384 }
385
386 var deepiter func(interface{}) error
387 deepiter = func(item interface{}) error {
388 switch val := item.(type) {
389 case map[string]interface{}:
390 if val["$type"] == "blob" {
391 if ref, ok := val["ref"].(string); ok {
392 c, err := cid.Parse(ref)
393 if err != nil {
394 return err
395 }
396 cids = append(cids, c)
397 }
398 for _, v := range val {
399 return deepiter(v)
400 }
401 }
402 case []interface{}:
403 for _, v := range val {
404 deepiter(v)
405 }
406 }
407
408 return nil
409 }
410
411 if err := deepiter(decoded); err != nil {
412 return nil, err
413 }
414
415 return cids, nil
416}