An atproto PDS written in Go
1package server 2 3import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "fmt" 8 "io" 9 "time" 10 11 "github.com/Azure/go-autorest/autorest/to" 12 "github.com/bluesky-social/indigo/api/atproto" 13 "github.com/bluesky-social/indigo/atproto/data" 14 "github.com/bluesky-social/indigo/atproto/syntax" 15 "github.com/bluesky-social/indigo/carstore" 16 "github.com/bluesky-social/indigo/events" 17 lexutil "github.com/bluesky-social/indigo/lex/util" 18 "github.com/bluesky-social/indigo/repo" 19 "github.com/bluesky-social/indigo/util" 20 "github.com/haileyok/cocoon/blockstore" 21 "github.com/haileyok/cocoon/models" 22 blocks "github.com/ipfs/go-block-format" 23 "github.com/ipfs/go-cid" 24 cbor "github.com/ipfs/go-ipld-cbor" 25 "github.com/ipld/go-car" 26 "gorm.io/gorm" 27 "gorm.io/gorm/clause" 28) 29 30type RepoMan struct { 31 db *gorm.DB 32 s *Server 33 clock *syntax.TIDClock 34} 35 36func NewRepoMan(s *Server) *RepoMan { 37 clock := syntax.NewTIDClock(0) 38 39 return &RepoMan{ 40 s: s, 41 db: s.db, 42 clock: &clock, 43 } 44} 45 46type OpType string 47 48var ( 49 OpTypeCreate = OpType("com.atproto.repo.applyWrites#create") 50 OpTypeUpdate = OpType("com.atproto.repo.applyWrites#update") 51 OpTypeDelete = OpType("com.atproto.repo.applyWrites#delete") 52) 53 54func (ot OpType) String() string { 55 return string(ot) 56} 57 58type Op struct { 59 Type OpType `json:"$type"` 60 Collection string `json:"collection"` 61 Rkey *string `json:"rkey,omitempty"` 62 Validate *bool `json:"validate,omitempty"` 63 SwapRecord *string `json:"swapRecord,omitempty"` 64 Record *MarshalableMap `json:"record,omitempty"` 65} 66 67type MarshalableMap map[string]any 68 69type FirehoseOp struct { 70 Cid cid.Cid 71 Path string 72 Action string 73} 74 75func (mm *MarshalableMap) MarshalCBOR(w io.Writer) error { 76 data, err := data.MarshalCBOR(*mm) 77 if err != nil { 78 return err 79 } 80 81 w.Write(data) 82 83 return nil 84} 85 86type ApplyWriteResult struct { 87 Type *string `json:"$type,omitempty"` 88 Uri *string `json:"uri,omitempty"` 89 Cid *string `json:"cid,omitempty"` 90 Commit *RepoCommit `json:"commit,omitempty"` 91 ValidationStatus *string `json:"validationStatus,omitempty"` 92} 93 94type RepoCommit struct { 95 Cid string `json:"cid"` 96 Rev string `json:"rev"` 97} 98 99// TODO make use of swap commit 100func (rm *RepoMan) applyWrites(urepo models.Repo, writes []Op, swapCommit *string) ([]ApplyWriteResult, error) { 101 rootcid, err := cid.Cast(urepo.Root) 102 if err != nil { 103 return nil, err 104 } 105 106 dbs := blockstore.New(urepo.Did, rm.db) 107 r, err := repo.OpenRepo(context.TODO(), dbs, rootcid) 108 109 entries := []models.Record{} 110 var results []ApplyWriteResult 111 112 for i, op := range writes { 113 if op.Type != OpTypeCreate && op.Rkey == nil { 114 return nil, fmt.Errorf("invalid rkey") 115 } else if op.Rkey == nil { 116 op.Rkey = to.StringPtr(rm.clock.Next().String()) 117 writes[i].Rkey = op.Rkey 118 } 119 120 _, err := syntax.ParseRecordKey(*op.Rkey) 121 if err != nil { 122 return nil, err 123 } 124 125 switch op.Type { 126 case OpTypeCreate: 127 j, err := json.Marshal(*op.Record) 128 if err != nil { 129 return nil, err 130 } 131 out, err := data.UnmarshalJSON(j) 132 if err != nil { 133 return nil, err 134 } 135 mm := MarshalableMap(out) 136 nc, err := r.PutRecord(context.TODO(), op.Collection+"/"+*op.Rkey, &mm) 137 if err != nil { 138 return nil, err 139 } 140 d, err := data.MarshalCBOR(mm) 141 if err != nil { 142 return nil, err 143 } 144 entries = append(entries, models.Record{ 145 Did: urepo.Did, 146 CreatedAt: rm.clock.Next().String(), 147 Nsid: op.Collection, 148 Rkey: *op.Rkey, 149 Cid: nc.String(), 150 Value: d, 151 }) 152 results = append(results, ApplyWriteResult{ 153 Type: to.StringPtr(OpTypeCreate.String()), 154 Uri: to.StringPtr("at://" + urepo.Did + "/" + op.Collection + "/" + *op.Rkey), 155 Cid: to.StringPtr(nc.String()), 156 ValidationStatus: to.StringPtr("valid"), // TODO: obviously this might not be true atm lol 157 }) 158 case OpTypeDelete: 159 var old models.Record 160 if err := rm.db.Raw("SELECT value FROM records WHERE did = ? AND nsid = ? AND rkey = ?", urepo.Did, op.Collection, op.Rkey).Scan(&old).Error; err != nil { 161 return nil, err 162 } 163 entries = append(entries, models.Record{ 164 Did: urepo.Did, 165 Nsid: op.Collection, 166 Rkey: *op.Rkey, 167 Value: old.Value, 168 }) 169 err := r.DeleteRecord(context.TODO(), op.Collection+"/"+*op.Rkey) 170 if err != nil { 171 return nil, err 172 } 173 results = append(results, ApplyWriteResult{ 174 Type: to.StringPtr(OpTypeDelete.String()), 175 }) 176 case OpTypeUpdate: 177 j, err := json.Marshal(*op.Record) 178 if err != nil { 179 return nil, err 180 } 181 out, err := data.UnmarshalJSON(j) 182 if err != nil { 183 return nil, err 184 } 185 mm := MarshalableMap(out) 186 nc, err := r.UpdateRecord(context.TODO(), op.Collection+"/"+*op.Rkey, &mm) 187 if err != nil { 188 return nil, err 189 } 190 d, err := data.MarshalCBOR(mm) 191 if err != nil { 192 return nil, err 193 } 194 entries = append(entries, models.Record{ 195 Did: urepo.Did, 196 CreatedAt: rm.clock.Next().String(), 197 Nsid: op.Collection, 198 Rkey: *op.Rkey, 199 Cid: nc.String(), 200 Value: d, 201 }) 202 results = append(results, ApplyWriteResult{ 203 Type: to.StringPtr(OpTypeUpdate.String()), 204 Uri: to.StringPtr("at://" + urepo.Did + "/" + op.Collection + "/" + *op.Rkey), 205 Cid: to.StringPtr(nc.String()), 206 ValidationStatus: to.StringPtr("valid"), // TODO: obviously this might not be true atm lol 207 }) 208 } 209 } 210 211 newroot, rev, err := r.Commit(context.TODO(), urepo.SignFor) 212 if err != nil { 213 return nil, err 214 } 215 216 buf := new(bytes.Buffer) 217 218 hb, err := cbor.DumpObject(&car.CarHeader{ 219 Roots: []cid.Cid{newroot}, 220 Version: 1, 221 }) 222 223 if _, err := carstore.LdWrite(buf, hb); err != nil { 224 return nil, err 225 } 226 227 diffops, err := r.DiffSince(context.TODO(), rootcid) 228 if err != nil { 229 return nil, err 230 } 231 232 ops := make([]*atproto.SyncSubscribeRepos_RepoOp, 0, len(diffops)) 233 234 for _, op := range diffops { 235 var c cid.Cid 236 switch op.Op { 237 case "add", "mut": 238 kind := "create" 239 if op.Op == "mut" { 240 kind = "update" 241 } 242 243 c = op.NewCid 244 ll := lexutil.LexLink(op.NewCid) 245 ops = append(ops, &atproto.SyncSubscribeRepos_RepoOp{ 246 Action: kind, 247 Path: op.Rpath, 248 Cid: &ll, 249 }) 250 251 case "del": 252 c = op.OldCid 253 ll := lexutil.LexLink(op.OldCid) 254 ops = append(ops, &atproto.SyncSubscribeRepos_RepoOp{ 255 Action: "delete", 256 Path: op.Rpath, 257 Cid: nil, 258 Prev: &ll, 259 }) 260 } 261 262 blk, err := dbs.Get(context.TODO(), c) 263 if err != nil { 264 return nil, err 265 } 266 267 if _, err := carstore.LdWrite(buf, blk.Cid().Bytes(), blk.RawData()); err != nil { 268 return nil, err 269 } 270 } 271 272 for _, op := range dbs.GetLog() { 273 if _, err := carstore.LdWrite(buf, op.Cid().Bytes(), op.RawData()); err != nil { 274 return nil, err 275 } 276 } 277 278 var blobs []lexutil.LexLink 279 for _, entry := range entries { 280 var cids []cid.Cid 281 if entry.Cid != "" { 282 if err := rm.s.db.Clauses(clause.OnConflict{ 283 Columns: []clause.Column{{Name: "did"}, {Name: "nsid"}, {Name: "rkey"}}, 284 UpdateAll: true, 285 }).Create(&entry).Error; err != nil { 286 return nil, err 287 } 288 289 cids, err = rm.incrementBlobRefs(urepo, entry.Value) 290 if err != nil { 291 return nil, err 292 } 293 } else { 294 if err := rm.s.db.Delete(&entry).Error; err != nil { 295 return nil, err 296 } 297 cids, err = rm.decrementBlobRefs(urepo, entry.Value) 298 if err != nil { 299 return nil, err 300 } 301 } 302 303 for _, c := range cids { 304 blobs = append(blobs, lexutil.LexLink(c)) 305 } 306 } 307 308 rm.s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{ 309 RepoCommit: &atproto.SyncSubscribeRepos_Commit{ 310 Repo: urepo.Did, 311 Blocks: buf.Bytes(), 312 Blobs: blobs, 313 Rev: rev, 314 Since: &urepo.Rev, 315 Commit: lexutil.LexLink(newroot), 316 Time: time.Now().Format(util.ISO8601), 317 Ops: ops, 318 TooBig: false, 319 }, 320 }) 321 322 if err := dbs.UpdateRepo(context.TODO(), newroot, rev); err != nil { 323 return nil, err 324 } 325 326 for i := range results { 327 results[i].Type = to.StringPtr(*results[i].Type + "Result") 328 results[i].Commit = &RepoCommit{ 329 Cid: newroot.String(), 330 Rev: rev, 331 } 332 } 333 334 return results, nil 335} 336 337func (rm *RepoMan) getRecordProof(urepo models.Repo, collection, rkey string) (cid.Cid, []blocks.Block, error) { 338 c, err := cid.Cast(urepo.Root) 339 if err != nil { 340 return cid.Undef, nil, err 341 } 342 343 dbs := blockstore.New(urepo.Did, rm.db) 344 bs := util.NewLoggingBstore(dbs) 345 346 r, err := repo.OpenRepo(context.TODO(), bs, c) 347 if err != nil { 348 return cid.Undef, nil, err 349 } 350 351 _, _, err = r.GetRecordBytes(context.TODO(), collection+"/"+rkey) 352 if err != nil { 353 return cid.Undef, nil, err 354 } 355 356 return c, bs.GetLoggedBlocks(), nil 357} 358 359func (rm *RepoMan) incrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) { 360 cids, err := getBlobCidsFromCbor(cbor) 361 if err != nil { 362 return nil, err 363 } 364 365 for _, c := range cids { 366 if err := rm.db.Exec("UPDATE blobs SET ref_count = ref_count + 1 WHERE did = ? AND cid = ?", urepo.Did, c.Bytes()).Error; err != nil { 367 return nil, err 368 } 369 } 370 371 return cids, nil 372} 373 374func (rm *RepoMan) decrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) { 375 cids, err := getBlobCidsFromCbor(cbor) 376 if err != nil { 377 return nil, err 378 } 379 380 for _, c := range cids { 381 var res struct { 382 ID uint 383 Count int 384 } 385 if err := rm.db.Raw("UPDATE blobs SET ref_count = ref_count - 1 WHERE did = ? AND cid = ? RETURNING id, ref_count", urepo.Did, c.Bytes()).Scan(&res).Error; err != nil { 386 return nil, err 387 } 388 389 if res.Count == 0 { 390 if err := rm.db.Exec("DELETE FROM blobs WHERE id = ?", res.ID).Error; err != nil { 391 return nil, err 392 } 393 if err := rm.db.Exec("DELETE FROM blob_parts WHERE blob_id = ?", res.ID).Error; err != nil { 394 return nil, err 395 } 396 } 397 } 398 399 return cids, nil 400} 401 402// to be honest, we could just store both the cbor and non-cbor in []entries above to avoid an additional 403// unmarshal here. this will work for now though 404func getBlobCidsFromCbor(cbor []byte) ([]cid.Cid, error) { 405 var cids []cid.Cid 406 407 decoded, err := data.UnmarshalCBOR(cbor) 408 if err != nil { 409 return nil, fmt.Errorf("error unmarshaling cbor: %w", err) 410 } 411 412 var deepiter func(interface{}) error 413 deepiter = func(item interface{}) error { 414 switch val := item.(type) { 415 case map[string]interface{}: 416 if val["$type"] == "blob" { 417 if ref, ok := val["ref"].(string); ok { 418 c, err := cid.Parse(ref) 419 if err != nil { 420 return err 421 } 422 cids = append(cids, c) 423 } 424 for _, v := range val { 425 return deepiter(v) 426 } 427 } 428 case []interface{}: 429 for _, v := range val { 430 deepiter(v) 431 } 432 } 433 434 return nil 435 } 436 437 if err := deepiter(decoded); err != nil { 438 return nil, err 439 } 440 441 return cids, nil 442}