An atproto PDS written in Go
1package server 2 3import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "fmt" 8 "io" 9 "time" 10 11 "github.com/Azure/go-autorest/autorest/to" 12 "github.com/bluesky-social/indigo/api/atproto" 13 "github.com/bluesky-social/indigo/atproto/atdata" 14 "github.com/bluesky-social/indigo/atproto/syntax" 15 "github.com/bluesky-social/indigo/carstore" 16 "github.com/bluesky-social/indigo/events" 17 lexutil "github.com/bluesky-social/indigo/lex/util" 18 "github.com/bluesky-social/indigo/repo" 19 "github.com/haileyok/cocoon/internal/db" 20 "github.com/haileyok/cocoon/models" 21 "github.com/haileyok/cocoon/recording_blockstore" 22 blocks "github.com/ipfs/go-block-format" 23 "github.com/ipfs/go-cid" 24 cbor "github.com/ipfs/go-ipld-cbor" 25 "github.com/ipld/go-car" 26 "gorm.io/gorm/clause" 27) 28 29type RepoMan struct { 30 db *db.DB 31 s *Server 32 clock *syntax.TIDClock 33} 34 35func NewRepoMan(s *Server) *RepoMan { 36 clock := syntax.NewTIDClock(0) 37 38 return &RepoMan{ 39 s: s, 40 db: s.db, 41 clock: &clock, 42 } 43} 44 45type OpType string 46 47var ( 48 OpTypeCreate = OpType("com.atproto.repo.applyWrites#create") 49 OpTypeUpdate = OpType("com.atproto.repo.applyWrites#update") 50 OpTypeDelete = OpType("com.atproto.repo.applyWrites#delete") 51) 52 53func (ot OpType) String() string { 54 return string(ot) 55} 56 57type Op struct { 58 Type OpType `json:"$type"` 59 Collection string `json:"collection"` 60 Rkey *string `json:"rkey,omitempty"` 61 Validate *bool `json:"validate,omitempty"` 62 SwapRecord *string `json:"swapRecord,omitempty"` 63 Record *MarshalableMap `json:"record,omitempty"` 64} 65 66type MarshalableMap map[string]any 67 68type FirehoseOp struct { 69 Cid cid.Cid 70 Path string 71 Action string 72} 73 74func (mm *MarshalableMap) MarshalCBOR(w io.Writer) error { 75 data, err := atdata.MarshalCBOR(*mm) 76 if err != nil { 77 return err 78 } 79 80 w.Write(data) 81 82 return nil 83} 84 85type ApplyWriteResult struct { 86 Type *string `json:"$type,omitempty"` 87 Uri *string `json:"uri,omitempty"` 88 Cid *string `json:"cid,omitempty"` 89 Commit *RepoCommit `json:"commit,omitempty"` 90 ValidationStatus *string `json:"validationStatus,omitempty"` 91} 92 93type RepoCommit struct { 94 Cid string `json:"cid"` 95 Rev string `json:"rev"` 96} 97 98// TODO make use of swap commit 99func (rm *RepoMan) applyWrites(ctx context.Context, urepo models.Repo, writes []Op, swapCommit *string) ([]ApplyWriteResult, error) { 100 rootcid, err := cid.Cast(urepo.Root) 101 if err != nil { 102 return nil, err 103 } 104 105 dbs := rm.s.getBlockstore(urepo.Did) 106 bs := recording_blockstore.New(dbs) 107 r, err := repo.OpenRepo(ctx, bs, rootcid) 108 109 entries := make([]models.Record, 0, len(writes)) 110 results := make([]ApplyWriteResult, 0, len(writes)) 111 112 for i, op := range writes { 113 if op.Type != OpTypeCreate && op.Rkey == nil { 114 return nil, fmt.Errorf("invalid rkey") 115 } else if op.Type == OpTypeCreate && op.Rkey != nil { 116 _, _, err := r.GetRecord(ctx, op.Collection+"/"+*op.Rkey) 117 if err == nil { 118 op.Type = OpTypeUpdate 119 } 120 } else if op.Rkey == nil { 121 op.Rkey = to.StringPtr(rm.clock.Next().String()) 122 writes[i].Rkey = op.Rkey 123 } 124 125 _, err := syntax.ParseRecordKey(*op.Rkey) 126 if err != nil { 127 return nil, err 128 } 129 130 switch op.Type { 131 case OpTypeCreate: 132 j, err := json.Marshal(*op.Record) 133 if err != nil { 134 return nil, err 135 } 136 out, err := atdata.UnmarshalJSON(j) 137 if err != nil { 138 return nil, err 139 } 140 mm := MarshalableMap(out) 141 142 // HACK: if a record doesn't contain a $type, we can manually set it here based on the op's collection 143 if mm["$type"] == "" { 144 mm["$type"] = op.Collection 145 } 146 147 nc, err := r.PutRecord(ctx, op.Collection+"/"+*op.Rkey, &mm) 148 if err != nil { 149 return nil, err 150 } 151 d, err := atdata.MarshalCBOR(mm) 152 if err != nil { 153 return nil, err 154 } 155 entries = append(entries, models.Record{ 156 Did: urepo.Did, 157 CreatedAt: rm.clock.Next().String(), 158 Nsid: op.Collection, 159 Rkey: *op.Rkey, 160 Cid: nc.String(), 161 Value: d, 162 }) 163 results = append(results, ApplyWriteResult{ 164 Type: to.StringPtr(OpTypeCreate.String()), 165 Uri: to.StringPtr("at://" + urepo.Did + "/" + op.Collection + "/" + *op.Rkey), 166 Cid: to.StringPtr(nc.String()), 167 ValidationStatus: to.StringPtr("valid"), // TODO: obviously this might not be true atm lol 168 }) 169 case OpTypeDelete: 170 var old models.Record 171 if err := rm.db.Raw("SELECT value FROM records WHERE did = ? AND nsid = ? AND rkey = ?", nil, urepo.Did, op.Collection, op.Rkey).Scan(&old).Error; err != nil { 172 return nil, err 173 } 174 entries = append(entries, models.Record{ 175 Did: urepo.Did, 176 Nsid: op.Collection, 177 Rkey: *op.Rkey, 178 Value: old.Value, 179 }) 180 err := r.DeleteRecord(ctx, op.Collection+"/"+*op.Rkey) 181 if err != nil { 182 return nil, err 183 } 184 results = append(results, ApplyWriteResult{ 185 Type: to.StringPtr(OpTypeDelete.String()), 186 }) 187 case OpTypeUpdate: 188 j, err := json.Marshal(*op.Record) 189 if err != nil { 190 return nil, err 191 } 192 out, err := atdata.UnmarshalJSON(j) 193 if err != nil { 194 return nil, err 195 } 196 mm := MarshalableMap(out) 197 nc, err := r.UpdateRecord(ctx, op.Collection+"/"+*op.Rkey, &mm) 198 if err != nil { 199 return nil, err 200 } 201 d, err := atdata.MarshalCBOR(mm) 202 if err != nil { 203 return nil, err 204 } 205 entries = append(entries, models.Record{ 206 Did: urepo.Did, 207 CreatedAt: rm.clock.Next().String(), 208 Nsid: op.Collection, 209 Rkey: *op.Rkey, 210 Cid: nc.String(), 211 Value: d, 212 }) 213 results = append(results, ApplyWriteResult{ 214 Type: to.StringPtr(OpTypeUpdate.String()), 215 Uri: to.StringPtr("at://" + urepo.Did + "/" + op.Collection + "/" + *op.Rkey), 216 Cid: to.StringPtr(nc.String()), 217 ValidationStatus: to.StringPtr("valid"), // TODO: obviously this might not be true atm lol 218 }) 219 } 220 } 221 222 newroot, rev, err := r.Commit(ctx, urepo.SignFor) 223 if err != nil { 224 return nil, err 225 } 226 227 buf := new(bytes.Buffer) 228 229 hb, err := cbor.DumpObject(&car.CarHeader{ 230 Roots: []cid.Cid{newroot}, 231 Version: 1, 232 }) 233 if err != nil { 234 return nil, err 235 } 236 237 if _, err := carstore.LdWrite(buf, hb); err != nil { 238 return nil, err 239 } 240 241 diffops, err := r.DiffSince(ctx, rootcid) 242 if err != nil { 243 return nil, err 244 } 245 246 ops := make([]*atproto.SyncSubscribeRepos_RepoOp, 0, len(diffops)) 247 248 for _, op := range diffops { 249 var c cid.Cid 250 switch op.Op { 251 case "add", "mut": 252 kind := "create" 253 if op.Op == "mut" { 254 kind = "update" 255 } 256 257 c = op.NewCid 258 ll := lexutil.LexLink(op.NewCid) 259 ops = append(ops, &atproto.SyncSubscribeRepos_RepoOp{ 260 Action: kind, 261 Path: op.Rpath, 262 Cid: &ll, 263 }) 264 265 case "del": 266 c = op.OldCid 267 ll := lexutil.LexLink(op.OldCid) 268 ops = append(ops, &atproto.SyncSubscribeRepos_RepoOp{ 269 Action: "delete", 270 Path: op.Rpath, 271 Cid: nil, 272 Prev: &ll, 273 }) 274 } 275 276 blk, err := dbs.Get(ctx, c) 277 if err != nil { 278 return nil, err 279 } 280 281 if _, err := carstore.LdWrite(buf, blk.Cid().Bytes(), blk.RawData()); err != nil { 282 return nil, err 283 } 284 } 285 286 for _, op := range bs.GetWriteLog() { 287 if _, err := carstore.LdWrite(buf, op.Cid().Bytes(), op.RawData()); err != nil { 288 return nil, err 289 } 290 } 291 292 var blobs []lexutil.LexLink 293 for _, entry := range entries { 294 var cids []cid.Cid 295 if entry.Cid != "" { 296 if err := rm.s.db.Create(&entry, []clause.Expression{clause.OnConflict{ 297 Columns: []clause.Column{{Name: "did"}, {Name: "nsid"}, {Name: "rkey"}}, 298 UpdateAll: true, 299 }}).Error; err != nil { 300 return nil, err 301 } 302 303 cids, err = rm.incrementBlobRefs(urepo, entry.Value) 304 if err != nil { 305 return nil, err 306 } 307 } else { 308 if err := rm.s.db.Delete(&entry, nil).Error; err != nil { 309 return nil, err 310 } 311 cids, err = rm.decrementBlobRefs(urepo, entry.Value) 312 if err != nil { 313 return nil, err 314 } 315 } 316 317 for _, c := range cids { 318 blobs = append(blobs, lexutil.LexLink(c)) 319 } 320 } 321 322 rm.s.evtman.AddEvent(ctx, &events.XRPCStreamEvent{ 323 RepoCommit: &atproto.SyncSubscribeRepos_Commit{ 324 Repo: urepo.Did, 325 Blocks: buf.Bytes(), 326 Blobs: blobs, 327 Rev: rev, 328 Since: &urepo.Rev, 329 Commit: lexutil.LexLink(newroot), 330 Time: time.Now().Format(time.RFC3339Nano), 331 Ops: ops, 332 TooBig: false, 333 }, 334 }) 335 336 if err := rm.s.UpdateRepo(ctx, urepo.Did, newroot, rev); err != nil { 337 return nil, err 338 } 339 340 for i := range results { 341 results[i].Type = to.StringPtr(*results[i].Type + "Result") 342 results[i].Commit = &RepoCommit{ 343 Cid: newroot.String(), 344 Rev: rev, 345 } 346 } 347 348 return results, nil 349} 350 351func (rm *RepoMan) getRecordProof(ctx context.Context, urepo models.Repo, collection, rkey string) (cid.Cid, []blocks.Block, error) { 352 c, err := cid.Cast(urepo.Root) 353 if err != nil { 354 return cid.Undef, nil, err 355 } 356 357 dbs := rm.s.getBlockstore(urepo.Did) 358 bs := recording_blockstore.New(dbs) 359 360 r, err := repo.OpenRepo(ctx, bs, c) 361 if err != nil { 362 return cid.Undef, nil, err 363 } 364 365 _, _, err = r.GetRecordBytes(ctx, collection+"/"+rkey) 366 if err != nil { 367 return cid.Undef, nil, err 368 } 369 370 return c, bs.GetReadLog(), nil 371} 372 373func (rm *RepoMan) incrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) { 374 cids, err := getBlobCidsFromCbor(cbor) 375 if err != nil { 376 return nil, err 377 } 378 379 for _, c := range cids { 380 if err := rm.db.Exec("UPDATE blobs SET ref_count = ref_count + 1 WHERE did = ? AND cid = ?", nil, urepo.Did, c.Bytes()).Error; err != nil { 381 return nil, err 382 } 383 } 384 385 return cids, nil 386} 387 388func (rm *RepoMan) decrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) { 389 cids, err := getBlobCidsFromCbor(cbor) 390 if err != nil { 391 return nil, err 392 } 393 394 for _, c := range cids { 395 var res struct { 396 ID uint 397 Count int 398 } 399 if err := rm.db.Raw("UPDATE blobs SET ref_count = ref_count - 1 WHERE did = ? AND cid = ? RETURNING id, ref_count", nil, urepo.Did, c.Bytes()).Scan(&res).Error; err != nil { 400 return nil, err 401 } 402 403 if res.Count == 0 { 404 if err := rm.db.Exec("DELETE FROM blobs WHERE id = ?", nil, res.ID).Error; err != nil { 405 return nil, err 406 } 407 if err := rm.db.Exec("DELETE FROM blob_parts WHERE blob_id = ?", nil, res.ID).Error; err != nil { 408 return nil, err 409 } 410 } 411 } 412 413 return cids, nil 414} 415 416// to be honest, we could just store both the cbor and non-cbor in []entries above to avoid an additional 417// unmarshal here. this will work for now though 418func getBlobCidsFromCbor(cbor []byte) ([]cid.Cid, error) { 419 var cids []cid.Cid 420 421 decoded, err := atdata.UnmarshalCBOR(cbor) 422 if err != nil { 423 return nil, fmt.Errorf("error unmarshaling cbor: %w", err) 424 } 425 426 var deepiter func(any) error 427 deepiter = func(item any) error { 428 switch val := item.(type) { 429 case map[string]any: 430 if val["$type"] == "blob" { 431 if ref, ok := val["ref"].(string); ok { 432 c, err := cid.Parse(ref) 433 if err != nil { 434 return err 435 } 436 cids = append(cids, c) 437 } 438 for _, v := range val { 439 return deepiter(v) 440 } 441 } 442 case []any: 443 for _, v := range val { 444 deepiter(v) 445 } 446 } 447 448 return nil 449 } 450 451 if err := deepiter(decoded); err != nil { 452 return nil, err 453 } 454 455 return cids, nil 456}