An atproto PDS written in Go

Compare changes

Choose any two refs to compare.

+1
server/handle_account.go
···
func (s *Server) handleAccount(e echo.Context) error {
ctx := e.Request().Context()
+
repo, sess, err := s.getSessionRepoOrErr(e)
if err != nil {
return e.Redirect(303, "/account/signin")
+8 -7
server/handle_import_repo.go
···
import (
"bytes"
-
"context"
"io"
"slices"
"strings"
···
)
func (s *Server) handleRepoImportRepo(e echo.Context) error {
+
ctx := e.Request().Context()
+
urepo := e.Get("repo").(*models.RepoActor)
b, err := io.ReadAll(e.Request().Body)
···
slices.Reverse(orderedBlocks)
-
if err := bs.PutMany(context.TODO(), orderedBlocks); err != nil {
+
if err := bs.PutMany(ctx, orderedBlocks); err != nil {
s.logger.Error("could not insert blocks", "error", err)
return helpers.ServerError(e, nil)
}
-
r, err := repo.OpenRepo(context.TODO(), bs, cs.Header.Roots[0])
+
r, err := repo.OpenRepo(ctx, bs, cs.Header.Roots[0])
if err != nil {
s.logger.Error("could not open repo", "error", err)
return helpers.ServerError(e, nil)
···
clock := syntax.NewTIDClock(0)
-
if err := r.ForEach(context.TODO(), "", func(key string, cid cid.Cid) error {
+
if err := r.ForEach(ctx, "", func(key string, cid cid.Cid) error {
pts := strings.Split(key, "/")
nsid := pts[0]
rkey := pts[1]
cidStr := cid.String()
-
b, err := bs.Get(context.TODO(), cid)
+
b, err := bs.Get(ctx, cid)
if err != nil {
s.logger.Error("record bytes don't exist in blockstore", "error", err)
return helpers.ServerError(e, nil)
···
tx.Commit()
-
root, rev, err := r.Commit(context.TODO(), urepo.SignFor)
+
root, rev, err := r.Commit(ctx, urepo.SignFor)
if err != nil {
s.logger.Error("error committing", "error", err)
return helpers.ServerError(e, nil)
}
-
if err := s.UpdateRepo(context.TODO(), urepo.Repo.Did, root, rev); err != nil {
+
if err := s.UpdateRepo(ctx, urepo.Repo.Did, root, rev); err != nil {
s.logger.Error("error updating repo after commit", "error", err)
return helpers.ServerError(e, nil)
}
+3 -1
server/handle_repo_apply_writes.go
···
}
func (s *Server) handleApplyWrites(e echo.Context) error {
+
ctx := e.Request().Context()
+
repo := e.Get("repo").(*models.RepoActor)
var req ComAtprotoRepoApplyWritesRequest
···
})
}
-
results, err := s.repoman.applyWrites(repo.Repo, ops, req.SwapCommit)
+
results, err := s.repoman.applyWrites(ctx, repo.Repo, ops, req.SwapCommit)
if err != nil {
s.logger.Error("error applying writes", "error", err)
return helpers.ServerError(e, nil)
+3 -1
server/handle_repo_create_record.go
···
}
func (s *Server) handleCreateRecord(e echo.Context) error {
+
ctx := e.Request().Context()
+
repo := e.Get("repo").(*models.RepoActor)
var req ComAtprotoRepoCreateRecordRequest
···
optype = OpTypeUpdate
}
-
results, err := s.repoman.applyWrites(repo.Repo, []Op{
+
results, err := s.repoman.applyWrites(ctx, repo.Repo, []Op{
{
Type: optype,
Collection: req.Collection,
+3 -1
server/handle_repo_delete_record.go
···
}
func (s *Server) handleDeleteRecord(e echo.Context) error {
+
ctx := e.Request().Context()
+
repo := e.Get("repo").(*models.RepoActor)
var req ComAtprotoRepoDeleteRecordRequest
···
return helpers.InputError(e, nil)
}
-
results, err := s.repoman.applyWrites(repo.Repo, []Op{
+
results, err := s.repoman.applyWrites(ctx, repo.Repo, []Op{
{
Type: OpTypeDelete,
Collection: req.Collection,
+3 -1
server/handle_repo_put_record.go
···
}
func (s *Server) handlePutRecord(e echo.Context) error {
+
ctx := e.Request().Context()
+
repo := e.Get("repo").(*models.RepoActor)
var req ComAtprotoRepoPutRecordRequest
···
optype = OpTypeUpdate
}
-
results, err := s.repoman.applyWrites(repo.Repo, []Op{
+
results, err := s.repoman.applyWrites(ctx, repo.Repo, []Op{
{
Type: optype,
Collection: req.Collection,
+46 -15
server/handle_server_check_account_status.go
···
package server
import (
+
"errors"
+
"sync"
+
+
"github.com/bluesky-social/indigo/atproto/syntax"
"github.com/haileyok/cocoon/internal/helpers"
"github.com/haileyok/cocoon/models"
"github.com/ipfs/go-cid"
···
func (s *Server) handleServerCheckAccountStatus(e echo.Context) error {
urepo := e.Get("repo").(*models.RepoActor)
+
_, didErr := syntax.ParseDID(urepo.Repo.Did)
+
if didErr != nil {
+
s.logger.Error("error validating did", "err", didErr)
+
}
+
resp := ComAtprotoServerCheckAccountStatusResponse{
Activated: true, // TODO: should allow for deactivation etc.
-
ValidDid: true, // TODO: should probably verify?
+
ValidDid: didErr == nil,
RepoRev: urepo.Rev,
ImportedBlobs: 0, // TODO: ???
}
···
s.logger.Error("error casting cid", "error", err)
return helpers.ServerError(e, nil)
}
+
resp.RepoCommit = rootcid.String()
type CountResp struct {
···
}
var blockCtResp CountResp
-
if err := s.db.Raw("SELECT COUNT(*) AS ct FROM blocks WHERE did = ?", nil, urepo.Repo.Did).Scan(&blockCtResp).Error; err != nil {
-
s.logger.Error("error getting block count", "error", err)
-
return helpers.ServerError(e, nil)
-
}
-
resp.RepoBlocks = blockCtResp.Ct
+
var recCtResp CountResp
+
var blobCtResp CountResp
+
+
var wg sync.WaitGroup
+
var procErr error
+
+
wg.Add(1)
+
go func() {
+
defer wg.Done()
+
if err := s.db.Raw("SELECT COUNT(*) AS ct FROM blocks WHERE did = ?", nil, urepo.Repo.Did).Scan(&blockCtResp).Error; err != nil {
+
s.logger.Error("error getting block count", "error", err)
+
procErr = errors.Join(procErr, err)
+
}
+
}()
+
+
wg.Add(1)
+
go func() {
+
defer wg.Done()
+
if err := s.db.Raw("SELECT COUNT(*) AS ct FROM records WHERE did = ?", nil, urepo.Repo.Did).Scan(&recCtResp).Error; err != nil {
+
s.logger.Error("error getting record count", "error", err)
+
procErr = errors.Join(procErr, err)
+
}
+
}()
-
var recCtResp CountResp
-
if err := s.db.Raw("SELECT COUNT(*) AS ct FROM records WHERE did = ?", nil, urepo.Repo.Did).Scan(&recCtResp).Error; err != nil {
-
s.logger.Error("error getting record count", "error", err)
-
return helpers.ServerError(e, nil)
-
}
-
resp.IndexedRecords = recCtResp.Ct
+
wg.Add(1)
+
go func() {
+
if err := s.db.Raw("SELECT COUNT(*) AS ct FROM blobs WHERE did = ?", nil, urepo.Repo.Did).Scan(&blobCtResp).Error; err != nil {
+
s.logger.Error("error getting expected blobs count", "error", err)
+
procErr = errors.Join(procErr, err)
+
}
+
}()
-
var blobCtResp CountResp
-
if err := s.db.Raw("SELECT COUNT(*) AS ct FROM blobs WHERE did = ?", nil, urepo.Repo.Did).Scan(&blobCtResp).Error; err != nil {
-
s.logger.Error("error getting record count", "error", err)
+
wg.Wait()
+
if procErr != nil {
return helpers.ServerError(e, nil)
}
+
+
resp.RepoBlocks = blockCtResp.Ct
+
resp.IndexedRecords = recCtResp.Ct
resp.ExpectedBlobs = blobCtResp.Ct
return e.JSON(200, resp)
+3 -1
server/handle_sync_get_record.go
···
)
func (s *Server) handleSyncGetRecord(e echo.Context) error {
+
ctx := e.Request().Context()
+
did := e.QueryParam("did")
collection := e.QueryParam("collection")
rkey := e.QueryParam("rkey")
···
return helpers.ServerError(e, nil)
}
-
root, blocks, err := s.repoman.getRecordProof(urepo, collection, rkey)
+
root, blocks, err := s.repoman.getRecordProof(ctx, urepo, collection, rkey)
if err != nil {
return err
}
+19 -16
server/repo.go
···
}
// TODO make use of swap commit
-
func (rm *RepoMan) applyWrites(urepo models.Repo, writes []Op, swapCommit *string) ([]ApplyWriteResult, error) {
+
func (rm *RepoMan) applyWrites(ctx context.Context, urepo models.Repo, writes []Op, swapCommit *string) ([]ApplyWriteResult, error) {
rootcid, err := cid.Cast(urepo.Root)
if err != nil {
return nil, err
···
dbs := rm.s.getBlockstore(urepo.Did)
bs := recording_blockstore.New(dbs)
-
r, err := repo.OpenRepo(context.TODO(), bs, rootcid)
+
r, err := repo.OpenRepo(ctx, bs, rootcid)
-
entries := []models.Record{}
-
var results []ApplyWriteResult
+
entries := make([]models.Record, 0, len(writes))
+
results := make([]ApplyWriteResult, 0, len(writes))
for i, op := range writes {
if op.Type != OpTypeCreate && op.Rkey == nil {
return nil, fmt.Errorf("invalid rkey")
} else if op.Type == OpTypeCreate && op.Rkey != nil {
-
_, _, err := r.GetRecord(context.TODO(), op.Collection+"/"+*op.Rkey)
+
_, _, err := r.GetRecord(ctx, op.Collection+"/"+*op.Rkey)
if err == nil {
op.Type = OpTypeUpdate
}
···
mm["$type"] = op.Collection
}
-
nc, err := r.PutRecord(context.TODO(), op.Collection+"/"+*op.Rkey, &mm)
+
nc, err := r.PutRecord(ctx, op.Collection+"/"+*op.Rkey, &mm)
if err != nil {
return nil, err
}
···
Rkey: *op.Rkey,
Value: old.Value,
})
-
err := r.DeleteRecord(context.TODO(), op.Collection+"/"+*op.Rkey)
+
err := r.DeleteRecord(ctx, op.Collection+"/"+*op.Rkey)
if err != nil {
return nil, err
}
···
return nil, err
}
mm := MarshalableMap(out)
-
nc, err := r.UpdateRecord(context.TODO(), op.Collection+"/"+*op.Rkey, &mm)
+
nc, err := r.UpdateRecord(ctx, op.Collection+"/"+*op.Rkey, &mm)
if err != nil {
return nil, err
}
···
}
}
-
newroot, rev, err := r.Commit(context.TODO(), urepo.SignFor)
+
newroot, rev, err := r.Commit(ctx, urepo.SignFor)
if err != nil {
return nil, err
}
···
Roots: []cid.Cid{newroot},
Version: 1,
})
+
if err != nil {
+
return nil, err
+
}
if _, err := carstore.LdWrite(buf, hb); err != nil {
return nil, err
}
-
diffops, err := r.DiffSince(context.TODO(), rootcid)
+
diffops, err := r.DiffSince(ctx, rootcid)
if err != nil {
return nil, err
}
···
})
}
-
blk, err := dbs.Get(context.TODO(), c)
+
blk, err := dbs.Get(ctx, c)
if err != nil {
return nil, err
}
···
}
}
-
rm.s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{
+
rm.s.evtman.AddEvent(ctx, &events.XRPCStreamEvent{
RepoCommit: &atproto.SyncSubscribeRepos_Commit{
Repo: urepo.Did,
Blocks: buf.Bytes(),
···
},
})
-
if err := rm.s.UpdateRepo(context.TODO(), urepo.Did, newroot, rev); err != nil {
+
if err := rm.s.UpdateRepo(ctx, urepo.Did, newroot, rev); err != nil {
return nil, err
}
···
return results, nil
}
-
func (rm *RepoMan) getRecordProof(urepo models.Repo, collection, rkey string) (cid.Cid, []blocks.Block, error) {
+
func (rm *RepoMan) getRecordProof(ctx context.Context, urepo models.Repo, collection, rkey string) (cid.Cid, []blocks.Block, error) {
c, err := cid.Cast(urepo.Root)
if err != nil {
return cid.Undef, nil, err
···
dbs := rm.s.getBlockstore(urepo.Did)
bs := recording_blockstore.New(dbs)
-
r, err := repo.OpenRepo(context.TODO(), bs, c)
+
r, err := repo.OpenRepo(ctx, bs, c)
if err != nil {
return cid.Undef, nil, err
}
-
_, _, err = r.GetRecordBytes(context.TODO(), collection+"/"+rkey)
+
_, _, err = r.GetRecordBytes(ctx, collection+"/"+rkey)
if err != nil {
return cid.Undef, nil, err
}