An atproto PDS written in Go

Compare changes

Choose any two refs to compare.

+58
.github/workflows/docker-image.yml
···
+
name: Docker image
+
+
on:
+
workflow_dispatch:
+
push:
+
+
env:
+
REGISTRY: ghcr.io
+
IMAGE_NAME: ${{ github.repository }}
+
+
jobs:
+
build-and-push-image:
+
runs-on: ubuntu-latest
+
# Sets the permissions granted to the `GITHUB_TOKEN` for the actions in this job.
+
permissions:
+
contents: read
+
packages: write
+
attestations: write
+
id-token: write
+
#
+
steps:
+
- name: Checkout repository
+
uses: actions/checkout@v4
+
# Uses the `docker/login-action` action to log in to the Container registry registry using the account and password that will publish the packages. Once published, the packages are scoped to the account defined here.
+
- name: Log in to the Container registry
+
uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1
+
with:
+
registry: ${{ env.REGISTRY }}
+
username: ${{ github.actor }}
+
password: ${{ secrets.GITHUB_TOKEN }}
+
# This step uses [docker/metadata-action](https://github.com/docker/metadata-action#about) to extract tags and labels that will be applied to the specified image. The `id` "meta" allows the output of this step to be referenced in a subsequent step. The `images` value provides the base name for the tags and labels.
+
- name: Extract metadata (tags, labels) for Docker
+
id: meta
+
uses: docker/metadata-action@v5
+
with:
+
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
+
tags: |
+
type=sha
+
type=sha,format=long
+
# This step uses the `docker/build-push-action` action to build the image, based on your repository's `Dockerfile`. If the build succeeds, it pushes the image to GitHub Packages.
+
# It uses the `context` parameter to define the build's context as the set of files located in the specified path. For more information, see "[Usage](https://github.com/docker/build-push-action#usage)" in the README of the `docker/build-push-action` repository.
+
# It uses the `tags` and `labels` parameters to tag and label the image with the output from the "meta" step.
+
- name: Build and push Docker image
+
id: push
+
uses: docker/build-push-action@v5
+
with:
+
context: .
+
push: true
+
tags: ${{ steps.meta.outputs.tags }}
+
labels: ${{ steps.meta.outputs.labels }}
+
+
# This step generates an artifact attestation for the image, which is an unforgeable statement about where and how it was built. It increases supply chain security for people who consume the image. For more information, see "[AUTOTITLE](/actions/security-guides/using-artifact-attestations-to-establish-provenance-for-builds)."
+
- name: Generate artifact attestation
+
uses: actions/attest-build-provenance@v1
+
with:
+
subject-name: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME}}
+
subject-digest: ${{ steps.push.outputs.digest }}
+
push-to-registry: true
+25
Dockerfile
···
+
### Compile stage
+
FROM golang:1.25.1-bookworm AS build-env
+
+
ADD . /dockerbuild
+
WORKDIR /dockerbuild
+
+
RUN GIT_VERSION=$(git describe --tags --long --always || echo "dev-local") && \
+
go mod tidy && \
+
go build -ldflags "-X main.Version=$GIT_VERSION" -o cocoon ./cmd/cocoon
+
+
### Run stage
+
FROM debian:bookworm-slim AS run
+
+
RUN apt-get update && apt-get install -y dumb-init runit ca-certificates && rm -rf /var/lib/apt/lists/*
+
ENTRYPOINT ["dumb-init", "--"]
+
+
WORKDIR /
+
RUN mkdir -p data/cocoon
+
COPY --from=build-env /dockerbuild/cocoon /
+
+
CMD ["/cocoon", "run"]
+
+
LABEL org.opencontainers.image.source=https://github.com/haileyok/cocoon
+
LABEL org.opencontainers.image.description="Cocoon ATProto PDS"
+
LABEL org.opencontainers.image.licenses=MIT
+4
Makefile
···
.env:
if [ ! -f ".env" ]; then cp example.dev.env .env; fi
+
+
.PHONY: docker-build
+
docker-build:
+
docker build -t cocoon .
+63 -59
README.md
···
Cocoon is a PDS implementation in Go. It is highly experimental, and is not ready for any production use.
-
### Impmlemented Endpoints
+
## Implemented Endpoints
> [!NOTE]
-
Just because something is implemented doesn't mean it is finisehd. Tons of these are returning bad errors, don't do validation properly, etc. I'll make a "second pass" checklist at some point to do all of that.
+
Just because something is implemented doesn't mean it is finished. Tons of these are returning bad errors, don't do validation properly, etc. I'll make a "second pass" checklist at some point to do all of that.
-
#### Identity
-
- [ ] com.atproto.identity.getRecommendedDidCredentials
-
- [ ] com.atproto.identity.requestPlcOperationSignature
-
- [x] com.atproto.identity.resolveHandle
-
- [ ] com.atproto.identity.signPlcOperation
-
- [ ] com.atproto.identity.submitPlcOperatioin
-
- [x] com.atproto.identity.updateHandle
+
### Identity
-
#### Repo
-
- [x] com.atproto.repo.applyWrites
-
- [x] com.atproto.repo.createRecord
-
- [x] com.atproto.repo.putRecord
-
- [x] com.atproto.repo.deleteRecord
-
- [x] com.atproto.repo.describeRepo
-
- [x] com.atproto.repo.getRecord
-
- [ ] com.atproto.repo.importRepo
-
- [x] com.atproto.repo.listRecords
-
- [ ] com.atproto.repo.listMissingBlobs
+
- [ ] `com.atproto.identity.getRecommendedDidCredentials`
+
- [ ] `com.atproto.identity.requestPlcOperationSignature`
+
- [x] `com.atproto.identity.resolveHandle`
+
- [ ] `com.atproto.identity.signPlcOperation`
+
- [ ] `com.atproto.identity.submitPlcOperation`
+
- [x] `com.atproto.identity.updateHandle`
-
#### Server
-
- [ ] com.atproto.server.activateAccount
-
- [ ] com.atproto.server.checkAccountStatus
-
- [x] com.atproto.server.confirmEmail
-
- [x] com.atproto.server.createAccount
-
- [x] com.atproto.server.createInviteCode
-
- [x] com.atproto.server.createInviteCodes
-
- [ ] com.atproto.server.deactivateAccount
-
- [ ] com.atproto.server.deleteAccount
-
- [x] com.atproto.server.deleteSession
-
- [x] com.atproto.server.describeServer
-
- [ ] com.atproto.server.getAccountInviteCodes
-
- [ ] com.atproto.server.getServiceAuth
-
- ~[ ] com.atproto.server.listAppPasswords~ - not going to add app passwords
-
- [x] com.atproto.server.refreshSession
-
- [ ] com.atproto.server.requestAccountDelete
-
- [x] com.atproto.server.requestEmailConfirmation
-
- [x] com.atproto.server.requestEmailUpdate
-
- [x] com.atproto.server.requestPasswordReset
-
- [ ] com.atproto.server.reserveSigningKey
-
- [x] com.atproto.server.resetPassword
-
- ~[ ] com.atproto.server.revokeAppPassword~ - not going to add app passwords
-
- [x] com.atproto.server.updateEmail
+
### Repo
-
#### Sync
-
- [x] com.atproto.sync.getBlob
-
- [x] com.atproto.sync.getBlocks
-
- [x] com.atproto.sync.getLatestCommit
-
- [x] com.atproto.sync.getRecord
-
- [x] com.atproto.sync.getRepoStatus
-
- [x] com.atproto.sync.getRepo
-
- [x] com.atproto.sync.listBlobs
-
- [x] com.atproto.sync.listRepos
-
- ~[ ] com.atproto.sync.notifyOfUpdate~ - BGS doesn't even have this implemented lol
-
- [x] com.atproto.sync.requestCrawl
-
- [x] com.atproto.sync.subscribeRepos
+
- [x] `com.atproto.repo.applyWrites`
+
- [x] `com.atproto.repo.createRecord`
+
- [x] `com.atproto.repo.putRecord`
+
- [x] `com.atproto.repo.deleteRecord`
+
- [x] `com.atproto.repo.describeRepo`
+
- [x] `com.atproto.repo.getRecord`
+
- [x] `com.atproto.repo.importRepo` (Works "okay". You still have to handle PLC operations on your own when migrating. Use with extreme caution.)
+
- [x] `com.atproto.repo.listRecords`
+
- [ ] `com.atproto.repo.listMissingBlobs`
+
+
### Server
+
+
- [x] `com.atproto.server.activateAccount`
+
- [x] `com.atproto.server.checkAccountStatus`
+
- [x] `com.atproto.server.confirmEmail`
+
- [x] `com.atproto.server.createAccount`
+
- [x] `com.atproto.server.createInviteCode`
+
- [x] `com.atproto.server.createInviteCodes`
+
- [x] `com.atproto.server.deactivateAccount`
+
- [ ] `com.atproto.server.deleteAccount`
+
- [x] `com.atproto.server.deleteSession`
+
- [x] `com.atproto.server.describeServer`
+
- [ ] `com.atproto.server.getAccountInviteCodes`
+
- [ ] `com.atproto.server.getServiceAuth`
+
- ~~[ ] `com.atproto.server.listAppPasswords`~~ - not going to add app passwords
+
- [x] `com.atproto.server.refreshSession`
+
- [ ] `com.atproto.server.requestAccountDelete`
+
- [x] `com.atproto.server.requestEmailConfirmation`
+
- [x] `com.atproto.server.requestEmailUpdate`
+
- [x] `com.atproto.server.requestPasswordReset`
+
- [ ] `com.atproto.server.reserveSigningKey`
+
- [x] `com.atproto.server.resetPassword`
+
- ~~[] `com.atproto.server.revokeAppPassword`~~ - not going to add app passwords
+
- [x] `com.atproto.server.updateEmail`
-
#### Other
-
- [ ] com.atproto.label.queryLabels
-
- [ ] com.atproto.moderation.createReport
-
- [x] app.bsky.actor.getPreferences
-
- [x] app.bsky.actor.putPreferences
+
### Sync
+
- [x] `com.atproto.sync.getBlob`
+
- [x] `com.atproto.sync.getBlocks`
+
- [x] `com.atproto.sync.getLatestCommit`
+
- [x] `com.atproto.sync.getRecord`
+
- [x] `com.atproto.sync.getRepoStatus`
+
- [x] `com.atproto.sync.getRepo`
+
- [x] `com.atproto.sync.listBlobs`
+
- [x] `com.atproto.sync.listRepos`
+
- ~~[ ] `com.atproto.sync.notifyOfUpdate`~~ - BGS doesn't even have this implemented lol
+
- [x] `com.atproto.sync.requestCrawl`
+
- [x] `com.atproto.sync.subscribeRepos`
+
+
### Other
+
+
- [ ] `com.atproto.label.queryLabels`
+
- [x] `com.atproto.moderation.createReport` (Note: this should be handled by proxying, not actually implemented in the PDS)
+
- [x] `app.bsky.actor.getPreferences`
+
- [x] `app.bsky.actor.putPreferences`
## License
-163
blockstore/blockstore.go
···
-
package blockstore
-
-
import (
-
"context"
-
"fmt"
-
-
"github.com/bluesky-social/indigo/atproto/syntax"
-
"github.com/haileyok/cocoon/internal/db"
-
"github.com/haileyok/cocoon/models"
-
blocks "github.com/ipfs/go-block-format"
-
"github.com/ipfs/go-cid"
-
"gorm.io/gorm/clause"
-
)
-
-
type SqliteBlockstore struct {
-
db *db.DB
-
did string
-
readonly bool
-
inserts map[cid.Cid]blocks.Block
-
}
-
-
func New(did string, db *db.DB) *SqliteBlockstore {
-
return &SqliteBlockstore{
-
did: did,
-
db: db,
-
readonly: false,
-
inserts: map[cid.Cid]blocks.Block{},
-
}
-
}
-
-
func NewReadOnly(did string, db *db.DB) *SqliteBlockstore {
-
return &SqliteBlockstore{
-
did: did,
-
db: db,
-
readonly: true,
-
inserts: map[cid.Cid]blocks.Block{},
-
}
-
}
-
-
func (bs *SqliteBlockstore) Get(ctx context.Context, cid cid.Cid) (blocks.Block, error) {
-
var block models.Block
-
-
maybeBlock, ok := bs.inserts[cid]
-
if ok {
-
return maybeBlock, nil
-
}
-
-
if err := bs.db.Raw("SELECT * FROM blocks WHERE did = ? AND cid = ?", nil, bs.did, cid.Bytes()).Scan(&block).Error; err != nil {
-
return nil, err
-
}
-
-
b, err := blocks.NewBlockWithCid(block.Value, cid)
-
if err != nil {
-
return nil, err
-
}
-
-
return b, nil
-
}
-
-
func (bs *SqliteBlockstore) Put(ctx context.Context, block blocks.Block) error {
-
bs.inserts[block.Cid()] = block
-
-
if bs.readonly {
-
return nil
-
}
-
-
b := models.Block{
-
Did: bs.did,
-
Cid: block.Cid().Bytes(),
-
Rev: syntax.NewTIDNow(0).String(), // TODO: WARN, this is bad. don't do this
-
Value: block.RawData(),
-
}
-
-
if err := bs.db.Create(&b, []clause.Expression{clause.OnConflict{
-
Columns: []clause.Column{{Name: "did"}, {Name: "cid"}},
-
UpdateAll: true,
-
}}).Error; err != nil {
-
return err
-
}
-
-
return nil
-
}
-
-
func (bs *SqliteBlockstore) DeleteBlock(context.Context, cid.Cid) error {
-
panic("not implemented")
-
}
-
-
func (bs *SqliteBlockstore) Has(context.Context, cid.Cid) (bool, error) {
-
panic("not implemented")
-
}
-
-
func (bs *SqliteBlockstore) GetSize(context.Context, cid.Cid) (int, error) {
-
panic("not implemented")
-
}
-
-
func (bs *SqliteBlockstore) PutMany(ctx context.Context, blocks []blocks.Block) error {
-
tx := bs.db.BeginDangerously()
-
-
for _, block := range blocks {
-
bs.inserts[block.Cid()] = block
-
-
if bs.readonly {
-
continue
-
}
-
-
b := models.Block{
-
Did: bs.did,
-
Cid: block.Cid().Bytes(),
-
Rev: syntax.NewTIDNow(0).String(), // TODO: WARN, this is bad. don't do this
-
Value: block.RawData(),
-
}
-
-
if err := tx.Clauses(clause.OnConflict{
-
Columns: []clause.Column{{Name: "did"}, {Name: "cid"}},
-
UpdateAll: true,
-
}).Create(&b).Error; err != nil {
-
tx.Rollback()
-
return err
-
}
-
}
-
-
if bs.readonly {
-
return nil
-
}
-
-
tx.Commit()
-
-
return nil
-
}
-
-
func (bs *SqliteBlockstore) AllKeysChan(ctx context.Context) (<-chan cid.Cid, error) {
-
panic("not implemented")
-
}
-
-
func (bs *SqliteBlockstore) HashOnRead(enabled bool) {
-
panic("not implemented")
-
}
-
-
func (bs *SqliteBlockstore) UpdateRepo(ctx context.Context, root cid.Cid, rev string) error {
-
if err := bs.db.Exec("UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, bs.did).Error; err != nil {
-
return err
-
}
-
-
return nil
-
}
-
-
func (bs *SqliteBlockstore) Execute(ctx context.Context) error {
-
if !bs.readonly {
-
return fmt.Errorf("blockstore was not readonly")
-
}
-
-
bs.readonly = false
-
for _, b := range bs.inserts {
-
bs.Put(ctx, b)
-
}
-
bs.readonly = true
-
-
return nil
-
}
-
-
func (bs *SqliteBlockstore) GetLog() map[cid.Cid]blocks.Block {
-
return bs.inserts
-
}
-186
cmd/admin/main.go
···
-
package main
-
-
import (
-
"crypto/ecdsa"
-
"crypto/elliptic"
-
"crypto/rand"
-
"encoding/json"
-
"fmt"
-
"os"
-
"time"
-
-
"github.com/bluesky-social/indigo/atproto/crypto"
-
"github.com/bluesky-social/indigo/atproto/syntax"
-
"github.com/haileyok/cocoon/internal/helpers"
-
"github.com/lestrrat-go/jwx/v2/jwk"
-
"github.com/urfave/cli/v2"
-
"golang.org/x/crypto/bcrypt"
-
"gorm.io/driver/sqlite"
-
"gorm.io/gorm"
-
)
-
-
func main() {
-
app := cli.App{
-
Name: "admin",
-
Commands: cli.Commands{
-
runCreateRotationKey,
-
runCreatePrivateJwk,
-
runCreateInviteCode,
-
runResetPassword,
-
},
-
ErrWriter: os.Stdout,
-
}
-
-
app.Run(os.Args)
-
}
-
-
var runCreateRotationKey = &cli.Command{
-
Name: "create-rotation-key",
-
Usage: "creates a rotation key for your pds",
-
Flags: []cli.Flag{
-
&cli.StringFlag{
-
Name: "out",
-
Required: true,
-
Usage: "output file for your rotation key",
-
},
-
},
-
Action: func(cmd *cli.Context) error {
-
key, err := crypto.GeneratePrivateKeyK256()
-
if err != nil {
-
return err
-
}
-
-
bytes := key.Bytes()
-
-
if err := os.WriteFile(cmd.String("out"), bytes, 0644); err != nil {
-
return err
-
}
-
-
return nil
-
},
-
}
-
-
var runCreatePrivateJwk = &cli.Command{
-
Name: "create-private-jwk",
-
Usage: "creates a private jwk for your pds",
-
Flags: []cli.Flag{
-
&cli.StringFlag{
-
Name: "out",
-
Required: true,
-
Usage: "output file for your jwk",
-
},
-
},
-
Action: func(cmd *cli.Context) error {
-
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
-
if err != nil {
-
return err
-
}
-
-
key, err := jwk.FromRaw(privKey)
-
if err != nil {
-
return err
-
}
-
-
kid := fmt.Sprintf("%d", time.Now().Unix())
-
-
if err := key.Set(jwk.KeyIDKey, kid); err != nil {
-
return err
-
}
-
-
b, err := json.Marshal(key)
-
if err != nil {
-
return err
-
}
-
-
if err := os.WriteFile(cmd.String("out"), b, 0644); err != nil {
-
return err
-
}
-
-
return nil
-
},
-
}
-
-
var runCreateInviteCode = &cli.Command{
-
Name: "create-invite-code",
-
Usage: "creates an invite code",
-
Flags: []cli.Flag{
-
&cli.StringFlag{
-
Name: "for",
-
Usage: "optional did to assign the invite code to",
-
},
-
&cli.IntFlag{
-
Name: "uses",
-
Usage: "number of times the invite code can be used",
-
Value: 1,
-
},
-
},
-
Action: func(cmd *cli.Context) error {
-
db, err := newDb()
-
if err != nil {
-
return err
-
}
-
-
forDid := "did:plc:123"
-
if cmd.String("for") != "" {
-
did, err := syntax.ParseDID(cmd.String("for"))
-
if err != nil {
-
return err
-
}
-
-
forDid = did.String()
-
}
-
-
uses := cmd.Int("uses")
-
-
code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(8), helpers.RandomVarchar(8))
-
-
if err := db.Exec("INSERT INTO invite_codes (did, code, remaining_use_count) VALUES (?, ?, ?)", forDid, code, uses).Error; err != nil {
-
return err
-
}
-
-
fmt.Printf("New invite code created with %d uses: %s\n", uses, code)
-
-
return nil
-
},
-
}
-
-
var runResetPassword = &cli.Command{
-
Name: "reset-password",
-
Usage: "resets a password",
-
Flags: []cli.Flag{
-
&cli.StringFlag{
-
Name: "did",
-
Usage: "did of the user who's password you want to reset",
-
},
-
},
-
Action: func(cmd *cli.Context) error {
-
db, err := newDb()
-
if err != nil {
-
return err
-
}
-
-
didStr := cmd.String("did")
-
did, err := syntax.ParseDID(didStr)
-
if err != nil {
-
return err
-
}
-
-
newPass := fmt.Sprintf("%s-%s", helpers.RandomVarchar(12), helpers.RandomVarchar(12))
-
hashed, err := bcrypt.GenerateFromPassword([]byte(newPass), 10)
-
if err != nil {
-
return err
-
}
-
-
if err := db.Exec("UPDATE repos SET password = ? WHERE did = ?", hashed, did.String()).Error; err != nil {
-
return err
-
}
-
-
fmt.Printf("Password for %s has been reset to: %s", did.String(), newPass)
-
-
return nil
-
},
-
}
-
-
func newDb() (*gorm.DB, error) {
-
return gorm.Open(sqlite.Open("cocoon.db"), &gorm.Config{})
-
}
+193 -9
cmd/cocoon/main.go
···
package main
import (
+
"crypto/ecdsa"
+
"crypto/elliptic"
+
"crypto/rand"
+
"encoding/json"
"fmt"
"os"
+
"time"
+
"github.com/bluesky-social/indigo/atproto/atcrypto"
+
"github.com/bluesky-social/indigo/atproto/syntax"
+
"github.com/haileyok/cocoon/internal/helpers"
"github.com/haileyok/cocoon/server"
_ "github.com/joho/godotenv/autoload"
+
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/urfave/cli/v2"
+
"golang.org/x/crypto/bcrypt"
+
"gorm.io/driver/sqlite"
+
"gorm.io/gorm"
)
var Version = "dev"
···
Name: "s3-backups-enabled",
EnvVars: []string{"COCOON_S3_BACKUPS_ENABLED"},
},
+
&cli.BoolFlag{
+
Name: "s3-blobstore-enabled",
+
EnvVars: []string{"COCOON_S3_BLOBSTORE_ENABLED"},
+
},
&cli.StringFlag{
Name: "s3-region",
EnvVars: []string{"COCOON_S3_REGION"},
···
Name: "session-secret",
EnvVars: []string{"COCOON_SESSION_SECRET"},
},
+
&cli.StringFlag{
+
Name: "blockstore-variant",
+
EnvVars: []string{"COCOON_BLOCKSTORE_VARIANT"},
+
Value: "sqlite",
+
},
+
&cli.StringFlag{
+
Name: "fallback-proxy",
+
EnvVars: []string{"COCOON_FALLBACK_PROXY"},
+
},
},
Commands: []*cli.Command{
-
run,
+
runServe,
+
runCreateRotationKey,
+
runCreatePrivateJwk,
+
runCreateInviteCode,
+
runResetPassword,
},
ErrWriter: os.Stdout,
Version: Version,
···
}
}
-
var run = &cli.Command{
+
var runServe = &cli.Command{
Name: "run",
Usage: "Start the cocoon PDS",
Flags: []cli.Flag{},
Action: func(cmd *cli.Context) error {
+
s, err := server.New(&server.Args{
Addr: cmd.String("addr"),
DbName: cmd.String("db-name"),
···
SmtpEmail: cmd.String("smtp-email"),
SmtpName: cmd.String("smtp-name"),
S3Config: &server.S3Config{
-
BackupsEnabled: cmd.Bool("s3-backups-enabled"),
-
Region: cmd.String("s3-region"),
-
Bucket: cmd.String("s3-bucket"),
-
Endpoint: cmd.String("s3-endpoint"),
-
AccessKey: cmd.String("s3-access-key"),
-
SecretKey: cmd.String("s3-secret-key"),
+
BackupsEnabled: cmd.Bool("s3-backups-enabled"),
+
BlobstoreEnabled: cmd.Bool("s3-blobstore-enabled"),
+
Region: cmd.String("s3-region"),
+
Bucket: cmd.String("s3-bucket"),
+
Endpoint: cmd.String("s3-endpoint"),
+
AccessKey: cmd.String("s3-access-key"),
+
SecretKey: cmd.String("s3-secret-key"),
},
-
SessionSecret: cmd.String("session-secret"),
+
SessionSecret: cmd.String("session-secret"),
+
BlockstoreVariant: server.MustReturnBlockstoreVariant(cmd.String("blockstore-variant")),
+
FallbackProxy: cmd.String("fallback-proxy"),
})
if err != nil {
fmt.Printf("error creating cocoon: %v", err)
···
return nil
},
}
+
+
var runCreateRotationKey = &cli.Command{
+
Name: "create-rotation-key",
+
Usage: "creates a rotation key for your pds",
+
Flags: []cli.Flag{
+
&cli.StringFlag{
+
Name: "out",
+
Required: true,
+
Usage: "output file for your rotation key",
+
},
+
},
+
Action: func(cmd *cli.Context) error {
+
key, err := atcrypto.GeneratePrivateKeyK256()
+
if err != nil {
+
return err
+
}
+
+
bytes := key.Bytes()
+
+
if err := os.WriteFile(cmd.String("out"), bytes, 0644); err != nil {
+
return err
+
}
+
+
return nil
+
},
+
}
+
+
var runCreatePrivateJwk = &cli.Command{
+
Name: "create-private-jwk",
+
Usage: "creates a private jwk for your pds",
+
Flags: []cli.Flag{
+
&cli.StringFlag{
+
Name: "out",
+
Required: true,
+
Usage: "output file for your jwk",
+
},
+
},
+
Action: func(cmd *cli.Context) error {
+
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+
if err != nil {
+
return err
+
}
+
+
key, err := jwk.FromRaw(privKey)
+
if err != nil {
+
return err
+
}
+
+
kid := fmt.Sprintf("%d", time.Now().Unix())
+
+
if err := key.Set(jwk.KeyIDKey, kid); err != nil {
+
return err
+
}
+
+
b, err := json.Marshal(key)
+
if err != nil {
+
return err
+
}
+
+
if err := os.WriteFile(cmd.String("out"), b, 0644); err != nil {
+
return err
+
}
+
+
return nil
+
},
+
}
+
+
var runCreateInviteCode = &cli.Command{
+
Name: "create-invite-code",
+
Usage: "creates an invite code",
+
Flags: []cli.Flag{
+
&cli.StringFlag{
+
Name: "for",
+
Usage: "optional did to assign the invite code to",
+
},
+
&cli.IntFlag{
+
Name: "uses",
+
Usage: "number of times the invite code can be used",
+
Value: 1,
+
},
+
},
+
Action: func(cmd *cli.Context) error {
+
db, err := newDb()
+
if err != nil {
+
return err
+
}
+
+
forDid := "did:plc:123"
+
if cmd.String("for") != "" {
+
did, err := syntax.ParseDID(cmd.String("for"))
+
if err != nil {
+
return err
+
}
+
+
forDid = did.String()
+
}
+
+
uses := cmd.Int("uses")
+
+
code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(8), helpers.RandomVarchar(8))
+
+
if err := db.Exec("INSERT INTO invite_codes (did, code, remaining_use_count) VALUES (?, ?, ?)", forDid, code, uses).Error; err != nil {
+
return err
+
}
+
+
fmt.Printf("New invite code created with %d uses: %s\n", uses, code)
+
+
return nil
+
},
+
}
+
+
var runResetPassword = &cli.Command{
+
Name: "reset-password",
+
Usage: "resets a password",
+
Flags: []cli.Flag{
+
&cli.StringFlag{
+
Name: "did",
+
Usage: "did of the user who's password you want to reset",
+
},
+
},
+
Action: func(cmd *cli.Context) error {
+
db, err := newDb()
+
if err != nil {
+
return err
+
}
+
+
didStr := cmd.String("did")
+
did, err := syntax.ParseDID(didStr)
+
if err != nil {
+
return err
+
}
+
+
newPass := fmt.Sprintf("%s-%s", helpers.RandomVarchar(12), helpers.RandomVarchar(12))
+
hashed, err := bcrypt.GenerateFromPassword([]byte(newPass), 10)
+
if err != nil {
+
return err
+
}
+
+
if err := db.Exec("UPDATE repos SET password = ? WHERE did = ?", hashed, did.String()).Error; err != nil {
+
return err
+
}
+
+
fmt.Printf("Password for %s has been reset to: %s", did.String(), newPass)
+
+
return nil
+
},
+
}
+
+
func newDb() (*gorm.DB, error) {
+
return gorm.Open(sqlite.Open("cocoon.db"), &gorm.Config{})
+
}
+45
cspell.json
···
+
{
+
"version": "0.2",
+
"language": "en",
+
"words": [
+
"atproto",
+
"bsky",
+
"Cocoon",
+
"PDS",
+
"Plc",
+
"plc",
+
"repo",
+
"InviteCodes",
+
"InviteCode",
+
"Invite",
+
"Signin",
+
"Signout",
+
"JWKS",
+
"dpop",
+
"BGS",
+
"pico",
+
"picocss",
+
"par",
+
"blobs",
+
"blob",
+
"did",
+
"DID",
+
"OAuth",
+
"oauth",
+
"par",
+
"Cocoon",
+
"memcache",
+
"db",
+
"helpers",
+
"middleware",
+
"repo",
+
"static",
+
"pico",
+
"picocss",
+
"MIT",
+
"Go"
+
],
+
"ignorePaths": [
+
"server/static/pico.css"
+
]
+
}
+3 -2
go.mod
···
require (
github.com/Azure/go-autorest/autorest/to v0.4.1
github.com/aws/aws-sdk-go v1.55.7
-
github.com/bluesky-social/indigo v0.0.0-20250414202759-826fcdeaa36b
+
github.com/bluesky-social/indigo v0.0.0-20251009212240-20524de167fe
github.com/btcsuite/websocket v0.0.0-20150119174127-31079b680792
github.com/domodwyer/mailyak/v3 v3.6.2
github.com/go-pkgz/expirable-cache/v3 v3.0.0
···
github.com/google/uuid v1.4.0
github.com/gorilla/sessions v1.4.0
github.com/gorilla/websocket v1.5.1
+
github.com/hako/durafmt v0.0.0-20210608085754-5c1018a4e16b
github.com/hashicorp/golang-lru/v2 v2.0.7
github.com/ipfs/go-block-format v0.2.0
github.com/ipfs/go-cid v0.4.1
+
github.com/ipfs/go-ipfs-blockstore v1.3.1
github.com/ipfs/go-ipld-cbor v0.1.0
github.com/ipld/go-car v0.6.1-0.20230509095817-92d28eb23ba4
github.com/joho/godotenv v1.5.1
···
github.com/ipfs/bbloom v0.0.4 // indirect
github.com/ipfs/go-blockservice v0.5.2 // indirect
github.com/ipfs/go-datastore v0.6.0 // indirect
-
github.com/ipfs/go-ipfs-blockstore v1.3.1 // indirect
github.com/ipfs/go-ipfs-ds-help v1.1.1 // indirect
github.com/ipfs/go-ipfs-exchange-interface v0.2.1 // indirect
github.com/ipfs/go-ipfs-util v0.0.3 // indirect
+4 -4
go.sum
···
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYEDvkta6I8/rnYM5gSdSV2tJ6XbZuEtY=
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k=
-
github.com/bluesky-social/indigo v0.0.0-20250414202759-826fcdeaa36b h1:elwfbe+W7GkUmPKFX1h7HaeHvC/kC0XJWfiEHC62xPg=
-
github.com/bluesky-social/indigo v0.0.0-20250414202759-826fcdeaa36b/go.mod h1:yjdhLA1LkK8VDS/WPUoYPo25/Hq/8rX38Ftr67EsqKY=
+
github.com/bluesky-social/indigo v0.0.0-20251009212240-20524de167fe h1:VBhaqE5ewQgXbY5SfSWFZC/AwHFo7cHxZKFYi2ce9Yo=
+
github.com/bluesky-social/indigo v0.0.0-20251009212240-20524de167fe/go.mod h1:RuQVrCGm42QNsgumKaR6se+XkFKfCPNwdCiTvqKRUck=
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY=
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4=
github.com/btcsuite/websocket v0.0.0-20150119174127-31079b680792 h1:R8vQdOQdZ9Y3SkEwmHoWBmX1DNXhXZqlTpq6s4tyJGc=
···
github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY=
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8=
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4=
+
github.com/hako/durafmt v0.0.0-20210608085754-5c1018a4e16b h1:wDUNC2eKiL35DbLvsDhiblTUXHxcOPwQSCzi7xpQUN4=
+
github.com/hako/durafmt v0.0.0-20210608085754-5c1018a4e16b/go.mod h1:VzxiSdG6j1pi7rwGm/xYI5RbtpBgM8sARDXlvEvxlu0=
github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ=
github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48=
github.com/hashicorp/go-hclog v0.9.2 h1:CG6TE5H9/JXsFWJCfoIVpKFIkFe6ysEuHirp4DxCsHI=
···
github.com/ipfs/go-block-format v0.2.0/go.mod h1:+jpL11nFx5A/SPpsoBn6Bzkra/zaArfSmsknbPMYgzM=
github.com/ipfs/go-blockservice v0.5.2 h1:in9Bc+QcXwd1apOVM7Un9t8tixPKdaHQFdLSUM1Xgk8=
github.com/ipfs/go-blockservice v0.5.2/go.mod h1:VpMblFEqG67A/H2sHKAemeH9vlURVavlysbdUI632yk=
-
github.com/ipfs/go-bs-sqlite3 v0.0.0-20221122195556-bfcee1be620d h1:9V+GGXCuOfDiFpdAHz58q9mKLg447xp0cQKvqQrAwYE=
-
github.com/ipfs/go-bs-sqlite3 v0.0.0-20221122195556-bfcee1be620d/go.mod h1:pMbnFyNAGjryYCLCe59YDLRv/ujdN+zGJBT1umlvYRM=
github.com/ipfs/go-cid v0.4.1 h1:A/T3qGvxi4kpKWWcPC/PgbvDA2bjVLO7n4UeVwnbs/s=
github.com/ipfs/go-cid v0.4.1/go.mod h1:uQHwDeX4c6CtyrFwdqyhpNcxVewur1M7l7fNU7LKwZk=
github.com/ipfs/go-datastore v0.6.0 h1:JKyz+Gvz1QEZw0LsX1IBn+JFCJQH4SJVFtM4uWU0Myk=
+74 -55
identity/identity.go
···
"github.com/bluesky-social/indigo/util"
)
-
func ResolveHandle(ctx context.Context, cli *http.Client, handle string) (string, error) {
-
if cli == nil {
-
cli = util.RobustHTTPClient()
+
func ResolveHandleFromTXT(ctx context.Context, handle string) (string, error) {
+
name := fmt.Sprintf("_atproto.%s", handle)
+
recs, err := net.LookupTXT(name)
+
if err != nil {
+
return "", fmt.Errorf("handle could not be resolved via txt: %w", err)
+
}
+
+
for _, rec := range recs {
+
if strings.HasPrefix(rec, "did=") {
+
maybeDid := strings.Split(rec, "did=")[1]
+
if _, err := syntax.ParseDID(maybeDid); err == nil {
+
return maybeDid, nil
+
}
+
}
+
}
+
+
return "", fmt.Errorf("handle could not be resolved via txt: no record found")
+
}
+
+
func ResolveHandleFromWellKnown(ctx context.Context, cli *http.Client, handle string) (string, error) {
+
ustr := fmt.Sprintf("https://%s/.well-known/atproto-did", handle)
+
req, err := http.NewRequestWithContext(
+
ctx,
+
"GET",
+
ustr,
+
nil,
+
)
+
if err != nil {
+
return "", fmt.Errorf("handle could not be resolved via web: %w", err)
}
-
var did string
+
resp, err := cli.Do(req)
+
if err != nil {
+
return "", fmt.Errorf("handle could not be resolved via web: %w", err)
+
}
+
defer resp.Body.Close()
-
_, err := syntax.ParseHandle(handle)
+
b, err := io.ReadAll(resp.Body)
if err != nil {
-
return "", err
+
return "", fmt.Errorf("handle could not be resolved via web: %w", err)
}
-
recs, err := net.LookupTXT(fmt.Sprintf("_atproto.%s", handle))
-
if err == nil {
-
for _, rec := range recs {
-
if strings.HasPrefix(rec, "did=") {
-
did = strings.Split(rec, "did=")[1]
-
break
-
}
-
}
-
} else {
-
fmt.Printf("erorr getting txt records: %v\n", err)
+
if resp.StatusCode != http.StatusOK {
+
return "", fmt.Errorf("handle could not be resolved via web: invalid status code %d", resp.StatusCode)
}
-
if did == "" {
-
req, err := http.NewRequestWithContext(
-
ctx,
-
"GET",
-
fmt.Sprintf("https://%s/.well-known/atproto-did", handle),
-
nil,
-
)
-
if err != nil {
-
return "", nil
-
}
+
maybeDid := string(b)
-
resp, err := http.DefaultClient.Do(req)
-
if err != nil {
-
return "", nil
-
}
-
defer resp.Body.Close()
+
if _, err := syntax.ParseDID(maybeDid); err != nil {
+
return "", fmt.Errorf("handle could not be resolved via web: invalid did in document")
+
}
-
if resp.StatusCode != http.StatusOK {
-
io.Copy(io.Discard, resp.Body)
-
return "", fmt.Errorf("unable to resolve handle")
-
}
+
return maybeDid, nil
+
}
-
b, err := io.ReadAll(resp.Body)
-
if err != nil {
-
return "", err
-
}
+
func ResolveHandle(ctx context.Context, cli *http.Client, handle string) (string, error) {
+
if cli == nil {
+
cli = util.RobustHTTPClient()
+
}
-
maybeDid := string(b)
+
_, err := syntax.ParseHandle(handle)
+
if err != nil {
+
return "", err
+
}
-
if _, err := syntax.ParseDID(maybeDid); err != nil {
-
return "", fmt.Errorf("unable to resolve handle")
-
}
+
if maybeDidFromTxt, err := ResolveHandleFromTXT(ctx, handle); err == nil {
+
return maybeDidFromTxt, nil
+
}
-
did = maybeDid
+
if maybeDidFromWeb, err := ResolveHandleFromWellKnown(ctx, cli, handle); err == nil {
+
return maybeDidFromWeb, nil
}
-
return did, nil
+
return "", fmt.Errorf("handle could not be resolved")
+
}
+
+
func DidToDocUrl(did string) (string, error) {
+
if strings.HasPrefix(did, "did:plc:") {
+
return fmt.Sprintf("https://plc.directory/%s", did), nil
+
} else if after, ok := strings.CutPrefix(did, "did:web:"); ok {
+
return fmt.Sprintf("https://%s/.well-known/did.json", after), nil
+
} else {
+
return "", fmt.Errorf("did was not a supported did type")
+
}
}
func FetchDidDoc(ctx context.Context, cli *http.Client, did string) (*DidDoc, error) {
···
cli = util.RobustHTTPClient()
}
-
var ustr string
-
if strings.HasPrefix(did, "did:plc:") {
-
ustr = fmt.Sprintf("https://plc.directory/%s", did)
-
} else if strings.HasPrefix(did, "did:web:") {
-
ustr = fmt.Sprintf("https://%s/.well-known/did.json", strings.TrimPrefix(did, "did:web:"))
-
} else {
-
return nil, fmt.Errorf("did was not a supported did type")
+
ustr, err := DidToDocUrl(did)
+
if err != nil {
+
return nil, err
}
req, err := http.NewRequestWithContext(ctx, "GET", ustr, nil)
···
return nil, err
}
-
resp, err := http.DefaultClient.Do(req)
+
resp, err := cli.Do(req)
if err != nil {
return nil, err
}
···
if resp.StatusCode != 200 {
io.Copy(io.Discard, resp.Body)
-
return nil, fmt.Errorf("could not find identity in plc registry")
+
return nil, fmt.Errorf("unable to find did doc at url. did: %s. url: %s", did, ustr)
}
var diddoc DidDoc
···
return nil, err
}
-
resp, err := http.DefaultClient.Do(req)
+
resp, err := cli.Do(req)
if err != nil {
return nil, err
}
+15 -5
identity/passport.go
···
type Passport struct {
h *http.Client
bc BackingCache
-
lk sync.Mutex
+
mu sync.RWMutex
}
func NewPassport(h *http.Client, bc BackingCache) *Passport {
···
return &Passport{
h: h,
bc: bc,
-
lk: sync.Mutex{},
}
}
···
skipCache, _ := ctx.Value("skip-cache").(bool)
if !skipCache {
+
p.mu.RLock()
cached, ok := p.bc.GetDoc(did)
+
p.mu.RUnlock()
+
if ok {
return cached, nil
}
}
-
p.lk.Lock() // this is pretty pathetic, and i should rethink this. but for now, fuck it
-
defer p.lk.Unlock()
-
doc, err := FetchDidDoc(ctx, p.h, did)
if err != nil {
return nil, err
}
+
p.mu.Lock()
p.bc.PutDoc(did, doc)
+
p.mu.Unlock()
return doc, nil
}
···
skipCache, _ := ctx.Value("skip-cache").(bool)
if !skipCache {
+
p.mu.RLock()
cached, ok := p.bc.GetDid(handle)
+
p.mu.RUnlock()
+
if ok {
return cached, nil
}
···
return "", err
}
+
p.mu.Lock()
p.bc.PutDid(handle, did)
+
p.mu.Unlock()
return did, nil
}
func (p *Passport) BustDoc(ctx context.Context, did string) error {
+
p.mu.Lock()
+
defer p.mu.Unlock()
return p.bc.BustDoc(did)
}
func (p *Passport) BustDid(ctx context.Context, handle string) error {
+
p.mu.Lock()
+
defer p.mu.Unlock()
return p.bc.BustDid(handle)
}
+13
internal/helpers/helpers.go
···
"math/rand"
"net/url"
+
"github.com/Azure/go-autorest/autorest/to"
"github.com/labstack/echo/v4"
"github.com/lestrrat-go/jwx/v2/jwk"
)
···
msg += ". " + *suffix
}
return genericError(e, 400, msg)
+
}
+
+
func InvalidTokenError(e echo.Context) error {
+
return InputError(e, to.StringPtr("InvalidToken"))
+
}
+
+
func ExpiredTokenError(e echo.Context) error {
+
// WARN: See https://github.com/bluesky-social/atproto/discussions/3319
+
return e.JSON(400, map[string]string{
+
"error": "ExpiredToken",
+
"message": "*",
+
})
}
func genericError(e echo.Context, code int, msg string) error {
+17 -2
models/models.go
···
"context"
"time"
-
"github.com/bluesky-social/indigo/atproto/crypto"
+
"github.com/Azure/go-autorest/autorest/to"
+
"github.com/bluesky-social/indigo/atproto/atcrypto"
)
type Repo struct {
···
Rev string
Root []byte
Preferences []byte
+
Deactivated bool
}
func (r *Repo) SignFor(ctx context.Context, did string, msg []byte) ([]byte, error) {
-
k, err := crypto.ParsePrivateBytesK256(r.SigningKey)
+
k, err := atcrypto.ParsePrivateBytesK256(r.SigningKey)
if err != nil {
return nil, err
}
···
}
return sig, nil
+
}
+
+
func (r *Repo) Status() *string {
+
var status *string
+
if r.Deactivated {
+
status = to.StringPtr("deactivated")
+
}
+
return status
+
}
+
+
func (r *Repo) Active() bool {
+
return r.Status() == nil
}
type Actor struct {
···
Did string `gorm:"index;index:idx_blob_did_cid"`
Cid []byte `gorm:"index;index:idx_blob_did_cid"`
RefCount int
+
Storage string `gorm:"default:sqlite;check:storage in ('sqlite', 's3')"`
}
type BlobPart struct {
+8
oauth/client/client.go
···
+
package client
+
+
import "github.com/lestrrat-go/jwx/v2/jwk"
+
+
type Client struct {
+
Metadata *Metadata
+
JWKS jwk.Key
+
}
+397
oauth/client/manager.go
···
+
package client
+
+
import (
+
"context"
+
"encoding/json"
+
"errors"
+
"fmt"
+
"io"
+
"log/slog"
+
"net/http"
+
"net/url"
+
"slices"
+
"strings"
+
"time"
+
+
cache "github.com/go-pkgz/expirable-cache/v3"
+
"github.com/haileyok/cocoon/internal/helpers"
+
"github.com/lestrrat-go/jwx/v2/jwk"
+
)
+
+
type Manager struct {
+
cli *http.Client
+
logger *slog.Logger
+
jwksCache cache.Cache[string, jwk.Key]
+
metadataCache cache.Cache[string, Metadata]
+
}
+
+
type ManagerArgs struct {
+
Cli *http.Client
+
Logger *slog.Logger
+
}
+
+
func NewManager(args ManagerArgs) *Manager {
+
if args.Logger == nil {
+
args.Logger = slog.Default()
+
}
+
+
if args.Cli == nil {
+
args.Cli = http.DefaultClient
+
}
+
+
jwksCache := cache.NewCache[string, jwk.Key]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute)
+
metadataCache := cache.NewCache[string, Metadata]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute)
+
+
return &Manager{
+
cli: args.Cli,
+
logger: args.Logger,
+
jwksCache: jwksCache,
+
metadataCache: metadataCache,
+
}
+
}
+
+
func (cm *Manager) GetClient(ctx context.Context, clientId string) (*Client, error) {
+
metadata, err := cm.getClientMetadata(ctx, clientId)
+
if err != nil {
+
return nil, err
+
}
+
+
var jwks jwk.Key
+
if metadata.JWKS != nil && len(metadata.JWKS.Keys) > 0 {
+
// TODO: this is kinda bad but whatever for now. there could obviously be more than one jwk, and we need to
+
// make sure we use the right one
+
b, err := json.Marshal(metadata.JWKS.Keys[0])
+
if err != nil {
+
return nil, err
+
}
+
+
k, err := helpers.ParseJWKFromBytes(b)
+
if err != nil {
+
return nil, err
+
}
+
+
jwks = k
+
} else if metadata.JWKSURI != nil {
+
maybeJwks, err := cm.getClientJwks(ctx, clientId, *metadata.JWKSURI)
+
if err != nil {
+
return nil, err
+
}
+
+
jwks = maybeJwks
+
} else {
+
return nil, fmt.Errorf("no valid jwks found in oauth client metadata")
+
}
+
+
return &Client{
+
Metadata: metadata,
+
JWKS: jwks,
+
}, nil
+
}
+
+
func (cm *Manager) getClientMetadata(ctx context.Context, clientId string) (*Metadata, error) {
+
metadataCached, ok := cm.metadataCache.Get(clientId)
+
if !ok {
+
req, err := http.NewRequestWithContext(ctx, "GET", clientId, nil)
+
if err != nil {
+
return nil, err
+
}
+
+
resp, err := cm.cli.Do(req)
+
if err != nil {
+
return nil, err
+
}
+
defer resp.Body.Close()
+
+
if resp.StatusCode != http.StatusOK {
+
io.Copy(io.Discard, resp.Body)
+
return nil, fmt.Errorf("fetching client metadata returned response code %d", resp.StatusCode)
+
}
+
+
b, err := io.ReadAll(resp.Body)
+
if err != nil {
+
return nil, fmt.Errorf("error reading bytes from client response: %w", err)
+
}
+
+
validated, err := validateAndParseMetadata(clientId, b)
+
if err != nil {
+
return nil, err
+
}
+
+
return validated, nil
+
} else {
+
return &metadataCached, nil
+
}
+
}
+
+
func (cm *Manager) getClientJwks(ctx context.Context, clientId, jwksUri string) (jwk.Key, error) {
+
jwks, ok := cm.jwksCache.Get(clientId)
+
if !ok {
+
req, err := http.NewRequestWithContext(ctx, "GET", jwksUri, nil)
+
if err != nil {
+
return nil, err
+
}
+
+
resp, err := cm.cli.Do(req)
+
if err != nil {
+
return nil, err
+
}
+
defer resp.Body.Close()
+
+
if resp.StatusCode != http.StatusOK {
+
io.Copy(io.Discard, resp.Body)
+
return nil, fmt.Errorf("fetching client jwks returned response code %d", resp.StatusCode)
+
}
+
+
type Keys struct {
+
Keys []map[string]any `json:"keys"`
+
}
+
+
var keys Keys
+
if err := json.NewDecoder(resp.Body).Decode(&keys); err != nil {
+
return nil, fmt.Errorf("error unmarshaling keys response: %w", err)
+
}
+
+
if len(keys.Keys) == 0 {
+
return nil, errors.New("no keys in jwks response")
+
}
+
+
// TODO: this is again bad, we should be figuring out which one we need to use...
+
b, err := json.Marshal(keys.Keys[0])
+
if err != nil {
+
return nil, fmt.Errorf("could not marshal key: %w", err)
+
}
+
+
k, err := helpers.ParseJWKFromBytes(b)
+
if err != nil {
+
return nil, err
+
}
+
+
jwks = k
+
}
+
+
return jwks, nil
+
}
+
+
func validateAndParseMetadata(clientId string, b []byte) (*Metadata, error) {
+
var metadataMap map[string]any
+
if err := json.Unmarshal(b, &metadataMap); err != nil {
+
return nil, fmt.Errorf("error unmarshaling metadata: %w", err)
+
}
+
+
_, jwksOk := metadataMap["jwks"].(string)
+
_, jwksUriOk := metadataMap["jwks_uri"].(string)
+
if jwksOk && jwksUriOk {
+
return nil, errors.New("jwks_uri and jwks are mutually exclusive")
+
}
+
+
for _, k := range []string{
+
"default_max_age",
+
"userinfo_signed_response_alg",
+
"id_token_signed_response_alg",
+
"userinfo_encryhpted_response_alg",
+
"authorization_encrypted_response_enc",
+
"authorization_encrypted_response_alg",
+
"tls_client_certificate_bound_access_tokens",
+
} {
+
_, kOk := metadataMap[k]
+
if kOk {
+
return nil, fmt.Errorf("unsupported `%s` parameter", k)
+
}
+
}
+
+
var metadata Metadata
+
if err := json.Unmarshal(b, &metadata); err != nil {
+
return nil, fmt.Errorf("error unmarshaling metadata: %w", err)
+
}
+
+
u, err := url.Parse(metadata.ClientURI)
+
if err != nil {
+
return nil, fmt.Errorf("unable to parse client uri: %w", err)
+
}
+
+
if isLocalHostname(u.Hostname()) {
+
return nil, errors.New("`client_uri` hostname is invalid")
+
}
+
+
if metadata.Scope == "" {
+
return nil, errors.New("missing `scopes` scope")
+
}
+
+
scopes := strings.Split(metadata.Scope, " ")
+
if !slices.Contains(scopes, "atproto") {
+
return nil, errors.New("missing `atproto` scope")
+
}
+
+
scopesMap := map[string]bool{}
+
for _, scope := range scopes {
+
if scopesMap[scope] {
+
return nil, fmt.Errorf("duplicate scope `%s`", scope)
+
}
+
+
// TODO: check for unsupported scopes
+
+
scopesMap[scope] = true
+
}
+
+
grantTypesMap := map[string]bool{}
+
for _, gt := range metadata.GrantTypes {
+
if grantTypesMap[gt] {
+
return nil, fmt.Errorf("duplicate grant type `%s`", gt)
+
}
+
+
switch gt {
+
case "implicit":
+
return nil, errors.New("grantg type `implicit` is not allowed")
+
case "authorization_code", "refresh_token":
+
// TODO check if this grant type is supported
+
default:
+
return nil, fmt.Errorf("grant tyhpe `%s` is not supported", gt)
+
}
+
+
grantTypesMap[gt] = true
+
}
+
+
if metadata.ClientID != clientId {
+
return nil, errors.New("`client_id` does not match")
+
}
+
+
subjectType, subjectTypeOk := metadataMap["subject_type"].(string)
+
if subjectTypeOk && subjectType != "public" {
+
return nil, errors.New("only public `subject_type` is supported")
+
}
+
+
switch metadata.TokenEndpointAuthMethod {
+
case "none":
+
if metadata.TokenEndpointAuthSigningAlg != "" {
+
return nil, errors.New("token_endpoint_auth_method `none` must not have token_endpoint_auth_signing_alg")
+
}
+
case "private_key_jwt":
+
if metadata.JWKS == nil && metadata.JWKSURI == nil {
+
return nil, errors.New("private_key_jwt auth method requires jwks or jwks_uri")
+
}
+
+
if metadata.JWKS != nil && len(metadata.JWKS.Keys) == 0 {
+
return nil, errors.New("private_key_jwt auth method requires atleast one key in jwks")
+
}
+
+
if metadata.TokenEndpointAuthSigningAlg == "" {
+
return nil, errors.New("missing token_endpoint_auth_signing_alg in client metadata")
+
}
+
default:
+
return nil, fmt.Errorf("unsupported client authentication method `%s`", metadata.TokenEndpointAuthMethod)
+
}
+
+
if !metadata.DpopBoundAccessTokens {
+
return nil, errors.New("dpop_bound_access_tokens must be true")
+
}
+
+
if !slices.Contains(metadata.ResponseTypes, "code") {
+
return nil, errors.New("response_types must inclue `code`")
+
}
+
+
if !slices.Contains(metadata.GrantTypes, "authorization_code") {
+
return nil, errors.New("the `code` response type requires that `grant_types` contains `authorization_code`")
+
}
+
+
if len(metadata.RedirectURIs) == 0 {
+
return nil, errors.New("at least one `redirect_uri` is required")
+
}
+
+
if metadata.ApplicationType == "native" && metadata.TokenEndpointAuthMethod != "none" {
+
return nil, errors.New("native clients must authenticate using `none` method")
+
}
+
+
if metadata.ApplicationType == "web" && slices.Contains(metadata.GrantTypes, "implicit") {
+
for _, ruri := range metadata.RedirectURIs {
+
u, err := url.Parse(ruri)
+
if err != nil {
+
return nil, fmt.Errorf("error parsing redirect uri: %w", err)
+
}
+
+
if u.Scheme != "https" {
+
return nil, errors.New("web clients must use https redirect uris")
+
}
+
+
if u.Hostname() == "localhost" {
+
return nil, errors.New("web clients must not use localhost as the hostname")
+
}
+
}
+
}
+
+
for _, ruri := range metadata.RedirectURIs {
+
u, err := url.Parse(ruri)
+
if err != nil {
+
return nil, fmt.Errorf("error parsing redirect uri: %w", err)
+
}
+
+
if u.User != nil {
+
if u.User.Username() != "" {
+
return nil, fmt.Errorf("redirect uri %s must not contain credentials", ruri)
+
}
+
+
if _, hasPass := u.User.Password(); hasPass {
+
return nil, fmt.Errorf("redirect uri %s must not contain credentials", ruri)
+
}
+
}
+
+
switch true {
+
case u.Hostname() == "localhost":
+
return nil, errors.New("loopback redirect uri is not allowed (use explicit ips instead)")
+
case u.Hostname() == "127.0.0.1", u.Hostname() == "[::1]":
+
if metadata.ApplicationType != "native" {
+
return nil, errors.New("loopback redirect uris are only allowed for native apps")
+
}
+
+
if u.Port() != "" {
+
// reference impl doesn't do anything with this?
+
}
+
+
if u.Scheme != "http" {
+
return nil, fmt.Errorf("loopback redirect uri %s must use http", ruri)
+
}
+
+
break
+
case u.Scheme == "http":
+
return nil, errors.New("only loopbvack redirect uris are allowed to use the `http` scheme")
+
case u.Scheme == "https":
+
if isLocalHostname(u.Hostname()) {
+
return nil, fmt.Errorf("redirect uri %s's domain must not be a local hostname", ruri)
+
}
+
break
+
case strings.Contains(u.Scheme, "."):
+
if metadata.ApplicationType != "native" {
+
return nil, errors.New("private-use uri scheme redirect uris are only allowed for native apps")
+
}
+
+
revdomain := reverseDomain(u.Scheme)
+
+
if isLocalHostname(revdomain) {
+
return nil, errors.New("private use uri scheme redirect uris must not be local hostnames")
+
}
+
+
if strings.HasPrefix(u.String(), fmt.Sprintf("%s://", u.Scheme)) || u.Hostname() != "" || u.Port() != "" {
+
return nil, fmt.Errorf("private use uri scheme must be in the form ")
+
}
+
default:
+
return nil, fmt.Errorf("invalid redirect uri scheme `%s`", u.Scheme)
+
}
+
}
+
+
return &metadata, nil
+
}
+
+
func isLocalHostname(hostname string) bool {
+
pts := strings.Split(hostname, ".")
+
if len(pts) < 2 {
+
return true
+
}
+
+
tld := strings.ToLower(pts[len(pts)-1])
+
return tld == "test" || tld == "local" || tld == "localhost" || tld == "invalid" || tld == "example"
+
}
+
+
func reverseDomain(domain string) string {
+
pts := strings.Split(domain, ".")
+
slices.Reverse(pts)
+
return strings.Join(pts, ".")
+
}
+24
oauth/client/metadata.go
···
+
package client
+
+
type Metadata struct {
+
ClientID string `json:"client_id"`
+
ClientName string `json:"client_name"`
+
ClientURI string `json:"client_uri"`
+
LogoURI string `json:"logo_uri"`
+
TOSURI string `json:"tos_uri"`
+
PolicyURI string `json:"policy_uri"`
+
RedirectURIs []string `json:"redirect_uris"`
+
GrantTypes []string `json:"grant_types"`
+
ResponseTypes []string `json:"response_types"`
+
ApplicationType string `json:"application_type"`
+
DpopBoundAccessTokens bool `json:"dpop_bound_access_tokens"`
+
JWKSURI *string `json:"jwks_uri,omitempty"`
+
JWKS *MetadataJwks `json:"jwks,omitempty"`
+
Scope string `json:"scope"`
+
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"`
+
TokenEndpointAuthSigningAlg string `json:"token_endpoint_auth_signing_alg"`
+
}
+
+
type MetadataJwks struct {
+
Keys []any `json:"keys"`
+
}
-8
oauth/client.go
···
-
package oauth
-
-
import "github.com/lestrrat-go/jwx/v2/jwk"
-
-
type Client struct {
-
Metadata *ClientMetadata
-
JWKS jwk.Key
-
}
-390
oauth/client_manager/client_manager.go
···
-
package client_manager
-
-
import (
-
"context"
-
"encoding/json"
-
"errors"
-
"fmt"
-
"io"
-
"log/slog"
-
"net/http"
-
"net/url"
-
"slices"
-
"strings"
-
"time"
-
-
cache "github.com/go-pkgz/expirable-cache/v3"
-
"github.com/haileyok/cocoon/internal/helpers"
-
"github.com/haileyok/cocoon/oauth"
-
"github.com/lestrrat-go/jwx/v2/jwk"
-
)
-
-
type ClientManager struct {
-
cli *http.Client
-
logger *slog.Logger
-
jwksCache cache.Cache[string, jwk.Key]
-
metadataCache cache.Cache[string, oauth.ClientMetadata]
-
}
-
-
type Args struct {
-
Cli *http.Client
-
Logger *slog.Logger
-
}
-
-
func New(args Args) *ClientManager {
-
if args.Logger == nil {
-
args.Logger = slog.Default()
-
}
-
-
if args.Cli == nil {
-
args.Cli = http.DefaultClient
-
}
-
-
jwksCache := cache.NewCache[string, jwk.Key]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute)
-
metadataCache := cache.NewCache[string, oauth.ClientMetadata]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute)
-
-
return &ClientManager{
-
cli: args.Cli,
-
logger: args.Logger,
-
jwksCache: jwksCache,
-
metadataCache: metadataCache,
-
}
-
}
-
-
func (cm *ClientManager) GetClient(ctx context.Context, clientId string) (*oauth.Client, error) {
-
metadata, err := cm.getClientMetadata(ctx, clientId)
-
if err != nil {
-
return nil, err
-
}
-
-
var jwks jwk.Key
-
if metadata.JWKS != nil {
-
// TODO: this is kinda bad but whatever for now. there could obviously be more than one jwk, and we need to
-
// make sure we use the right one
-
k, err := helpers.ParseJWKFromBytes((*metadata.JWKS)[0])
-
if err != nil {
-
return nil, err
-
}
-
jwks = k
-
} else if metadata.JWKSURI != nil {
-
maybeJwks, err := cm.getClientJwks(ctx, clientId, *metadata.JWKSURI)
-
if err != nil {
-
return nil, err
-
}
-
-
jwks = maybeJwks
-
}
-
-
return &oauth.Client{
-
Metadata: metadata,
-
JWKS: jwks,
-
}, nil
-
}
-
-
func (cm *ClientManager) getClientMetadata(ctx context.Context, clientId string) (*oauth.ClientMetadata, error) {
-
metadataCached, ok := cm.metadataCache.Get(clientId)
-
if !ok {
-
req, err := http.NewRequestWithContext(ctx, "GET", clientId, nil)
-
if err != nil {
-
return nil, err
-
}
-
-
resp, err := cm.cli.Do(req)
-
if err != nil {
-
return nil, err
-
}
-
defer resp.Body.Close()
-
-
if resp.StatusCode != http.StatusOK {
-
io.Copy(io.Discard, resp.Body)
-
return nil, fmt.Errorf("fetching client metadata returned response code %d", resp.StatusCode)
-
}
-
-
b, err := io.ReadAll(resp.Body)
-
if err != nil {
-
return nil, fmt.Errorf("error reading bytes from client response: %w", err)
-
}
-
-
validated, err := validateAndParseMetadata(clientId, b)
-
if err != nil {
-
return nil, err
-
}
-
-
return validated, nil
-
} else {
-
return &metadataCached, nil
-
}
-
}
-
-
func (cm *ClientManager) getClientJwks(ctx context.Context, clientId, jwksUri string) (jwk.Key, error) {
-
jwks, ok := cm.jwksCache.Get(clientId)
-
if !ok {
-
req, err := http.NewRequestWithContext(ctx, "GET", jwksUri, nil)
-
if err != nil {
-
return nil, err
-
}
-
-
resp, err := cm.cli.Do(req)
-
if err != nil {
-
return nil, err
-
}
-
defer resp.Body.Close()
-
-
if resp.StatusCode != http.StatusOK {
-
io.Copy(io.Discard, resp.Body)
-
return nil, fmt.Errorf("fetching client jwks returned response code %d", resp.StatusCode)
-
}
-
-
type Keys struct {
-
Keys []map[string]any `json:"keys"`
-
}
-
-
var keys Keys
-
if err := json.NewDecoder(resp.Body).Decode(&keys); err != nil {
-
return nil, fmt.Errorf("error unmarshaling keys response: %w", err)
-
}
-
-
if len(keys.Keys) == 0 {
-
return nil, errors.New("no keys in jwks response")
-
}
-
-
// TODO: this is again bad, we should be figuring out which one we need to use...
-
b, err := json.Marshal(keys.Keys[0])
-
if err != nil {
-
return nil, fmt.Errorf("could not marshal key: %w", err)
-
}
-
-
k, err := helpers.ParseJWKFromBytes(b)
-
if err != nil {
-
return nil, err
-
}
-
-
jwks = k
-
}
-
-
return jwks, nil
-
}
-
-
func validateAndParseMetadata(clientId string, b []byte) (*oauth.ClientMetadata, error) {
-
var metadataMap map[string]any
-
if err := json.Unmarshal(b, &metadataMap); err != nil {
-
return nil, fmt.Errorf("error unmarshaling metadata: %w", err)
-
}
-
-
_, jwksOk := metadataMap["jwks"].(string)
-
_, jwksUriOk := metadataMap["jwks_uri"].(string)
-
if jwksOk && jwksUriOk {
-
return nil, errors.New("jwks_uri and jwks are mutually exclusive")
-
}
-
-
for _, k := range []string{
-
"default_max_age",
-
"userinfo_signed_response_alg",
-
"id_token_signed_response_alg",
-
"userinfo_encryhpted_response_alg",
-
"authorization_encrypted_response_enc",
-
"authorization_encrypted_response_alg",
-
"tls_client_certificate_bound_access_tokens",
-
} {
-
_, kOk := metadataMap[k]
-
if kOk {
-
return nil, fmt.Errorf("unsupported `%s` parameter", k)
-
}
-
}
-
-
var metadata oauth.ClientMetadata
-
if err := json.Unmarshal(b, &metadata); err != nil {
-
return nil, fmt.Errorf("error unmarshaling metadata: %w", err)
-
}
-
-
u, err := url.Parse(metadata.ClientURI)
-
if err != nil {
-
return nil, fmt.Errorf("unable to parse client uri: %w", err)
-
}
-
-
if isLocalHostname(u.Hostname()) {
-
return nil, errors.New("`client_uri` hostname is invalid")
-
}
-
-
if metadata.Scope == "" {
-
return nil, errors.New("missing `scopes` scope")
-
}
-
-
scopes := strings.Split(metadata.Scope, " ")
-
if !slices.Contains(scopes, "atproto") {
-
return nil, errors.New("missing `atproto` scope")
-
}
-
-
scopesMap := map[string]bool{}
-
for _, scope := range scopes {
-
if scopesMap[scope] {
-
return nil, fmt.Errorf("duplicate scope `%s`", scope)
-
}
-
-
// TODO: check for unsupported scopes
-
-
scopesMap[scope] = true
-
}
-
-
grantTypesMap := map[string]bool{}
-
for _, gt := range metadata.GrantTypes {
-
if grantTypesMap[gt] {
-
return nil, fmt.Errorf("duplicate grant type `%s`", gt)
-
}
-
-
switch gt {
-
case "implicit":
-
return nil, errors.New("grantg type `implicit` is not allowed")
-
case "authorization_code", "refresh_token":
-
// TODO check if this grant type is supported
-
default:
-
return nil, fmt.Errorf("grant tyhpe `%s` is not supported", gt)
-
}
-
-
grantTypesMap[gt] = true
-
}
-
-
if metadata.ClientID != clientId {
-
return nil, errors.New("`client_id` does not match")
-
}
-
-
subjectType, subjectTypeOk := metadataMap["subject_type"].(string)
-
if subjectTypeOk && subjectType != "public" {
-
return nil, errors.New("only public `subject_type` is supported")
-
}
-
-
switch metadata.TokenEndpointAuthMethod {
-
case "none":
-
if metadata.TokenEndpointAuthSigningAlg != "" {
-
return nil, errors.New("token_endpoint_auth_method `none` must not have token_endpoint_auth_signing_alg")
-
}
-
case "private_key_jwt":
-
if metadata.JWKS == nil && metadata.JWKSURI == nil {
-
return nil, errors.New("private_key_jwt auth method requires jwks or jwks_uri")
-
}
-
-
if metadata.JWKS != nil && len(*metadata.JWKS) == 0 {
-
return nil, errors.New("private_key_jwt auth method requires atleast one key in jwks")
-
}
-
-
if metadata.TokenEndpointAuthSigningAlg == "" {
-
return nil, errors.New("missing token_endpoint_auth_signing_alg in client metadata")
-
}
-
default:
-
return nil, fmt.Errorf("unsupported client authentication method `%s`", metadata.TokenEndpointAuthMethod)
-
}
-
-
if !metadata.DpopBoundAccessTokens {
-
return nil, errors.New("dpop_bound_access_tokens must be true")
-
}
-
-
if !slices.Contains(metadata.ResponseTypes, "code") {
-
return nil, errors.New("response_types must inclue `code`")
-
}
-
-
if !slices.Contains(metadata.GrantTypes, "authorization_code") {
-
return nil, errors.New("the `code` response type requires that `grant_types` contains `authorization_code`")
-
}
-
-
if len(metadata.RedirectURIs) == 0 {
-
return nil, errors.New("at least one `redirect_uri` is required")
-
}
-
-
if metadata.ApplicationType == "native" && metadata.TokenEndpointAuthMethod == "none" {
-
return nil, errors.New("native clients must authenticate using `none` method")
-
}
-
-
if metadata.ApplicationType == "web" && slices.Contains(metadata.GrantTypes, "implicit") {
-
for _, ruri := range metadata.RedirectURIs {
-
u, err := url.Parse(ruri)
-
if err != nil {
-
return nil, fmt.Errorf("error parsing redirect uri: %w", err)
-
}
-
-
if u.Scheme != "https" {
-
return nil, errors.New("web clients must use https redirect uris")
-
}
-
-
if u.Hostname() == "localhost" {
-
return nil, errors.New("web clients must not use localhost as the hostname")
-
}
-
}
-
}
-
-
for _, ruri := range metadata.RedirectURIs {
-
u, err := url.Parse(ruri)
-
if err != nil {
-
return nil, fmt.Errorf("error parsing redirect uri: %w", err)
-
}
-
-
if u.User != nil {
-
if u.User.Username() != "" {
-
return nil, fmt.Errorf("redirect uri %s must not contain credentials", ruri)
-
}
-
-
if _, hasPass := u.User.Password(); hasPass {
-
return nil, fmt.Errorf("redirect uri %s must not contain credentials", ruri)
-
}
-
}
-
-
switch true {
-
case u.Hostname() == "localhost":
-
return nil, errors.New("loopback redirect uri is not allowed (use explicit ips instead)")
-
case u.Hostname() == "127.0.0.1", u.Hostname() == "[::1]":
-
if metadata.ApplicationType != "native" {
-
return nil, errors.New("loopback redirect uris are only allowed for native apps")
-
}
-
-
if u.Port() != "" {
-
// reference impl doesn't do anything with this?
-
}
-
-
if u.Scheme != "http" {
-
return nil, fmt.Errorf("loopback redirect uri %s must use http", ruri)
-
}
-
-
break
-
case u.Scheme == "http":
-
return nil, errors.New("only loopbvack redirect uris are allowed to use the `http` scheme")
-
case u.Scheme == "https":
-
if isLocalHostname(u.Hostname()) {
-
return nil, fmt.Errorf("redirect uri %s's domain must not be a local hostname", ruri)
-
}
-
break
-
case strings.Contains(u.Scheme, "."):
-
if metadata.ApplicationType != "native" {
-
return nil, errors.New("private-use uri scheme redirect uris are only allowed for native apps")
-
}
-
-
revdomain := reverseDomain(u.Scheme)
-
-
if isLocalHostname(revdomain) {
-
return nil, errors.New("private use uri scheme redirect uris must not be local hostnames")
-
}
-
-
if strings.HasPrefix(u.String(), fmt.Sprintf("%s://", u.Scheme)) || u.Hostname() != "" || u.Port() != "" {
-
return nil, fmt.Errorf("private use uri scheme must be in the form ")
-
}
-
default:
-
return nil, fmt.Errorf("invalid redirect uri scheme `%s`", u.Scheme)
-
}
-
}
-
-
return &metadata, nil
-
}
-
-
func isLocalHostname(hostname string) bool {
-
pts := strings.Split(hostname, ".")
-
if len(pts) < 2 {
-
return true
-
}
-
-
tld := strings.ToLower(pts[len(pts)-1])
-
return tld == "test" || tld == "local" || tld == "localhost" || tld == "invalid" || tld == "example"
-
}
-
-
func reverseDomain(domain string) string {
-
pts := strings.Split(domain, ".")
-
slices.Reverse(pts)
-
return strings.Join(pts, ".")
-
}
-20
oauth/client_metadata.go
···
-
package oauth
-
-
type ClientMetadata struct {
-
ClientID string `json:"client_id"`
-
ClientName string `json:"client_name"`
-
ClientURI string `json:"client_uri"`
-
LogoURI string `json:"logo_uri"`
-
TOSURI string `json:"tos_uri"`
-
PolicyURI string `json:"policy_uri"`
-
RedirectURIs []string `json:"redirect_uris"`
-
GrantTypes []string `json:"grant_types"`
-
ResponseTypes []string `json:"response_types"`
-
ApplicationType string `json:"application_type"`
-
DpopBoundAccessTokens bool `json:"dpop_bound_access_tokens"`
-
JWKSURI *string `json:"jwks_uri,omitempty"`
-
JWKS *[][]byte `json:"jwks,omitempty"`
-
Scope string `json:"scope"`
-
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"`
-
TokenEndpointAuthSigningAlg string `json:"token_endpoint_auth_signing_alg"`
-
}
-251
oauth/dpop/dpop_manager/dpop_manager.go
···
-
package dpop_manager
-
-
import (
-
"crypto"
-
"crypto/sha256"
-
"encoding/base64"
-
"encoding/json"
-
"errors"
-
"fmt"
-
"log/slog"
-
"net/http"
-
"net/url"
-
"strings"
-
"time"
-
-
"github.com/golang-jwt/jwt/v4"
-
"github.com/haileyok/cocoon/internal/helpers"
-
"github.com/haileyok/cocoon/oauth/constants"
-
"github.com/haileyok/cocoon/oauth/dpop"
-
"github.com/haileyok/cocoon/oauth/dpop/nonce"
-
"github.com/lestrrat-go/jwx/v2/jwa"
-
"github.com/lestrrat-go/jwx/v2/jwk"
-
)
-
-
type DpopManager struct {
-
nonce *nonce.Nonce
-
jtiCache *jtiCache
-
logger *slog.Logger
-
hostname string
-
}
-
-
type Args struct {
-
NonceSecret []byte
-
NonceRotationInterval time.Duration
-
OnNonceSecretCreated func([]byte)
-
JTICacheSize int
-
Logger *slog.Logger
-
Hostname string
-
}
-
-
func New(args Args) *DpopManager {
-
if args.Logger == nil {
-
args.Logger = slog.Default()
-
}
-
-
if args.JTICacheSize == 0 {
-
args.JTICacheSize = 100_000
-
}
-
-
if args.NonceSecret == nil {
-
args.Logger.Warn("nonce secret passed to dpop manager was nil. existing sessions may break. consider saving and restoring your nonce.")
-
}
-
-
return &DpopManager{
-
nonce: nonce.NewNonce(nonce.Args{
-
RotationInterval: args.NonceRotationInterval,
-
Secret: args.NonceSecret,
-
OnSecretCreated: args.OnNonceSecretCreated,
-
}),
-
jtiCache: newJTICache(args.JTICacheSize),
-
logger: args.Logger,
-
hostname: args.Hostname,
-
}
-
}
-
-
func (dm *DpopManager) CheckProof(reqMethod, reqUrl string, headers http.Header, accessToken *string) (*dpop.Proof, error) {
-
if reqMethod == "" {
-
return nil, errors.New("HTTP method is required")
-
}
-
-
if !strings.HasPrefix(reqUrl, "https://") {
-
reqUrl = "https://" + dm.hostname + reqUrl
-
}
-
-
proof := extractProof(headers)
-
-
if proof == "" {
-
return nil, nil
-
}
-
-
parser := jwt.NewParser(jwt.WithoutClaimsValidation())
-
var token *jwt.Token
-
-
token, _, err := parser.ParseUnverified(proof, jwt.MapClaims{})
-
if err != nil {
-
return nil, fmt.Errorf("could not parse dpop proof jwt: %w", err)
-
}
-
-
typ, _ := token.Header["typ"].(string)
-
if typ != "dpop+jwt" {
-
return nil, errors.New(`invalid dpop proof jwt: "typ" must be 'dpop+jwt'`)
-
}
-
-
dpopJwk, jwkOk := token.Header["jwk"].(map[string]any)
-
if !jwkOk {
-
return nil, errors.New(`invalid dpop proof jwt: "jwk" is missing in header`)
-
}
-
-
jwkb, err := json.Marshal(dpopJwk)
-
if err != nil {
-
return nil, fmt.Errorf("failed to marshal jwk: %w", err)
-
}
-
-
key, err := jwk.ParseKey(jwkb)
-
if err != nil {
-
return nil, fmt.Errorf("failed to parse jwk: %w", err)
-
}
-
-
var pubKey any
-
if err := key.Raw(&pubKey); err != nil {
-
return nil, fmt.Errorf("failed to get raw public key: %w", err)
-
}
-
-
token, err = jwt.Parse(proof, func(t *jwt.Token) (any, error) {
-
alg := t.Header["alg"].(string)
-
-
switch key.KeyType() {
-
case jwa.EC:
-
if !strings.HasPrefix(alg, "ES") {
-
return nil, fmt.Errorf("algorithm %s doesn't match EC key type", alg)
-
}
-
case jwa.RSA:
-
if !strings.HasPrefix(alg, "RS") && !strings.HasPrefix(alg, "PS") {
-
return nil, fmt.Errorf("algorithm %s doesn't match RSA key type", alg)
-
}
-
case jwa.OKP:
-
if alg != "EdDSA" {
-
return nil, fmt.Errorf("algorithm %s doesn't match OKP key type", alg)
-
}
-
}
-
-
return pubKey, nil
-
}, jwt.WithValidMethods([]string{"ES256", "ES384", "ES512", "RS256", "RS384", "RS512", "PS256", "PS384", "PS512", "EdDSA"}))
-
if err != nil {
-
return nil, fmt.Errorf("could not verify dpop proof jwt: %w", err)
-
}
-
-
if !token.Valid {
-
return nil, errors.New("dpop proof jwt is invalid")
-
}
-
-
claims, ok := token.Claims.(jwt.MapClaims)
-
if !ok {
-
return nil, errors.New("no claims in dpop proof jwt")
-
}
-
-
iat, iatOk := claims["iat"].(float64)
-
if !iatOk {
-
return nil, errors.New(`invalid dpop proof jwt: "iat" is missing`)
-
}
-
-
iatTime := time.Unix(int64(iat), 0)
-
now := time.Now()
-
-
if now.Sub(iatTime) > constants.DpopNonceMaxAge+constants.DpopCheckTolerance {
-
return nil, errors.New("dpop proof too old")
-
}
-
-
if iatTime.Sub(now) > constants.DpopCheckTolerance {
-
return nil, errors.New("dpop proof iat is in the future")
-
}
-
-
jti, _ := claims["jti"].(string)
-
if jti == "" {
-
return nil, errors.New(`invalid dpop proof jwt: "jti" is missing`)
-
}
-
-
if dm.jtiCache.add(jti) {
-
return nil, errors.New("dpop proof replay detected")
-
}
-
-
htm, _ := claims["htm"].(string)
-
if htm == "" {
-
return nil, errors.New(`invalid dpop proof jwt: "htm" is missing`)
-
}
-
-
if htm != reqMethod {
-
return nil, errors.New(`invalid dpop proof jwt: "htm" mismatch`)
-
}
-
-
htu, _ := claims["htu"].(string)
-
if htu == "" {
-
return nil, errors.New(`invalid dpop proof jwt: "htu" is missing`)
-
}
-
-
parsedHtu, err := helpers.OauthParseHtu(htu)
-
if err != nil {
-
return nil, errors.New(`invalid dpop proof jwt: "htu" could not be parsed`)
-
}
-
-
u, _ := url.Parse(reqUrl)
-
if parsedHtu != helpers.OauthNormalizeHtu(u) {
-
return nil, fmt.Errorf(`invalid dpop proof jwt: "htu" mismatch. reqUrl: %s, parsed: %s, normalized: %s`, reqUrl, parsedHtu, helpers.OauthNormalizeHtu(u))
-
}
-
-
nonce, _ := claims["nonce"].(string)
-
if nonce == "" {
-
// WARN: this _must_ be `use_dpop_nonce` for clients know they should make another request
-
return nil, errors.New("use_dpop_nonce")
-
}
-
-
if nonce != "" && !dm.nonce.Check(nonce) {
-
// WARN: this _must_ be `use_dpop_nonce` so that clients will fetch a new nonce
-
return nil, errors.New("use_dpop_nonce")
-
}
-
-
ath, _ := claims["ath"].(string)
-
-
if accessToken != nil && *accessToken != "" {
-
if ath == "" {
-
return nil, errors.New(`invalid dpop proof jwt: "ath" is required with access token`)
-
}
-
-
hash := sha256.Sum256([]byte(*accessToken))
-
if ath != base64.RawURLEncoding.EncodeToString(hash[:]) {
-
return nil, errors.New(`invalid dpop proof jwt: "ath" mismatch`)
-
}
-
} else if ath != "" {
-
return nil, errors.New(`invalid dpop proof jwt: "ath" claim not allowed`)
-
}
-
-
thumbBytes, err := key.Thumbprint(crypto.SHA256)
-
if err != nil {
-
return nil, fmt.Errorf("failed to calculate thumbprint: %w", err)
-
}
-
-
thumb := base64.RawURLEncoding.EncodeToString(thumbBytes)
-
-
return &dpop.Proof{
-
JTI: jti,
-
JKT: thumb,
-
HTM: htm,
-
HTU: htu,
-
}, nil
-
}
-
-
func extractProof(headers http.Header) string {
-
dpopHeaders := headers["Dpop"]
-
switch len(dpopHeaders) {
-
case 0:
-
return ""
-
case 1:
-
return dpopHeaders[0]
-
default:
-
return ""
-
}
-
}
-
-
func (dm *DpopManager) NextNonce() string {
-
return dm.nonce.NextNonce()
-
}
-28
oauth/dpop/dpop_manager/jti_cache.go
···
-
package dpop_manager
-
-
import (
-
"sync"
-
"time"
-
-
cache "github.com/go-pkgz/expirable-cache/v3"
-
"github.com/haileyok/cocoon/oauth/constants"
-
)
-
-
type jtiCache struct {
-
mu sync.Mutex
-
cache cache.Cache[string, bool]
-
}
-
-
func newJTICache(size int) *jtiCache {
-
cache := cache.NewCache[string, bool]().WithTTL(24 * time.Hour).WithLRU().WithTTL(constants.JTITtl)
-
return &jtiCache{
-
cache: cache,
-
mu: sync.Mutex{},
-
}
-
}
-
-
func (c *jtiCache) add(jti string) bool {
-
c.mu.Lock()
-
defer c.mu.Unlock()
-
return c.cache.Add(jti, true)
-
}
+28
oauth/dpop/jti_cache.go
···
+
package dpop
+
+
import (
+
"sync"
+
"time"
+
+
cache "github.com/go-pkgz/expirable-cache/v3"
+
"github.com/haileyok/cocoon/oauth/constants"
+
)
+
+
type jtiCache struct {
+
mu sync.Mutex
+
cache cache.Cache[string, bool]
+
}
+
+
func newJTICache(size int) *jtiCache {
+
cache := cache.NewCache[string, bool]().WithTTL(24 * time.Hour).WithLRU().WithTTL(constants.JTITtl)
+
return &jtiCache{
+
cache: cache,
+
mu: sync.Mutex{},
+
}
+
}
+
+
func (c *jtiCache) add(jti string) bool {
+
c.mu.Lock()
+
defer c.mu.Unlock()
+
return c.cache.Add(jti, true)
+
}
+253
oauth/dpop/manager.go
···
+
package dpop
+
+
import (
+
"crypto"
+
"crypto/sha256"
+
"encoding/base64"
+
"encoding/json"
+
"errors"
+
"fmt"
+
"log/slog"
+
"net/http"
+
"net/url"
+
"strings"
+
"time"
+
+
"github.com/golang-jwt/jwt/v4"
+
"github.com/haileyok/cocoon/internal/helpers"
+
"github.com/haileyok/cocoon/oauth/constants"
+
"github.com/lestrrat-go/jwx/v2/jwa"
+
"github.com/lestrrat-go/jwx/v2/jwk"
+
)
+
+
type Manager struct {
+
nonce *Nonce
+
jtiCache *jtiCache
+
logger *slog.Logger
+
hostname string
+
}
+
+
type ManagerArgs struct {
+
NonceSecret []byte
+
NonceRotationInterval time.Duration
+
OnNonceSecretCreated func([]byte)
+
JTICacheSize int
+
Logger *slog.Logger
+
Hostname string
+
}
+
+
var (
+
ErrUseDpopNonce = errors.New("use_dpop_nonce")
+
)
+
+
func NewManager(args ManagerArgs) *Manager {
+
if args.Logger == nil {
+
args.Logger = slog.Default()
+
}
+
+
if args.JTICacheSize == 0 {
+
args.JTICacheSize = 100_000
+
}
+
+
if args.NonceSecret == nil {
+
args.Logger.Warn("nonce secret passed to dpop manager was nil. existing sessions may break. consider saving and restoring your nonce.")
+
}
+
+
return &Manager{
+
nonce: NewNonce(NonceArgs{
+
RotationInterval: args.NonceRotationInterval,
+
Secret: args.NonceSecret,
+
OnSecretCreated: args.OnNonceSecretCreated,
+
}),
+
jtiCache: newJTICache(args.JTICacheSize),
+
logger: args.Logger,
+
hostname: args.Hostname,
+
}
+
}
+
+
func (dm *Manager) CheckProof(reqMethod, reqUrl string, headers http.Header, accessToken *string) (*Proof, error) {
+
if reqMethod == "" {
+
return nil, errors.New("HTTP method is required")
+
}
+
+
if !strings.HasPrefix(reqUrl, "https://") {
+
reqUrl = "https://" + dm.hostname + reqUrl
+
}
+
+
proof := extractProof(headers)
+
+
if proof == "" {
+
return nil, nil
+
}
+
+
parser := jwt.NewParser(jwt.WithoutClaimsValidation())
+
var token *jwt.Token
+
+
token, _, err := parser.ParseUnverified(proof, jwt.MapClaims{})
+
if err != nil {
+
return nil, fmt.Errorf("could not parse dpop proof jwt: %w", err)
+
}
+
+
typ, _ := token.Header["typ"].(string)
+
if typ != "dpop+jwt" {
+
return nil, errors.New(`invalid dpop proof jwt: "typ" must be 'dpop+jwt'`)
+
}
+
+
dpopJwk, jwkOk := token.Header["jwk"].(map[string]any)
+
if !jwkOk {
+
return nil, errors.New(`invalid dpop proof jwt: "jwk" is missing in header`)
+
}
+
+
jwkb, err := json.Marshal(dpopJwk)
+
if err != nil {
+
return nil, fmt.Errorf("failed to marshal jwk: %w", err)
+
}
+
+
key, err := jwk.ParseKey(jwkb)
+
if err != nil {
+
return nil, fmt.Errorf("failed to parse jwk: %w", err)
+
}
+
+
var pubKey any
+
if err := key.Raw(&pubKey); err != nil {
+
return nil, fmt.Errorf("failed to get raw public key: %w", err)
+
}
+
+
token, err = jwt.Parse(proof, func(t *jwt.Token) (any, error) {
+
alg := t.Header["alg"].(string)
+
+
switch key.KeyType() {
+
case jwa.EC:
+
if !strings.HasPrefix(alg, "ES") {
+
return nil, fmt.Errorf("algorithm %s doesn't match EC key type", alg)
+
}
+
case jwa.RSA:
+
if !strings.HasPrefix(alg, "RS") && !strings.HasPrefix(alg, "PS") {
+
return nil, fmt.Errorf("algorithm %s doesn't match RSA key type", alg)
+
}
+
case jwa.OKP:
+
if alg != "EdDSA" {
+
return nil, fmt.Errorf("algorithm %s doesn't match OKP key type", alg)
+
}
+
}
+
+
return pubKey, nil
+
}, jwt.WithValidMethods([]string{"ES256", "ES384", "ES512", "RS256", "RS384", "RS512", "PS256", "PS384", "PS512", "EdDSA"}))
+
if err != nil {
+
return nil, fmt.Errorf("could not verify dpop proof jwt: %w", err)
+
}
+
+
if !token.Valid {
+
return nil, errors.New("dpop proof jwt is invalid")
+
}
+
+
claims, ok := token.Claims.(jwt.MapClaims)
+
if !ok {
+
return nil, errors.New("no claims in dpop proof jwt")
+
}
+
+
iat, iatOk := claims["iat"].(float64)
+
if !iatOk {
+
return nil, errors.New(`invalid dpop proof jwt: "iat" is missing`)
+
}
+
+
iatTime := time.Unix(int64(iat), 0)
+
now := time.Now()
+
+
if now.Sub(iatTime) > constants.DpopNonceMaxAge+constants.DpopCheckTolerance {
+
return nil, errors.New("dpop proof too old")
+
}
+
+
if iatTime.Sub(now) > constants.DpopCheckTolerance {
+
return nil, errors.New("dpop proof iat is in the future")
+
}
+
+
jti, _ := claims["jti"].(string)
+
if jti == "" {
+
return nil, errors.New(`invalid dpop proof jwt: "jti" is missing`)
+
}
+
+
if dm.jtiCache.add(jti) {
+
return nil, errors.New("dpop proof replay detected")
+
}
+
+
htm, _ := claims["htm"].(string)
+
if htm == "" {
+
return nil, errors.New(`invalid dpop proof jwt: "htm" is missing`)
+
}
+
+
if htm != reqMethod {
+
return nil, errors.New(`invalid dpop proof jwt: "htm" mismatch`)
+
}
+
+
htu, _ := claims["htu"].(string)
+
if htu == "" {
+
return nil, errors.New(`invalid dpop proof jwt: "htu" is missing`)
+
}
+
+
parsedHtu, err := helpers.OauthParseHtu(htu)
+
if err != nil {
+
return nil, errors.New(`invalid dpop proof jwt: "htu" could not be parsed`)
+
}
+
+
u, _ := url.Parse(reqUrl)
+
if parsedHtu != helpers.OauthNormalizeHtu(u) {
+
return nil, fmt.Errorf(`invalid dpop proof jwt: "htu" mismatch. reqUrl: %s, parsed: %s, normalized: %s`, reqUrl, parsedHtu, helpers.OauthNormalizeHtu(u))
+
}
+
+
nonce, _ := claims["nonce"].(string)
+
if nonce == "" {
+
// WARN: this _must_ be `use_dpop_nonce` for clients know they should make another request
+
return nil, ErrUseDpopNonce
+
}
+
+
if nonce != "" && !dm.nonce.Check(nonce) {
+
// WARN: this _must_ be `use_dpop_nonce` so that clients will fetch a new nonce
+
return nil, ErrUseDpopNonce
+
}
+
+
ath, _ := claims["ath"].(string)
+
+
if accessToken != nil && *accessToken != "" {
+
if ath == "" {
+
return nil, errors.New(`invalid dpop proof jwt: "ath" is required with access token`)
+
}
+
+
hash := sha256.Sum256([]byte(*accessToken))
+
if ath != base64.RawURLEncoding.EncodeToString(hash[:]) {
+
return nil, errors.New(`invalid dpop proof jwt: "ath" mismatch`)
+
}
+
} else if ath != "" {
+
return nil, errors.New(`invalid dpop proof jwt: "ath" claim not allowed`)
+
}
+
+
thumbBytes, err := key.Thumbprint(crypto.SHA256)
+
if err != nil {
+
return nil, fmt.Errorf("failed to calculate thumbprint: %w", err)
+
}
+
+
thumb := base64.RawURLEncoding.EncodeToString(thumbBytes)
+
+
return &Proof{
+
JTI: jti,
+
JKT: thumb,
+
HTM: htm,
+
HTU: htu,
+
}, nil
+
}
+
+
func extractProof(headers http.Header) string {
+
dpopHeaders := headers["Dpop"]
+
switch len(dpopHeaders) {
+
case 0:
+
return ""
+
case 1:
+
return dpopHeaders[0]
+
default:
+
return ""
+
}
+
}
+
+
func (dm *Manager) NextNonce() string {
+
return dm.nonce.NextNonce()
+
}
-108
oauth/dpop/nonce/nonce.go
···
-
package nonce
-
-
import (
-
"crypto/hmac"
-
"crypto/sha256"
-
"encoding/base64"
-
"encoding/binary"
-
"sync"
-
"time"
-
-
"github.com/haileyok/cocoon/internal/helpers"
-
"github.com/haileyok/cocoon/oauth/constants"
-
)
-
-
type Nonce struct {
-
rotationInterval time.Duration
-
secret []byte
-
-
mu sync.RWMutex
-
-
counter int64
-
prev string
-
curr string
-
next string
-
}
-
-
type Args struct {
-
RotationInterval time.Duration
-
Secret []byte
-
OnSecretCreated func([]byte)
-
}
-
-
func NewNonce(args Args) *Nonce {
-
if args.RotationInterval == 0 {
-
args.RotationInterval = constants.NonceMaxRotationInterval / 3
-
}
-
-
if args.RotationInterval > constants.NonceMaxRotationInterval {
-
args.RotationInterval = constants.NonceMaxRotationInterval
-
}
-
-
if args.Secret == nil {
-
args.Secret = helpers.RandomBytes(constants.NonceSecretByteLength)
-
args.OnSecretCreated(args.Secret)
-
}
-
-
n := &Nonce{
-
rotationInterval: args.RotationInterval,
-
secret: args.Secret,
-
mu: sync.RWMutex{},
-
}
-
-
n.counter = n.currentCounter()
-
n.prev = n.compute(n.counter - 1)
-
n.curr = n.compute(n.counter)
-
n.next = n.compute(n.counter + 1)
-
-
return n
-
}
-
-
func (n *Nonce) currentCounter() int64 {
-
return time.Now().UnixNano() / int64(n.rotationInterval)
-
}
-
-
func (n *Nonce) compute(counter int64) string {
-
h := hmac.New(sha256.New, n.secret)
-
counterBytes := make([]byte, 8)
-
binary.BigEndian.PutUint64(counterBytes, uint64(counter))
-
h.Write(counterBytes)
-
return base64.RawURLEncoding.EncodeToString(h.Sum(nil))
-
}
-
-
func (n *Nonce) rotate() {
-
counter := n.currentCounter()
-
diff := counter - n.counter
-
-
switch diff {
-
case 0:
-
// counter == n.counter, do nothing
-
case 1:
-
n.prev = n.curr
-
n.curr = n.next
-
n.next = n.compute(counter + 1)
-
case 2:
-
n.prev = n.next
-
n.curr = n.compute(counter)
-
n.next = n.compute(counter + 1)
-
default:
-
n.prev = n.compute(counter - 1)
-
n.curr = n.compute(counter)
-
n.next = n.compute(counter + 1)
-
}
-
-
n.counter = counter
-
}
-
-
func (n *Nonce) NextNonce() string {
-
n.mu.Lock()
-
defer n.mu.Unlock()
-
n.rotate()
-
return n.next
-
}
-
-
func (n *Nonce) Check(nonce string) bool {
-
n.mu.RLock()
-
defer n.mu.RUnlock()
-
return nonce == n.prev || nonce == n.curr || nonce == n.next
-
}
+108
oauth/dpop/nonce.go
···
+
package dpop
+
+
import (
+
"crypto/hmac"
+
"crypto/sha256"
+
"encoding/base64"
+
"encoding/binary"
+
"sync"
+
"time"
+
+
"github.com/haileyok/cocoon/internal/helpers"
+
"github.com/haileyok/cocoon/oauth/constants"
+
)
+
+
type Nonce struct {
+
rotationInterval time.Duration
+
secret []byte
+
+
mu sync.RWMutex
+
+
counter int64
+
prev string
+
curr string
+
next string
+
}
+
+
type NonceArgs struct {
+
RotationInterval time.Duration
+
Secret []byte
+
OnSecretCreated func([]byte)
+
}
+
+
func NewNonce(args NonceArgs) *Nonce {
+
if args.RotationInterval == 0 {
+
args.RotationInterval = constants.NonceMaxRotationInterval / 3
+
}
+
+
if args.RotationInterval > constants.NonceMaxRotationInterval {
+
args.RotationInterval = constants.NonceMaxRotationInterval
+
}
+
+
if args.Secret == nil {
+
args.Secret = helpers.RandomBytes(constants.NonceSecretByteLength)
+
args.OnSecretCreated(args.Secret)
+
}
+
+
n := &Nonce{
+
rotationInterval: args.RotationInterval,
+
secret: args.Secret,
+
mu: sync.RWMutex{},
+
}
+
+
n.counter = n.currentCounter()
+
n.prev = n.compute(n.counter - 1)
+
n.curr = n.compute(n.counter)
+
n.next = n.compute(n.counter + 1)
+
+
return n
+
}
+
+
func (n *Nonce) currentCounter() int64 {
+
return time.Now().UnixNano() / int64(n.rotationInterval)
+
}
+
+
func (n *Nonce) compute(counter int64) string {
+
h := hmac.New(sha256.New, n.secret)
+
counterBytes := make([]byte, 8)
+
binary.BigEndian.PutUint64(counterBytes, uint64(counter))
+
h.Write(counterBytes)
+
return base64.RawURLEncoding.EncodeToString(h.Sum(nil))
+
}
+
+
func (n *Nonce) rotate() {
+
counter := n.currentCounter()
+
diff := counter - n.counter
+
+
switch diff {
+
case 0:
+
// counter == n.counter, do nothing
+
case 1:
+
n.prev = n.curr
+
n.curr = n.next
+
n.next = n.compute(counter + 1)
+
case 2:
+
n.prev = n.next
+
n.curr = n.compute(counter)
+
n.next = n.compute(counter + 1)
+
default:
+
n.prev = n.compute(counter - 1)
+
n.curr = n.compute(counter)
+
n.next = n.compute(counter + 1)
+
}
+
+
n.counter = counter
+
}
+
+
func (n *Nonce) NextNonce() string {
+
n.mu.Lock()
+
defer n.mu.Unlock()
+
n.rotate()
+
return n.next
+
}
+
+
func (n *Nonce) Check(nonce string) bool {
+
n.mu.RLock()
+
defer n.mu.RUnlock()
+
return nonce == n.prev || nonce == n.curr || nonce == n.next
+
}
+32
oauth/helpers.go
···
"errors"
"fmt"
"net/url"
+
"time"
"github.com/haileyok/cocoon/internal/helpers"
"github.com/haileyok/cocoon/oauth/constants"
+
"github.com/haileyok/cocoon/oauth/provider"
)
func GenerateCode() string {
···
return reqId, nil
}
+
+
type SessionAgeResult struct {
+
SessionAge time.Duration
+
RefreshAge time.Duration
+
SessionExpired bool
+
RefreshExpired bool
+
}
+
+
func GetSessionAgeFromToken(t provider.OauthToken) SessionAgeResult {
+
sessionLifetime := constants.PublicClientSessionLifetime
+
refreshLifetime := constants.PublicClientRefreshLifetime
+
if t.ClientAuth.Method != "none" {
+
sessionLifetime = constants.ConfidentialClientSessionLifetime
+
refreshLifetime = constants.ConfidentialClientRefreshLifetime
+
}
+
+
res := SessionAgeResult{}
+
+
res.SessionAge = time.Since(t.CreatedAt)
+
if res.SessionAge > sessionLifetime {
+
res.SessionExpired = true
+
}
+
+
refreshAge := time.Since(t.UpdatedAt)
+
if refreshAge > refreshLifetime {
+
res.RefreshExpired = true
+
}
+
+
return res
+
}
+3 -26
oauth/provider/client_auth.go
···
import (
"context"
"crypto"
-
"database/sql/driver"
"encoding/base64"
-
"encoding/json"
"errors"
"fmt"
"time"
"github.com/golang-jwt/jwt/v4"
-
"github.com/haileyok/cocoon/oauth"
+
"github.com/haileyok/cocoon/oauth/client"
"github.com/haileyok/cocoon/oauth/constants"
"github.com/haileyok/cocoon/oauth/dpop"
)
-
type ClientAuth struct {
-
Method string
-
Alg string
-
Kid string
-
Jkt string
-
Jti string
-
Exp *float64
-
}
-
-
func (ca *ClientAuth) Scan(value any) error {
-
b, ok := value.([]byte)
-
if !ok {
-
return fmt.Errorf("failed to unmarshal OauthParRequest value")
-
}
-
return json.Unmarshal(b, ca)
-
}
-
-
func (ca ClientAuth) Value() (driver.Value, error) {
-
return json.Marshal(ca)
-
}
-
type AuthenticateClientOptions struct {
AllowMissingDpopProof bool
}
···
ClientAssertion *string `form:"client_assertion" json:"client_assertion,omitempty"`
}
-
func (p *Provider) AuthenticateClient(ctx context.Context, req AuthenticateClientRequestBase, proof *dpop.Proof, opts *AuthenticateClientOptions) (*oauth.Client, *ClientAuth, error) {
+
func (p *Provider) AuthenticateClient(ctx context.Context, req AuthenticateClientRequestBase, proof *dpop.Proof, opts *AuthenticateClientOptions) (*client.Client, *ClientAuth, error) {
client, err := p.ClientManager.GetClient(ctx, req.ClientID)
if err != nil {
return nil, nil, fmt.Errorf("failed to get client: %w", err)
···
return client, clientAuth, nil
}
-
func (p *Provider) Authenticate(_ context.Context, req AuthenticateClientRequestBase, client *oauth.Client) (*ClientAuth, error) {
+
func (p *Provider) Authenticate(_ context.Context, req AuthenticateClientRequestBase, client *client.Client) (*ClientAuth, error) {
metadata := client.Metadata
if metadata.TokenEndpointAuthMethod == "none" {
+83
oauth/provider/models.go
···
+
package provider
+
+
import (
+
"database/sql/driver"
+
"encoding/json"
+
"fmt"
+
"time"
+
+
"gorm.io/gorm"
+
)
+
+
type ClientAuth struct {
+
Method string
+
Alg string
+
Kid string
+
Jkt string
+
Jti string
+
Exp *float64
+
}
+
+
func (ca *ClientAuth) Scan(value any) error {
+
b, ok := value.([]byte)
+
if !ok {
+
return fmt.Errorf("failed to unmarshal OauthParRequest value")
+
}
+
return json.Unmarshal(b, ca)
+
}
+
+
func (ca ClientAuth) Value() (driver.Value, error) {
+
return json.Marshal(ca)
+
}
+
+
type ParRequest struct {
+
AuthenticateClientRequestBase
+
ResponseType string `form:"response_type" json:"response_type" validate:"required"`
+
CodeChallenge *string `form:"code_challenge" json:"code_challenge" validate:"required"`
+
CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" validate:"required"`
+
State string `form:"state" json:"state" validate:"required"`
+
RedirectURI string `form:"redirect_uri" json:"redirect_uri" validate:"required"`
+
Scope string `form:"scope" json:"scope" validate:"required"`
+
LoginHint *string `form:"login_hint" json:"login_hint,omitempty"`
+
DpopJkt *string `form:"dpop_jkt" json:"dpop_jkt,omitempty"`
+
}
+
+
func (opr *ParRequest) Scan(value any) error {
+
b, ok := value.([]byte)
+
if !ok {
+
return fmt.Errorf("failed to unmarshal OauthParRequest value")
+
}
+
return json.Unmarshal(b, opr)
+
}
+
+
func (opr ParRequest) Value() (driver.Value, error) {
+
return json.Marshal(opr)
+
}
+
+
type OauthToken struct {
+
gorm.Model
+
ClientId string `gorm:"index"`
+
ClientAuth ClientAuth `gorm:"type:json"`
+
Parameters ParRequest `gorm:"type:json"`
+
ExpiresAt time.Time `gorm:"index"`
+
DeviceId string
+
Sub string `gorm:"index"`
+
Code string `gorm:"index"`
+
Token string `gorm:"uniqueIndex"`
+
RefreshToken string `gorm:"uniqueIndex"`
+
Ip string
+
}
+
+
type OauthAuthorizationRequest struct {
+
gorm.Model
+
RequestId string `gorm:"primaryKey"`
+
ClientId string `gorm:"index"`
+
ClientAuth ClientAuth `gorm:"type:json"`
+
Parameters ParRequest `gorm:"type:json"`
+
ExpiresAt time.Time `gorm:"index"`
+
DeviceId *string
+
Sub *string
+
Code *string
+
Accepted *bool
+
Ip string
+
}
+8 -64
oauth/provider/provider.go
···
package provider
import (
-
"database/sql/driver"
-
"encoding/json"
-
"fmt"
-
"time"
-
-
"github.com/haileyok/cocoon/oauth/client_manager"
-
"github.com/haileyok/cocoon/oauth/dpop/dpop_manager"
-
"gorm.io/gorm"
+
"github.com/haileyok/cocoon/oauth/client"
+
"github.com/haileyok/cocoon/oauth/dpop"
)
type Provider struct {
-
ClientManager *client_manager.ClientManager
-
DpopManager *dpop_manager.DpopManager
+
ClientManager *client.Manager
+
DpopManager *dpop.Manager
hostname string
}
type Args struct {
Hostname string
-
ClientManagerArgs client_manager.Args
-
DpopManagerArgs dpop_manager.Args
+
ClientManagerArgs client.ManagerArgs
+
DpopManagerArgs dpop.ManagerArgs
}
func NewProvider(args Args) *Provider {
return &Provider{
-
ClientManager: client_manager.New(args.ClientManagerArgs),
-
DpopManager: dpop_manager.New(args.DpopManagerArgs),
+
ClientManager: client.NewManager(args.ClientManagerArgs),
+
DpopManager: dpop.NewManager(args.DpopManagerArgs),
hostname: args.Hostname,
}
}
···
func (p *Provider) NextNonce() string {
return p.DpopManager.NextNonce()
}
-
-
type ParRequest struct {
-
AuthenticateClientRequestBase
-
ResponseType string `form:"response_type" json:"response_type" validate:"required"`
-
CodeChallenge *string `form:"code_challenge" json:"code_challenge" validate:"required"`
-
CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" validate:"required"`
-
State string `form:"state" json:"state" validate:"required"`
-
RedirectURI string `form:"redirect_uri" json:"redirect_uri" validate:"required"`
-
Scope string `form:"scope" json:"scope" validate:"required"`
-
LoginHint *string `form:"login_hint" json:"login_hint,omitempty"`
-
DpopJkt *string `form:"dpop_jkt" json:"dpop_jkt,omitempty"`
-
}
-
-
func (opr *ParRequest) Scan(value any) error {
-
b, ok := value.([]byte)
-
if !ok {
-
return fmt.Errorf("failed to unmarshal OauthParRequest value")
-
}
-
return json.Unmarshal(b, opr)
-
}
-
-
func (opr ParRequest) Value() (driver.Value, error) {
-
return json.Marshal(opr)
-
}
-
-
type OauthToken struct {
-
gorm.Model
-
ClientId string `gorm:"index"`
-
ClientAuth ClientAuth `gorm:"type:json"`
-
Parameters ParRequest `gorm:"type:json"`
-
ExpiresAt time.Time `gorm:"index"`
-
DeviceId string
-
Sub string `gorm:"index"`
-
Code string `gorm:"index"`
-
Token string `gorm:"uniqueIndex"`
-
RefreshToken string `gorm:"uniqueIndex"`
-
}
-
-
type OauthAuthorizationRequest struct {
-
gorm.Model
-
RequestId string `gorm:"primaryKey"`
-
ClientId string `gorm:"index"`
-
ClientAuth ClientAuth `gorm:"type:json"`
-
Parameters ParRequest `gorm:"type:json"`
-
ExpiresAt time.Time `gorm:"index"`
-
DeviceId *string
-
Sub *string
-
Code *string
-
Accepted *bool
-
}
+5 -5
plc/client.go
···
"net/url"
"strings"
-
"github.com/bluesky-social/indigo/atproto/crypto"
+
"github.com/bluesky-social/indigo/atproto/atcrypto"
"github.com/bluesky-social/indigo/util"
"github.com/haileyok/cocoon/identity"
)
···
h *http.Client
service string
pdsHostname string
-
rotationKey *crypto.PrivateKeyK256
+
rotationKey *atcrypto.PrivateKeyK256
}
type ClientArgs struct {
···
args.H = util.RobustHTTPClient()
}
-
rk, err := crypto.ParsePrivateBytesK256([]byte(args.RotationKey))
+
rk, err := atcrypto.ParsePrivateBytesK256([]byte(args.RotationKey))
if err != nil {
return nil, err
}
···
}, nil
}
-
func (c *Client) CreateDID(sigkey *crypto.PrivateKeyK256, recovery string, handle string) (string, *Operation, error) {
+
func (c *Client) CreateDID(sigkey *atcrypto.PrivateKeyK256, recovery string, handle string) (string, *Operation, error) {
pubsigkey, err := sigkey.PublicKey()
if err != nil {
return "", nil, err
···
return did, &op, nil
}
-
func (c *Client) SignOp(sigkey *crypto.PrivateKeyK256, op *Operation) error {
+
func (c *Client) SignOp(sigkey *atcrypto.PrivateKeyK256, op *Operation) error {
b, err := op.MarshalCBOR()
if err != nil {
return err
+2 -2
plc/types.go
···
import (
"encoding/json"
-
"github.com/bluesky-social/indigo/atproto/data"
+
"github.com/bluesky-social/indigo/atproto/atdata"
"github.com/haileyok/cocoon/identity"
cbg "github.com/whyrusleeping/cbor-gen"
)
···
return nil, err
}
-
b, err = data.MarshalCBOR(m)
+
b, err = atdata.MarshalCBOR(m)
if err != nil {
return nil, err
}
+85
recording_blockstore/recording_blockstore.go
···
+
package recording_blockstore
+
+
import (
+
"context"
+
"fmt"
+
+
blockformat "github.com/ipfs/go-block-format"
+
"github.com/ipfs/go-cid"
+
blockstore "github.com/ipfs/go-ipfs-blockstore"
+
)
+
+
type RecordingBlockstore struct {
+
base blockstore.Blockstore
+
+
inserts map[cid.Cid]blockformat.Block
+
reads map[cid.Cid]blockformat.Block
+
}
+
+
func New(base blockstore.Blockstore) *RecordingBlockstore {
+
return &RecordingBlockstore{
+
base: base,
+
inserts: make(map[cid.Cid]blockformat.Block),
+
reads: make(map[cid.Cid]blockformat.Block),
+
}
+
}
+
+
func (bs *RecordingBlockstore) Has(ctx context.Context, c cid.Cid) (bool, error) {
+
return bs.base.Has(ctx, c)
+
}
+
+
func (bs *RecordingBlockstore) Get(ctx context.Context, c cid.Cid) (blockformat.Block, error) {
+
b, err := bs.base.Get(ctx, c)
+
if err != nil {
+
return nil, err
+
}
+
bs.reads[c] = b
+
return b, nil
+
}
+
+
func (bs *RecordingBlockstore) GetSize(ctx context.Context, c cid.Cid) (int, error) {
+
return bs.base.GetSize(ctx, c)
+
}
+
+
func (bs *RecordingBlockstore) DeleteBlock(ctx context.Context, c cid.Cid) error {
+
return bs.base.DeleteBlock(ctx, c)
+
}
+
+
func (bs *RecordingBlockstore) Put(ctx context.Context, block blockformat.Block) error {
+
if err := bs.base.Put(ctx, block); err != nil {
+
return err
+
}
+
bs.inserts[block.Cid()] = block
+
return nil
+
}
+
+
func (bs *RecordingBlockstore) PutMany(ctx context.Context, blocks []blockformat.Block) error {
+
if err := bs.base.PutMany(ctx, blocks); err != nil {
+
return err
+
}
+
+
for _, b := range blocks {
+
bs.inserts[b.Cid()] = b
+
}
+
+
return nil
+
}
+
+
func (bs *RecordingBlockstore) AllKeysChan(ctx context.Context) (<-chan cid.Cid, error) {
+
return nil, fmt.Errorf("iteration not allowed on recording blockstore")
+
}
+
+
func (bs *RecordingBlockstore) HashOnRead(enabled bool) {
+
}
+
+
func (bs *RecordingBlockstore) GetWriteLog() map[cid.Cid]blockformat.Block {
+
return bs.inserts
+
}
+
+
func (bs *RecordingBlockstore) GetReadLog() []blockformat.Block {
+
var blocks []blockformat.Block
+
for _, b := range bs.reads {
+
blocks = append(blocks, b)
+
}
+
return blocks
+
}
+30
server/blockstore_variant.go
···
+
package server
+
+
import (
+
"github.com/haileyok/cocoon/sqlite_blockstore"
+
blockstore "github.com/ipfs/go-ipfs-blockstore"
+
)
+
+
type BlockstoreVariant int
+
+
const (
+
BlockstoreVariantSqlite = iota
+
)
+
+
func MustReturnBlockstoreVariant(maybeBsv string) BlockstoreVariant {
+
switch maybeBsv {
+
case "sqlite":
+
return BlockstoreVariantSqlite
+
default:
+
panic("invalid blockstore variant provided")
+
}
+
}
+
+
func (s *Server) getBlockstore(did string) blockstore.Blockstore {
+
switch s.config.BlockstoreVariant {
+
case BlockstoreVariantSqlite:
+
return sqlite_blockstore.New(did, s.db)
+
default:
+
return sqlite_blockstore.New(did, s.db)
+
}
+
}
+37 -7
server/handle_account.go
···
import (
"time"
+
"github.com/haileyok/cocoon/oauth"
+
"github.com/haileyok/cocoon/oauth/constants"
"github.com/haileyok/cocoon/oauth/provider"
+
"github.com/hako/durafmt"
"github.com/labstack/echo/v4"
)
func (s *Server) handleAccount(e echo.Context) error {
+
ctx := e.Request().Context()
repo, sess, err := s.getSessionRepoOrErr(e)
if err != nil {
return e.Redirect(303, "/account/signin")
}
-
now := time.Now()
+
oldestPossibleSession := time.Now().Add(constants.ConfidentialClientSessionLifetime)
var tokens []provider.OauthToken
-
if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE sub = ? AND expires_at >= ? ORDER BY created_at ASC", nil, repo.Repo.Did, now).Scan(&tokens).Error; err != nil {
+
if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE sub = ? AND created_at < ? ORDER BY created_at ASC", nil, repo.Repo.Did, oldestPossibleSession).Scan(&tokens).Error; err != nil {
s.logger.Error("couldnt fetch oauth sessions for account", "did", repo.Repo.Did, "error", err)
sess.AddFlash("Unable to fetch sessions. See server logs for more details.", "error")
sess.Save(e.Request(), e.Response())
···
})
}
+
var filtered []provider.OauthToken
+
for _, t := range tokens {
+
ageRes := oauth.GetSessionAgeFromToken(t)
+
if ageRes.SessionExpired {
+
continue
+
}
+
filtered = append(filtered, t)
+
}
+
+
now := time.Now()
+
tokenInfo := []map[string]string{}
for _, t := range tokens {
+
ageRes := oauth.GetSessionAgeFromToken(t)
+
maxTime := constants.PublicClientSessionLifetime
+
if t.ClientAuth.Method != "none" {
+
maxTime = constants.ConfidentialClientSessionLifetime
+
}
+
+
var clientName string
+
metadata, err := s.oauthProvider.ClientManager.GetClient(ctx, t.ClientId)
+
if err != nil {
+
clientName = t.ClientId
+
} else {
+
clientName = metadata.Metadata.ClientName
+
}
+
tokenInfo = append(tokenInfo, map[string]string{
-
"ClientId": t.ClientId,
-
"CreatedAt": t.CreatedAt.Format("02 Jan 06 15:04 MST"),
-
"UpdatedAt": t.CreatedAt.Format("02 Jan 06 15:04 MST"),
-
"ExpiresAt": t.CreatedAt.Format("02 Jan 06 15:04 MST"),
-
"Token": t.Token,
+
"ClientName": clientName,
+
"Age": durafmt.Parse(ageRes.SessionAge).LimitFirstN(2).String(),
+
"LastUpdated": durafmt.Parse(now.Sub(t.UpdatedAt)).LimitFirstN(2).String(),
+
"ExpiresIn": durafmt.Parse(now.Add(maxTime).Sub(now)).LimitFirstN(2).String(),
+
"Token": t.Token,
+
"Ip": t.Ip,
})
}
+1 -1
server/handle_actor_get_preferences.go
···
err := json.Unmarshal(repo.Preferences, &prefs)
if err != nil || prefs["preferences"] == nil {
prefs = map[string]any{
-
"preferences": map[string]any{},
+
"preferences": []any{},
}
}
+2 -11
server/handle_identity_update_handle.go
···
"github.com/Azure/go-autorest/autorest/to"
"github.com/bluesky-social/indigo/api/atproto"
-
"github.com/bluesky-social/indigo/atproto/crypto"
+
"github.com/bluesky-social/indigo/atproto/atcrypto"
"github.com/bluesky-social/indigo/events"
"github.com/bluesky-social/indigo/util"
"github.com/haileyok/cocoon/identity"
···
Prev: &latest.Cid,
}
-
k, err := crypto.ParsePrivateBytesK256(repo.SigningKey)
+
k, err := atcrypto.ParsePrivateBytesK256(repo.SigningKey)
if err != nil {
s.logger.Error("error parsing signing key", "error", err)
return helpers.ServerError(e, nil)
···
if err := s.passport.BustDoc(context.TODO(), repo.Repo.Did); err != nil {
s.logger.Warn("error busting did doc", "error", err)
}
-
-
s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{
-
RepoHandle: &atproto.SyncSubscribeRepos_Handle{
-
Did: repo.Repo.Did,
-
Handle: req.Handle,
-
Seq: time.Now().UnixMicro(), // TODO: no
-
Time: time.Now().Format(util.ISO8601),
-
},
-
})
s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{
RepoIdentity: &atproto.SyncSubscribeRepos_Identity{
+2 -3
server/handle_import_repo.go
···
"github.com/bluesky-social/indigo/atproto/syntax"
"github.com/bluesky-social/indigo/repo"
-
"github.com/haileyok/cocoon/blockstore"
"github.com/haileyok/cocoon/internal/helpers"
"github.com/haileyok/cocoon/models"
blocks "github.com/ipfs/go-block-format"
···
return helpers.ServerError(e, nil)
}
-
bs := blockstore.New(urepo.Repo.Did, s.db)
+
bs := s.getBlockstore(urepo.Repo.Did)
cs, err := car.NewCarReader(bytes.NewReader(b))
if err != nil {
···
return helpers.ServerError(e, nil)
}
-
if err := bs.UpdateRepo(context.TODO(), root, rev); err != nil {
+
if err := s.UpdateRepo(context.TODO(), urepo.Repo.Did, root, rev); err != nil {
s.logger.Error("error updating repo after commit", "error", err)
return helpers.ServerError(e, nil)
}
+1 -1
server/handle_oauth_authorize.go
···
code := oauth.GenerateCode()
-
if err := s.db.Exec("UPDATE oauth_authorization_requests SET sub = ?, code = ?, accepted = ? WHERE request_id = ?", nil, repo.Repo.Did, code, true, reqId).Error; err != nil {
+
if err := s.db.Exec("UPDATE oauth_authorization_requests SET sub = ?, code = ?, accepted = ?, ip = ? WHERE request_id = ?", nil, repo.Repo.Did, code, true, e.RealIP(), reqId).Error; err != nil {
s.logger.Error("error updating authorization request", "error", err)
return helpers.ServerError(e, nil)
}
+9 -2
server/handle_oauth_par.go
···
package server
import (
+
"errors"
"time"
"github.com/Azure/go-autorest/autorest/to"
"github.com/haileyok/cocoon/internal/helpers"
"github.com/haileyok/cocoon/oauth"
"github.com/haileyok/cocoon/oauth/constants"
+
"github.com/haileyok/cocoon/oauth/dpop"
"github.com/haileyok/cocoon/oauth/provider"
"github.com/labstack/echo/v4"
)
···
// TODO: this seems wrong. should be a way to get the entire request url i believe, but this will work for now
dpopProof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, "https://"+s.config.Hostname+e.Request().URL.String(), e.Request().Header, nil)
if err != nil {
+
if errors.Is(err, dpop.ErrUseDpopNonce) {
+
return e.JSON(400, map[string]string{
+
"error": "use_dpop_nonce",
+
})
+
}
s.logger.Error("error getting dpop proof", "error", err)
-
return helpers.InputError(e, to.StringPtr(err.Error()))
+
return helpers.InputError(e, nil)
}
client, clientAuth, err := s.oauthProvider.AuthenticateClient(e.Request().Context(), parRequest.AuthenticateClientRequestBase, dpopProof, &provider.AuthenticateClientOptions{
···
AllowMissingDpopProof: true,
})
if err != nil {
-
s.logger.Error("error authenticating client", "error", err)
+
s.logger.Error("error authenticating client", "client_id", parRequest.ClientID, "error", err)
return helpers.InputError(e, to.StringPtr(err.Error()))
}
+13 -12
server/handle_oauth_token.go
···
"bytes"
"crypto/sha256"
"encoding/base64"
+
"errors"
"fmt"
"slices"
"time"
···
"github.com/haileyok/cocoon/internal/helpers"
"github.com/haileyok/cocoon/oauth"
"github.com/haileyok/cocoon/oauth/constants"
+
"github.com/haileyok/cocoon/oauth/dpop"
"github.com/haileyok/cocoon/oauth/provider"
"github.com/labstack/echo/v4"
)
···
proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, e.Request().URL.String(), e.Request().Header, nil)
if err != nil {
+
if errors.Is(err, dpop.ErrUseDpopNonce) {
+
return e.JSON(400, map[string]string{
+
"error": "use_dpop_nonce",
+
})
+
}
s.logger.Error("error getting dpop proof", "error", err)
-
return helpers.InputError(e, to.StringPtr(err.Error()))
+
return helpers.InputError(e, nil)
}
client, clientAuth, err := s.oauthProvider.AuthenticateClient(e.Request().Context(), req.AuthenticateClientRequestBase, proof, &provider.AuthenticateClientOptions{
AllowMissingDpopProof: true,
})
if err != nil {
-
s.logger.Error("error authenticating client", "error", err)
+
s.logger.Error("error authenticating client", "client_id", req.ClientID, "error", err)
return helpers.InputError(e, to.StringPtr(err.Error()))
}
···
Code: *authReq.Code,
Token: accessString,
RefreshToken: refreshToken,
+
Ip: authReq.Ip,
}, nil).Error; err != nil {
s.logger.Error("error creating token in db", "error", err)
return helpers.ServerError(e, nil)
···
return helpers.InputError(e, to.StringPtr("dpop proof does not match expected jkt"))
}
-
sessionLifetime := constants.PublicClientSessionLifetime
-
refreshLifetime := constants.PublicClientRefreshLifetime
-
if clientAuth.Method != "none" {
-
sessionLifetime = constants.ConfidentialClientSessionLifetime
-
refreshLifetime = constants.ConfidentialClientRefreshLifetime
-
}
+
ageRes := oauth.GetSessionAgeFromToken(oauthToken)
-
sessionAge := time.Since(oauthToken.CreatedAt)
-
if sessionAge > sessionLifetime {
+
if ageRes.SessionExpired {
return helpers.InputError(e, to.StringPtr("Session expired"))
}
-
refreshAge := time.Since(oauthToken.UpdatedAt)
-
if refreshAge > refreshLifetime {
+
if ageRes.RefreshExpired {
return helpers.InputError(e, to.StringPtr("Refresh token expired"))
}
+28 -16
server/handle_proxy.go
···
secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec"
)
-
func (s *Server) handleProxy(e echo.Context) error {
-
repo, isAuthed := e.Get("repo").(*models.RepoActor)
-
-
pts := strings.Split(e.Request().URL.Path, "/")
-
if len(pts) != 3 {
-
return fmt.Errorf("incorrect number of parts")
-
}
-
+
func (s *Server) getAtprotoProxyEndpointFromRequest(e echo.Context) (string, string, error) {
svc := e.Request().Header.Get("atproto-proxy")
-
if svc == "" {
-
svc = "did:web:api.bsky.app#bsky_appview" // TODO: should be a config var probably
+
if svc == "" && s.config.FallbackProxy != "" {
+
svc = s.config.FallbackProxy
}
svcPts := strings.Split(svc, "#")
if len(svcPts) != 2 {
-
return fmt.Errorf("invalid service header")
+
return "", "", fmt.Errorf("invalid service header")
}
svcDid := svcPts[0]
···
doc, err := s.passport.FetchDoc(e.Request().Context(), svcDid)
if err != nil {
-
return err
+
return "", "", err
}
var endpoint string
···
}
}
+
return endpoint, svcDid, nil
+
}
+
+
func (s *Server) handleProxy(e echo.Context) error {
+
lgr := s.logger.With("handler", "handleProxy")
+
+
repo, isAuthed := e.Get("repo").(*models.RepoActor)
+
+
pts := strings.Split(e.Request().URL.Path, "/")
+
if len(pts) != 3 {
+
return fmt.Errorf("incorrect number of parts")
+
}
+
+
endpoint, svcDid, err := s.getAtprotoProxyEndpointFromRequest(e)
+
if err != nil {
+
lgr.Error("could not get atproto proxy", "error", err)
+
return helpers.ServerError(e, nil)
+
}
+
requrl := e.Request().URL
requrl.Host = strings.TrimPrefix(endpoint, "https://")
requrl.Scheme = "https"
···
}
hj, err := json.Marshal(header)
if err != nil {
-
s.logger.Error("error marshaling header", "error", err)
+
lgr.Error("error marshaling header", "error", err)
return helpers.ServerError(e, nil)
}
···
}
pj, err := json.Marshal(payload)
if err != nil {
-
s.logger.Error("error marashaling payload", "error", err)
+
lgr.Error("error marashaling payload", "error", err)
return helpers.ServerError(e, nil)
}
···
sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey)
if err != nil {
-
s.logger.Error("can't load private key", "error", err)
+
lgr.Error("can't load private key", "error", err)
return err
}
R, S, _, err := sk.SignRaw(rand.Reader, hash[:])
if err != nil {
-
s.logger.Error("error signing", "error", err)
+
lgr.Error("error signing", "error", err)
}
rBytes := R.Bytes()
+2 -2
server/handle_repo_get_record.go
···
package server
import (
-
"github.com/bluesky-social/indigo/atproto/data"
+
"github.com/bluesky-social/indigo/atproto/atdata"
"github.com/bluesky-social/indigo/atproto/syntax"
"github.com/haileyok/cocoon/models"
"github.com/labstack/echo/v4"
···
return err
}
-
val, err := data.UnmarshalCBOR(record.Value)
+
val, err := atdata.UnmarshalCBOR(record.Value)
if err != nil {
return s.handleProxy(e) // TODO: this should be getting handled like...if we don't find it in the db. why doesn't it throw error up there?
}
+2 -2
server/handle_repo_list_records.go
···
"strconv"
"github.com/Azure/go-autorest/autorest/to"
-
"github.com/bluesky-social/indigo/atproto/data"
+
"github.com/bluesky-social/indigo/atproto/atdata"
"github.com/bluesky-social/indigo/atproto/syntax"
"github.com/haileyok/cocoon/internal/helpers"
"github.com/haileyok/cocoon/models"
···
items := []ComAtprotoRepoListRecordsRecordItem{}
for _, r := range records {
-
val, err := data.UnmarshalCBOR(r.Value)
+
val, err := atdata.UnmarshalCBOR(r.Value)
if err != nil {
return err
}
+2 -2
server/handle_repo_list_repos.go
···
Did: r.Did,
Head: c.String(),
Rev: r.Rev,
-
Active: true,
-
Status: nil,
+
Active: r.Active(),
+
Status: r.Status(),
})
}
+50 -8
server/handle_repo_upload_blob.go
···
import (
"bytes"
+
"fmt"
"io"
+
"github.com/aws/aws-sdk-go/aws"
+
"github.com/aws/aws-sdk-go/aws/credentials"
+
"github.com/aws/aws-sdk-go/aws/session"
+
"github.com/aws/aws-sdk-go/service/s3"
"github.com/haileyok/cocoon/internal/helpers"
"github.com/haileyok/cocoon/models"
"github.com/ipfs/go-cid"
···
mime = "application/octet-stream"
}
+
storage := "sqlite"
+
s3Upload := s.s3Config != nil && s.s3Config.BlobstoreEnabled
+
if s3Upload {
+
storage = "s3"
+
}
blob := models.Blob{
Did: urepo.Repo.Did,
RefCount: 0,
CreatedAt: s.repoman.clock.Next().String(),
+
Storage: storage,
}
if err := s.db.Create(&blob, nil).Error; err != nil {
···
read += n
fulldata.Write(data)
-
blobPart := models.BlobPart{
-
BlobID: blob.ID,
-
Idx: part,
-
Data: data,
-
}
+
if !s3Upload {
+
blobPart := models.BlobPart{
+
BlobID: blob.ID,
+
Idx: part,
+
Data: data,
+
}
-
if err := s.db.Create(&blobPart, nil).Error; err != nil {
-
s.logger.Error("error adding blob part to db", "error", err)
-
return helpers.ServerError(e, nil)
+
if err := s.db.Create(&blobPart, nil).Error; err != nil {
+
s.logger.Error("error adding blob part to db", "error", err)
+
return helpers.ServerError(e, nil)
+
}
}
part++
···
if err != nil {
s.logger.Error("error creating cid prefix", "error", err)
return helpers.ServerError(e, nil)
+
}
+
+
if s3Upload {
+
config := &aws.Config{
+
Region: aws.String(s.s3Config.Region),
+
Credentials: credentials.NewStaticCredentials(s.s3Config.AccessKey, s.s3Config.SecretKey, ""),
+
}
+
+
if s.s3Config.Endpoint != "" {
+
config.Endpoint = aws.String(s.s3Config.Endpoint)
+
config.S3ForcePathStyle = aws.Bool(true)
+
}
+
+
sess, err := session.NewSession(config)
+
if err != nil {
+
s.logger.Error("error creating aws session", "error", err)
+
return helpers.ServerError(e, nil)
+
}
+
+
svc := s3.New(sess)
+
+
if _, err := svc.PutObject(&s3.PutObjectInput{
+
Bucket: aws.String(s.s3Config.Bucket),
+
Key: aws.String(fmt.Sprintf("blobs/%s/%s", urepo.Repo.Did, c.String())),
+
Body: bytes.NewReader(fulldata.Bytes()),
+
}); err != nil {
+
s.logger.Error("error uploading blob to s3", "error", err)
+
return helpers.ServerError(e, nil)
+
}
}
if err := s.db.Exec("UPDATE blobs SET cid = ? WHERE id = ?", nil, c.Bytes(), blob.ID).Error; err != nil {
+45
server/handle_server_activate_account.go
···
+
package server
+
+
import (
+
"context"
+
"time"
+
+
"github.com/bluesky-social/indigo/api/atproto"
+
"github.com/bluesky-social/indigo/events"
+
"github.com/bluesky-social/indigo/util"
+
"github.com/haileyok/cocoon/internal/helpers"
+
"github.com/haileyok/cocoon/models"
+
"github.com/labstack/echo/v4"
+
)
+
+
type ComAtprotoServerActivateAccountRequest struct {
+
// NOTE: this implementation will not pay attention to this value
+
DeleteAfter time.Time `json:"deleteAfter"`
+
}
+
+
func (s *Server) handleServerActivateAccount(e echo.Context) error {
+
var req ComAtprotoServerDeactivateAccountRequest
+
if err := e.Bind(&req); err != nil {
+
s.logger.Error("error binding", "error", err)
+
return helpers.ServerError(e, nil)
+
}
+
+
urepo := e.Get("repo").(*models.RepoActor)
+
+
if err := s.db.Exec("UPDATE repos SET deactivated = ? WHERE did = ?", nil, false, urepo.Repo.Did).Error; err != nil {
+
s.logger.Error("error updating account status to deactivated", "error", err)
+
return helpers.ServerError(e, nil)
+
}
+
+
s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{
+
RepoAccount: &atproto.SyncSubscribeRepos_Account{
+
Active: true,
+
Did: urepo.Repo.Did,
+
Status: nil,
+
Seq: time.Now().UnixMicro(), // TODO: bad puppy
+
Time: time.Now().Format(util.ISO8601),
+
},
+
})
+
+
return e.NoContent(200)
+
}
+65
server/handle_server_check_account_status.go
···
+
package server
+
+
import (
+
"github.com/haileyok/cocoon/internal/helpers"
+
"github.com/haileyok/cocoon/models"
+
"github.com/ipfs/go-cid"
+
"github.com/labstack/echo/v4"
+
)
+
+
type ComAtprotoServerCheckAccountStatusResponse struct {
+
Activated bool `json:"activated"`
+
ValidDid bool `json:"validDid"`
+
RepoCommit string `json:"repoCommit"`
+
RepoRev string `json:"repoRev"`
+
RepoBlocks int64 `json:"repoBlocks"`
+
IndexedRecords int64 `json:"indexedRecords"`
+
PrivateStateValues int64 `json:"privateStateValues"`
+
ExpectedBlobs int64 `json:"expectedBlobs"`
+
ImportedBlobs int64 `json:"importedBlobs"`
+
}
+
+
func (s *Server) handleServerCheckAccountStatus(e echo.Context) error {
+
urepo := e.Get("repo").(*models.RepoActor)
+
+
resp := ComAtprotoServerCheckAccountStatusResponse{
+
Activated: true, // TODO: should allow for deactivation etc.
+
ValidDid: true, // TODO: should probably verify?
+
RepoRev: urepo.Rev,
+
ImportedBlobs: 0, // TODO: ???
+
}
+
+
rootcid, err := cid.Cast(urepo.Root)
+
if err != nil {
+
s.logger.Error("error casting cid", "error", err)
+
return helpers.ServerError(e, nil)
+
}
+
resp.RepoCommit = rootcid.String()
+
+
type CountResp struct {
+
Ct int64
+
}
+
+
var blockCtResp CountResp
+
if err := s.db.Raw("SELECT COUNT(*) AS ct FROM blocks WHERE did = ?", nil, urepo.Repo.Did).Scan(&blockCtResp).Error; err != nil {
+
s.logger.Error("error getting block count", "error", err)
+
return helpers.ServerError(e, nil)
+
}
+
resp.RepoBlocks = blockCtResp.Ct
+
+
var recCtResp CountResp
+
if err := s.db.Raw("SELECT COUNT(*) AS ct FROM records WHERE did = ?", nil, urepo.Repo.Did).Scan(&recCtResp).Error; err != nil {
+
s.logger.Error("error getting record count", "error", err)
+
return helpers.ServerError(e, nil)
+
}
+
resp.IndexedRecords = recCtResp.Ct
+
+
var blobCtResp CountResp
+
if err := s.db.Raw("SELECT COUNT(*) AS ct FROM blobs WHERE did = ?", nil, urepo.Repo.Did).Scan(&blobCtResp).Error; err != nil {
+
s.logger.Error("error getting record count", "error", err)
+
return helpers.ServerError(e, nil)
+
}
+
resp.ExpectedBlobs = blobCtResp.Ct
+
+
return e.JSON(200, resp)
+
}
+2 -2
server/handle_server_confirm_email.go
···
}
if urepo.EmailVerificationCode == nil || urepo.EmailVerificationCodeExpiresAt == nil {
-
return helpers.InputError(e, to.StringPtr("ExpiredToken"))
+
return helpers.ExpiredTokenError(e)
}
if *urepo.EmailVerificationCode != req.Token {
···
}
if time.Now().UTC().After(*urepo.EmailVerificationCodeExpiresAt) {
-
return helpers.InputError(e, to.StringPtr("ExpiredToken"))
+
return helpers.ExpiredTokenError(e)
}
now := time.Now().UTC()
+4 -14
server/handle_server_create_account.go
···
"github.com/Azure/go-autorest/autorest/to"
"github.com/bluesky-social/indigo/api/atproto"
-
"github.com/bluesky-social/indigo/atproto/crypto"
+
"github.com/bluesky-social/indigo/atproto/atcrypto"
"github.com/bluesky-social/indigo/atproto/syntax"
"github.com/bluesky-social/indigo/events"
"github.com/bluesky-social/indigo/repo"
"github.com/bluesky-social/indigo/util"
-
"github.com/haileyok/cocoon/blockstore"
"github.com/haileyok/cocoon/internal/helpers"
"github.com/haileyok/cocoon/models"
"github.com/labstack/echo/v4"
···
// TODO: unsupported domains
-
k, err := crypto.GeneratePrivateKeyK256()
+
k, err := atcrypto.GeneratePrivateKeyK256()
if err != nil {
s.logger.Error("error creating signing key", "endpoint", "com.atproto.server.createAccount", "error", err)
return helpers.ServerError(e, nil)
···
}
if customDidHeader == "" {
-
bs := blockstore.New(signupDid, s.db)
+
bs := s.getBlockstore(signupDid)
r := repo.NewRepo(context.TODO(), signupDid, bs)
root, rev, err := r.Commit(context.TODO(), urepo.SignFor)
···
return helpers.ServerError(e, nil)
}
-
if err := bs.UpdateRepo(context.TODO(), root, rev); err != nil {
+
if err := s.UpdateRepo(context.TODO(), urepo.Did, root, rev); err != nil {
s.logger.Error("error updating repo after commit", "error", err)
return helpers.ServerError(e, nil)
}
-
-
s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{
-
RepoHandle: &atproto.SyncSubscribeRepos_Handle{
-
Did: urepo.Did,
-
Handle: request.Handle,
-
Seq: time.Now().UnixMicro(), // TODO: no
-
Time: time.Now().Format(util.ISO8601),
-
},
-
})
s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{
RepoIdentity: &atproto.SyncSubscribeRepos_Identity{
+2 -2
server/handle_server_create_session.go
···
Email: repo.Email,
EmailConfirmed: repo.EmailConfirmedAt != nil,
EmailAuthFactor: false,
-
Active: true, // TODO: eventually do takedowns
-
Status: nil, // TODO eventually do takedowns
+
Active: repo.Active(),
+
Status: repo.Status(),
})
}
+46
server/handle_server_deactivate_account.go
···
+
package server
+
+
import (
+
"context"
+
"time"
+
+
"github.com/Azure/go-autorest/autorest/to"
+
"github.com/bluesky-social/indigo/api/atproto"
+
"github.com/bluesky-social/indigo/events"
+
"github.com/bluesky-social/indigo/util"
+
"github.com/haileyok/cocoon/internal/helpers"
+
"github.com/haileyok/cocoon/models"
+
"github.com/labstack/echo/v4"
+
)
+
+
type ComAtprotoServerDeactivateAccountRequest struct {
+
// NOTE: this implementation will not pay attention to this value
+
DeleteAfter time.Time `json:"deleteAfter"`
+
}
+
+
func (s *Server) handleServerDeactivateAccount(e echo.Context) error {
+
var req ComAtprotoServerDeactivateAccountRequest
+
if err := e.Bind(&req); err != nil {
+
s.logger.Error("error binding", "error", err)
+
return helpers.ServerError(e, nil)
+
}
+
+
urepo := e.Get("repo").(*models.RepoActor)
+
+
if err := s.db.Exec("UPDATE repos SET deactivated = ? WHERE did = ?", nil, true, urepo.Repo.Did).Error; err != nil {
+
s.logger.Error("error updating account status to deactivated", "error", err)
+
return helpers.ServerError(e, nil)
+
}
+
+
s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{
+
RepoAccount: &atproto.SyncSubscribeRepos_Account{
+
Active: false,
+
Did: urepo.Repo.Did,
+
Status: to.StringPtr("deactivated"),
+
Seq: time.Now().UnixMicro(), // TODO: bad puppy
+
Time: time.Now().Format(util.ISO8601),
+
},
+
})
+
+
return e.NoContent(200)
+
}
+8 -6
server/handle_server_get_service_auth.go
···
type ServerGetServiceAuthRequest struct {
Aud string `query:"aud" validate:"required,atproto-did"`
-
Exp int64 `query:"exp"`
-
Lxm string `query:"lxm" validate:"required,atproto-nsid"`
+
// exp should be a float, as some clients will send a non-integer expiration
+
Exp float64 `query:"exp"`
+
Lxm string `query:"lxm" validate:"required,atproto-nsid"`
}
func (s *Server) handleServerGetServiceAuth(e echo.Context) error {
···
return helpers.InputError(e, nil)
}
+
exp := int64(req.Exp)
now := time.Now().Unix()
-
if req.Exp == 0 {
-
req.Exp = now + 60 // default
+
if exp == 0 {
+
exp = now + 60 // default
}
if req.Lxm == "com.atproto.server.getServiceAuth" {
···
}
maxExp := now + (60 * 30)
-
if req.Exp > maxExp {
+
if exp > maxExp {
return helpers.InputError(e, to.StringPtr("expiration too big. smoller please"))
}
···
"aud": req.Aud,
"lxm": req.Lxm,
"jti": uuid.NewString(),
-
"exp": req.Exp,
+
"exp": exp,
"iat": now,
}
pj, err := json.Marshal(payload)
+2 -2
server/handle_server_get_session.go
···
Email: repo.Email,
EmailConfirmed: repo.EmailConfirmedAt != nil,
EmailAuthFactor: false, // TODO: todo todo
-
Active: true,
-
Status: nil,
+
Active: repo.Active(),
+
Status: repo.Status(),
})
}
+2 -2
server/handle_server_refresh_session.go
···
RefreshJwt: sess.RefreshToken,
Handle: repo.Handle,
Did: repo.Repo.Did,
-
Active: true,
-
Status: nil,
+
Active: repo.Active(),
+
Status: repo.Status(),
})
}
+2 -2
server/handle_server_reset_password.go
···
}
if *urepo.PasswordResetCode != req.Token {
-
return helpers.InputError(e, to.StringPtr("InvalidToken"))
+
return helpers.InvalidTokenError(e)
}
if time.Now().UTC().After(*urepo.PasswordResetCodeExpiresAt) {
-
return helpers.InputError(e, to.StringPtr("ExpiredToken"))
+
return helpers.ExpiredTokenError(e)
}
hash, err := bcrypt.GenerateFromPassword([]byte(req.Password), 10)
+3 -4
server/handle_server_update_email.go
···
import (
"time"
-
"github.com/Azure/go-autorest/autorest/to"
"github.com/haileyok/cocoon/internal/helpers"
"github.com/haileyok/cocoon/models"
"github.com/labstack/echo/v4"
···
}
if urepo.EmailUpdateCode == nil || urepo.EmailUpdateCodeExpiresAt == nil {
-
return helpers.InputError(e, to.StringPtr("InvalidToken"))
+
return helpers.InvalidTokenError(e)
}
if *urepo.EmailUpdateCode != req.Token {
-
return helpers.InputError(e, to.StringPtr("InvalidToken"))
+
return helpers.InvalidTokenError(e)
}
if time.Now().UTC().After(*urepo.EmailUpdateCodeExpiresAt) {
-
return helpers.InputError(e, to.StringPtr("ExpiredToken"))
+
return helpers.ExpiredTokenError(e)
}
if err := s.db.Exec("UPDATE repos SET email_update_code = NULL, email_update_code_expires_at = NULL, email_confirmed_at = NULL, email = ? WHERE did = ?", nil, req.Email, urepo.Repo.Did).Error; err != nil {
+79 -8
server/handle_sync_get_blob.go
···
import (
"bytes"
+
"fmt"
+
"io"
+
"github.com/Azure/go-autorest/autorest/to"
+
"github.com/aws/aws-sdk-go/aws"
+
"github.com/aws/aws-sdk-go/aws/credentials"
+
"github.com/aws/aws-sdk-go/aws/session"
+
"github.com/aws/aws-sdk-go/service/s3"
"github.com/haileyok/cocoon/internal/helpers"
"github.com/haileyok/cocoon/models"
"github.com/ipfs/go-cid"
···
return helpers.InputError(e, nil)
}
+
urepo, err := s.getRepoActorByDid(did)
+
if err != nil {
+
s.logger.Error("could not find user for requested blob", "error", err)
+
return helpers.InputError(e, nil)
+
}
+
+
status := urepo.Status()
+
if status != nil {
+
if *status == "deactivated" {
+
return helpers.InputError(e, to.StringPtr("RepoDeactivated"))
+
}
+
}
+
var blob models.Blob
if err := s.db.Raw("SELECT * FROM blobs WHERE did = ? AND cid = ?", nil, did, c.Bytes()).Scan(&blob).Error; err != nil {
s.logger.Error("error looking up blob", "error", err)
···
buf := new(bytes.Buffer)
-
var parts []models.BlobPart
-
if err := s.db.Raw("SELECT * FROM blob_parts WHERE blob_id = ? ORDER BY idx", nil, blob.ID).Scan(&parts).Error; err != nil {
-
s.logger.Error("error getting blob parts", "error", err)
+
if blob.Storage == "sqlite" {
+
var parts []models.BlobPart
+
if err := s.db.Raw("SELECT * FROM blob_parts WHERE blob_id = ? ORDER BY idx", nil, blob.ID).Scan(&parts).Error; err != nil {
+
s.logger.Error("error getting blob parts", "error", err)
+
return helpers.ServerError(e, nil)
+
}
+
+
// TODO: we can just stream this, don't need to make a buffer
+
for _, p := range parts {
+
buf.Write(p.Data)
+
}
+
} else if blob.Storage == "s3" && s.s3Config != nil && s.s3Config.BlobstoreEnabled {
+
config := &aws.Config{
+
Region: aws.String(s.s3Config.Region),
+
Credentials: credentials.NewStaticCredentials(s.s3Config.AccessKey, s.s3Config.SecretKey, ""),
+
}
+
+
if s.s3Config.Endpoint != "" {
+
config.Endpoint = aws.String(s.s3Config.Endpoint)
+
config.S3ForcePathStyle = aws.Bool(true)
+
}
+
+
sess, err := session.NewSession(config)
+
if err != nil {
+
s.logger.Error("error creating aws session", "error", err)
+
return helpers.ServerError(e, nil)
+
}
+
+
svc := s3.New(sess)
+
if result, err := svc.GetObject(&s3.GetObjectInput{
+
Bucket: aws.String(s.s3Config.Bucket),
+
Key: aws.String(fmt.Sprintf("blobs/%s/%s", urepo.Repo.Did, c.String())),
+
}); err != nil {
+
s.logger.Error("error getting blob from s3", "error", err)
+
return helpers.ServerError(e, nil)
+
} else {
+
read := 0
+
part := 0
+
partBuf := make([]byte, 0x10000)
+
+
for {
+
n, err := io.ReadFull(result.Body, partBuf)
+
if err == io.ErrUnexpectedEOF || err == io.EOF {
+
if n == 0 {
+
break
+
}
+
} else if err != nil && err != io.ErrUnexpectedEOF {
+
s.logger.Error("error reading blob", "error", err)
+
return helpers.ServerError(e, nil)
+
}
+
+
data := partBuf[:n]
+
read += n
+
buf.Write(data)
+
part++
+
}
+
}
+
} else {
+
s.logger.Error("unknown storage", "storage", blob.Storage)
return helpers.ServerError(e, nil)
-
}
-
-
// TODO: we can just stream this, don't need to make a buffer
-
for _, p := range parts {
-
buf.Write(p.Data)
}
e.Response().Header().Set(echo.HeaderContentDisposition, "attachment; filename="+c.String())
+14 -12
server/handle_sync_get_blocks.go
···
import (
"bytes"
-
"context"
-
"strings"
"github.com/bluesky-social/indigo/carstore"
-
"github.com/haileyok/cocoon/blockstore"
"github.com/haileyok/cocoon/internal/helpers"
"github.com/ipfs/go-cid"
cbor "github.com/ipfs/go-ipld-cbor"
···
"github.com/labstack/echo/v4"
)
+
type ComAtprotoSyncGetBlocksRequest struct {
+
Did string `query:"did"`
+
Cids []string `query:"cids"`
+
}
+
func (s *Server) handleGetBlocks(e echo.Context) error {
-
did := e.QueryParam("did")
-
cidsstr := e.QueryParam("cids")
-
if did == "" {
+
ctx := e.Request().Context()
+
+
var req ComAtprotoSyncGetBlocksRequest
+
if err := e.Bind(&req); err != nil {
return helpers.InputError(e, nil)
}
-
cidstrs := strings.Split(cidsstr, ",")
-
cids := []cid.Cid{}
+
var cids []cid.Cid
-
for _, cs := range cidstrs {
+
for _, cs := range req.Cids {
c, err := cid.Cast([]byte(cs))
if err != nil {
return err
···
cids = append(cids, c)
}
-
urepo, err := s.getRepoActorByDid(did)
+
urepo, err := s.getRepoActorByDid(req.Did)
if err != nil {
return helpers.ServerError(e, nil)
}
···
return helpers.ServerError(e, nil)
}
-
bs := blockstore.New(urepo.Repo.Did, s.db)
+
bs := s.getBlockstore(urepo.Repo.Did)
for _, c := range cids {
-
b, err := bs.Get(context.TODO(), c)
+
b, err := bs.Get(ctx, c)
if err != nil {
return err
}
+2 -2
server/handle_sync_get_repo_status.go
···
return e.JSON(200, ComAtprotoSyncGetRepoStatusResponse{
Did: urepo.Repo.Did,
-
Active: true,
-
Status: nil,
+
Active: urepo.Active(),
+
Status: urepo.Status(),
Rev: &urepo.Rev,
})
}
+14
server/handle_sync_list_blobs.go
···
package server
import (
+
"github.com/Azure/go-autorest/autorest/to"
"github.com/haileyok/cocoon/internal/helpers"
"github.com/haileyok/cocoon/models"
"github.com/ipfs/go-cid"
···
cursorquery = "AND created_at < ?"
}
params = append(params, limit)
+
+
urepo, err := s.getRepoActorByDid(did)
+
if err != nil {
+
s.logger.Error("could not find user for requested blobs", "error", err)
+
return helpers.InputError(e, nil)
+
}
+
+
status := urepo.Status()
+
if status != nil {
+
if *status == "deactivated" {
+
return helpers.InputError(e, to.StringPtr("RepoDeactivated"))
+
}
+
}
var blobs []models.Blob
if err := s.db.Raw("SELECT * FROM blobs WHERE did = ? "+cursorquery+" ORDER BY created_at DESC LIMIT ?", nil, params...).Scan(&blobs).Error; err != nil {
-18
server/handle_sync_subscribe_repos.go
···
import (
"fmt"
-
"net/http"
"github.com/bluesky-social/indigo/events"
"github.com/bluesky-social/indigo/lex/util"
"github.com/btcsuite/websocket"
"github.com/labstack/echo/v4"
)
-
-
var upgrader = websocket.Upgrader{
-
ReadBufferSize: 1024,
-
WriteBufferSize: 1024,
-
CheckOrigin: func(r *http.Request) bool {
-
return true
-
},
-
}
func (s *Server) handleSyncSubscribeRepos(e echo.Context) error {
conn, err := websocket.Upgrade(e.Response().Writer, e.Request(), e.Response().Header(), 1<<10, 1<<10)
···
case evt.RepoCommit != nil:
header.MsgType = "#commit"
obj = evt.RepoCommit
-
case evt.RepoHandle != nil:
-
header.MsgType = "#handle"
-
obj = evt.RepoHandle
case evt.RepoIdentity != nil:
header.MsgType = "#identity"
obj = evt.RepoIdentity
···
case evt.RepoInfo != nil:
header.MsgType = "#info"
obj = evt.RepoInfo
-
case evt.RepoMigrate != nil:
-
header.MsgType = "#migrate"
-
obj = evt.RepoMigrate
-
case evt.RepoTombstone != nil:
-
header.MsgType = "#tombstone"
-
obj = evt.RepoTombstone
default:
return fmt.Errorf("unrecognized event kind")
}
+275
server/middleware.go
···
+
package server
+
+
import (
+
"crypto/sha256"
+
"encoding/base64"
+
"errors"
+
"fmt"
+
"strings"
+
"time"
+
+
"github.com/Azure/go-autorest/autorest/to"
+
"github.com/golang-jwt/jwt/v4"
+
"github.com/haileyok/cocoon/internal/helpers"
+
"github.com/haileyok/cocoon/models"
+
"github.com/haileyok/cocoon/oauth/dpop"
+
"github.com/haileyok/cocoon/oauth/provider"
+
"github.com/labstack/echo/v4"
+
"gitlab.com/yawning/secp256k1-voi"
+
secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec"
+
"gorm.io/gorm"
+
)
+
+
func (s *Server) handleAdminMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
+
return func(e echo.Context) error {
+
username, password, ok := e.Request().BasicAuth()
+
if !ok || username != "admin" || password != s.config.AdminPassword {
+
return helpers.InputError(e, to.StringPtr("Unauthorized"))
+
}
+
+
if err := next(e); err != nil {
+
e.Error(err)
+
}
+
+
return nil
+
}
+
}
+
+
func (s *Server) handleLegacySessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
+
return func(e echo.Context) error {
+
authheader := e.Request().Header.Get("authorization")
+
if authheader == "" {
+
return e.JSON(401, map[string]string{"error": "Unauthorized"})
+
}
+
+
pts := strings.Split(authheader, " ")
+
if len(pts) != 2 {
+
return helpers.ServerError(e, nil)
+
}
+
+
// move on to oauth session middleware if this is a dpop token
+
if pts[0] == "DPoP" {
+
return next(e)
+
}
+
+
tokenstr := pts[1]
+
token, _, err := new(jwt.Parser).ParseUnverified(tokenstr, jwt.MapClaims{})
+
claims, ok := token.Claims.(jwt.MapClaims)
+
if !ok {
+
return helpers.InvalidTokenError(e)
+
}
+
+
var did string
+
var repo *models.RepoActor
+
+
// service auth tokens
+
lxm, hasLxm := claims["lxm"]
+
if hasLxm {
+
pts := strings.Split(e.Request().URL.String(), "/")
+
if lxm != pts[len(pts)-1] {
+
s.logger.Error("service auth lxm incorrect", "lxm", lxm, "expected", pts[len(pts)-1], "error", err)
+
return helpers.InputError(e, nil)
+
}
+
+
maybeDid, ok := claims["iss"].(string)
+
if !ok {
+
s.logger.Error("no iss in service auth token", "error", err)
+
return helpers.InputError(e, nil)
+
}
+
did = maybeDid
+
+
maybeRepo, err := s.getRepoActorByDid(did)
+
if err != nil {
+
s.logger.Error("error fetching repo", "error", err)
+
return helpers.ServerError(e, nil)
+
}
+
repo = maybeRepo
+
}
+
+
if token.Header["alg"] != "ES256K" {
+
token, err = new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) {
+
if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok {
+
return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"])
+
}
+
return s.privateKey.Public(), nil
+
})
+
if err != nil {
+
s.logger.Error("error parsing jwt", "error", err)
+
return helpers.ExpiredTokenError(e)
+
}
+
+
if !token.Valid {
+
return helpers.InvalidTokenError(e)
+
}
+
} else {
+
kpts := strings.Split(tokenstr, ".")
+
signingInput := kpts[0] + "." + kpts[1]
+
hash := sha256.Sum256([]byte(signingInput))
+
sigBytes, err := base64.RawURLEncoding.DecodeString(kpts[2])
+
if err != nil {
+
s.logger.Error("error decoding signature bytes", "error", err)
+
return helpers.ServerError(e, nil)
+
}
+
+
if len(sigBytes) != 64 {
+
s.logger.Error("incorrect sigbytes length", "length", len(sigBytes))
+
return helpers.ServerError(e, nil)
+
}
+
+
rBytes := sigBytes[:32]
+
sBytes := sigBytes[32:]
+
rr, _ := secp256k1.NewScalarFromBytes((*[32]byte)(rBytes))
+
ss, _ := secp256k1.NewScalarFromBytes((*[32]byte)(sBytes))
+
+
sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey)
+
if err != nil {
+
s.logger.Error("can't load private key", "error", err)
+
return err
+
}
+
+
pubKey, ok := sk.Public().(*secp256k1secec.PublicKey)
+
if !ok {
+
s.logger.Error("error getting public key from sk")
+
return helpers.ServerError(e, nil)
+
}
+
+
verified := pubKey.VerifyRaw(hash[:], rr, ss)
+
if !verified {
+
s.logger.Error("error verifying", "error", err)
+
return helpers.ServerError(e, nil)
+
}
+
}
+
+
isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession"
+
scope, _ := claims["scope"].(string)
+
+
if isRefresh && scope != "com.atproto.refresh" {
+
return helpers.InvalidTokenError(e)
+
} else if !hasLxm && !isRefresh && scope != "com.atproto.access" {
+
return helpers.InvalidTokenError(e)
+
}
+
+
table := "tokens"
+
if isRefresh {
+
table = "refresh_tokens"
+
}
+
+
if isRefresh {
+
type Result struct {
+
Found bool
+
}
+
var result Result
+
if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil {
+
if err == gorm.ErrRecordNotFound {
+
return helpers.InvalidTokenError(e)
+
}
+
+
s.logger.Error("error getting token from db", "error", err)
+
return helpers.ServerError(e, nil)
+
}
+
+
if !result.Found {
+
return helpers.InvalidTokenError(e)
+
}
+
}
+
+
exp, ok := claims["exp"].(float64)
+
if !ok {
+
s.logger.Error("error getting iat from token")
+
return helpers.ServerError(e, nil)
+
}
+
+
if exp < float64(time.Now().UTC().Unix()) {
+
return helpers.ExpiredTokenError(e)
+
}
+
+
if repo == nil {
+
maybeRepo, err := s.getRepoActorByDid(claims["sub"].(string))
+
if err != nil {
+
s.logger.Error("error fetching repo", "error", err)
+
return helpers.ServerError(e, nil)
+
}
+
repo = maybeRepo
+
did = repo.Repo.Did
+
}
+
+
e.Set("repo", repo)
+
e.Set("did", did)
+
e.Set("token", tokenstr)
+
+
if err := next(e); err != nil {
+
return helpers.InvalidTokenError(e)
+
}
+
+
return nil
+
}
+
}
+
+
func (s *Server) handleOauthSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
+
return func(e echo.Context) error {
+
authheader := e.Request().Header.Get("authorization")
+
if authheader == "" {
+
return e.JSON(401, map[string]string{"error": "Unauthorized"})
+
}
+
+
pts := strings.Split(authheader, " ")
+
if len(pts) != 2 {
+
return helpers.ServerError(e, nil)
+
}
+
+
if pts[0] != "DPoP" {
+
return next(e)
+
}
+
+
accessToken := pts[1]
+
+
nonce := s.oauthProvider.NextNonce()
+
if nonce != "" {
+
e.Response().Header().Set("DPoP-Nonce", nonce)
+
e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce")
+
}
+
+
proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, "https://"+s.config.Hostname+e.Request().URL.String(), e.Request().Header, to.StringPtr(accessToken))
+
if err != nil {
+
if errors.Is(err, dpop.ErrUseDpopNonce) {
+
return e.JSON(400, map[string]string{
+
"error": "use_dpop_nonce",
+
})
+
}
+
s.logger.Error("invalid dpop proof", "error", err)
+
return helpers.InputError(e, nil)
+
}
+
+
var oauthToken provider.OauthToken
+
if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE token = ?", nil, accessToken).Scan(&oauthToken).Error; err != nil {
+
s.logger.Error("error finding access token in db", "error", err)
+
return helpers.InputError(e, nil)
+
}
+
+
if oauthToken.Token == "" {
+
return helpers.InvalidTokenError(e)
+
}
+
+
if *oauthToken.Parameters.DpopJkt != proof.JKT {
+
s.logger.Error("jkt mismatch", "token", oauthToken.Parameters.DpopJkt, "proof", proof.JKT)
+
return helpers.InputError(e, to.StringPtr("dpop jkt mismatch"))
+
}
+
+
if time.Now().After(oauthToken.ExpiresAt) {
+
return helpers.ExpiredTokenError(e)
+
}
+
+
repo, err := s.getRepoActorByDid(oauthToken.Sub)
+
if err != nil {
+
s.logger.Error("could not find actor in db", "error", err)
+
return helpers.ServerError(e, nil)
+
}
+
+
e.Set("repo", repo)
+
e.Set("did", repo.Repo.Did)
+
e.Set("token", accessToken)
+
e.Set("scopes", strings.Split(oauthToken.Parameters.Scope, " "))
+
+
return next(e)
+
}
+
}
+27 -21
server/repo.go
···
"github.com/Azure/go-autorest/autorest/to"
"github.com/bluesky-social/indigo/api/atproto"
-
"github.com/bluesky-social/indigo/atproto/data"
+
"github.com/bluesky-social/indigo/atproto/atdata"
"github.com/bluesky-social/indigo/atproto/syntax"
"github.com/bluesky-social/indigo/carstore"
"github.com/bluesky-social/indigo/events"
lexutil "github.com/bluesky-social/indigo/lex/util"
"github.com/bluesky-social/indigo/repo"
-
"github.com/bluesky-social/indigo/util"
-
"github.com/haileyok/cocoon/blockstore"
"github.com/haileyok/cocoon/internal/db"
"github.com/haileyok/cocoon/models"
+
"github.com/haileyok/cocoon/recording_blockstore"
blocks "github.com/ipfs/go-block-format"
"github.com/ipfs/go-cid"
cbor "github.com/ipfs/go-ipld-cbor"
···
}
func (mm *MarshalableMap) MarshalCBOR(w io.Writer) error {
-
data, err := data.MarshalCBOR(*mm)
+
data, err := atdata.MarshalCBOR(*mm)
if err != nil {
return err
}
···
return nil, err
}
-
dbs := blockstore.New(urepo.Did, rm.db)
-
r, err := repo.OpenRepo(context.TODO(), dbs, rootcid)
+
dbs := rm.s.getBlockstore(urepo.Did)
+
bs := recording_blockstore.New(dbs)
+
r, err := repo.OpenRepo(context.TODO(), bs, rootcid)
entries := []models.Record{}
var results []ApplyWriteResult
···
if err != nil {
return nil, err
}
-
out, err := data.UnmarshalJSON(j)
+
out, err := atdata.UnmarshalJSON(j)
if err != nil {
return nil, err
}
mm := MarshalableMap(out)
+
+
// HACK: if a record doesn't contain a $type, we can manually set it here based on the op's collection
+
if mm["$type"] == "" {
+
mm["$type"] = op.Collection
+
}
+
nc, err := r.PutRecord(context.TODO(), op.Collection+"/"+*op.Rkey, &mm)
if err != nil {
return nil, err
}
-
d, err := data.MarshalCBOR(mm)
+
d, err := atdata.MarshalCBOR(mm)
if err != nil {
return nil, err
}
···
if err != nil {
return nil, err
}
-
out, err := data.UnmarshalJSON(j)
+
out, err := atdata.UnmarshalJSON(j)
if err != nil {
return nil, err
}
···
if err != nil {
return nil, err
}
-
d, err := data.MarshalCBOR(mm)
+
d, err := atdata.MarshalCBOR(mm)
if err != nil {
return nil, err
}
···
}
}
-
for _, op := range dbs.GetLog() {
+
for _, op := range bs.GetWriteLog() {
if _, err := carstore.LdWrite(buf, op.Cid().Bytes(), op.RawData()); err != nil {
return nil, err
}
···
Rev: rev,
Since: &urepo.Rev,
Commit: lexutil.LexLink(newroot),
-
Time: time.Now().Format(util.ISO8601),
+
Time: time.Now().Format(time.RFC3339Nano),
Ops: ops,
TooBig: false,
},
})
-
if err := dbs.UpdateRepo(context.TODO(), newroot, rev); err != nil {
+
if err := rm.s.UpdateRepo(context.TODO(), urepo.Did, newroot, rev); err != nil {
return nil, err
}
···
return cid.Undef, nil, err
}
-
dbs := blockstore.New(urepo.Did, rm.db)
-
bs := util.NewLoggingBstore(dbs)
+
dbs := rm.s.getBlockstore(urepo.Did)
+
bs := recording_blockstore.New(dbs)
r, err := repo.OpenRepo(context.TODO(), bs, c)
if err != nil {
···
return cid.Undef, nil, err
}
-
return c, bs.GetLoggedBlocks(), nil
+
return c, bs.GetReadLog(), nil
}
func (rm *RepoMan) incrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) {
···
func getBlobCidsFromCbor(cbor []byte) ([]cid.Cid, error) {
var cids []cid.Cid
-
decoded, err := data.UnmarshalCBOR(cbor)
+
decoded, err := atdata.UnmarshalCBOR(cbor)
if err != nil {
return nil, fmt.Errorf("error unmarshaling cbor: %w", err)
}
-
var deepiter func(interface{}) error
-
deepiter = func(item interface{}) error {
+
var deepiter func(any) error
+
deepiter = func(item any) error {
switch val := item.(type) {
-
case map[string]interface{}:
+
case map[string]any:
if val["$type"] == "blob" {
if ref, ok := val["ref"].(string); ok {
c, err := cid.Parse(ref)
···
return deepiter(v)
}
}
-
case []interface{}:
+
case []any:
for _, v := range val {
deepiter(v)
}
+55 -291
server/server.go
···
"bytes"
"context"
"crypto/ecdsa"
-
"crypto/sha256"
"embed"
-
"encoding/base64"
"errors"
"fmt"
"io"
···
"net/smtp"
"os"
"path/filepath"
-
"strings"
"sync"
"text/template"
"time"
-
"github.com/Azure/go-autorest/autorest/to"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
···
"github.com/bluesky-social/indigo/xrpc"
"github.com/domodwyer/mailyak/v3"
"github.com/go-playground/validator"
-
"github.com/golang-jwt/jwt/v4"
"github.com/gorilla/sessions"
"github.com/haileyok/cocoon/identity"
"github.com/haileyok/cocoon/internal/db"
"github.com/haileyok/cocoon/internal/helpers"
"github.com/haileyok/cocoon/models"
-
"github.com/haileyok/cocoon/oauth/client_manager"
+
"github.com/haileyok/cocoon/oauth/client"
"github.com/haileyok/cocoon/oauth/constants"
-
"github.com/haileyok/cocoon/oauth/dpop/dpop_manager"
+
"github.com/haileyok/cocoon/oauth/dpop"
"github.com/haileyok/cocoon/oauth/provider"
"github.com/haileyok/cocoon/plc"
+
"github.com/ipfs/go-cid"
echo_session "github.com/labstack/echo-contrib/session"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
slogecho "github.com/samber/slog-echo"
-
"gitlab.com/yawning/secp256k1-voi"
-
secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
···
)
type S3Config struct {
-
BackupsEnabled bool
-
Endpoint string
-
Region string
-
Bucket string
-
AccessKey string
-
SecretKey string
+
BackupsEnabled bool
+
BlobstoreEnabled bool
+
Endpoint string
+
Region string
+
Bucket string
+
AccessKey string
+
SecretKey string
}
type Server struct {
···
oauthProvider *provider.Provider
evtman *events.EventManager
passport *identity.Passport
+
fallbackProxy string
dbName string
s3Config *S3Config
···
S3Config *S3Config
SessionSecret string
+
+
BlockstoreVariant BlockstoreVariant
+
FallbackProxy string
}
type config struct {
-
Version string
-
Did string
-
Hostname string
-
ContactEmail string
-
EnforcePeering bool
-
Relays []string
-
AdminPassword string
-
SmtpEmail string
-
SmtpName string
+
Version string
+
Did string
+
Hostname string
+
ContactEmail string
+
EnforcePeering bool
+
Relays []string
+
AdminPassword string
+
SmtpEmail string
+
SmtpName string
+
BlockstoreVariant BlockstoreVariant
+
FallbackProxy string
}
type CustomValidator struct {
···
return t.templates.ExecuteTemplate(w, name, data)
}
-
func (s *Server) handleAdminMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
-
return func(e echo.Context) error {
-
username, password, ok := e.Request().BasicAuth()
-
if !ok || username != "admin" || password != s.config.AdminPassword {
-
return helpers.InputError(e, to.StringPtr("Unauthorized"))
-
}
-
-
if err := next(e); err != nil {
-
e.Error(err)
-
}
-
-
return nil
-
}
-
}
-
-
func (s *Server) handleLegacySessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
-
return func(e echo.Context) error {
-
authheader := e.Request().Header.Get("authorization")
-
if authheader == "" {
-
return e.JSON(401, map[string]string{"error": "Unauthorized"})
-
}
-
-
pts := strings.Split(authheader, " ")
-
if len(pts) != 2 {
-
return helpers.ServerError(e, nil)
-
}
-
-
// move on to oauth session middleware if this is a dpop token
-
if pts[0] == "DPoP" {
-
return next(e)
-
}
-
-
tokenstr := pts[1]
-
token, _, err := new(jwt.Parser).ParseUnverified(tokenstr, jwt.MapClaims{})
-
claims, ok := token.Claims.(jwt.MapClaims)
-
if !ok {
-
return helpers.InputError(e, to.StringPtr("InvalidToken"))
-
}
-
-
var did string
-
var repo *models.RepoActor
-
-
// service auth tokens
-
lxm, hasLxm := claims["lxm"]
-
if hasLxm {
-
pts := strings.Split(e.Request().URL.String(), "/")
-
if lxm != pts[len(pts)-1] {
-
s.logger.Error("service auth lxm incorrect", "lxm", lxm, "expected", pts[len(pts)-1], "error", err)
-
return helpers.InputError(e, nil)
-
}
-
-
maybeDid, ok := claims["iss"].(string)
-
if !ok {
-
s.logger.Error("no iss in service auth token", "error", err)
-
return helpers.InputError(e, nil)
-
}
-
did = maybeDid
-
-
maybeRepo, err := s.getRepoActorByDid(did)
-
if err != nil {
-
s.logger.Error("error fetching repo", "error", err)
-
return helpers.ServerError(e, nil)
-
}
-
repo = maybeRepo
-
}
-
-
if token.Header["alg"] != "ES256K" {
-
token, err = new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) {
-
if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok {
-
return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"])
-
}
-
return s.privateKey.Public(), nil
-
})
-
if err != nil {
-
s.logger.Error("error parsing jwt", "error", err)
-
// NOTE: https://github.com/bluesky-social/atproto/discussions/3319
-
return e.JSON(400, map[string]string{"error": "ExpiredToken", "message": "token has expired"})
-
}
-
-
if !token.Valid {
-
return helpers.InputError(e, to.StringPtr("InvalidToken"))
-
}
-
} else {
-
kpts := strings.Split(tokenstr, ".")
-
signingInput := kpts[0] + "." + kpts[1]
-
hash := sha256.Sum256([]byte(signingInput))
-
sigBytes, err := base64.RawURLEncoding.DecodeString(kpts[2])
-
if err != nil {
-
s.logger.Error("error decoding signature bytes", "error", err)
-
return helpers.ServerError(e, nil)
-
}
-
-
if len(sigBytes) != 64 {
-
s.logger.Error("incorrect sigbytes length", "length", len(sigBytes))
-
return helpers.ServerError(e, nil)
-
}
-
-
rBytes := sigBytes[:32]
-
sBytes := sigBytes[32:]
-
rr, _ := secp256k1.NewScalarFromBytes((*[32]byte)(rBytes))
-
ss, _ := secp256k1.NewScalarFromBytes((*[32]byte)(sBytes))
-
-
sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey)
-
if err != nil {
-
s.logger.Error("can't load private key", "error", err)
-
return err
-
}
-
-
pubKey, ok := sk.Public().(*secp256k1secec.PublicKey)
-
if !ok {
-
s.logger.Error("error getting public key from sk")
-
return helpers.ServerError(e, nil)
-
}
-
-
verified := pubKey.VerifyRaw(hash[:], rr, ss)
-
if !verified {
-
s.logger.Error("error verifying", "error", err)
-
return helpers.ServerError(e, nil)
-
}
-
}
-
-
isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession"
-
scope, _ := claims["scope"].(string)
-
-
if isRefresh && scope != "com.atproto.refresh" {
-
return helpers.InputError(e, to.StringPtr("InvalidToken"))
-
} else if !hasLxm && !isRefresh && scope != "com.atproto.access" {
-
return helpers.InputError(e, to.StringPtr("InvalidToken"))
-
}
-
-
table := "tokens"
-
if isRefresh {
-
table = "refresh_tokens"
-
}
-
-
if isRefresh {
-
type Result struct {
-
Found bool
-
}
-
var result Result
-
if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil {
-
if err == gorm.ErrRecordNotFound {
-
return helpers.InputError(e, to.StringPtr("InvalidToken"))
-
}
-
-
s.logger.Error("error getting token from db", "error", err)
-
return helpers.ServerError(e, nil)
-
}
-
-
if !result.Found {
-
return helpers.InputError(e, to.StringPtr("InvalidToken"))
-
}
-
}
-
-
exp, ok := claims["exp"].(float64)
-
if !ok {
-
s.logger.Error("error getting iat from token")
-
return helpers.ServerError(e, nil)
-
}
-
-
if exp < float64(time.Now().UTC().Unix()) {
-
return helpers.InputError(e, to.StringPtr("ExpiredToken"))
-
}
-
-
if repo == nil {
-
maybeRepo, err := s.getRepoActorByDid(claims["sub"].(string))
-
if err != nil {
-
s.logger.Error("error fetching repo", "error", err)
-
return helpers.ServerError(e, nil)
-
}
-
repo = maybeRepo
-
did = repo.Repo.Did
-
}
-
-
e.Set("repo", repo)
-
e.Set("did", did)
-
e.Set("token", tokenstr)
-
-
if err := next(e); err != nil {
-
e.Error(err)
-
}
-
-
return nil
-
}
-
}
-
-
func (s *Server) handleOauthSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
-
return func(e echo.Context) error {
-
authheader := e.Request().Header.Get("authorization")
-
if authheader == "" {
-
return e.JSON(401, map[string]string{"error": "Unauthorized"})
-
}
-
-
pts := strings.Split(authheader, " ")
-
if len(pts) != 2 {
-
return helpers.ServerError(e, nil)
-
}
-
-
if pts[0] != "DPoP" {
-
return next(e)
-
}
-
-
accessToken := pts[1]
-
-
proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, "https://"+s.config.Hostname+e.Request().URL.String(), e.Request().Header, to.StringPtr(accessToken))
-
if err != nil {
-
s.logger.Error("invalid dpop proof", "error", err)
-
return helpers.InputError(e, to.StringPtr(err.Error()))
-
}
-
-
var oauthToken provider.OauthToken
-
if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE token = ?", nil, accessToken).Scan(&oauthToken).Error; err != nil {
-
s.logger.Error("error finding access token in db", "error", err)
-
return helpers.InputError(e, nil)
-
}
-
-
if oauthToken.Token == "" {
-
return helpers.InputError(e, to.StringPtr("InvalidToken"))
-
}
-
-
if *oauthToken.Parameters.DpopJkt != proof.JKT {
-
s.logger.Error("jkt mismatch", "token", oauthToken.Parameters.DpopJkt, "proof", proof.JKT)
-
return helpers.InputError(e, to.StringPtr("dpop jkt mismatch"))
-
}
-
-
if time.Now().After(oauthToken.ExpiresAt) {
-
return e.JSON(400, map[string]string{"error": "ExpiredToken", "message": "token has expired"})
-
}
-
-
repo, err := s.getRepoActorByDid(oauthToken.Sub)
-
if err != nil {
-
s.logger.Error("could not find actor in db", "error", err)
-
return helpers.ServerError(e, nil)
-
}
-
-
nonce := s.oauthProvider.NextNonce()
-
if nonce != "" {
-
e.Response().Header().Set("DPoP-Nonce", nonce)
-
e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce")
-
}
-
-
e.Set("repo", repo)
-
e.Set("did", repo.Repo.Did)
-
e.Set("token", accessToken)
-
e.Set("scopes", strings.Split(oauthToken.Parameters.Scope, " "))
-
-
return next(e)
-
}
-
}
-
func New(args *Args) (*Server, error) {
if args.Addr == "" {
return nil, fmt.Errorf("addr must be set")
···
IdleTimeout: 5 * time.Minute,
}
-
gdb, err := gorm.Open(sqlite.Open("cocoon.db"), &gorm.Config{})
+
gdb, err := gorm.Open(sqlite.Open(args.DbName), &gorm.Config{})
if err != nil {
return nil, err
}
···
plcClient: plcClient,
privateKey: &pkey,
config: &config{
-
Version: args.Version,
-
Did: args.Did,
-
Hostname: args.Hostname,
-
ContactEmail: args.ContactEmail,
-
EnforcePeering: false,
-
Relays: args.Relays,
-
AdminPassword: args.AdminPassword,
-
SmtpName: args.SmtpName,
-
SmtpEmail: args.SmtpEmail,
+
Version: args.Version,
+
Did: args.Did,
+
Hostname: args.Hostname,
+
ContactEmail: args.ContactEmail,
+
EnforcePeering: false,
+
Relays: args.Relays,
+
AdminPassword: args.AdminPassword,
+
SmtpName: args.SmtpName,
+
SmtpEmail: args.SmtpEmail,
+
BlockstoreVariant: args.BlockstoreVariant,
+
FallbackProxy: args.FallbackProxy,
},
evtman: events.NewEventManager(events.NewMemPersister()),
passport: identity.NewPassport(h, identity.NewMemCache(10_000)),
···
oauthProvider: provider.NewProvider(provider.Args{
Hostname: args.Hostname,
-
ClientManagerArgs: client_manager.Args{
+
ClientManagerArgs: client.ManagerArgs{
Cli: oauthCli,
Logger: args.Logger,
},
-
DpopManagerArgs: dpop_manager.Args{
+
DpopManagerArgs: dpop.ManagerArgs{
NonceSecret: nonceSecret,
NonceRotationInterval: constants.NonceMaxRotationInterval / 3,
OnNonceSecretCreated: func(newNonce []byte) {
···
// TODO: should validate these args
if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" {
-
args.Logger.Warn("not enough smpt args were provided. mailing will not work for your server.")
+
args.Logger.Warn("not enough smtp args were provided. mailing will not work for your server.")
} else {
mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost))
mail.From(s.config.SmtpEmail)
···
s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
s.echo.GET("/xrpc/com.atproto.server.getServiceAuth", s.handleServerGetServiceAuth, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
+
s.echo.GET("/xrpc/com.atproto.server.checkAccountStatus", s.handleServerCheckAccountStatus, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
+
s.echo.POST("/xrpc/com.atproto.server.deactivateAccount", s.handleServerDeactivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
+
s.echo.POST("/xrpc/com.atproto.server.activateAccount", s.handleServerActivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
// repo
s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
···
s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
-
// are there any routes that we should be allowing without auth? i dont think so but idk
-
s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
-
s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
-
// admin routes
s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware)
s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware)
+
+
// are there any routes that we should be allowing without auth? i dont think so but idk
+
s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
+
s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
}
func (s *Server) Serve(ctx context.Context) error {
···
go s.doBackup()
}
}
+
+
func (s *Server) UpdateRepo(ctx context.Context, did string, root cid.Cid, rev string) error {
+
if err := s.db.Exec("UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, did).Error; err != nil {
+
return err
+
}
+
+
return nil
+
}
+5 -4
server/templates/account.html
···
</div>
{{ else }} {{ range .Tokens }}
<div class="base-container">
-
<h4>{{ .ClientId }}</h4>
-
<p>Created: {{ .CreatedAt }}</p>
-
<p>Updated: {{ .UpdatedAt }}</p>
-
<p>Expires: {{ .ExpiresAt }}</p>
+
<h4>{{ .ClientName }}</h4>
+
<p>Session Age: {{ .Age}}</p>
+
<p>Last Updated: {{ .LastUpdated }} ago</p>
+
<p>Expires In: {{ .ExpiresIn }}</p>
+
<p>IP Address: {{ .Ip }}</p>
<form action="/account/revoke" method="post">
<input type="hidden" name="token" value="{{ .Token }}" />
<button type="submit" value="">Revoke</button>
+137
sqlite_blockstore/sqlite_blockstore.go
···
+
package sqlite_blockstore
+
+
import (
+
"context"
+
"fmt"
+
+
"github.com/bluesky-social/indigo/atproto/syntax"
+
"github.com/haileyok/cocoon/internal/db"
+
"github.com/haileyok/cocoon/models"
+
blocks "github.com/ipfs/go-block-format"
+
"github.com/ipfs/go-cid"
+
"gorm.io/gorm/clause"
+
)
+
+
type SqliteBlockstore struct {
+
db *db.DB
+
did string
+
readonly bool
+
inserts map[cid.Cid]blocks.Block
+
}
+
+
func New(did string, db *db.DB) *SqliteBlockstore {
+
return &SqliteBlockstore{
+
did: did,
+
db: db,
+
readonly: false,
+
inserts: map[cid.Cid]blocks.Block{},
+
}
+
}
+
+
func NewReadOnly(did string, db *db.DB) *SqliteBlockstore {
+
return &SqliteBlockstore{
+
did: did,
+
db: db,
+
readonly: true,
+
inserts: map[cid.Cid]blocks.Block{},
+
}
+
}
+
+
func (bs *SqliteBlockstore) Get(ctx context.Context, cid cid.Cid) (blocks.Block, error) {
+
var block models.Block
+
+
maybeBlock, ok := bs.inserts[cid]
+
if ok {
+
return maybeBlock, nil
+
}
+
+
if err := bs.db.Raw("SELECT * FROM blocks WHERE did = ? AND cid = ?", nil, bs.did, cid.Bytes()).Scan(&block).Error; err != nil {
+
return nil, err
+
}
+
+
b, err := blocks.NewBlockWithCid(block.Value, cid)
+
if err != nil {
+
return nil, err
+
}
+
+
return b, nil
+
}
+
+
func (bs *SqliteBlockstore) Put(ctx context.Context, block blocks.Block) error {
+
bs.inserts[block.Cid()] = block
+
+
if bs.readonly {
+
return nil
+
}
+
+
b := models.Block{
+
Did: bs.did,
+
Cid: block.Cid().Bytes(),
+
Rev: syntax.NewTIDNow(0).String(), // TODO: WARN, this is bad. don't do this
+
Value: block.RawData(),
+
}
+
+
if err := bs.db.Create(&b, []clause.Expression{clause.OnConflict{
+
Columns: []clause.Column{{Name: "did"}, {Name: "cid"}},
+
UpdateAll: true,
+
}}).Error; err != nil {
+
return err
+
}
+
+
return nil
+
}
+
+
func (bs *SqliteBlockstore) DeleteBlock(context.Context, cid.Cid) error {
+
panic("not implemented")
+
}
+
+
func (bs *SqliteBlockstore) Has(context.Context, cid.Cid) (bool, error) {
+
panic("not implemented")
+
}
+
+
func (bs *SqliteBlockstore) GetSize(context.Context, cid.Cid) (int, error) {
+
panic("not implemented")
+
}
+
+
func (bs *SqliteBlockstore) PutMany(ctx context.Context, blocks []blocks.Block) error {
+
tx := bs.db.BeginDangerously()
+
+
for _, block := range blocks {
+
bs.inserts[block.Cid()] = block
+
+
if bs.readonly {
+
continue
+
}
+
+
b := models.Block{
+
Did: bs.did,
+
Cid: block.Cid().Bytes(),
+
Rev: syntax.NewTIDNow(0).String(), // TODO: WARN, this is bad. don't do this
+
Value: block.RawData(),
+
}
+
+
if err := tx.Clauses(clause.OnConflict{
+
Columns: []clause.Column{{Name: "did"}, {Name: "cid"}},
+
UpdateAll: true,
+
}).Create(&b).Error; err != nil {
+
tx.Rollback()
+
return err
+
}
+
}
+
+
if bs.readonly {
+
return nil
+
}
+
+
tx.Commit()
+
+
return nil
+
}
+
+
func (bs *SqliteBlockstore) AllKeysChan(ctx context.Context) (<-chan cid.Cid, error) {
+
return nil, fmt.Errorf("iteration not allowed on sqlite blockstore")
+
}
+
+
func (bs *SqliteBlockstore) HashOnRead(enabled bool) {
+
panic("not implemented")
+
}