An atproto PDS written in Go
at 0.0.3 10 kB view raw
1package server 2 3import ( 4 "bytes" 5 "context" 6 "fmt" 7 "io" 8 "time" 9 10 "github.com/Azure/go-autorest/autorest/to" 11 "github.com/bluesky-social/indigo/api/atproto" 12 "github.com/bluesky-social/indigo/atproto/data" 13 "github.com/bluesky-social/indigo/atproto/syntax" 14 "github.com/bluesky-social/indigo/carstore" 15 "github.com/bluesky-social/indigo/events" 16 lexutil "github.com/bluesky-social/indigo/lex/util" 17 "github.com/bluesky-social/indigo/repo" 18 "github.com/bluesky-social/indigo/util" 19 "github.com/haileyok/cocoon/blockstore" 20 "github.com/haileyok/cocoon/models" 21 blocks "github.com/ipfs/go-block-format" 22 "github.com/ipfs/go-cid" 23 cbor "github.com/ipfs/go-ipld-cbor" 24 "github.com/ipld/go-car" 25 "gorm.io/gorm" 26 "gorm.io/gorm/clause" 27) 28 29type RepoMan struct { 30 db *gorm.DB 31 s *Server 32 clock *syntax.TIDClock 33} 34 35func NewRepoMan(s *Server) *RepoMan { 36 clock := syntax.NewTIDClock(0) 37 38 return &RepoMan{ 39 s: s, 40 db: s.db, 41 clock: &clock, 42 } 43} 44 45type OpType string 46 47var ( 48 OpTypeCreate = OpType("com.atproto.repo.applyWrites#create") 49 OpTypeUpdate = OpType("com.atproto.repo.applyWrites#update") 50 OpTypeDelete = OpType("com.atproto.repo.applyWrites#delete") 51) 52 53func (ot OpType) String() string { 54 return string(ot) 55} 56 57type Op struct { 58 Type OpType `json:"$type"` 59 Collection string `json:"collection"` 60 Rkey *string `json:"rkey,omitempty"` 61 Validate *bool `json:"validate,omitempty"` 62 SwapRecord *string `json:"swapRecord,omitempty"` 63 Record *MarshalableMap `json:"record,omitempty"` 64} 65 66type MarshalableMap map[string]any 67 68type FirehoseOp struct { 69 Cid cid.Cid 70 Path string 71 Action string 72} 73 74func (mm *MarshalableMap) MarshalCBOR(w io.Writer) error { 75 data, err := data.MarshalCBOR(*mm) 76 if err != nil { 77 return err 78 } 79 80 w.Write(data) 81 82 return nil 83} 84 85type ApplyWriteResult struct { 86 Type *string `json:"$type,omitempty"` 87 Uri *string `json:"uri,omitempty"` 88 Cid *string `json:"cid,omitempty"` 89 Commit *RepoCommit `json:"commit,omitempty"` 90 ValidationStatus *string `json:"validationStatus,omitempty"` 91} 92 93type RepoCommit struct { 94 Cid string `json:"cid"` 95 Rev string `json:"rev"` 96} 97 98// TODO make use of swap commit 99func (rm *RepoMan) applyWrites(urepo models.Repo, writes []Op, swapCommit *string) ([]ApplyWriteResult, error) { 100 rootcid, err := cid.Cast(urepo.Root) 101 if err != nil { 102 return nil, err 103 } 104 105 dbs := blockstore.New(urepo.Did, rm.db) 106 r, err := repo.OpenRepo(context.TODO(), dbs, rootcid) 107 108 entries := []models.Record{} 109 var results []ApplyWriteResult 110 111 for i, op := range writes { 112 if op.Type != OpTypeCreate && op.Rkey == nil { 113 return nil, fmt.Errorf("invalid rkey") 114 } else if op.Rkey == nil { 115 op.Rkey = to.StringPtr(rm.clock.Next().String()) 116 writes[i].Rkey = op.Rkey 117 } 118 119 _, err := syntax.ParseRecordKey(*op.Rkey) 120 if err != nil { 121 return nil, err 122 } 123 124 switch op.Type { 125 case OpTypeCreate: 126 nc, err := r.PutRecord(context.TODO(), op.Collection+"/"+*op.Rkey, op.Record) 127 if err != nil { 128 return nil, err 129 } 130 131 d, _ := data.MarshalCBOR(*op.Record) 132 entries = append(entries, models.Record{ 133 Did: urepo.Did, 134 CreatedAt: rm.clock.Next().String(), 135 Nsid: op.Collection, 136 Rkey: *op.Rkey, 137 Cid: nc.String(), 138 Value: d, 139 }) 140 results = append(results, ApplyWriteResult{ 141 Type: to.StringPtr(OpTypeCreate.String()), 142 Uri: to.StringPtr("at://" + urepo.Did + "/" + op.Collection + "/" + *op.Rkey), 143 Cid: to.StringPtr(nc.String()), 144 ValidationStatus: to.StringPtr("valid"), // TODO: obviously this might not be true atm lol 145 }) 146 case OpTypeDelete: 147 var old models.Record 148 if err := rm.db.Raw("SELECT value FROM records WHERE did = ? AND nsid = ? AND rkey = ?", urepo.Did, op.Collection, op.Rkey).Scan(&old).Error; err != nil { 149 return nil, err 150 } 151 entries = append(entries, models.Record{ 152 Did: urepo.Did, 153 Nsid: op.Collection, 154 Rkey: *op.Rkey, 155 Value: old.Value, 156 }) 157 err := r.DeleteRecord(context.TODO(), op.Collection+"/"+*op.Rkey) 158 if err != nil { 159 return nil, err 160 } 161 results = append(results, ApplyWriteResult{ 162 Type: to.StringPtr(OpTypeDelete.String()), 163 }) 164 case OpTypeUpdate: 165 nc, err := r.UpdateRecord(context.TODO(), op.Collection+"/"+*op.Rkey, op.Record) 166 if err != nil { 167 return nil, err 168 } 169 170 d, _ := data.MarshalCBOR(*op.Record) 171 entries = append(entries, models.Record{ 172 Did: urepo.Did, 173 CreatedAt: rm.clock.Next().String(), 174 Nsid: op.Collection, 175 Rkey: *op.Rkey, 176 Cid: nc.String(), 177 Value: d, 178 }) 179 results = append(results, ApplyWriteResult{ 180 Type: to.StringPtr(OpTypeUpdate.String()), 181 Uri: to.StringPtr("at://" + urepo.Did + "/" + op.Collection + "/" + *op.Rkey), 182 Cid: to.StringPtr(nc.String()), 183 ValidationStatus: to.StringPtr("valid"), // TODO: obviously this might not be true atm lol 184 }) 185 } 186 } 187 188 newroot, rev, err := r.Commit(context.TODO(), urepo.SignFor) 189 if err != nil { 190 return nil, err 191 } 192 193 buf := new(bytes.Buffer) 194 195 hb, err := cbor.DumpObject(&car.CarHeader{ 196 Roots: []cid.Cid{newroot}, 197 Version: 1, 198 }) 199 200 if _, err := carstore.LdWrite(buf, hb); err != nil { 201 return nil, err 202 } 203 204 diffops, err := r.DiffSince(context.TODO(), rootcid) 205 if err != nil { 206 return nil, err 207 } 208 209 ops := make([]*atproto.SyncSubscribeRepos_RepoOp, 0, len(diffops)) 210 211 for _, op := range diffops { 212 var c cid.Cid 213 switch op.Op { 214 case "add", "mut": 215 kind := "create" 216 if op.Op == "mut" { 217 kind = "update" 218 } 219 220 c = op.NewCid 221 ll := lexutil.LexLink(op.NewCid) 222 ops = append(ops, &atproto.SyncSubscribeRepos_RepoOp{ 223 Action: kind, 224 Path: op.Rpath, 225 Cid: &ll, 226 }) 227 228 case "del": 229 c = op.OldCid 230 ll := lexutil.LexLink(op.OldCid) 231 ops = append(ops, &atproto.SyncSubscribeRepos_RepoOp{ 232 Action: "delete", 233 Path: op.Rpath, 234 Cid: nil, 235 Prev: &ll, 236 }) 237 } 238 239 blk, err := dbs.Get(context.TODO(), c) 240 if err != nil { 241 return nil, err 242 } 243 244 if _, err := carstore.LdWrite(buf, blk.Cid().Bytes(), blk.RawData()); err != nil { 245 return nil, err 246 } 247 } 248 249 for _, op := range dbs.GetLog() { 250 if _, err := carstore.LdWrite(buf, op.Cid().Bytes(), op.RawData()); err != nil { 251 return nil, err 252 } 253 } 254 255 var blobs []lexutil.LexLink 256 for _, entry := range entries { 257 var cids []cid.Cid 258 if entry.Cid != "" { 259 if err := rm.s.db.Clauses(clause.OnConflict{ 260 Columns: []clause.Column{{Name: "did"}, {Name: "nsid"}, {Name: "rkey"}}, 261 UpdateAll: true, 262 }).Create(&entry).Error; err != nil { 263 return nil, err 264 } 265 266 cids, err = rm.incrementBlobRefs(urepo, entry.Value) 267 if err != nil { 268 return nil, err 269 } 270 } else { 271 if err := rm.s.db.Delete(&entry).Error; err != nil { 272 return nil, err 273 } 274 cids, err = rm.decrementBlobRefs(urepo, entry.Value) 275 if err != nil { 276 return nil, err 277 } 278 } 279 280 for _, c := range cids { 281 blobs = append(blobs, lexutil.LexLink(c)) 282 } 283 } 284 285 rm.s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{ 286 RepoCommit: &atproto.SyncSubscribeRepos_Commit{ 287 Repo: urepo.Did, 288 Blocks: buf.Bytes(), 289 Blobs: blobs, 290 Rev: rev, 291 Since: &urepo.Rev, 292 Commit: lexutil.LexLink(newroot), 293 Time: time.Now().Format(util.ISO8601), 294 Ops: ops, 295 TooBig: false, 296 }, 297 }) 298 299 if err := dbs.UpdateRepo(context.TODO(), newroot, rev); err != nil { 300 return nil, err 301 } 302 303 for i := range results { 304 results[i].Type = to.StringPtr(*results[i].Type + "Result") 305 results[i].Commit = &RepoCommit{ 306 Cid: newroot.String(), 307 Rev: rev, 308 } 309 } 310 311 return results, nil 312} 313 314func (rm *RepoMan) getRecordProof(urepo models.Repo, collection, rkey string) (cid.Cid, []blocks.Block, error) { 315 c, err := cid.Cast(urepo.Root) 316 if err != nil { 317 return cid.Undef, nil, err 318 } 319 320 dbs := blockstore.New(urepo.Did, rm.db) 321 bs := util.NewLoggingBstore(dbs) 322 323 r, err := repo.OpenRepo(context.TODO(), bs, c) 324 if err != nil { 325 return cid.Undef, nil, err 326 } 327 328 _, _, err = r.GetRecordBytes(context.TODO(), collection+"/"+rkey) 329 if err != nil { 330 return cid.Undef, nil, err 331 } 332 333 return c, bs.GetLoggedBlocks(), nil 334} 335 336func (rm *RepoMan) incrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) { 337 cids, err := getBlobCidsFromCbor(cbor) 338 if err != nil { 339 return nil, err 340 } 341 342 for _, c := range cids { 343 if err := rm.db.Exec("UPDATE blobs SET ref_count = ref_count + 1 WHERE did = ? AND cid = ?", urepo.Did, c.Bytes()).Error; err != nil { 344 return nil, err 345 } 346 } 347 348 return cids, nil 349} 350 351func (rm *RepoMan) decrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) { 352 cids, err := getBlobCidsFromCbor(cbor) 353 if err != nil { 354 return nil, err 355 } 356 357 for _, c := range cids { 358 var res struct { 359 ID uint 360 Count int 361 } 362 if err := rm.db.Raw("UPDATE blobs SET ref_count = ref_count - 1 WHERE did = ? AND cid = ? RETURNING id, ref_count", urepo.Did, c.Bytes()).Scan(&res).Error; err != nil { 363 return nil, err 364 } 365 366 if res.Count == 0 { 367 if err := rm.db.Exec("DELETE FROM blobs WHERE id = ?", res.ID).Error; err != nil { 368 return nil, err 369 } 370 if err := rm.db.Exec("DELETE FROM blob_parts WHERE blob_id = ?", res.ID).Error; err != nil { 371 return nil, err 372 } 373 } 374 } 375 376 return cids, nil 377} 378 379// to be honest, we could just store both the cbor and non-cbor in []entries above to avoid an additional 380// unmarshal here. this will work for now though 381func getBlobCidsFromCbor(cbor []byte) ([]cid.Cid, error) { 382 var cids []cid.Cid 383 384 decoded, err := data.UnmarshalCBOR(cbor) 385 if err != nil { 386 return nil, fmt.Errorf("error unmarshaling cbor: %w", err) 387 } 388 389 var deepiter func(interface{}) error 390 deepiter = func(item interface{}) error { 391 switch val := item.(type) { 392 case map[string]interface{}: 393 if val["$type"] == "blob" { 394 if ref, ok := val["ref"].(string); ok { 395 c, err := cid.Parse(ref) 396 if err != nil { 397 return err 398 } 399 cids = append(cids, c) 400 } 401 for _, v := range val { 402 return deepiter(v) 403 } 404 } 405 case []interface{}: 406 for _, v := range val { 407 deepiter(v) 408 } 409 } 410 411 return nil 412 } 413 414 if err := deepiter(decoded); err != nil { 415 return nil, err 416 } 417 418 return cids, nil 419}