From fb64beba9c0bb4b9bf3b22d22f74c751def3ec4c Mon Sep 17 00:00:00 2001 From: Tom Sherman Date: Mon, 11 Aug 2025 11:41:01 +0100 Subject: [PATCH] Initial implementation sketch --- api/tangled/cbor_gen.go | 244 ++++++++++++++++++++++- api/tangled/tangledpipeline.go | 12 +- cmd/gen.go | 1 + flake.nix | 2 +- lexicons/pipeline/pipeline.json | 17 ++ spindle/engine/engine.go | 18 +- spindle/models/pipeline.go | 14 ++ spindle/oidc/oidc.go | 329 ++++++++++++++++++++++++++++++++ spindle/server.go | 15 +- 9 files changed, 640 insertions(+), 12 deletions(-) create mode 100644 spindle/oidc/oidc.go diff --git a/api/tangled/cbor_gen.go b/api/tangled/cbor_gen.go index eba0337..6320bf7 100644 --- a/api/tangled/cbor_gen.go +++ b/api/tangled/cbor_gen.go @@ -3923,12 +3923,16 @@ func (t *Pipeline_Step) MarshalCBOR(w io.Writer) error { } cw := cbg.NewCborWriter(w) - fieldCount := 3 + fieldCount := 4 if t.Environment == nil { fieldCount-- } + if t.Oidcs_tokens == nil { + fieldCount-- + } + if _, err := cw.Write(cbg.CborEncodeMajorType(cbg.MajMap, uint64(fieldCount))); err != nil { return err } @@ -4007,6 +4011,35 @@ func (t *Pipeline_Step) MarshalCBOR(w io.Writer) error { } } + + // t.Oidcs_tokens ([]*tangled.Pipeline_Step_Oidcs_tokens_Elem) (slice) + if t.Oidcs_tokens != nil { + + if len("oidcs_tokens") > 1000000 { + return xerrors.Errorf("Value in field \"oidcs_tokens\" was too long") + } + + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len("oidcs_tokens"))); err != nil { + return err + } + if _, err := cw.WriteString(string("oidcs_tokens")); err != nil { + return err + } + + if len(t.Oidcs_tokens) > 8192 { + return xerrors.Errorf("Slice value in field t.Oidcs_tokens was too long") + } + + if err := cw.WriteMajorTypeHeader(cbg.MajArray, uint64(len(t.Oidcs_tokens))); err != nil { + return err + } + for _, v := range t.Oidcs_tokens { + if err := v.MarshalCBOR(cw); err != nil { + return err + } + + } + } return nil } @@ -4035,7 +4068,7 @@ func (t *Pipeline_Step) UnmarshalCBOR(r io.Reader) (err error) { n := extra - nameBuf := make([]byte, 11) + nameBuf := make([]byte, 12) for i := uint64(0); i < n; i++ { nameLen, ok, err := cbg.ReadFullStringIntoBuf(cr, nameBuf, 1000000) if err != nil { @@ -4122,6 +4155,213 @@ func (t *Pipeline_Step) UnmarshalCBOR(r io.Reader) (err error) { } } + // t.Oidcs_tokens ([]*tangled.Pipeline_Step_Oidcs_tokens_Elem) (slice) + case "oidcs_tokens": + + maj, extra, err = cr.ReadHeader() + if err != nil { + return err + } + + if extra > 8192 { + return fmt.Errorf("t.Oidcs_tokens: array too large (%d)", extra) + } + + if maj != cbg.MajArray { + return fmt.Errorf("expected cbor array") + } + + if extra > 0 { + t.Oidcs_tokens = make([]*Pipeline_Step_Oidcs_tokens_Elem, extra) + } + + for i := 0; i < int(extra); i++ { + { + var maj byte + var extra uint64 + var err error + _ = maj + _ = extra + _ = err + + { + + b, err := cr.ReadByte() + if err != nil { + return err + } + if b != cbg.CborNull[0] { + if err := cr.UnreadByte(); err != nil { + return err + } + t.Oidcs_tokens[i] = new(Pipeline_Step_Oidcs_tokens_Elem) + if err := t.Oidcs_tokens[i].UnmarshalCBOR(cr); err != nil { + return xerrors.Errorf("unmarshaling t.Oidcs_tokens[i] pointer: %w", err) + } + } + + } + + } + } + + default: + // Field doesn't exist on this type, so ignore it + if err := cbg.ScanForLinks(r, func(cid.Cid) {}); err != nil { + return err + } + } + } + + return nil +} +func (t *Pipeline_Step_Oidcs_tokens_Elem) MarshalCBOR(w io.Writer) error { + if t == nil { + _, err := w.Write(cbg.CborNull) + return err + } + + cw := cbg.NewCborWriter(w) + fieldCount := 2 + + if t.Aud == nil { + fieldCount-- + } + + if _, err := cw.Write(cbg.CborEncodeMajorType(cbg.MajMap, uint64(fieldCount))); err != nil { + return err + } + + // t.Aud (string) (string) + if t.Aud != nil { + + if len("aud") > 1000000 { + return xerrors.Errorf("Value in field \"aud\" was too long") + } + + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len("aud"))); err != nil { + return err + } + if _, err := cw.WriteString(string("aud")); err != nil { + return err + } + + if t.Aud == nil { + if _, err := cw.Write(cbg.CborNull); err != nil { + return err + } + } else { + if len(*t.Aud) > 1000000 { + return xerrors.Errorf("Value in field t.Aud was too long") + } + + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len(*t.Aud))); err != nil { + return err + } + if _, err := cw.WriteString(string(*t.Aud)); err != nil { + return err + } + } + } + + // t.Name (string) (string) + if len("name") > 1000000 { + return xerrors.Errorf("Value in field \"name\" was too long") + } + + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len("name"))); err != nil { + return err + } + if _, err := cw.WriteString(string("name")); err != nil { + return err + } + + if len(t.Name) > 1000000 { + return xerrors.Errorf("Value in field t.Name was too long") + } + + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len(t.Name))); err != nil { + return err + } + if _, err := cw.WriteString(string(t.Name)); err != nil { + return err + } + return nil +} + +func (t *Pipeline_Step_Oidcs_tokens_Elem) UnmarshalCBOR(r io.Reader) (err error) { + *t = Pipeline_Step_Oidcs_tokens_Elem{} + + cr := cbg.NewCborReader(r) + + maj, extra, err := cr.ReadHeader() + if err != nil { + return err + } + defer func() { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + }() + + if maj != cbg.MajMap { + return fmt.Errorf("cbor input should be of type map") + } + + if extra > cbg.MaxLength { + return fmt.Errorf("Pipeline_Step_Oidcs_tokens_Elem: map struct too large (%d)", extra) + } + + n := extra + + nameBuf := make([]byte, 4) + for i := uint64(0); i < n; i++ { + nameLen, ok, err := cbg.ReadFullStringIntoBuf(cr, nameBuf, 1000000) + if err != nil { + return err + } + + if !ok { + // Field doesn't exist on this type, so ignore it + if err := cbg.ScanForLinks(cr, func(cid.Cid) {}); err != nil { + return err + } + continue + } + + switch string(nameBuf[:nameLen]) { + // t.Aud (string) (string) + case "aud": + + { + b, err := cr.ReadByte() + if err != nil { + return err + } + if b != cbg.CborNull[0] { + if err := cr.UnreadByte(); err != nil { + return err + } + + sval, err := cbg.ReadStringWithMax(cr, 1000000) + if err != nil { + return err + } + + t.Aud = (*string)(&sval) + } + } + // t.Name (string) (string) + case "name": + + { + sval, err := cbg.ReadStringWithMax(cr, 1000000) + if err != nil { + return err + } + + t.Name = string(sval) + } default: // Field doesn't exist on this type, so ignore it diff --git a/api/tangled/tangledpipeline.go b/api/tangled/tangledpipeline.go index ae625c4..8611534 100644 --- a/api/tangled/tangledpipeline.go +++ b/api/tangled/tangledpipeline.go @@ -63,9 +63,15 @@ type Pipeline_PushTriggerData struct { // Pipeline_Step is a "step" in the sh.tangled.pipeline schema. type Pipeline_Step struct { - Command string `json:"command" cborgen:"command"` - Environment []*Pipeline_Pair `json:"environment,omitempty" cborgen:"environment,omitempty"` - Name string `json:"name" cborgen:"name"` + Command string `json:"command" cborgen:"command"` + Environment []*Pipeline_Pair `json:"environment,omitempty" cborgen:"environment,omitempty"` + Name string `json:"name" cborgen:"name"` + Oidcs_tokens []*Pipeline_Step_Oidcs_tokens_Elem `json:"oidcs_tokens,omitempty" cborgen:"oidcs_tokens,omitempty"` +} + +type Pipeline_Step_Oidcs_tokens_Elem struct { + Aud *string `json:"aud,omitempty" cborgen:"aud,omitempty"` + Name string `json:"name" cborgen:"name"` } // Pipeline_TriggerMetadata is a "triggerMetadata" in the sh.tangled.pipeline schema. diff --git a/cmd/gen.go b/cmd/gen.go index 2e07c0e..a33c6da 100644 --- a/cmd/gen.go +++ b/cmd/gen.go @@ -34,6 +34,7 @@ func main() { tangled.Pipeline_PushTriggerData{}, tangled.PipelineStatus{}, tangled.Pipeline_Step{}, + tangled.Pipeline_Step_Oidcs_tokens_Elem{}, tangled.Pipeline_TriggerMetadata{}, tangled.Pipeline_TriggerRepo{}, tangled.Pipeline_Workflow{}, diff --git a/flake.nix b/flake.nix index db11dd2..739ef64 100644 --- a/flake.nix +++ b/flake.nix @@ -116,7 +116,7 @@ stdenv = pkgs.pkgsStatic.stdenv; }; in { - default = staticShell { + default = pkgs.mkShell { nativeBuildInputs = [ pkgs.go pkgs.air diff --git a/lexicons/pipeline/pipeline.json b/lexicons/pipeline/pipeline.json index d9e9872..b0894b8 100644 --- a/lexicons/pipeline/pipeline.json +++ b/lexicons/pipeline/pipeline.json @@ -241,6 +241,23 @@ "type": "ref", "ref": "#pair" } + }, + "oidcs_tokens": { + "type": "array", + "items": { + "type": "object", + "required": [ + "name" + ], + "properties": { + "name": { + "type": "string" + }, + "aud": { + "type": "string" + } + } + } } } }, diff --git a/spindle/engine/engine.go b/spindle/engine/engine.go index 65e32eb..2308699 100644 --- a/spindle/engine/engine.go +++ b/spindle/engine/engine.go @@ -25,6 +25,7 @@ import ( "tangled.sh/tangled.sh/core/spindle/config" "tangled.sh/tangled.sh/core/spindle/db" "tangled.sh/tangled.sh/core/spindle/models" + "tangled.sh/tangled.sh/core/spindle/oidc" "tangled.sh/tangled.sh/core/spindle/secrets" ) @@ -41,12 +42,13 @@ type Engine struct { n *notifier.Notifier cfg *config.Config vault secrets.Manager + oidc oidc.OidcTokenGenerator cleanupMu sync.Mutex cleanup map[string][]cleanupFunc } -func New(ctx context.Context, cfg *config.Config, db *db.DB, n *notifier.Notifier, vault secrets.Manager) (*Engine, error) { +func New(ctx context.Context, cfg *config.Config, db *db.DB, n *notifier.Notifier, vault secrets.Manager, oidc *oidc.OidcTokenGenerator) (*Engine, error) { dcli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation()) if err != nil { return nil, err @@ -61,6 +63,7 @@ func New(ctx context.Context, cfg *config.Config, db *db.DB, n *notifier.Notifie n: n, cfg: cfg, vault: vault, + oidc: *oidc, } e.cleanup = make(map[string][]cleanupFunc) @@ -124,7 +127,7 @@ func (e *Engine) StartWorkflows(ctx context.Context, pipeline *models.Pipeline, ctx, cancel := context.WithTimeout(ctx, workflowTimeout) defer cancel() - err = e.StartSteps(ctx, wid, w, allSecrets) + err = e.StartSteps(ctx, wid, w, allSecrets, pipeline, pipelineId) if err != nil { if errors.Is(err, ErrTimedOut) { dbErr := e.db.StatusTimeout(wid, e.n) @@ -202,7 +205,7 @@ func (e *Engine) SetupWorkflow(ctx context.Context, wid models.WorkflowId) error // ONLY marks pipeline as failed if container's exit code is non-zero. // All other errors are bubbled up. // Fixed version of the step execution logic -func (e *Engine) StartSteps(ctx context.Context, wid models.WorkflowId, w models.Workflow, secrets []secrets.UnlockedSecret) error { +func (e *Engine) StartSteps(ctx context.Context, wid models.WorkflowId, w models.Workflow, secrets []secrets.UnlockedSecret, pipeline *models.Pipeline, pipelineId models.PipelineId) error { workflowEnvs := ConstructEnvs(w.Environment) for _, s := range secrets { workflowEnvs.AddEnv(s.Key, s.Value) @@ -222,6 +225,15 @@ func (e *Engine) StartSteps(ctx context.Context, wid models.WorkflowId, w models envs.AddEnv("HOME", workspaceDir) e.l.Debug("envs for step", "step", step.Name, "envs", envs.Slice()) + for _, t := range step.OidcTokens { + token, err := e.oidc.CreateToken(t, pipelineId, pipeline.RepoOwner, pipeline.RepoName) + if err != nil { + e.l.Error("failed to get OIDC token", "error", err, "token", t.Name) + return fmt.Errorf("getting OIDC token: %w", err) + } + envs.AddEnv(t.Name, token) + } + hostConfig := hostConfig(wid) resp, err := e.docker.ContainerCreate(ctx, &container.Config{ Image: w.Image, diff --git a/spindle/models/pipeline.go b/spindle/models/pipeline.go index 8561b21..85619e9 100644 --- a/spindle/models/pipeline.go +++ b/spindle/models/pipeline.go @@ -18,6 +18,7 @@ type Step struct { Name string Environment map[string]string Kind StepKind + OidcTokens []OidcToken } type StepKind int @@ -29,6 +30,11 @@ const ( StepKindUser ) +type OidcToken struct { + Name string + Aud *string +} + type Workflow struct { Steps []Step Environment map[string]string @@ -60,6 +66,14 @@ func ToPipeline(pl tangled.Pipeline, cfg config.Config) *Pipeline { sstep.Name = tstep.Name sstep.Kind = StepKindUser swf.Steps = append(swf.Steps, sstep) + + sstep.OidcTokens = make([]OidcToken, 0, len(tstep.Oidcs_tokens)) + for _, ttoken := range tstep.Oidcs_tokens { + sstep.OidcTokens = append(sstep.OidcTokens, OidcToken{ + Name: ttoken.Name, + Aud: ttoken.Aud, + }) + } } swf.Name = twf.Name swf.Environment = workflowEnvToMap(twf.Environment) diff --git a/spindle/oidc/oidc.go b/spindle/oidc/oidc.go new file mode 100644 index 0000000..3434660 --- /dev/null +++ b/spindle/oidc/oidc.go @@ -0,0 +1,329 @@ +package oidc + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "time" + + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/lestrrat-go/jwx/v2/jwt" + "tangled.sh/tangled.sh/core/spindle/models" +) + +const JWKSPath = "/.well-known/jwks.json" + +// OidcKeyPair represents an OIDC key pair with both private and public keys +type OidcKeyPair struct { + privateKey *ecdsa.PrivateKey + publicKey *ecdsa.PublicKey + keyID string + jwkKey jwk.Key +} + +// OidcTokenGenerator handles OIDC token generation and key management with rotation +type OidcTokenGenerator struct { + currentKeyPair *OidcKeyPair + nextKeyPair *OidcKeyPair + l *slog.Logger + issuer string +} + +// NewOidcTokenGenerator creates a new OIDC token generator with in-memory key management +func NewOidcTokenGenerator(issuer string) (*OidcTokenGenerator, error) { + // Create new keys + currentKeyPair, err := NewOidcKeyPair() + if err != nil { + return nil, fmt.Errorf("failed to generate initial current key pair: %w", err) + } + + return &OidcTokenGenerator{ + issuer: issuer, + currentKeyPair: currentKeyPair, + }, nil +} + +// NewOidcKeyPair generates a new ECDSA key pair for OIDC token signing +func NewOidcKeyPair() (*OidcKeyPair, error) { + privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, fmt.Errorf("failed to generate ECDSA key: %w", err) + } + + keyID := fmt.Sprintf("spindle-%d", time.Now().Unix()) + + // Create JWK from the private key + jwkKey, err := jwk.FromRaw(privKey) + if err != nil { + return nil, fmt.Errorf("failed to create JWK from private key: %w", err) + } + + // Set the key ID + if err := jwkKey.Set(jwk.KeyIDKey, keyID); err != nil { + return nil, fmt.Errorf("failed to set key ID: %w", err) + } + + // Set algorithm + if err := jwkKey.Set(jwk.AlgorithmKey, jwa.ES256); err != nil { + return nil, fmt.Errorf("failed to set algorithm: %w", err) + } + + // Set usage + if err := jwkKey.Set(jwk.KeyUsageKey, "sig"); err != nil { + return nil, fmt.Errorf("failed to set key usage: %w", err) + } + + return &OidcKeyPair{ + privateKey: privKey, + publicKey: &privKey.PublicKey, + keyID: keyID, + jwkKey: jwkKey, + }, nil +} + +// LoadOidcKeyPair loads an existing key pair from JWK JSON +func LoadOidcKeyPair(jwkJSON []byte) (*OidcKeyPair, error) { + jwkKey, err := jwk.ParseKey(jwkJSON) + if err != nil { + return nil, fmt.Errorf("failed to parse JWK: %w", err) + } + + var privKey *ecdsa.PrivateKey + if err := jwkKey.Raw(&privKey); err != nil { + return nil, fmt.Errorf("failed to extract private key: %w", err) + } + + keyID, ok := jwkKey.Get(jwk.KeyIDKey) + if !ok { + return nil, fmt.Errorf("JWK missing key ID") + } + + keyIDStr, ok := keyID.(string) + if !ok { + return nil, fmt.Errorf("JWK key ID is not a string") + } + + return &OidcKeyPair{ + privateKey: privKey, + publicKey: &privKey.PublicKey, + keyID: keyIDStr, + jwkKey: jwkKey, + }, nil +} + +// GetKeyID returns the key ID +func (k *OidcKeyPair) GetKeyID() string { + return k.keyID +} + +// RotateKeys performs key rotation: generates new next key, moves next to current +func (g *OidcTokenGenerator) RotateKeys() error { + // Generate a new key pair for the next key + newNextKeyPair, err := NewOidcKeyPair() + if err != nil { + return fmt.Errorf("failed to generate new next key pair: %w", err) + } + + // Perform rotation: next becomes current, new key becomes next + g.currentKeyPair = g.nextKeyPair + g.nextKeyPair = newNextKeyPair + + // If we don't have a current key (first time setup), use the new key + if g.currentKeyPair == nil { + g.currentKeyPair = newNextKeyPair + // Generate another new key for next + g.nextKeyPair, err = NewOidcKeyPair() + if err != nil { + return fmt.Errorf("failed to generate next key pair for first setup: %w", err) + } + } + + return nil +} + +func (g *OidcTokenGenerator) GetCurrentKeyID() string { + if g.currentKeyPair == nil { + return "" + } + return g.currentKeyPair.GetKeyID() +} + +// GetNextKeyID returns the next key's ID +func (g *OidcTokenGenerator) GetNextKeyID() string { + if g.nextKeyPair == nil { + return "" + } + return g.nextKeyPair.GetKeyID() +} + +// HasKeys returns true if the generator has at least a current key +func (g *OidcTokenGenerator) HasKeys() bool { + return g.currentKeyPair != nil +} + +// OidcClaims represents the claims in an OIDC token +type OidcClaims struct { + // Standard JWT claims + Issuer string `json:"iss"` + Subject string `json:"sub"` + Audience string `json:"aud"` + ExpiresAt int64 `json:"exp"` + NotBefore int64 `json:"nbf"` + IssuedAt int64 `json:"iat"` + JWTID string `json:"jti"` +} + +// CreateToken creates a signed JWT token for the given OidcToken and pipeline context +func (g *OidcTokenGenerator) CreateToken( + oidcToken models.OidcToken, + pipelineId models.PipelineId, + repoOwner, repoName string, +) (string, error) { + now := time.Now() + exp := now.Add(5 * time.Minute) + + // Determine audience - use the provided audience or default to issuer + audience := fmt.Sprintf(g.issuer) + if oidcToken.Aud != nil && *oidcToken.Aud != "" { + audience = *oidcToken.Aud + } + + pipelineUri := pipelineId.AtUri() + + // Create claims + claims := OidcClaims{ + Issuer: g.issuer, + // Hardcode the did as did:web of the issuer. At some point knots will have their own DIDs which will be used here + Subject: pipelineUri.String(), + Audience: audience, + ExpiresAt: exp.Unix(), + NotBefore: now.Unix(), + IssuedAt: now.Unix(), + // Repo owner, name, and id should be global unique but we add timestamp to ensure uniqueness + JWTID: fmt.Sprintf("%s/%s-%s-%d", repoOwner, repoName, pipelineUri.RecordKey(), now.Unix()), + } + + // Create JWT token + token := jwt.New() + + // Set all claims + if err := token.Set(jwt.IssuerKey, claims.Issuer); err != nil { + return "", fmt.Errorf("failed to set issuer: %w", err) + } + if err := token.Set(jwt.SubjectKey, claims.Subject); err != nil { + return "", fmt.Errorf("failed to set subject: %w", err) + } + if err := token.Set(jwt.AudienceKey, claims.Audience); err != nil { + return "", fmt.Errorf("failed to set audience: %w", err) + } + if err := token.Set(jwt.ExpirationKey, claims.ExpiresAt); err != nil { + return "", fmt.Errorf("failed to set expiration: %w", err) + } + if err := token.Set(jwt.NotBeforeKey, claims.NotBefore); err != nil { + return "", fmt.Errorf("failed to set not before: %w", err) + } + if err := token.Set(jwt.IssuedAtKey, claims.IssuedAt); err != nil { + return "", fmt.Errorf("failed to set issued at: %w", err) + } + if err := token.Set(jwt.JwtIDKey, claims.JWTID); err != nil { + return "", fmt.Errorf("failed to set JWT ID: %w", err) + } + + // Sign the token with the current key + if g.currentKeyPair == nil { + return "", fmt.Errorf("no current key pair available for signing") + } + signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.ES256, g.currentKeyPair.jwkKey)) + if err != nil { + return "", fmt.Errorf("failed to sign token: %w", err) + } + + return string(signedToken), nil +} + +// JWKSHandler serves the JWKS endpoint as an HTTP handler +func (g *OidcTokenGenerator) JWKSHandler(w http.ResponseWriter, r *http.Request) { + var keys []jwk.Key + + // Add current key if available + if g.currentKeyPair != nil { + pubJWK, err := jwk.PublicKeyOf(g.currentKeyPair.jwkKey) + if err != nil { + http.Error(w, fmt.Sprintf("failed to extract current public key from JWK: %v", err), http.StatusInternalServerError) + return + } + keys = append(keys, pubJWK) + } + + // Add next key if available + if g.nextKeyPair != nil { + pubJWK, err := jwk.PublicKeyOf(g.nextKeyPair.jwkKey) + if err != nil { + http.Error(w, fmt.Sprintf("failed to extract next public key from JWK: %v", err), http.StatusInternalServerError) + return + } + keys = append(keys, pubJWK) + } + + if len(keys) == 0 { + http.Error(w, "no keys available for JWKS", http.StatusInternalServerError) + return + } + + jwks := map[string]interface{}{ + "keys": keys, + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(jwks); err != nil { + http.Error(w, fmt.Sprintf("failed to encode JWKS: %v", err), http.StatusInternalServerError) + } +} + +// DiscoveryHandler serves the OIDC discovery endpoint for JWKS +func (g *OidcTokenGenerator) DiscoveryHandler(w http.ResponseWriter, r *http.Request) { + claimsSupported := []string{ + "iss", + "sub", + "aud", + "exp", + "nbf", + "iat", + "jti", + } + + responseTypesSupported := []string{ + "id_token", + } + + subjectTypesSupported := []string{ + "public", + } + + idTokenSigningAlgValuesSupported := []string{ + jwa.RS256.String(), + } + + scopesSupported := []string{ + "openid", + } + + discovery := map[string]interface{}{ + "issuer": g.issuer, + "jwks_uri": fmt.Sprintf("%s%s", g.issuer, JWKSPath), + "claims_supported": claimsSupported, + "response_types_supported": responseTypesSupported, + "subject_types_supported": subjectTypesSupported, + "id_token_signing_alg_values_supported": idTokenSigningAlgValuesSupported, + "scopes_supported": scopesSupported, + } + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(discovery); err != nil { + http.Error(w, fmt.Sprintf("failed to encode discovery document: %v", err), http.StatusInternalServerError) + } +} diff --git a/spindle/server.go b/spindle/server.go index f04b206..e68741c 100644 --- a/spindle/server.go +++ b/spindle/server.go @@ -21,6 +21,7 @@ import ( "tangled.sh/tangled.sh/core/spindle/db" "tangled.sh/tangled.sh/core/spindle/engine" "tangled.sh/tangled.sh/core/spindle/models" + "tangled.sh/tangled.sh/core/spindle/oidc" "tangled.sh/tangled.sh/core/spindle/queue" "tangled.sh/tangled.sh/core/spindle/secrets" "tangled.sh/tangled.sh/core/spindle/xrpc" @@ -93,7 +94,12 @@ func Run(ctx context.Context) error { return fmt.Errorf("unknown secrets provider: %s", cfg.Server.Secrets.Provider) } - eng, err := engine.New(ctx, cfg, d, &n, vault) + oidc, err := oidc.NewOidcTokenGenerator(cfg.Server.Hostname) + if err != nil { + return fmt.Errorf("failed to create OIDC token generator: %w", err) + } + + eng, err := engine.New(ctx, cfg, d, &n, vault, oidc) if err != nil { return err } @@ -188,12 +194,12 @@ func Run(ctx context.Context) error { }() logger.Info("starting spindle server", "address", cfg.Server.ListenAddr) - logger.Error("server error", "error", http.ListenAndServe(cfg.Server.ListenAddr, spindle.Router())) + logger.Error("server error", "error", http.ListenAndServe(cfg.Server.ListenAddr, spindle.Router(oidc))) return nil } -func (s *Spindle) Router() http.Handler { +func (s *Spindle) Router(oidcg *oidc.OidcTokenGenerator) http.Handler { mux := chi.NewRouter() mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { @@ -204,6 +210,9 @@ func (s *Spindle) Router() http.Handler { w.Write([]byte(s.cfg.Server.Owner)) }) mux.HandleFunc("/logs/{knot}/{rkey}/{name}", s.Logs) + mux.HandleFunc(oidc.JWKSPath, oidcg.JWKSHandler) + mux.HandleFunc("/.well-known/oidc-configuration", oidcg.DiscoveryHandler) + // TODO: Do we need webfinger issuer discovery? mux.Mount("/xrpc", s.XrpcRouter()) return mux -- 2.43.0