Merge pull request #10439 from olemarkus/remove-node-authorizer

Remove node-authorization
This commit is contained in:
Kubernetes Prow Robot 2021-01-11 11:14:25 -08:00 committed by GitHub
commit 789877984b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
392 changed files with 13 additions and 63123 deletions

View File

@ -1,9 +1,11 @@
### **Node Authorization Service**
:warning: The node authorization service is deprecated.
As of Kubernetes 1.19 kOps will, on AWS, ignore the `nodeAuthorization` field of the cluster spec and
**The node authorization service is deprecated.**
As of kOps 1.19 with AWS and Kubernetes 1.19, the `nodeAuthorization` field of the cluster spec will be ignored and
worker nodes will obtain client certificates for kubelet and other purposes through kops-controller.
As of kOps 1.20, setting `nodeAuthorization` is forbidden for any Kubernetes version and is replaced by the bootstrapping using kops-controller.
The [node authorization service] is an experimental service which in the absence of a kops-apiserver provides the distribution of tokens to the worker nodes. Bootstrap tokens provide worker nodes a short-time credential to request access kubeconfig certificate. A gist of the flow is;
- a secret of type `bootstrap.kubernetes.io/token` is created on behalf of a node in the kube-system namespace.

View File

@ -46,6 +46,8 @@ See the [documentation](/cluster_spec/#load-balancer-class) for more info.
* Allow users to partially compress user-data, check the instance groups docs for more details.
* Worker nodes on AWS will now be bootstrapped using kops-controller.
### CLI
* The `kops update cluster` command will now refuse to run on a cluster that
@ -140,6 +142,8 @@ has been updated by a newer version of kOps unless it is given the `--allow-kops
* The [manifest based cluster autoscaler addon](https://github.com/kubernetes/kops/tree/master/addons/cluster-autoscaler) has been deprecated in favour of a configurable addon.
* The experimental node authorizor is now ignored if you are using kubernetes 1.19. The feature will be removed in 1.20. Worker nodes will instead be authorized using kops-controller.
# Full change list since 1.18.0 release
## v1.18.0-alpha.3 to v1.19.0-alpha.1

View File

@ -36,4 +36,6 @@
* The `node-role.kubernetes.io/master` and `kubernetes.io/role` labels are deprecated and will be removed from control plane nodes in kOps 1.22
* The experimental node-authorizer that could be enabled using `nodeAuthorization` has been removed. Setting this value is now forbidden.
# Full change list since 1.19.0 release

5
go.mod
View File

@ -66,7 +66,6 @@ require (
github.com/digitalocean/godo v1.54.0
github.com/docker/docker v1.4.2-0.20200309214505-aa6a9891b09c
github.com/docker/spdystream v0.0.0-20181023171402-6480d4af844c // indirect
github.com/fullsailor/pkcs7 v0.0.0-20180422025557-ae226422660e
github.com/go-bindata/go-bindata/v3 v3.1.3
github.com/go-ini/ini v1.62.0
github.com/go-logr/logr v0.2.1-0.20200730175230-ee2de8da5be6
@ -74,11 +73,9 @@ require (
github.com/google/go-cmp v0.5.2
github.com/google/uuid v1.1.2
github.com/gophercloud/gophercloud v0.15.0
github.com/gorilla/mux v1.7.3
github.com/hashicorp/hcl/v2 v2.7.0
github.com/hashicorp/vault/api v1.0.4
github.com/jacksontj/memberlistmesh v0.0.0-20190905163944-93462b9d2bb7
github.com/jpillora/backoff v0.0.0-20170918002102-8eab2debe79d
github.com/miekg/coredns v0.0.0-20161111164017-20e25559d5ea
github.com/mitchellh/mapstructure v1.1.2
github.com/pelletier/go-toml v1.8.1
@ -90,11 +87,9 @@ require (
github.com/spf13/viper v1.7.0
github.com/spotinst/spotinst-sdk-go v1.75.0
github.com/stretchr/testify v1.6.1
github.com/urfave/cli v1.22.2
github.com/weaveworks/mesh v0.0.0-20170419100114-1f158d31de55
github.com/zclconf/go-cty v1.3.1
go.etcd.io/etcd v0.5.0-alpha.5.0.20200910180754-dd1b699fc489
go.uber.org/zap v1.13.0
golang.org/x/crypto v0.0.0-20201208171446-5f87f3452ae9
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d

4
go.sum
View File

@ -324,8 +324,6 @@ github.com/franela/goreq v0.0.0-20171204163338-bcd34c9993f8/go.mod h1:ZhphrRTfi2
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
github.com/fullsailor/pkcs7 v0.0.0-20180422025557-ae226422660e h1:qt5qtzBGD2AoRIxNNxJZr2dC4ei+pyolhbho9knAI1Q=
github.com/fullsailor/pkcs7 v0.0.0-20180422025557-ae226422660e/go.mod h1:KnogPXtdwXqoenmZCw6S+25EAm2MkxbG0deNDu4cbSA=
github.com/fvbommel/sortorder v1.0.1/go.mod h1:uk88iVf1ovNn1iLfgUVU2F9o5eO30ui720w+kxuqRs0=
github.com/garyburd/redigo v0.0.0-20150301180006-535138d7bcd7/go.mod h1:NR3MbYisc3/PwhQ00EMzDiPmrwpPxAn5GI05/YaO1SY=
github.com/ghodss/yaml v0.0.0-20150909031657-73d445a93680/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
@ -621,8 +619,6 @@ github.com/jmoiron/sqlx v1.2.0/go.mod h1:1FEQNm3xlJgrMD+FBdI9+xvCksHtbpVBBw5dYhB
github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg=
github.com/jonboulle/clockwork v0.1.0 h1:VKV+ZcuP6l3yW9doeqz6ziZGgcynBVQO+obU0+0hcPo=
github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo=
github.com/jpillora/backoff v0.0.0-20170918002102-8eab2debe79d h1:ix3WmphUvN0GDd0DO9MH0v6/5xTv+Xm1bPN+1UJn58k=
github.com/jpillora/backoff v0.0.0-20170918002102-8eab2debe79d/go.mod h1:2iMrUgbbvHEiQClaW2NsSzMyGHqN+rDFqY705q49KG0=
github.com/json-iterator/go v1.1.5/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=

View File

@ -1,3 +0,0 @@
*.swp
bin/
tests/

View File

@ -1,28 +0,0 @@
# Copyright 2019 The Kubernetes Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
FROM alpine:3.11 as ssl
RUN apk --no-cache add ca-certificates && \
apk --no-cache upgrade
FROM scratch
LABEL Name=node-authorizer \
Release=https://github.com/kubernetes/kops \
Url=https://github.com/kubernetes/kops \
Help=https://github.com/kubernetes/kops\issues
ADD bin/node-authorizer /node-authorizer
COPY --from=ssl /etc/ssl /etc/ssl
ENTRYPOINT ["/node-authorizer"]

View File

@ -1,26 +0,0 @@
load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_library")
go_library(
name = "go_default_library",
srcs = [
"client.go",
"main.go",
"server.go",
],
importpath = "k8s.io/kops/node-authorizer/cmd/node-authorizer",
visibility = ["//visibility:private"],
deps = [
"//node-authorizer/pkg/authorizers/alwaysallow:go_default_library",
"//node-authorizer/pkg/authorizers/aws:go_default_library",
"//node-authorizer/pkg/client:go_default_library",
"//node-authorizer/pkg/server:go_default_library",
"//node-authorizer/pkg/utils:go_default_library",
"//vendor/github.com/urfave/cli:go_default_library",
],
)
go_binary(
name = "node-authorizer",
embed = [":go_default_library"],
visibility = ["//visibility:public"],
)

View File

@ -1,104 +0,0 @@
/*
Copyright 2019 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package main
import (
"time"
"k8s.io/kops/node-authorizer/pkg/client"
"github.com/urfave/cli"
)
// addClientCommand creates and returns a client command
func addClientCommand() cli.Command {
return cli.Command{
Name: "client",
Usage: "starts the service in a client mode and attempts to acquire a bootstrap token",
Flags: []cli.Flag{
cli.StringFlag{
Name: "authorizer",
Usage: "provider we should use to authorize the node registration `NAME`",
EnvVar: "AUTHORIZER",
Value: "aws",
},
cli.StringFlag{
Name: "node-url",
Usage: "the url for the node authorizer service `URL`",
EnvVar: "NODE_AUTHORIZER_URL",
},
cli.StringFlag{
Name: "kubeapi-url",
Usage: "the url for the kubernetes api `URL`",
EnvVar: "KUBEAPI_URL",
},
cli.StringFlag{
Name: "kubeconfig",
Usage: "location to write bootstrap token config `PATH`",
EnvVar: "KUBECONFIG_BOOTSTRAP",
Value: "/var/lib/kubelet/bootstrap-kubeconfig",
},
cli.StringFlag{
Name: "tls-client-ca",
Usage: "file containing the certificate authority used to verify node endpoint `PATH`",
EnvVar: "TLS_CLIENT_CA",
},
cli.StringFlag{
Name: "tls-cert",
Usage: "file containing the client certificate `PATH`",
EnvVar: "TLS_CERT",
},
cli.StringFlag{
Name: "tls-private-key",
Usage: "file containing the client private key `PATH`",
EnvVar: "TLS_PRIVATE_KEY",
},
cli.DurationFlag{
Name: "interval",
Usage: "an interval to wait between failed request `DURATION`",
EnvVar: "INTERVAL",
Value: 3 * time.Second,
},
cli.DurationFlag{
Name: "timeout",
Usage: "the max time to wait before timing out `DURATION`",
EnvVar: "TIMEOUT",
Value: 30 * time.Second,
},
},
Action: func(ctx *cli.Context) error {
return actionClientCommand(ctx)
},
}
}
// actionClientCommand is the client action
func actionClientCommand(ctx *cli.Context) error {
return client.New(&client.Config{
Authorizer: ctx.String("authorizer"),
Interval: ctx.Duration("interval"),
KubeAPI: ctx.String("kubeapi-url"),
KubeConfigPath: ctx.String("kubeconfig"),
NodeURL: ctx.String("node-url"),
TLSCertPath: ctx.String("tls-cert"),
TLSClientCAPath: ctx.String("tls-client-ca"),
TLSPrivateKeyPath: ctx.String("tls-private-key"),
Timeout: ctx.Duration("timeout"),
})
}

View File

@ -1,44 +0,0 @@
/*
Copyright 2019 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package main
import (
"fmt"
"os"
"k8s.io/kops/node-authorizer/pkg/server"
"github.com/urfave/cli"
)
func main() {
app := cli.NewApp()
app.Authors = []cli.Author{
{
Name: "Rohith Jayawardene",
Email: "gambol99@gmail.com",
},
}
app.Commands = []cli.Command{addServerCommand(), addClientCommand()}
app.Usage = "used to provision the bootstrap tokens for node registration"
app.Version = server.Version
if err := app.Run(os.Args); err != nil {
fmt.Fprintf(os.Stderr, "[error] %s\n", err)
os.Exit(1)
}
}

View File

@ -1,198 +0,0 @@
/*
Copyright 2019 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package main
import (
"errors"
"fmt"
"time"
"k8s.io/kops/node-authorizer/pkg/authorizers/alwaysallow"
"k8s.io/kops/node-authorizer/pkg/authorizers/aws"
"k8s.io/kops/node-authorizer/pkg/server"
"k8s.io/kops/node-authorizer/pkg/utils"
"github.com/urfave/cli"
)
// addServerCommand creates and returns a server command
func addServerCommand() cli.Command {
return cli.Command{
Name: "server",
Usage: "starts the node-authorizer in server mode",
Flags: []cli.Flag{
cli.StringFlag{
Name: "authorizer",
Usage: "provider we should use to authorize the node registration `NAME`",
EnvVar: "AUTHORIZER",
},
cli.StringFlag{
Name: "listen",
Usage: "interface to bind the service `INTERFACE`",
EnvVar: "LISTEN",
Value: ":10443",
},
cli.StringFlag{
Name: "tls-client-ca",
Usage: "file containing the client certificate authority, required for mutual tls `PATH`",
EnvVar: "TLS_CLIENT_CA",
},
cli.StringFlag{
Name: "tls-cert",
Usage: "file containing the certificate `PATH`",
EnvVar: "TLS_CERT",
Value: "/certs/node-authorizer.pem",
},
cli.StringFlag{
Name: "tls-private-key",
Usage: "file containing the private key `PATH`",
EnvVar: "TLS_PRIVATE_KEY",
Value: "/certs/node-authorizer-key.pem",
},
cli.StringFlag{
Name: "cluster-name",
Usage: "name of the kubernetes cluster we are provisioning `NAME`",
EnvVar: "CLUSTER_NAME",
},
cli.StringFlag{
Name: "cluster-tag",
Usage: "name of the cloud tag used to identify the cluster name `NAME`",
EnvVar: "CLUSTER_TAG",
Value: "KubernetesCluster",
},
cli.StringSliceFlag{
Name: "feature",
Usage: "enables or disables a feature in the chosen authorizer `NAME`",
},
cli.DurationFlag{
Name: "token-ttl",
Usage: "expiration on created bootstrap token `DURATION`",
EnvVar: "TOKEN_TTL",
Value: 5 * time.Minute,
},
cli.StringFlag{
Name: "client-common-name",
Usage: "the common name of the client certificate when use mutual tls `NAME`",
EnvVar: "CLIENT_COMMON_NAME",
Value: "node-authorizer-client",
},
cli.DurationFlag{
Name: "certificate-ttl",
Usage: "check the certificates exist and if not wait for x period `DURATION`",
EnvVar: "CERTIFICATE_TTL",
Value: 1 * time.Hour,
},
cli.DurationFlag{
Name: "authorization-timeout",
Usage: "max time permitted for a authorization `DURATION`",
EnvVar: "AUTHORIZATION_TIMEOUT",
Value: 15 * time.Second,
},
},
Action: func(ctx *cli.Context) error {
return actionServerCommand(ctx)
},
}
}
// actionServerCommand is responsible for performing the server action
func actionServerCommand(ctx *cli.Context) error {
config := &server.Config{
AuthorizationTimeout: ctx.Duration("authorization-timeout"),
ClientCommonName: ctx.String("client-common-name"),
ClusterName: ctx.String("cluster-name"),
ClusterTag: ctx.String("cluster-tag"),
Features: ctx.StringSlice("feature"),
Listen: ctx.String("listen"),
TLSCertPath: ctx.String("tls-cert"),
TLSClientCAPath: ctx.String("tls-client-ca"),
TLSPrivateKeyPath: ctx.String("tls-private-key"),
TokenDuration: ctx.Duration("token-ttl"),
}
if ctx.String("authorizer") == "" {
return errors.New("no authorizer specified")
}
// @step: should we wait for the certificates to appear
if ctx.Duration("certificate-ttl") > 0 {
var files = []string{ctx.String("tls-cert"), ctx.String("tls-client-ca"), ctx.String("tls-private-key")}
var timeout = ctx.Duration("certificate-ttl")
if err := waitForCertificates(files, timeout); err != nil {
return err
}
}
// @step: create the authorizers
auth, err := createAuthorizer(ctx.String("authorizer"), config)
if err != nil {
return fmt.Errorf("failed to create authorizer: %v", err)
}
svc, err := server.New(config, auth)
if err != nil {
return err
}
return svc.Run()
}
// waitForCertificates is responsible for waiting for the certificates to appear
func waitForCertificates(files []string, timeout time.Duration) error {
doneCh := make(chan struct{})
go func() {
expires := time.Now().Add(timeout)
// @step: iterate the file we are looking for
for _, x := range files {
if x == "" {
continue
}
// @step: iterate until we find the file
for {
if utils.FileExists(x) {
break
}
fmt.Printf("waiting for file: %s to appear, timeouts in %s\n", x, time.Until(expires))
time.Sleep(5 * time.Second)
}
}
doneCh <- struct{}{}
}()
select {
case <-doneCh:
return nil
case <-time.After(timeout):
return fmt.Errorf("unable to find the certificates after %s timeout", timeout)
}
}
// createAuthorizer creates and returns a authorizer
func createAuthorizer(name string, config *server.Config) (server.Authorizer, error) {
switch name {
case "alwaysallow":
return alwaysallow.NewAuthorizer()
case "aws":
return aws.NewAuthorizer(config)
}
return nil, errors.New("unknown authorizer")
}

View File

@ -1,15 +0,0 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library(
name = "go_default_library",
srcs = [
"authorizer.go",
"verifier.go",
],
importpath = "k8s.io/kops/node-authorizer/pkg/authorizers/alwaysallow",
visibility = ["//visibility:public"],
deps = [
"//node-authorizer/pkg/server:go_default_library",
"//node-authorizer/pkg/utils:go_default_library",
],
)

View File

@ -1,51 +0,0 @@
/*
Copyright 2019 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package alwaysallow
import (
"context"
"k8s.io/kops/node-authorizer/pkg/server"
"k8s.io/kops/node-authorizer/pkg/utils"
)
// alwaysAllowAuth is the implementation for a node authozier
type alwaysAllowAuth struct{}
// NewAuthorizer creates and returns a alwaysAllow node authorizer
func NewAuthorizer() (server.Authorizer, error) {
utils.Logger.Warn("note the alwaysallow authorizer performs no authoritative checks and should only be used in test environments")
return &alwaysAllowAuth{}, nil
}
// Authorize is responsible for accepting the request
func (a *alwaysAllowAuth) Authorize(_ context.Context, req *server.NodeRegistration) error {
req.Status.Allowed = true
return nil
}
// Name returns the name of the authorizer
func (a *alwaysAllowAuth) Name() string {
return "alwaysAllow"
}
// Close is called when the authorizer is shutting down
func (a *alwaysAllowAuth) Close() error {
return nil
}

View File

@ -1,39 +0,0 @@
/*
Copyright 2019 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package alwaysallow
import (
"context"
"k8s.io/kops/node-authorizer/pkg/server"
"k8s.io/kops/node-authorizer/pkg/utils"
)
// alwaysallowVerifier implements the verifier
type alwaysallowVerifier struct{}
// NewVerifier creates and returns a new Verifier
func NewVerifier() (server.Verifier, error) {
utils.Logger.Warn("note the alwaysallow authorizer performs no authoritative checks and should only be used in test environments")
return &alwaysallowVerifier{}, nil
}
// VerifyIdentity is responsible for providing proof of identity
func (a *alwaysallowVerifier) VerifyIdentity(context.Context) ([]byte, error) {
return []byte{}, nil
}

View File

@ -1,36 +0,0 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
go_library(
name = "go_default_library",
srcs = [
"authorizer.go",
"types.go",
"verifier.go",
],
importpath = "k8s.io/kops/node-authorizer/pkg/authorizers/aws",
visibility = ["//visibility:public"],
deps = [
"//node-authorizer/pkg/server:go_default_library",
"//node-authorizer/pkg/utils:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/aws:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/aws/ec2metadata:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/aws/session:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/service/autoscaling:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/service/autoscaling/autoscalingiface:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/service/ec2:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/service/ec2/ec2iface:go_default_library",
"//vendor/github.com/fullsailor/pkcs7:go_default_library",
"//vendor/go.uber.org/zap:go_default_library",
],
)
go_test(
name = "go_default_test",
srcs = ["authorizer_test.go"],
embed = [":go_default_library"],
deps = [
"//node-authorizer/pkg/server:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/aws/ec2metadata:go_default_library",
"//vendor/github.com/stretchr/testify/assert:go_default_library",
],
)

View File

@ -1,350 +0,0 @@
/*
Copyright 2019 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package aws
import (
"bytes"
"context"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"time"
"k8s.io/kops/node-authorizer/pkg/server"
"k8s.io/kops/node-authorizer/pkg/utils"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/autoscaling"
"github.com/aws/aws-sdk-go/service/autoscaling/autoscalingiface"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/ec2/ec2iface"
"github.com/fullsailor/pkcs7"
"go.uber.org/zap"
)
var (
// a collection of aws public signing certificates
publicCertificates []*x509.Certificate
)
var (
// CheckIAMProfile indicates we should validate the IAM profile
CheckIAMProfile = "verify-iam-profile"
// CheckIPAddress indicates we should validate the client ip address
CheckIPAddress = "verify-ip"
// CheckSignature indicates we validate the signature of the document
CheckSignature = "verify-signature"
)
// awsNodeAuthorizer is the implementation for a node authorizer
type awsNodeAuthorizer struct {
// client is the ec2 interface
client ec2iface.EC2API
// asgc is the autoscaling client
asgc autoscalingiface.AutoScalingAPI
// config is the service configuration
config *server.Config
// identity is the identity document for the instance we are running on
identity ec2metadata.EC2InstanceIdentityDocument
// instance is the ec2 instance we are running on
instance *ec2.Instance
// vpcID is our vpc id
vpcID string
}
// NewAuthorizer creates and returns a aws node authorizer
func NewAuthorizer(config *server.Config) (server.Authorizer, error) {
// @step: load the public certificates
if err := GetPublicCertificates(); err != nil {
return nil, err
}
// @step: get the identity document for the instance we are running
document, err := getInstanceIdentityDocument()
if err != nil {
return nil, err
}
utils.Logger.Info("running node authorizer on instance",
zap.String("instance-id", document.InstanceID),
zap.String("region", document.Region))
// @step: we create a ec2 and autoscaling client
sess, err := session.NewSession(&aws.Config{
Region: aws.String(document.Region),
})
if err != nil {
return nil, err
}
client := ec2.New(sess)
asgc := autoscaling.New(sess)
// @step: get information on the instance we are running
instance, err := getInstance(client, document.InstanceID)
if err != nil {
return nil, err
}
return &awsNodeAuthorizer{
client: client,
asgc: asgc,
config: config,
identity: document,
instance: instance,
vpcID: aws.StringValue(instance.VpcId),
}, nil
}
// Authorize is responsible for accepting the request
func (a *awsNodeAuthorizer) Authorize(ctx context.Context, r *server.NodeRegistration) error {
identity := &ec2metadata.EC2InstanceIdentityDocument{}
// @step: decode the request
request, err := decodeRequest(r.Spec.Request)
if err != nil {
return err
}
// @step: extract and validate the document
if reason, err := func() (string, error) {
if a.config.UseFeature(CheckSignature) {
if reason, err := a.validateIdentityDocument(ctx, request.Document, identity); err != nil {
return "", err
} else if reason != "" {
return reason, nil
}
}
if reason, err := a.validateNodeInstance(ctx, identity, r); err != nil {
return "", err
} else if reason != "" {
return reason, nil
}
r.Status.Allowed = true
return "", nil
}(); err != nil {
return err
} else if reason != "" {
r.Deny(reason)
}
return nil
}
// validateNodeInstance is responsible for checking the instance exists and it part of the cluster
func (a *awsNodeAuthorizer) validateNodeInstance(ctx context.Context, doc *ec2metadata.EC2InstanceIdentityDocument, spec *server.NodeRegistration) (string, error) {
// @check we are in the same account
if a.identity.AccountID != doc.AccountID {
return "instance running in different account id", nil
}
// @check we found some instances
instance, err := getInstance(a.client, doc.InstanceID)
if err != nil {
return "", err
}
if aws.StringValue(instance.State.Name) != ec2.InstanceStateNameRunning {
return "instance is not running", nil
}
// @check the instance is running in our vpc
if aws.StringValue(instance.VpcId) != a.vpcID {
return "instance is not running in our VPC", nil
}
// @check the instance is tagged with our kubernetes cluster id
if !hasInstanceTags(a.config.ClusterTag, a.config.ClusterName, instance.Tags) {
return "missing cluster tag", nil
}
// @check the instance has access to the nodes IAM profile
if a.config.UseFeature(CheckIAMProfile) {
if instance.IamInstanceProfile == nil {
return "instance does not have an instance profile", nil
}
if aws.StringValue(instance.IamInstanceProfile.Arn) == "" {
return "instance profile arn is empty", nil
}
expectedArn := fmt.Sprintf("arn:aws:iam::%s:role/nodes.%s", a.identity.AccountID, a.config.ClusterName)
if expectedArn != aws.StringValue(instance.IamInstanceProfile.Arn) {
return fmt.Sprintf("invalid iam instance role, expected: %s, found: %s", expectedArn, aws.StringValue(instance.IamInstanceProfile.Arn)), nil
}
}
// @check the requester is as expected
if a.config.UseFeature(CheckIPAddress) {
if spec.Spec.RemoteAddr != aws.StringValue(instance.PrivateIpAddress) {
return fmt.Sprintf("ip address conflict, expected: %s, got: %s", aws.StringValue(instance.PrivateIpAddress), spec.Spec.RemoteAddr), nil
}
}
return "", nil
}
// validateIdentityDocument is responsible for validate the aws identity document
func (a *awsNodeAuthorizer) validateIdentityDocument(_ context.Context, signed []byte, document interface{}) (string, error) {
// @step: decode the signed document
decoded, err := base64.StdEncoding.DecodeString(string(signed))
if err != nil {
return "", err
}
// @step: get the digest
for _, x := range publicCertificates {
parsed, err := pkcs7.Parse(decoded)
if err != nil {
return "", err
}
parsed.Certificates = []*x509.Certificate{x}
if err := parsed.Verify(); err != nil {
utils.Logger.Warn("identity document not validated by certificates",
zap.String("common-name", x.Subject.CommonName),
zap.Error(err))
} else {
return "", json.NewDecoder(bytes.NewReader(parsed.Content)).Decode(document)
}
}
return "invalid signature", nil
}
// validateNodeRegistrationRequest is responsible for validating the request itself
func validateNodeRegistrationRequest(request *Request) error {
err := func() error {
if len(request.Document) <= 0 {
return errors.New("missing identity document")
}
return nil
}()
if err != nil {
return fmt.Errorf("invalid verification request: %s", err)
}
return nil
}
// decodeRequest is responsible for decoding the request
func decodeRequest(in []byte) (*Request, error) {
request := &Request{}
if err := json.NewDecoder(bytes.NewReader(in)).Decode(request); err != nil {
return nil, err
}
// @step: validate the node request
if err := validateNodeRegistrationRequest(request); err != nil {
return nil, err
}
return request, nil
}
func (a *awsNodeAuthorizer) Close() error {
return nil
}
// Name returns the name of the authozier
func (a *awsNodeAuthorizer) Name() string {
return "aws"
}
// hasInstanceTags checks the tags exists on the cluster
func hasInstanceTags(name, value string, tags []*ec2.Tag) bool {
for _, x := range tags {
if aws.StringValue(x.Key) == name && aws.StringValue(x.Value) == value {
return true
}
}
return false
}
// getInstanceIdentityDocument is responsible for retrieving the instance identity document
func getInstanceIdentityDocument() (ec2metadata.EC2InstanceIdentityDocument, error) {
var document ec2metadata.EC2InstanceIdentityDocument
sess, err := session.NewSession()
if err != nil {
return document, err
}
client := ec2metadata.New(sess)
maxInterval := 500 * time.Millisecond
maxTime := 5 * time.Second
err = utils.Retry(context.TODO(), maxInterval, maxTime, func() error {
x, err := client.GetInstanceIdentityDocument()
if err != nil {
return err
}
document = x
return nil
})
return document, err
}
// getInstance is responsible for getting the instance
func getInstance(client ec2iface.EC2API, instanceID string) (*ec2.Instance, error) {
// @step: describe the instance
resp, err := client.DescribeInstances(&ec2.DescribeInstancesInput{
InstanceIds: []*string{aws.String(instanceID)},
})
if err != nil {
return nil, err
}
// @check we found some instances
if len(resp.Reservations) <= 0 || len(resp.Reservations[0].Instances) <= 0 {
return nil, fmt.Errorf("missing instance id: %s", instanceID)
}
if len(resp.Reservations[0].Instances) > 1 {
return nil, fmt.Errorf("found multiple instances with instance id: %s", instanceID)
}
// @check the instance is running
instance := resp.Reservations[0].Instances[0]
if instance.State == nil {
return nil, errors.New("missing instance status")
}
return instance, nil
}
// GetPublicCertificates loads the certificates
func GetPublicCertificates() error {
for i := range awsCertificates {
block, _ := pem.Decode([]byte(awsCertificates[i]))
c, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return err
}
publicCertificates = append(publicCertificates, c)
}
return nil
}

View File

@ -1,72 +0,0 @@
/*
Copyright 2019 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package aws
import (
"context"
"testing"
"k8s.io/kops/node-authorizer/pkg/server"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/stretchr/testify/assert"
)
func newTestAuthorizer(t *testing.T, config *server.Config) *awsNodeAuthorizer {
if config == nil {
config = &server.Config{}
}
c := &awsNodeAuthorizer{
config: config,
vpcID: "test",
}
if err := GetPublicCertificates(); err != nil {
t.Errorf("unable to parse public certificates: %s", err)
t.FailNow()
}
return c
}
func TestValidateIdentityDocument(t *testing.T) {
c := newTestAuthorizer(t, nil)
request := &Request{
Document: []byte(`MIAGCSqGSIb3DQEHAqCAMIACAQExCzAJBgUrDgMCGgUAMIAGCSqGSIb3DQEHAaCAJIAEggHUewog
ICJtYXJrZXRwbGFjZVByb2R1Y3RDb2RlcyIgOiBudWxsLAogICJkZXZwYXlQcm9kdWN0Q29kZXMi
IDogbnVsbCwKICAicHJpdmF0ZUlwIiA6ICIxMC4yNTAuMTAxLjE3IiwKICAidmVyc2lvbiIgOiAi
MjAxNy0wOS0zMCIsCiAgImluc3RhbmNlSWQiIDogImktMDJhYTA0MDdmNDAwYmJmY2YiLAogICJi
aWxsaW5nUHJvZHVjdHMiIDogbnVsbCwKICAiaW5zdGFuY2VUeXBlIiA6ICJtNC5sYXJnZSIsCiAg
ImF2YWlsYWJpbGl0eVpvbmUiIDogImV1LXdlc3QtMmIiLAogICJrZXJuZWxJZCIgOiBudWxsLAog
ICJyYW1kaXNrSWQiIDogbnVsbCwKICAiYWNjb3VudElkIiA6ICI2NzA5MzA2NDYxMDMiLAogICJh
cmNoaXRlY3R1cmUiIDogIng4Nl82NCIsCiAgImltYWdlSWQiIDogImFtaS1iNTMwZDFkMiIsCiAg
InBlbmRpbmdUaW1lIiA6ICIyMDE4LTA2LTA5VDEzOjE5OjQzWiIsCiAgInJlZ2lvbiIgOiAiZXUt
d2VzdC0yIgp9AAAAAAAAMYIBFzCCARMCAQEwaTBcMQswCQYDVQQGEwJVUzEZMBcGA1UECBMQV2Fz
aGluZ3RvbiBTdGF0ZTEQMA4GA1UEBxMHU2VhdHRsZTEgMB4GA1UEChMXQW1hem9uIFdlYiBTZXJ2
aWNlcyBMTEMCCQCWukjZ5V4aZzAJBgUrDgMCGgUAoF0wGAYJKoZIhvcNAQkDMQsGCSqGSIb3DQEH
ATAcBgkqhkiG9w0BCQUxDxcNMTgwNjA5MTMxOTQ3WjAjBgkqhkiG9w0BCQQxFgQUzenf7yQR02cW
A4t1ZTpGNjz7490wCQYHKoZIzjgEAwQuMCwCFAZ1VFJSd81PnuG6+1sDFOgr/3tyAhRH14cLYMTN
Uce4CDpektlneHOeLQAAAAAAAA==`),
}
identity := &ec2metadata.EC2InstanceIdentityDocument{}
reason, err := c.validateIdentityDocument(context.TODO(), request.Document, identity)
assert.NoError(t, err)
assert.NotNil(t, identity)
assert.Empty(t, reason)
assert.Equal(t, "i-02aa0407f400bbfcf", identity.InstanceID)
}

View File

@ -1,67 +0,0 @@
/*
Copyright 2019 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package aws
// Request is the request the node authorizer
type Request struct {
// Document is the PKCS7 signed identity document
Document []byte
}
var (
// awsCertificates is a collection of AWS public certificates used to sign the identity documents
// https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instance-identity-documents.html
awsCertificates = []string{
// AWS Public Certificate
`-----BEGIN CERTIFICATE-----
MIIC7TCCAq0CCQCWukjZ5V4aZzAJBgcqhkjOOAQDMFwxCzAJBgNVBAYTAlVTMRkw
FwYDVQQIExBXYXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHEwdTZWF0dGxlMSAwHgYD
VQQKExdBbWF6b24gV2ViIFNlcnZpY2VzIExMQzAeFw0xMjAxMDUxMjU2MTJaFw0z
ODAxMDUxMjU2MTJaMFwxCzAJBgNVBAYTAlVTMRkwFwYDVQQIExBXYXNoaW5ndG9u
IFN0YXRlMRAwDgYDVQQHEwdTZWF0dGxlMSAwHgYDVQQKExdBbWF6b24gV2ViIFNl
cnZpY2VzIExMQzCCAbcwggEsBgcqhkjOOAQBMIIBHwKBgQCjkvcS2bb1VQ4yt/5e
ih5OO6kK/n1Lzllr7D8ZwtQP8fOEpp5E2ng+D6Ud1Z1gYipr58Kj3nssSNpI6bX3
VyIQzK7wLclnd/YozqNNmgIyZecN7EglK9ITHJLP+x8FtUpt3QbyYXJdmVMegN6P
hviYt5JH/nYl4hh3Pa1HJdskgQIVALVJ3ER11+Ko4tP6nwvHwh6+ERYRAoGBAI1j
k+tkqMVHuAFcvAGKocTgsjJem6/5qomzJuKDmbJNu9Qxw3rAotXau8Qe+MBcJl/U
hhy1KHVpCGl9fueQ2s6IL0CaO/buycU1CiYQk40KNHCcHfNiZbdlx1E9rpUp7bnF
lRa2v1ntMX3caRVDdbtPEWmdxSCYsYFDk4mZrOLBA4GEAAKBgEbmeve5f8LIE/Gf
MNmP9CM5eovQOGx5ho8WqD+aTebs+k2tn92BBPqeZqpWRa5P/+jrdKml1qx4llHW
MXrs3IgIb6+hUIB+S8dz8/mmO0bpr76RoZVCXYab2CZedFut7qc3WUH9+EUAH5mw
vSeDCOUMYQR7R9LINYwouHIziqQYMAkGByqGSM44BAMDLwAwLAIUWXBlk40xTwSw
7HX32MxXYruse9ACFBNGmdX2ZBrVNGrN9N2f6ROk0k9K
-----END CERTIFICATE-----`,
// AWS GovCloud (US) region
`-----BEGIN CERTIFICATE-----
MIICuzCCAiQCCQDrSGnlRgvSazANBgkqhkiG9w0BAQUFADCBoTELMAkGA1UEBhMC
VVMxCzAJBgNVBAgTAldBMRAwDgYDVQQHEwdTZWF0dGxlMRMwEQYDVQQKEwpBbWF6
b24uY29tMRYwFAYDVQQLEw1FQzIgQXV0aG9yaXR5MRowGAYDVQQDExFFQzIgQU1J
IEF1dGhvcml0eTEqMCgGCSqGSIb3DQEJARYbZWMyLWluc3RhbmNlLWlpZEBhbWF6
b24uY29tMB4XDTExMDgxMjE3MTgwNVoXDTIxMDgwOTE3MTgwNVowgaExCzAJBgNV
BAYTAlVTMQswCQYDVQQIEwJXQTEQMA4GA1UEBxMHU2VhdHRsZTETMBEGA1UEChMK
QW1hem9uLmNvbTEWMBQGA1UECxMNRUMyIEF1dGhvcml0eTEaMBgGA1UEAxMRRUMy
IEFNSSBBdXRob3JpdHkxKjAoBgkqhkiG9w0BCQEWG2VjMi1pbnN0YW5jZS1paWRA
YW1hem9uLmNvbTCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEAqaIcGFFTx/SO
1W5G91jHvyQdGP25n1Y91aXCuOOWAUTvSvNGpXrI4AXNrQF+CmIOC4beBASnHCx0
82jYudWBBl9Wiza0psYc9flrczSzVLMmN8w/c78F/95NfiQdnUQPpvgqcMeJo82c
gHkLR7XoFWgMrZJqrcUK0gnsQcb6kakCAwEAATANBgkqhkiG9w0BAQUFAAOBgQDF
VH0+UGZr1LCQ78PbBH0GreiDqMFfa+W8xASDYUZrMvY3kcIelkoIazvi4VtPO7Qc
yAiLr6nkk69Tr/MITnmmsZJZPetshqBndRyL+DaTRnF0/xvBQXj5tEh+AmRjvGtp
6iS1rQoNanN8oEcT2j4b48rmCmnDhRoBcFHwCYs/3w==
-----END CERTIFICATE-----`,
}
)

View File

@ -1,79 +0,0 @@
/*
Copyright 2019 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package aws
import (
"context"
"encoding/json"
"k8s.io/kops/node-authorizer/pkg/server"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/session"
)
type awsNodeVerifier struct{}
// NewVerifier creates and returns a verifier
func NewVerifier() (server.Verifier, error) {
return &awsNodeVerifier{}, nil
}
// Verify is responsible for build a identification document
func (a *awsNodeVerifier) VerifyIdentity(ctx context.Context) ([]byte, error) {
errs := make(chan error)
doneCh := make(chan []byte)
go func() {
encoded, err := func() ([]byte, error) {
// @step: create a metadata client
sess, err := session.NewSession()
if err != nil {
return []byte{}, err
}
client := ec2metadata.New(sess)
// @step: get the pkcs7 signature from the metadata service
signature, err := client.GetDynamicData("/instance-identity/pkcs7")
if err != nil {
return []byte{}, err
}
// @step: construct request for the request
request := &Request{
Document: []byte(signature),
}
return json.Marshal(request)
}()
if err != nil {
errs <- err
return
}
doneCh <- encoded
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case err := <-errs:
return nil, err
case req := <-doneCh:
return req, nil
}
}

View File

@ -1,20 +0,0 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library(
name = "go_default_library",
srcs = [
"client.go",
"helper.go",
"type.go",
],
importpath = "k8s.io/kops/node-authorizer/pkg/client",
visibility = ["//visibility:public"],
deps = [
"//node-authorizer/pkg/authorizers/alwaysallow:go_default_library",
"//node-authorizer/pkg/authorizers/aws:go_default_library",
"//node-authorizer/pkg/server:go_default_library",
"//node-authorizer/pkg/utils:go_default_library",
"//vendor/go.uber.org/zap:go_default_library",
"//vendor/k8s.io/client-go/tools/clientcmd/api/v1:go_default_library",
],
)

View File

@ -1,130 +0,0 @@
/*
Copyright 2019 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package client
import (
"context"
"errors"
"fmt"
"io/ioutil"
"net/url"
"os"
"path/filepath"
"k8s.io/kops/node-authorizer/pkg/utils"
"go.uber.org/zap"
)
// New returns a client verifier
func New(config *Config) error {
if err := config.IsValid(); err != nil {
return err
}
utils.Logger.Info("attempting to acquire a node bootstrap configuration",
zap.Duration("timeout", config.Timeout),
zap.String("authorizer", config.Authorizer),
zap.String("kubeapi-url", config.KubeAPI),
zap.String("kubeconfig", config.KubeConfigPath),
zap.String("registration-url", config.NodeURL))
// @step: if we have a kubeconfig already we can skip it
if utils.FileExists(config.KubeConfigPath) {
utils.Logger.Info("skipping the client authorization as kubeconfig found",
zap.String("kubeconfig", config.KubeConfigPath))
return nil
}
// @step: create the verifier
verifier, err := newNodeVerifier(config.Authorizer)
if err != nil {
return err
}
hc, err := makeHTTPClient(config)
if err != nil {
return err
}
// @step: attempt to get the token
err = utils.Retry(context.TODO(), config.Interval, config.Timeout, func() error {
token, err := makeTokenRequest(context.TODO(), hc, verifier, config)
if err != nil {
utils.Logger.Error("failed to request bootstrap token from node authorizer service", zap.Error(err))
return err
}
// @check if we have been refused
if !token.IsAllowed() {
utils.Logger.Error("node has been refused registration",
zap.String("reason", token.Status.Reason))
os.Exit(1)
}
utils.Logger.Info("successfully requested bootstrap token from service")
utils.Logger.Info("attempting to write bootstrap configuration")
kubeconfig, err := makeKubeconfig(context.TODO(), config, token.Status.Token)
if err != nil {
utils.Logger.Error("failed to generate the bootstrap token configuration",
zap.String("path", config.KubeConfigPath),
zap.Error(err))
return err
}
dirname := filepath.Dir(config.KubeConfigPath)
if err := os.MkdirAll(dirname, os.FileMode(0770)); err != nil {
return err
}
return ioutil.WriteFile(config.KubeConfigPath, kubeconfig, os.FileMode(0640))
})
if err != nil {
fmt.Fprintf(os.Stderr, "[error] %s\n", err)
os.Exit(1)
}
utils.Logger.Info("successfully wrote bootstrap configuration")
return nil
}
// IsValid validates the client configuration
func (c *Config) IsValid() error {
if c.Authorizer == "" {
return errors.New("no authorizer")
}
if c.KubeAPI == "" {
return errors.New("no kubeapi url")
}
if c.KubeConfigPath == "" {
return errors.New("no bootstrap kubeconfig path")
}
if c.Timeout <= 0 {
return errors.New("timeout must be greater than zero")
}
if _, err := url.Parse(c.KubeAPI); err != nil {
return fmt.Errorf("invalid kubeapi url: %s", err)
}
return nil
}

View File

@ -1,172 +0,0 @@
/*
Copyright 2019 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package client
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net"
"net/http"
"os"
"strings"
"time"
"k8s.io/kops/node-authorizer/pkg/authorizers/alwaysallow"
"k8s.io/kops/node-authorizer/pkg/authorizers/aws"
"k8s.io/kops/node-authorizer/pkg/server"
v1 "k8s.io/client-go/tools/clientcmd/api/v1"
)
// makeHTTPClient is responsible for making a http client
func makeHTTPClient(config *Config) (*http.Client, error) {
tlsConfig := &tls.Config{}
if config.TLSClientCAPath != "" {
ca, err := ioutil.ReadFile(config.TLSClientCAPath)
if err != nil {
return nil, err
}
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(ca)
tlsConfig.RootCAs = caCertPool
}
if config.TLSCertPath != "" && config.TLSPrivateKeyPath != "" {
certs, err := tls.LoadX509KeyPair(config.TLSCertPath, config.TLSPrivateKeyPath)
if err != nil {
return nil, err
}
tlsConfig.Certificates = []tls.Certificate{certs}
}
return &http.Client{
Timeout: 10 * time.Second,
Transport: &http.Transport{
Dial: (&net.Dialer{Timeout: 5 * time.Second}).Dial,
TLSHandshakeTimeout: 5 * time.Second,
TLSClientConfig: tlsConfig,
},
}, nil
}
// makeKubeconfig is responsible for generating a bootstrap config
func makeKubeconfig(ctx context.Context, config *Config, token string) ([]byte, error) {
// @step: load the certificate authority
content, err := ioutil.ReadFile(config.TLSClientCAPath)
if err != nil {
return []byte{}, err
}
// @step: we need to write out the kubeconfig
name := "bootstrap-context"
clusterName := "cluster"
cfg := &v1.Config{
APIVersion: "v1",
Kind: "Config",
AuthInfos: []v1.NamedAuthInfo{
{
Name: name,
AuthInfo: v1.AuthInfo{
Token: token,
},
},
},
Clusters: []v1.NamedCluster{
{
Name: clusterName,
Cluster: v1.Cluster{
Server: config.KubeAPI,
CertificateAuthorityData: content,
},
},
},
Contexts: []v1.NamedContext{
{
Name: name,
Context: v1.Context{
Cluster: clusterName,
AuthInfo: name,
},
},
},
CurrentContext: name,
}
return json.MarshalIndent(cfg, "", " ")
}
// makeTokenRequest makes a request for a bootstrap token
func makeTokenRequest(ctx context.Context, client *http.Client, verifier server.Verifier, config *Config) (*server.NodeRegistration, error) {
registration := &server.NodeRegistration{}
// @step: create the request payload
req, err := verifier.VerifyIdentity(ctx)
if err != nil {
return nil, err
}
hostname, err := getHostname()
if err != nil {
return nil, err
}
// @step: make the request to the node-authozier
url := fmt.Sprintf("%s/authorize/%s", strings.TrimSuffix(config.NodeURL, "/authorize"), hostname)
resp, err := client.Post(url, "application/json", bytes.NewReader(req))
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("invalid response code: %d", resp.StatusCode)
}
// @step: read in the response and decode
content, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if err := json.Unmarshal(content, registration); err != nil {
return nil, err
}
return registration, nil
}
// getHostname attempts to return the instance hostname
func getHostname() (string, error) {
return os.Hostname()
}
// newNodeVerifier returns a new verifier
func newNodeVerifier(name string) (server.Verifier, error) {
switch name {
case "aws":
return aws.NewVerifier()
case "alwaysallow":
return alwaysallow.NewVerifier()
}
return nil, errors.New("unsupported authorizer")
}

View File

@ -1,43 +0,0 @@
/*
Copyright 2019 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package client
import "time"
// Config is the configuration for the service
type Config struct {
// Authorizer is the name of the verifier to use
Authorizer string
// Interval is the pause between failed attempts
Interval time.Duration
// KubeAPI is the url for the kubernetes api
KubeAPI string
// KubeConfigPath is the location to write the bootstrap token config
KubeConfigPath string
// NodeURL is the url for the node authozier service
NodeURL string
// Timeout is the time will are willing to wait
Timeout time.Duration
// TLSCertPath is the path to the server TLS certificate
TLSCertPath string
// TLSClientCAPath is the path to a certificate authority
TLSClientCAPath string
// TLSPrivateKeyPath is the path to the private key
TLSPrivateKeyPath string
// Verbose indicate verbose logging
Verbose bool
}

View File

@ -1,27 +0,0 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library(
name = "go_default_library",
srcs = [
"admission.go",
"handlers.go",
"helper.go",
"metrics.go",
"middleware.go",
"server.go",
"token.go",
"types.go",
],
importpath = "k8s.io/kops/node-authorizer/pkg/server",
visibility = ["//visibility:public"],
deps = [
"//node-authorizer/pkg/utils:go_default_library",
"//vendor/github.com/gorilla/mux:go_default_library",
"//vendor/github.com/prometheus/client_golang/prometheus:go_default_library",
"//vendor/github.com/prometheus/client_golang/prometheus/promhttp:go_default_library",
"//vendor/go.uber.org/zap:go_default_library",
"//vendor/k8s.io/api/core/v1:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/apis/meta/v1:go_default_library",
"//vendor/k8s.io/client-go/kubernetes:go_default_library",
],
)

View File

@ -1,219 +0,0 @@
/*
Copyright 2019 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package server
import (
"context"
"fmt"
"time"
"k8s.io/kops/node-authorizer/pkg/utils"
"go.uber.org/zap"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)
var (
// CheckRegistration indicates we should validate the node is not regestered
CheckRegistration = "verify-registration"
)
// authorizeNodeRequest is responsible for handling the incoming authorization request
func (n *NodeAuthorizer) authorizeNodeRequest(ctx context.Context, request *NodeRegistration) error {
doneCh := make(chan error)
// @step: create a context to run under
ctx, cancel := context.WithTimeout(ctx, n.config.AuthorizationTimeout)
defer cancel()
// @step: background the request and wait for either a timeout or a token
go func() {
doneCh <- func() error {
// @step: check if the node request is authorized
if err := n.safelyAuthorizeNode(ctx, request); err != nil {
return err
}
if request.IsAllowed() {
return n.safelyProvisionBootstrapToken(ctx, request)
}
return nil
}()
}()
// @step: we either wait for the context to timeout or cancel, or we receive a done signal
select {
case <-ctx.Done():
if err := ctx.Err(); err != nil {
utils.Logger.Error("operation has either timed out or been cancelled",
zap.String("client", request.Spec.RemoteAddr),
zap.String("node", request.Spec.NodeName),
zap.Error(err))
}
return nil
case err := <-doneCh:
if err != nil {
utils.Logger.Error("failed to provision a bootstrap token",
zap.String("client", request.Spec.RemoteAddr),
zap.String("node", request.Spec.NodeName),
zap.Error(err))
}
}
if !request.IsAllowed() {
utils.Logger.Error("the node has been denied authorization",
zap.String("client", request.Spec.RemoteAddr),
zap.String("node", request.Spec.NodeName),
zap.String("reason", request.Status.Reason))
nodeAuthorizationMetric.WithLabelValues("denied").Inc()
return nil
}
utils.Logger.Info("node has been authorized access",
zap.String("client", request.Spec.RemoteAddr),
zap.String("node", request.Spec.NodeName))
nodeAuthorizationMetric.WithLabelValues("allowed").Inc()
return nil
}
// safelyAuthorizeNode checks if the request is permitted
func (n *NodeAuthorizer) safelyAuthorizeNode(ctx context.Context, request *NodeRegistration) error {
// @step: attempt to authorize the request
now := time.Now()
if err := n.authorizer.Authorize(ctx, request); err != nil {
authorizerErrorMetric.Inc()
return err
}
authorizerLatencyMetric.Observe(time.Since(now).Seconds())
// @check if the node is registered already
if n.config.UseFeature(CheckRegistration) {
if found, err := isNodeRegistered(ctx, n.client, request.Spec.NodeName); err != nil {
return fmt.Errorf("unable to check node registration status: %s", err)
} else if found {
request.Deny(fmt.Sprintf("node %s already registered", request.Spec.NodeName))
}
}
return nil
}
// safelyProvisionBootstrapToken is responsible for generating a bootstrap token for us
func (n *NodeAuthorizer) safelyProvisionBootstrapToken(ctx context.Context, request *NodeRegistration) error {
maxInterval := 500 * time.Millisecond
maxTime := 10 * time.Second
usages := []string{"authentication", "signing"}
now := time.Now()
if err := utils.Retry(ctx, maxInterval, maxTime, func() error {
token, err := n.createToken(n.config.TokenDuration, usages)
if err != nil {
return err
}
request.Status.Token = token.String()
return err
}); err != nil {
return err
}
tokenLatencyMetric.Observe(time.Since(now).Seconds())
return nil
}
// createToken generates a token for the instance
func (n *NodeAuthorizer) createToken(expiration time.Duration, usages []string) (*Token, error) {
var err error
var token *Token
ctx := context.TODO()
err = utils.Retry(ctx, 2000*time.Millisecond, 10*time.Second, func() error {
// @step: generate a random token for them
if token, err = NewToken(); err != nil {
return err
}
// @step: check if the token already exist, remote but a possibility
if found, err := n.hasToken(ctx, token); err != nil {
return err
} else if found {
return fmt.Errorf("duplicate token found: %s, skipping", token.ID)
}
// @step: add the secret to the namespace
v1secret := &v1.Secret{
ObjectMeta: metav1.ObjectMeta{
Name: token.Name(),
Labels: map[string]string{
"name": token.Name(),
},
},
Type: v1.SecretType(secretTypeBootstrapToken),
Data: encodeTokenSecretData(token, usages, expiration),
}
if _, err := n.client.CoreV1().Secrets(tokenNamespace).Create(ctx, v1secret, metav1.CreateOptions{}); err != nil {
return err
}
return nil
})
return token, err
}
// hasToken checks if the tokens already exists
func (n *NodeAuthorizer) hasToken(ctx context.Context, token *Token) (bool, error) {
resp, err := n.client.CoreV1().Secrets(tokenNamespace).List(ctx, metav1.ListOptions{
LabelSelector: "name=" + token.Name(),
Limit: 1,
})
if err != nil {
return false, err
}
return len(resp.Items) > 0, nil
}
// encodeTokenSecretData takes the token discovery object and an optional duration and returns the .Data for the Secret
func encodeTokenSecretData(token *Token, usages []string, ttl time.Duration) map[string][]byte {
data := map[string][]byte{
bootstrapTokenIDKey: []byte(token.ID),
bootstrapTokenSecretKey: []byte(token.Secret),
}
if ttl > 0 {
expire := time.Now().Add(ttl).Format(time.RFC3339)
data[bootstrapTokenExpirationKey] = []byte(expire)
}
for _, usage := range usages {
data[bootstrapTokenUsagePrefix+usage] = []byte("true")
}
return data
}

View File

@ -1,84 +0,0 @@
/*
Copyright 2019 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package server
import (
"encoding/json"
"io/ioutil"
"net/http"
"go.uber.org/zap"
"k8s.io/kops/node-authorizer/pkg/utils"
"github.com/gorilla/mux"
)
// authorizeHandler is responsible for handling the authorization requests
func (n *NodeAuthorizer) authorizeHandler(w http.ResponseWriter, r *http.Request) {
err := func() error {
// @check we have a body to read in
if r.Body == nil {
w.WriteHeader(http.StatusBadRequest)
return nil
}
// @step: read in the request body
content, err := ioutil.ReadAll(r.Body)
defer r.Body.Close()
if err != nil {
return err
}
address, err := getClientAddress(r.RemoteAddr)
if err != nil {
return err
}
// @step: construct the node registration request
req := &NodeRegistration{
Spec: NodeRegistrationSpec{
NodeName: mux.Vars(r)["name"],
RemoteAddr: address,
Request: content,
},
}
// @step: attempt to authorise the request
if err := n.authorizeNodeRequest(r.Context(), req); err != nil {
return err
}
// @check if the node was denied and if so, 403 it
if !req.Status.Allowed {
w.WriteHeader(http.StatusForbidden)
return nil
}
return json.NewEncoder(w).Encode(req)
}()
if err != nil {
utils.Logger.Info("failed to handle node request", zap.Error(err))
w.WriteHeader(http.StatusInternalServerError)
}
}
// healthHandler is responsible for providing health
func (n *NodeAuthorizer) healthHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Node-Authorizer-Version", Version)
w.Write([]byte("OK\n"))
}

View File

@ -1,67 +0,0 @@
/*
Copyright 2019 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package server
import (
"context"
"fmt"
"net"
"time"
"k8s.io/kops/node-authorizer/pkg/utils"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"
)
// getClientAddress returns the client address
func getClientAddress(address string) (string, error) {
host, _, err := net.SplitHostPort(address)
return host, err
}
// isNodeRegistered checks if the node is already registered with kubernetes
func isNodeRegistered(ctx context.Context, client kubernetes.Interface, nodename string) (bool, error) {
var registered bool
maxInterval := 1000 * time.Millisecond
maxTime := 10 * time.Second
// @lets try multiple times
err := utils.Retry(ctx, maxInterval, maxTime, func() error {
// @step: get a lit of nodes
nodes, err := client.CoreV1().Nodes().List(ctx, metav1.ListOptions{
LabelSelector: fmt.Sprintf("kubernetes.io/hostname=%s", nodename),
})
if err != nil {
return err
}
// @check if we found a registered node
if len(nodes.Items) > 0 {
registered = true
}
return nil
})
if err != nil {
return false, err
}
return registered, nil
}

View File

@ -1,63 +0,0 @@
/*
Copyright 2019 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package server
import (
"github.com/prometheus/client_golang/prometheus"
)
var (
authorizerErrorMetric = prometheus.NewCounter(
prometheus.CounterOpts{
Name: "node_authorizer_error_counter",
Help: "The number of errors encountered by the authorizer",
},
)
authorizeRequestLatencyMetric = prometheus.NewSummary(
prometheus.SummaryOpts{
Name: "node_request_latency_seconds",
Help: "A summary of the latency for incoming node authorization requests",
},
)
authorizerLatencyMetric = prometheus.NewSummary(
prometheus.SummaryOpts{
Name: "node_authorizer_latency_seconds",
Help: "A summary of the latency for the authorizer in seconds",
},
)
nodeAuthorizationMetric = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "node_authorizer_counter",
Help: "A counter of number node authorizations broken down by denied and allowed",
},
[]string{"action"},
)
tokenLatencyMetric = prometheus.NewSummary(
prometheus.SummaryOpts{
Name: "token_latency_seconds",
Help: "A summary of the latency experienced when creating bootstrap tokens in seconds",
},
)
)
func init() {
prometheus.MustRegister(authorizeRequestLatencyMetric)
prometheus.MustRegister(authorizerErrorMetric)
prometheus.MustRegister(authorizerLatencyMetric)
prometheus.MustRegister(nodeAuthorizationMetric)
prometheus.MustRegister(tokenLatencyMetric)
}

View File

@ -1,69 +0,0 @@
/*
Copyright 2019 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package server
import (
"fmt"
"net/http"
"k8s.io/kops/node-authorizer/pkg/utils"
"go.uber.org/zap"
)
// recovery is responsible for ensuring we don't exit on a panic
func recovery(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
defer func() {
if err := recover(); err != nil {
w.WriteHeader(http.StatusInternalServerError)
utils.Logger.Error("failed to handle request, threw exception",
zap.String("error", fmt.Sprintf("%v", err)))
}
}()
next.ServeHTTP(w, req)
})
}
// authorized is responsible for validating the client certificate
func authorized(next http.HandlerFunc, commonName string, requireAuth bool) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if requireAuth {
found := func() bool {
for _, x := range r.TLS.PeerCertificates {
if x.Subject.CommonName == commonName {
return true
}
}
return false
}()
if !found {
utils.Logger.Error("invalid client certificate",
zap.String("client", r.RemoteAddr))
w.WriteHeader(http.StatusForbidden)
return
}
}
next.ServeHTTP(w, r)
})
}

View File

@ -1,132 +0,0 @@
/*
Copyright 2019 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package server
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"net/http"
"os"
"os/signal"
"syscall"
"k8s.io/kops/node-authorizer/pkg/utils"
"github.com/gorilla/mux"
"github.com/prometheus/client_golang/prometheus/promhttp"
"go.uber.org/zap"
"k8s.io/client-go/kubernetes"
)
const (
// Version is the server version
Version = "v0.0.4"
// the namespace to place the secrets
tokenNamespace = "kube-system"
)
// NodeAuthorizer retains the authorizer state
type NodeAuthorizer struct {
// authorizer is a collection of authorizers
authorizer Authorizer
// client is the kubernetes api client
client kubernetes.Interface
// config is the configuration
config *Config
}
// New creates and returns a node authorizer
func New(config *Config, authorizer Authorizer) (*NodeAuthorizer, error) {
utils.Logger.Info("starting the node authorization service",
zap.String("listen", config.Listen),
zap.String("version", Version))
if authorizer == nil {
return nil, errors.New("no authorizer")
}
if err := config.IsValid(); err != nil {
return nil, fmt.Errorf("configuration error: %s", err)
}
return &NodeAuthorizer{
authorizer: authorizer,
config: config,
}, nil
}
// Run is responsible for starting the node authorizer service
func (n *NodeAuthorizer) Run() error {
// @step: create the kubernetes client
client, err := utils.GetKubernetesClient()
if err != nil {
return err
}
n.client = client
// @step: create the http service
server := &http.Server{
Addr: n.config.Listen,
TLSConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
PreferServerCipherSuites: true,
},
}
// @step: are we using mutual tls?
if n.config.TLSClientCAPath != "" {
// a client certificate is not required, but if given we need to verify it
server.TLSConfig.ClientAuth = tls.VerifyClientCertIfGiven
caCert, err := ioutil.ReadFile(n.config.TLSClientCAPath)
if err != nil {
return err
}
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
server.TLSConfig.ClientCAs = caCertPool
}
// @step: add the routing
r := mux.NewRouter()
r.Handle("/authorize/{name}", authorized(n.authorizeHandler, n.config.ClientCommonName, n.useMutualTLS())).Methods(http.MethodPost)
r.Handle("/metrics", promhttp.Handler()).Methods(http.MethodGet)
r.HandleFunc("/health", n.healthHandler).Methods(http.MethodGet)
server.Handler = recovery(r)
// @step: wait for either an error or a termination signal
errs := make(chan error, 2)
go func() {
errs <- server.ListenAndServeTLS(n.config.TLSCertPath, n.config.TLSPrivateKeyPath)
}()
go func() {
c := make(chan os.Signal, 1)
signal.Notify(c, syscall.SIGINT)
errs <- fmt.Errorf("received termination signal: %s", <-c)
}()
return <-errs
}
// useMutualTLS checks if we are using mutual tls
func (n *NodeAuthorizer) useMutualTLS() bool {
return n.config.TLSClientCAPath != ""
}

View File

@ -1,55 +0,0 @@
/*
Copyright 2019 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package server
import (
"fmt"
"k8s.io/kops/node-authorizer/pkg/utils"
)
const (
// the size of the id token
tokenIDBytes = 3
// the size of the secret
tokenSecretBytes = 8
)
// NewToken creates and returns a new token
func NewToken() (*Token, error) {
id, err := utils.RandomBytes(tokenIDBytes)
if err != nil {
return nil, err
}
secret, err := utils.RandomBytes(tokenSecretBytes)
if err != nil {
return nil, err
}
return &Token{ID: id, Secret: secret}, nil
}
// Name returns the secret name
func (t *Token) Name() string {
return fmt.Sprintf("%s%s", bootstrapTokenSecretPrefix, t.ID)
}
// String returns the encoded secret
func (t *Token) String() string {
return fmt.Sprintf("%s.%s", t.ID, t.Secret)
}

View File

@ -1,191 +0,0 @@
/*
Copyright 2019 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package server
import (
"context"
"errors"
"strconv"
"strings"
"time"
v1 "k8s.io/api/core/v1"
)
const (
// bootstrapTokenIDKey is the id of this token. This can be transmitted in the
// clear and encoded in the name of the secret. It must be a random 6 character
// string that matches the regexp `^([a-z0-9]{6})$`. Required.
bootstrapTokenIDKey = "token-id"
// bootstrapTokenSecretKey is the actual secret. It must be a random 16 character
// string that matches the regexp `^([a-z0-9]{16})$`. Required.
bootstrapTokenSecretKey = "token-secret"
// bootstrapTokenExpirationKey is when this token should be expired and no
// longer used. A controller will delete this resource after this time. This
// is an absolute UTC time using RFC3339. If this cannot be parsed, the token
// should be considered invalid. Optional.
bootstrapTokenExpirationKey = "expiration"
// bootstrapTokenUsagePrefix is the prefix for the other usage constants that specifies different
// functions of a bootstrap token
bootstrapTokenUsagePrefix = "usage-bootstrap-"
// bootstrapTokenSecretPrefix is the prefix for bootstrap token names.
// Bootstrap tokens secrets must be named in the form
// `bootstrap-token-<token-id>`. This is the prefix to be used before the
// token ID.
bootstrapTokenSecretPrefix = "bootstrap-token-"
// secretTypeBootstrapToken is used during the automated bootstrap process (first
// implemented by kubeadm). It stores tokens that are used to sign well known
// ConfigMaps. They may also eventually be used for authentication.
secretTypeBootstrapToken v1.SecretType = "bootstrap.kubernetes.io/token"
)
// Config is the configuration for the service
type Config struct {
// AuthorizationTimeout is the max duration for a authorization
AuthorizationTimeout time.Duration
// ClusterTag is the cloud tag key used to identity the cluster
ClusterTag string
// Features is arbitrary feature set for a authorizer
Features []string
// EnableVerbose indicate verbose logging
EnableVerbose bool
// ClientCommonName is the common name on the client certificate if mutual tls is enabled
ClientCommonName string
// ClusterName is the name of the kubernetes cluster
ClusterName string
// Listen is the interacted to bind to
Listen string
// TokenDuration is the expiration of a bootstrap token
TokenDuration time.Duration
// TLSCertPath is the path to the server TLS certificate
TLSCertPath string
// TLSClientCAPath is the path to a certificate authority
TLSClientCAPath string
// TLSPrivateKeyPath is the path to the private key
TLSPrivateKeyPath string
}
// UseFeature indicates a feature is in use
func (c *Config) UseFeature(name string) bool {
if len(c.Features) <= 0 {
return false
}
for _, x := range c.Features {
items := strings.Split(x, "=")
if items[0] != name {
continue
}
if len(items) == 1 {
return true
}
v, err := strconv.ParseBool(items[1])
if err != nil {
return false
}
return v
}
return false
}
// NodeRegistration is an incoming request
type NodeRegistration struct {
// Spec is the request specification
Spec NodeRegistrationSpec
// Status is the result of a admission
Status NodeRegistrationStatus
}
// Deny marks the request as denied and adds the reason why
func (n *NodeRegistration) Deny(reason string) {
n.Status.Allowed = false
n.Status.Reason = reason
}
// IsAllowed checks if the request if allowed
func (n *NodeRegistration) IsAllowed() bool {
return n.Status.Allowed
}
// Token defines a bootstrap token
type Token struct {
// ID is the id of the token
ID string
// Secret is the secret of the token
Secret string
}
// NodeRegistrationSpec is the node request specification
type NodeRegistrationSpec struct {
// NodeName is the name of the node
NodeName string
// RemoteAddr is the address of the requester
RemoteAddr string
// Request is the request body
Request []byte
}
// NodeRegistrationStatus is result of a authorization
type NodeRegistrationStatus struct {
// Allowed indicates the request is permitted
Allowed bool
// Token is the bootstrap token
Token string
// Reason is the reason for the error if any
Reason string
}
// Authorizer is the generic means to authorize the incoming node request
type Authorizer interface {
// Admit is responsible for checking if the request is permitted
Authorize(context.Context, *NodeRegistration) error
// Close provides a signal to close of resources
Close() error
// Name returns the name of the authorizer
Name() string
}
// Verifier is the client side of authorizer
type Verifier interface {
// VerifyIdentity is responsible for constructing the parameters for a request
VerifyIdentity(context.Context) ([]byte, error)
}
// IsValid checks the configuration options
func (c *Config) IsValid() error {
if c.ClusterName == "" {
return errors.New("no cluster name")
}
if c.Listen == "" {
return errors.New("no interface to bind specified")
}
if c.TLSCertPath == "" {
return errors.New("no tls certificate")
}
if c.TLSPrivateKeyPath == "" {
return errors.New("no private key")
}
return nil
}

View File

@ -1,18 +0,0 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library(
name = "go_default_library",
srcs = [
"logger.go",
"misc.go",
"retry.go",
],
importpath = "k8s.io/kops/node-authorizer/pkg/utils",
visibility = ["//visibility:public"],
deps = [
"//vendor/github.com/jpillora/backoff:go_default_library",
"//vendor/go.uber.org/zap:go_default_library",
"//vendor/k8s.io/client-go/kubernetes:go_default_library",
"//vendor/k8s.io/client-go/rest:go_default_library",
],
)

View File

@ -1,34 +0,0 @@
/*
Copyright 2019 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package utils
import (
"go.uber.org/zap"
)
var (
// Logger is the default logger
Logger *zap.Logger
)
func init() {
l, err := zap.NewProduction()
if err != nil {
panic("failed to create the default logger: " + err.Error())
}
Logger = l
}

View File

@ -1,56 +0,0 @@
/*
Copyright 2019 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package utils
import (
crypto_rand "crypto/rand"
"encoding/hex"
"os"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
)
// GetKubernetesClient returns a kubernetes api client for us
func GetKubernetesClient() (kubernetes.Interface, error) {
config, err := rest.InClusterConfig()
if err != nil {
return nil, err
}
return kubernetes.NewForConfig(config)
}
// FileExists checks if the file exists
func FileExists(filename string) bool {
if _, err := os.Stat(filename); err != nil {
return false
}
return true
}
// RandomBytes generates some random bytes
func RandomBytes(length int) (string, error) {
b := make([]byte, length)
_, err := crypto_rand.Read(b)
if err != nil {
return "", err
}
return hex.EncodeToString(b), nil
}

View File

@ -1,50 +0,0 @@
/*
Copyright 2019 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package utils
import (
"context"
"errors"
"time"
"github.com/jpillora/backoff"
"go.uber.org/zap"
)
// Retry attempts to perform an operation for x time
func Retry(ctx context.Context, interval, timeout time.Duration, fn func() error) (err error) {
j := &backoff.Backoff{Min: interval / 2, Max: interval, Factor: 1, Jitter: true}
to := time.NewTimer(timeout)
defer to.Stop()
for {
if err = fn(); err == nil {
return nil
}
Logger.Error("operation failed to execute", zap.Error(err))
// @check if the context has been cancelled or wait
select {
case <-ctx.Done():
return ctx.Err()
case <-to.C:
return errors.New("operation timed out")
case <-time.After(j.Duration()):
}
}
}

View File

@ -27,7 +27,6 @@ go_library(
"logrotate.go",
"manifests.go",
"miscutils.go",
"node_authorizer.go",
"ntp.go",
"packages.go",
"protokube.go",

View File

@ -374,20 +374,6 @@ func (c *NodeupModelContext) UseKopsControllerForNodeBootstrap() bool {
return model.UseKopsControllerForNodeBootstrap(c.Cluster)
}
// UseNodeAuthorization checks if have a node authorization policy
func (c *NodeupModelContext) UseNodeAuthorization() bool {
return c.Cluster.Spec.NodeAuthorization != nil
}
// UseNodeAuthorizer checks if node authorization is enabled
func (c *NodeupModelContext) UseNodeAuthorizer() bool {
if !c.UseNodeAuthorization() || !c.UseBootstrapTokens() {
return false
}
return c.Cluster.Spec.NodeAuthorization.NodeAuthorizer != nil
}
// UsesSecondaryIP checks if the CNI in use attaches secondary interfaces to the host.
func (c *NodeupModelContext) UsesSecondaryIP() bool {
return (c.Cluster.Spec.Networking.CNI != nil && c.Cluster.Spec.Networking.CNI.UsesSecondaryIP) || c.Cluster.Spec.Networking.AmazonVPC != nil || c.Cluster.Spec.Networking.LyftVPC != nil ||

View File

@ -1,142 +0,0 @@
/*
Copyright 2019 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package model
import (
"fmt"
"path"
"path/filepath"
"strings"
"time"
"k8s.io/kops/pkg/systemd"
"k8s.io/kops/upup/pkg/fi"
"k8s.io/kops/upup/pkg/fi/nodeup/nodetasks"
"k8s.io/klog/v2"
)
// NodeAuthorizationBuilder is responsible for node authorization
type NodeAuthorizationBuilder struct {
*NodeupModelContext
}
var _ fi.ModelBuilder = &NodeAuthorizationBuilder{}
// Build is responsible for handling the node authorization client
func (b *NodeAuthorizationBuilder) Build(c *fi.ModelBuilderContext) error {
// @check if we are a master and download the certificates for the node-authozier
if b.UseBootstrapTokens() && b.IsMaster {
name := "node-authorizer"
// creates /src/kubernetes/node-authorizer/{tls,tls-key}.pem
if err := b.BuildCertificatePairTask(c, name, name, "tls", nil); err != nil {
return err
}
// creates /src/kubernetes/node-authorizer/ca.pem
if err := b.BuildCertificateTask(c, fi.CertificateIDCA, filepath.Join(name, "ca.pem"), nil); err != nil {
return err
}
}
authorizerDir := "node-authorizer"
// @check if bootstrap tokens are enabled and download client certificates for nodes
if b.UseBootstrapTokens() && !b.IsMaster {
if err := b.BuildCertificatePairTask(c, "node-authorizer-client", authorizerDir, "tls", nil); err != nil {
return err
}
if err := b.BuildCertificateTask(c, fi.CertificateIDCA, authorizerDir+"/ca.pem", nil); err != nil {
return err
}
}
klog.V(3).Infof("bootstrap: %t, node authorization: %t, node authorizer: %t", b.UseBootstrapTokens(),
b.UseNodeAuthorization(), b.UseNodeAuthorizer())
// @check if the NodeAuthorizer provision the client service for nodes
if b.UseNodeAuthorizer() && !b.IsMaster {
na := b.Cluster.Spec.NodeAuthorization.NodeAuthorizer
klog.V(3).Infof("node authorization service is enabled, authorizer: %s", na.Authorizer)
klog.V(3).Infof("node authorization url: %s", na.NodeURL)
// @step: create the systemd unit to run the node authorization client
man := &systemd.Manifest{}
man.Set("Unit", "Description", "Node Authorization Client")
man.Set("Unit", "Documentation", "https://github.com/kubernetes/kops")
man.Set("Unit", "Before", kubeletService)
switch b.Cluster.Spec.ContainerRuntime {
case "docker":
man.Set("Unit", "After", "docker.service")
case "containerd":
man.Set("Unit", "After", "containerd.service")
default:
klog.Warningf("unknown container runtime %q", b.Cluster.Spec.ContainerRuntime)
}
clientCert := filepath.Join(b.PathSrvKubernetes(), authorizerDir, "tls.pem")
man.Set("Service", "Type", "oneshot")
man.Set("Service", "RemainAfterExit", "yes")
man.Set("Service", "EnvironmentFile", "/etc/environment")
man.Set("Service", "ExecStartPre", "/bin/mkdir -p /var/lib/kubelet")
man.Set("Service", "ExecStartPre", "/usr/bin/docker pull "+na.Image)
man.Set("Service", "ExecStartPre", "/bin/bash -c 'while [ ! -f "+clientCert+" ]; do sleep 5; done; sleep 5'")
interval := 10 * time.Second
if na.Interval != nil {
interval = na.Interval.Duration
}
timeout := 5 * time.Minute
if na.Timeout != nil {
timeout = na.Timeout.Duration
}
// @node: using a string array just to make it easier to read
dockerCmd := []string{
"/usr/bin/docker",
"run",
"--rm",
"--net=host",
"--volume=" + path.Dir(b.KubeletBootstrapKubeconfig()) + ":/var/lib/kubelet",
"--volume=" + filepath.Join(b.PathSrvKubernetes(), authorizerDir) + ":/config:ro",
na.Image,
"client",
"--authorizer=" + na.Authorizer,
"--interval=" + interval.String(),
"--kubeapi-url=" + fmt.Sprintf("https://%s", b.Cluster.Spec.MasterInternalName),
"--kubeconfig=" + b.KubeletBootstrapKubeconfig(),
"--node-url=" + na.NodeURL,
"--timeout=" + timeout.String(),
"--tls-client-ca=/config/ca.pem",
"--tls-cert=/config/tls.pem",
"--tls-private-key=/config/tls-key.pem",
}
man.Set("Service", "ExecStart", strings.Join(dockerCmd, " "))
// @step: add the service task
c.AddTask(&nodetasks.Service{
Name: "node-authorizer.service",
Definition: s(man.Render()),
Enabled: fi.Bool(true),
ManageState: fi.Bool(true),
Running: fi.Bool(true),
SmartRestart: fi.Bool(true),
})
}
return nil
}

View File

@ -148,7 +148,7 @@ func validateClusterSpec(spec *kops.ClusterSpec, c *kops.Cluster, fieldPath *fie
}
if spec.NodeAuthorization != nil {
allErrs = append(allErrs, validateNodeAuthorization(spec.NodeAuthorization, c, fieldPath.Child("nodeAuthorization"))...)
allErrs = append(allErrs, field.Forbidden(fieldPath.Child("nodeAuthorization"), "NodeAuthorization must be empty. The functionality has been reimplemented and is enabled on kubernetes >= 1.19.0."))
}
if spec.ClusterAutoscaler != nil {
@ -507,41 +507,6 @@ func validateKubelet(k *kops.KubeletConfigSpec, c *kops.Cluster, kubeletPath *fi
return allErrs
}
func validateNodeAuthorization(n *kops.NodeAuthorizationSpec, c *kops.Cluster, fldPath *field.Path) field.ErrorList {
allErrs := field.ErrorList{}
// @check the feature gate is enabled for this
if !featureflag.EnableNodeAuthorization.Enabled() {
return field.ErrorList{field.Forbidden(fldPath, "node authorization is experimental feature; set `export KOPS_FEATURE_FLAGS=EnableNodeAuthorization`")}
}
authorizerPath := fldPath.Child("nodeAuthorizer")
if c.Spec.NodeAuthorization.NodeAuthorizer == nil {
allErrs = append(allErrs, field.Required(authorizerPath, "no node authorization policy has been set"))
} else {
if c.Spec.NodeAuthorization.NodeAuthorizer.Port < 0 || n.NodeAuthorizer.Port >= 65535 {
allErrs = append(allErrs, field.Invalid(authorizerPath.Child("port"), n.NodeAuthorizer.Port, "invalid port"))
}
if c.Spec.NodeAuthorization.NodeAuthorizer.Timeout != nil && n.NodeAuthorizer.Timeout.Duration <= 0 {
allErrs = append(allErrs, field.Invalid(authorizerPath.Child("timeout"), n.NodeAuthorizer.Timeout, "must be greater than zero"))
}
if c.Spec.NodeAuthorization.NodeAuthorizer.TokenTTL != nil && n.NodeAuthorizer.TokenTTL.Duration < 0 {
allErrs = append(allErrs, field.Invalid(authorizerPath.Child("tokenTTL"), n.NodeAuthorizer.TokenTTL, "must be greater than or equal to zero"))
}
// @question: we could probably just default these settings in the model when the node-authorizer is enabled??
if c.Spec.KubeAPIServer == nil {
allErrs = append(allErrs, field.Required(field.NewPath("spec", "kubeAPIServer"), "bootstrap token authentication is not enabled in the kube-apiserver"))
} else if c.Spec.KubeAPIServer.EnableBootstrapAuthToken == nil {
allErrs = append(allErrs, field.Required(field.NewPath("spec", "kubeAPIServer", "enableBootstrapAuthToken"), "kube-apiserver has not been configured to use bootstrap tokens"))
} else if !fi.BoolValue(c.Spec.KubeAPIServer.EnableBootstrapAuthToken) {
allErrs = append(allErrs, field.Forbidden(field.NewPath("spec", "kubeAPIServer", "enableBootstrapAuthToken"), "bootstrap tokens in the kube-apiserver has been disabled"))
}
}
return allErrs
}
func validateNetworking(cluster *kops.Cluster, v *kops.NetworkingSpec, fldPath *field.Path) field.ErrorList {
c := &cluster.Spec
allErrs := field.ErrorList{}

View File

@ -58,8 +58,6 @@ var (
EnableExternalCloudController = New("EnableExternalCloudController", Bool(false))
// EnableExternalDNS enables external DNS
EnableExternalDNS = New("EnableExternalDNS", Bool(false))
// EnableNodeAuthorization enables the node authorization features
EnableNodeAuthorization = New("EnableNodeAuthorization", Bool(false))
// EnableSeparateConfigBase allows a config-base that is different from the state store
EnableSeparateConfigBase = New("EnableSeparateConfigBase", Bool(false))
// ExperimentalClusterDNS allows for setting the kubelet dns flag to experimental values.

View File

@ -293,9 +293,6 @@ func (b *BootstrapScript) Run(c *fi.Context) error {
spec["kubeProxy"] = cs.KubeProxy
spec["kubelet"] = cs.Kubelet
if cs.NodeAuthorization != nil {
spec["nodeAuthorization"] = cs.NodeAuthorization
}
if cs.KubeAPIServer != nil && cs.KubeAPIServer.EnableBootstrapAuthToken != nil {
spec["kubeAPIServer"] = map[string]interface{}{
"enableBootstrapAuthToken": cs.KubeAPIServer.EnableBootstrapAuthToken,

View File

@ -1,14 +0,0 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library(
name = "go_default_library",
srcs = ["options.go"],
importpath = "k8s.io/kops/pkg/model/components/node-authorizer",
visibility = ["//visibility:public"],
deps = [
"//pkg/apis/kops:go_default_library",
"//pkg/model/components:go_default_library",
"//upup/pkg/fi/loader:go_default_library",
"//vendor/k8s.io/apimachinery/pkg/apis/meta/v1:go_default_library",
],
)

View File

@ -1,104 +0,0 @@
/*
Copyright 2019 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package nodeauthorizer
import (
"errors"
"fmt"
"os"
"time"
"k8s.io/kops/pkg/apis/kops"
"k8s.io/kops/pkg/model/components"
"k8s.io/kops/upup/pkg/fi/loader"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)
// OptionsBuilder fills in the default options for the node-authorizer
type OptionsBuilder struct {
Context *components.OptionsContext
}
var _ loader.OptionsBuilder = &OptionsBuilder{}
var (
// DefaultPort is the default port to listen on
DefaultPort = 10443
// DefaultTimeout is the max time we are willing to wait before erroring
DefaultTimeout = &metav1.Duration{Duration: 20 * time.Second}
// DefaultTokenTTL is the default expiration on a bootstrap token
DefaultTokenTTL = &metav1.Duration{Duration: 5 * time.Minute}
)
// BuildOptions generates the configurations used to create node authorizer
func (b *OptionsBuilder) BuildOptions(o interface{}) error {
cs, ok := o.(*kops.ClusterSpec)
if !ok {
return errors.New("expected a ClusterSpec object")
}
if cs.NodeAuthorization != nil {
na := cs.NodeAuthorization
// NodeAuthorizerSpec
if na.NodeAuthorizer != nil {
if na.NodeAuthorizer.Authorizer == "" {
switch kops.CloudProviderID(cs.CloudProvider) {
case kops.CloudProviderAWS:
na.NodeAuthorizer.Authorizer = "aws"
default:
na.NodeAuthorizer.Authorizer = "alwaysallow"
}
}
if na.NodeAuthorizer.Image == "" {
na.NodeAuthorizer.Image = GetNodeAuthorizerImage()
}
if na.NodeAuthorizer.Port == 0 {
na.NodeAuthorizer.Port = DefaultPort
}
if na.NodeAuthorizer.Timeout == nil {
na.NodeAuthorizer.Timeout = DefaultTimeout
}
if na.NodeAuthorizer.TokenTTL == nil {
na.NodeAuthorizer.TokenTTL = DefaultTokenTTL
}
if na.NodeAuthorizer.NodeURL == "" {
na.NodeAuthorizer.NodeURL = fmt.Sprintf("https://node-authorizer-internal.%s:%d", b.Context.ClusterName, na.NodeAuthorizer.Port)
}
if na.NodeAuthorizer.Features == nil {
features := []string{"verify-registration", "verify-ip"}
switch kops.CloudProviderID(cs.CloudProvider) {
case kops.CloudProviderAWS:
features = append(features, "verify-signature")
}
na.NodeAuthorizer.Features = features
}
}
}
return nil
}
// GetNodeAuthorizerImage returns the image to use for the node-authorizer
func GetNodeAuthorizerImage() string {
if v := os.Getenv("NODE_AUTHORIZER_IMAGE"); v != "" {
return v
}
return "quay.io/gambol99/node-authorizer:v0.0.4@sha256:078b948b8207e43d35885f181713de3d3c0491fe40661d198f9bc00136cff271"
}

View File

@ -37,7 +37,6 @@
// upup/models/cloudup/resources/addons/networking.projectcalico.org.canal/k8s-1.15.yaml.template
// upup/models/cloudup/resources/addons/networking.projectcalico.org.canal/k8s-1.16.yaml.template
// upup/models/cloudup/resources/addons/networking.weave/k8s-1.12.yaml.template
// upup/models/cloudup/resources/addons/node-authorizer.addons.k8s.io/k8s-1.12.yaml.template
// upup/models/cloudup/resources/addons/node-termination-handler.aws/k8s-1.11.yaml.template
// upup/models/cloudup/resources/addons/nodelocaldns.addons.k8s.io/k8s-1.12.yaml.template
// upup/models/cloudup/resources/addons/openstack.addons.k8s.io/k8s-1.13.yaml.template
@ -40356,212 +40355,6 @@ func cloudupResourcesAddonsNetworkingWeaveK8s112YamlTemplate() (*asset, error) {
return a, nil
}
var _cloudupResourcesAddonsNodeAuthorizerAddonsK8sIoK8s112YamlTemplate = []byte(`{{- $proxy := .EgressProxy -}}
{{- $na := .NodeAuthorization.NodeAuthorizer -}}
{{- $name := "node-authorizer" -}}
{{- $namespace := "kube-system" -}}
---
apiVersion: v1
kind: ServiceAccount
metadata:
name: {{ $name }}
namespace: {{ $namespace }}
labels:
k8s-app: {{ $name }}
k8s-addon: {{ $name }}.addons.k8s.io
---
kind: ClusterRole
apiVersion: rbac.authorization.k8s.io/v1
metadata:
name: kops:{{ $name }}:nodes-viewer
labels:
k8s-app: {{ $name }}
k8s-addon: {{ $name }}.addons.k8s.io
rules:
- apiGroups:
- "*"
resources:
- nodes
verbs:
- get
- list
---
# permits the node access to create a CSR
kind: ClusterRoleBinding
apiVersion: rbac.authorization.k8s.io/v1
metadata:
name: kops:{{ $name }}:system:bootstrappers
labels:
k8s-app: {{ $name }}
k8s-addon: {{ $name }}.addons.k8s.io
roleRef:
kind: ClusterRole
name: system:node-bootstrapper
apiGroup: rbac.authorization.k8s.io
subjects:
- kind: Group
name: system:bootstrappers
apiGroup: rbac.authorization.k8s.io
---
# indicates to the controller to auto-sign the CSR for this group
kind: ClusterRoleBinding
apiVersion: rbac.authorization.k8s.io/v1
metadata:
name: kops:{{ $name }}:approval
labels:
k8s-app: {{ $name }}
k8s-addon: {{ $name }}.addons.k8s.io
roleRef:
kind: ClusterRole
name: system:certificates.k8s.io:certificatesigningrequests:nodeclient
apiGroup: rbac.authorization.k8s.io
subjects:
- kind: Group
name: system:bootstrappers
apiGroup: rbac.authorization.k8s.io
---
# the service permission requires to create the bootstrap tokens
apiVersion: rbac.authorization.k8s.io/v1
kind: Role
metadata:
name: kops:{{ $namespace }}:{{ $name }}
namespace: {{ $namespace }}
labels:
k8s-app: {{ $name }}
k8s-addon: {{ $name }}.addons.k8s.io
rules:
- apiGroups:
- "*"
resources:
- secrets
verbs:
- create
- list
---
apiVersion: rbac.authorization.k8s.io/v1
kind: RoleBinding
metadata:
name: kops:{{ $namespace }}:{{ $name }}
namespace: {{ $namespace }}
labels:
k8s-app: {{ $name }}
k8s-addon: {{ $name }}.addons.k8s.io
roleRef:
apiGroup: rbac.authorization.k8s.io
kind: Role
name: kops:{{ $namespace }}:{{ $name }}
subjects:
- kind: ServiceAccount
name: {{ $name }}
namespace: {{ $namespace }}
---
apiVersion: rbac.authorization.k8s.io/v1
kind: ClusterRoleBinding
metadata:
name: kops:{{ $name }}:nodes-viewer
labels:
k8s-app: {{ $name }}
k8s-addon: {{ $name }}.addons.k8s.io
roleRef:
apiGroup: rbac.authorization.k8s.io
kind: ClusterRole
name: kops:{{ $name }}:nodes-viewer
subjects:
- kind: ServiceAccount
name: {{ $name }}
namespace: {{ $namespace }}
---
kind: DaemonSet
apiVersion: apps/v1
metadata:
name: {{ $name }}
namespace: {{ $namespace }}
labels:
k8s-app: {{ $name }}
k8s-addon: {{ $name }}.addons.k8s.io
spec:
selector:
matchLabels:
k8s-app: {{ $name }}
template:
metadata:
labels:
k8s-app: {{ $name }}
annotations:
dns.alpha.kubernetes.io/internal: {{ $name }}-internal.{{ ClusterName }}
prometheus.io/port: "{{ $na.Port }}"
prometheus.io/scheme: "https"
prometheus.io/scrape: "true"
scheduler.alpha.kubernetes.io/critical-pod: ''
spec:
hostNetwork: true
nodeSelector:
kubernetes.io/role: master
priorityClassName: system-node-critical
serviceAccount: {{ $name }}
securityContext:
fsGroup: 1000
tolerations:
- key: "node-role.kubernetes.io/master"
effect: NoSchedule
volumes:
- name: config
hostPath:
path: /srv/kubernetes/node-authorizer
type: DirectoryOrCreate
containers:
- name: {{ $name }}
image: {{ $na.Image }}
args:
- server
- --authorization-timeout={{ $na.Timeout.Duration }}
- --authorizer={{ $na.Authorizer }}
- --cluster-name={{ ClusterName }}
{{- range $na.Features }}
- --feature={{ . }}
{{- end }}
- --listen=0.0.0.0:{{ $na.Port }}
- --tls-cert=/config/tls.pem
- --tls-client-ca=/config/ca.pem
- --tls-private-key=/config/tls-key.pem
- --token-ttl={{ $na.TokenTTL.Duration }}
{{- if $proxy }}
env:
- name: http_proxy
value: {{ $proxy.HTTPProxy.Host }}:{{ $proxy.HTTPProxy.Port }}
{{- if $proxy.ProxyExcludes }}
- name: no_proxy
value: {{ $proxy.ProxyExcludes }}
{{- end }}
{{- end }}
resources:
limits:
cpu: 100m
memory: 64Mi
requests:
cpu: 10m
memory: 10Mi
volumeMounts:
- mountPath: /config
readOnly: true
name: config
`)
func cloudupResourcesAddonsNodeAuthorizerAddonsK8sIoK8s112YamlTemplateBytes() ([]byte, error) {
return _cloudupResourcesAddonsNodeAuthorizerAddonsK8sIoK8s112YamlTemplate, nil
}
func cloudupResourcesAddonsNodeAuthorizerAddonsK8sIoK8s112YamlTemplate() (*asset, error) {
bytes, err := cloudupResourcesAddonsNodeAuthorizerAddonsK8sIoK8s112YamlTemplateBytes()
if err != nil {
return nil, err
}
info := bindataFileInfo{name: "cloudup/resources/addons/node-authorizer.addons.k8s.io/k8s-1.12.yaml.template", size: 0, mode: os.FileMode(0), modTime: time.Unix(0, 0)}
a := &asset{bytes: bytes, info: info}
return a, nil
}
var _cloudupResourcesAddonsNodeTerminationHandlerAwsK8s111YamlTemplate = []byte(`{{ with .NodeTerminationHandler }}
# Sourced from https://github.com/aws/aws-node-termination-handler/releases/download/v1.7.0/all-resources.yaml
---
@ -42656,7 +42449,6 @@ var _bindata = map[string]func() (*asset, error){
"cloudup/resources/addons/networking.projectcalico.org.canal/k8s-1.15.yaml.template": cloudupResourcesAddonsNetworkingProjectcalicoOrgCanalK8s115YamlTemplate,
"cloudup/resources/addons/networking.projectcalico.org.canal/k8s-1.16.yaml.template": cloudupResourcesAddonsNetworkingProjectcalicoOrgCanalK8s116YamlTemplate,
"cloudup/resources/addons/networking.weave/k8s-1.12.yaml.template": cloudupResourcesAddonsNetworkingWeaveK8s112YamlTemplate,
"cloudup/resources/addons/node-authorizer.addons.k8s.io/k8s-1.12.yaml.template": cloudupResourcesAddonsNodeAuthorizerAddonsK8sIoK8s112YamlTemplate,
"cloudup/resources/addons/node-termination-handler.aws/k8s-1.11.yaml.template": cloudupResourcesAddonsNodeTerminationHandlerAwsK8s111YamlTemplate,
"cloudup/resources/addons/nodelocaldns.addons.k8s.io/k8s-1.12.yaml.template": cloudupResourcesAddonsNodelocaldnsAddonsK8sIoK8s112YamlTemplate,
"cloudup/resources/addons/openstack.addons.k8s.io/k8s-1.13.yaml.template": cloudupResourcesAddonsOpenstackAddonsK8sIoK8s113YamlTemplate,
@ -42801,9 +42593,6 @@ var _bintree = &bintree{nil, map[string]*bintree{
"networking.weave": {nil, map[string]*bintree{
"k8s-1.12.yaml.template": {cloudupResourcesAddonsNetworkingWeaveK8s112YamlTemplate, map[string]*bintree{}},
}},
"node-authorizer.addons.k8s.io": {nil, map[string]*bintree{
"k8s-1.12.yaml.template": {cloudupResourcesAddonsNodeAuthorizerAddonsK8sIoK8s112YamlTemplate, map[string]*bintree{}},
}},
"node-termination-handler.aws": {nil, map[string]*bintree{
"k8s-1.11.yaml.template": {cloudupResourcesAddonsNodeTerminationHandlerAwsK8s111YamlTemplate, map[string]*bintree{}},
}},

View File

@ -1,189 +0,0 @@
{{- $proxy := .EgressProxy -}}
{{- $na := .NodeAuthorization.NodeAuthorizer -}}
{{- $name := "node-authorizer" -}}
{{- $namespace := "kube-system" -}}
---
apiVersion: v1
kind: ServiceAccount
metadata:
name: {{ $name }}
namespace: {{ $namespace }}
labels:
k8s-app: {{ $name }}
k8s-addon: {{ $name }}.addons.k8s.io
---
kind: ClusterRole
apiVersion: rbac.authorization.k8s.io/v1
metadata:
name: kops:{{ $name }}:nodes-viewer
labels:
k8s-app: {{ $name }}
k8s-addon: {{ $name }}.addons.k8s.io
rules:
- apiGroups:
- "*"
resources:
- nodes
verbs:
- get
- list
---
# permits the node access to create a CSR
kind: ClusterRoleBinding
apiVersion: rbac.authorization.k8s.io/v1
metadata:
name: kops:{{ $name }}:system:bootstrappers
labels:
k8s-app: {{ $name }}
k8s-addon: {{ $name }}.addons.k8s.io
roleRef:
kind: ClusterRole
name: system:node-bootstrapper
apiGroup: rbac.authorization.k8s.io
subjects:
- kind: Group
name: system:bootstrappers
apiGroup: rbac.authorization.k8s.io
---
# indicates to the controller to auto-sign the CSR for this group
kind: ClusterRoleBinding
apiVersion: rbac.authorization.k8s.io/v1
metadata:
name: kops:{{ $name }}:approval
labels:
k8s-app: {{ $name }}
k8s-addon: {{ $name }}.addons.k8s.io
roleRef:
kind: ClusterRole
name: system:certificates.k8s.io:certificatesigningrequests:nodeclient
apiGroup: rbac.authorization.k8s.io
subjects:
- kind: Group
name: system:bootstrappers
apiGroup: rbac.authorization.k8s.io
---
# the service permission requires to create the bootstrap tokens
apiVersion: rbac.authorization.k8s.io/v1
kind: Role
metadata:
name: kops:{{ $namespace }}:{{ $name }}
namespace: {{ $namespace }}
labels:
k8s-app: {{ $name }}
k8s-addon: {{ $name }}.addons.k8s.io
rules:
- apiGroups:
- "*"
resources:
- secrets
verbs:
- create
- list
---
apiVersion: rbac.authorization.k8s.io/v1
kind: RoleBinding
metadata:
name: kops:{{ $namespace }}:{{ $name }}
namespace: {{ $namespace }}
labels:
k8s-app: {{ $name }}
k8s-addon: {{ $name }}.addons.k8s.io
roleRef:
apiGroup: rbac.authorization.k8s.io
kind: Role
name: kops:{{ $namespace }}:{{ $name }}
subjects:
- kind: ServiceAccount
name: {{ $name }}
namespace: {{ $namespace }}
---
apiVersion: rbac.authorization.k8s.io/v1
kind: ClusterRoleBinding
metadata:
name: kops:{{ $name }}:nodes-viewer
labels:
k8s-app: {{ $name }}
k8s-addon: {{ $name }}.addons.k8s.io
roleRef:
apiGroup: rbac.authorization.k8s.io
kind: ClusterRole
name: kops:{{ $name }}:nodes-viewer
subjects:
- kind: ServiceAccount
name: {{ $name }}
namespace: {{ $namespace }}
---
kind: DaemonSet
apiVersion: apps/v1
metadata:
name: {{ $name }}
namespace: {{ $namespace }}
labels:
k8s-app: {{ $name }}
k8s-addon: {{ $name }}.addons.k8s.io
spec:
selector:
matchLabels:
k8s-app: {{ $name }}
template:
metadata:
labels:
k8s-app: {{ $name }}
annotations:
dns.alpha.kubernetes.io/internal: {{ $name }}-internal.{{ ClusterName }}
prometheus.io/port: "{{ $na.Port }}"
prometheus.io/scheme: "https"
prometheus.io/scrape: "true"
scheduler.alpha.kubernetes.io/critical-pod: ''
spec:
hostNetwork: true
nodeSelector:
kubernetes.io/role: master
priorityClassName: system-node-critical
serviceAccount: {{ $name }}
securityContext:
fsGroup: 1000
tolerations:
- key: "node-role.kubernetes.io/master"
effect: NoSchedule
volumes:
- name: config
hostPath:
path: /srv/kubernetes/node-authorizer
type: DirectoryOrCreate
containers:
- name: {{ $name }}
image: {{ $na.Image }}
args:
- server
- --authorization-timeout={{ $na.Timeout.Duration }}
- --authorizer={{ $na.Authorizer }}
- --cluster-name={{ ClusterName }}
{{- range $na.Features }}
- --feature={{ . }}
{{- end }}
- --listen=0.0.0.0:{{ $na.Port }}
- --tls-cert=/config/tls.pem
- --tls-client-ca=/config/ca.pem
- --tls-private-key=/config/tls-key.pem
- --token-ttl={{ $na.TokenTTL.Duration }}
{{- if $proxy }}
env:
- name: http_proxy
value: {{ $proxy.HTTPProxy.Host }}:{{ $proxy.HTTPProxy.Port }}
{{- if $proxy.ProxyExcludes }}
- name: no_proxy
value: {{ $proxy.ProxyExcludes }}
{{- end }}
{{- end }}
resources:
limits:
cpu: 100m
memory: 64Mi
requests:
cpu: 10m
memory: 10Mi
volumeMounts:
- mountPath: /config
readOnly: true
name: config

View File

@ -49,7 +49,6 @@ go_library(
"//pkg/model/components:go_default_library",
"//pkg/model/components/etcdmanager:go_default_library",
"//pkg/model/components/kubeapiserver:go_default_library",
"//pkg/model/components/node-authorizer:go_default_library",
"//pkg/model/domodel:go_default_library",
"//pkg/model/gcemodel:go_default_library",
"//pkg/model/iam:go_default_library",

View File

@ -277,26 +277,6 @@ func (b *BootstrapChannelBuilder) buildAddons(c *fi.ModelBuilderContext) (*chann
}
}
if b.Cluster.Spec.NodeAuthorization != nil {
{
key := "node-authorizer.addons.k8s.io"
version := "v0.0.4-kops.2"
{
location := key + "/k8s-1.12.yaml"
id := "k8s-1.12.yaml"
addons.Spec.Addons = append(addons.Spec.Addons, &channelsapi.AddonSpec{
Name: fi.String(key),
Version: fi.String(version),
Selector: map[string]string{"k8s-addon": key},
Manifest: fi.String(location),
Id: id,
})
}
}
}
kubeDNS := b.Cluster.Spec.KubeDNS
// This checks if the Kubernetes version is greater than or equal to 1.20
@ -350,11 +330,11 @@ func (b *BootstrapChannelBuilder) buildAddons(c *fi.ModelBuilderContext) (*chann
}
}
// @check if node authorization or bootstrap tokens are enabled an if so we can forgo applying
// @check if bootstrap tokens are enabled an if so we can forgo applying
// this manifest. For clusters whom are upgrading from RBAC to Node,RBAC the clusterrolebinding
// will remain and have to be deleted manually once all the nodes have been upgraded.
enableRBACAddon := true
if b.UseKopsControllerForNodeBootstrap() || b.Cluster.Spec.NodeAuthorization != nil {
if b.UseKopsControllerForNodeBootstrap() {
enableRBACAddon = false
}
if b.Cluster.Spec.KubeAPIServer != nil {

View File

@ -32,7 +32,6 @@ import (
"k8s.io/kops/pkg/dns"
"k8s.io/kops/pkg/model/components"
"k8s.io/kops/pkg/model/components/etcdmanager"
nodeauthorizer "k8s.io/kops/pkg/model/components/node-authorizer"
"k8s.io/kops/upup/pkg/fi"
"k8s.io/kops/upup/pkg/fi/loader"
"k8s.io/kops/util/pkg/reflectutils"
@ -265,7 +264,6 @@ func (c *populateClusterSpec) run(clientset simple.Clientset) error {
codeModels = append(codeModels, &components.DefaultsOptionsBuilder{Context: optionsContext})
codeModels = append(codeModels, &components.EtcdOptionsBuilder{OptionsContext: optionsContext})
codeModels = append(codeModels, &etcdmanager.EtcdManagerOptionsBuilder{OptionsContext: optionsContext})
codeModels = append(codeModels, &nodeauthorizer.OptionsBuilder{Context: optionsContext})
codeModels = append(codeModels, &components.KubeAPIServerOptionsBuilder{OptionsContext: optionsContext})
codeModels = append(codeModels, &components.DockerOptionsBuilder{OptionsContext: optionsContext})
codeModels = append(codeModels, &components.ContainerdOptionsBuilder{OptionsContext: optionsContext})

View File

@ -229,7 +229,6 @@ func (c *NodeUpCommand) Run(out io.Writer) error {
loader.Builders = append(loader.Builders, &model.CloudConfigBuilder{NodeupModelContext: modelContext})
loader.Builders = append(loader.Builders, &model.FileAssetsBuilder{NodeupModelContext: modelContext})
loader.Builders = append(loader.Builders, &model.HookBuilder{NodeupModelContext: modelContext})
loader.Builders = append(loader.Builders, &model.NodeAuthorizationBuilder{NodeupModelContext: modelContext})
loader.Builders = append(loader.Builders, &model.KubeletBuilder{NodeupModelContext: modelContext})
loader.Builders = append(loader.Builders, &model.KubectlBuilder{NodeupModelContext: modelContext})
loader.Builders = append(loader.Builders, &model.EtcdBuilder{NodeupModelContext: modelContext})

View File

@ -1,5 +0,0 @@
TAGS
tags
.*.swp
tomlcheck/tomlcheck
toml.test

View File

@ -1,15 +0,0 @@
language: go
go:
- 1.1
- 1.2
- 1.3
- 1.4
- 1.5
- 1.6
- tip
install:
- go install ./...
- go get github.com/BurntSushi/toml-test
script:
- export PATH="$PATH:$HOME/gopath/bin"
- make test

View File

@ -1,20 +0,0 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library(
name = "go_default_library",
srcs = [
"decode.go",
"decode_meta.go",
"doc.go",
"encode.go",
"encoding_types.go",
"encoding_types_1.1.go",
"lex.go",
"parse.go",
"type_check.go",
"type_fields.go",
],
importmap = "k8s.io/kops/vendor/github.com/BurntSushi/toml",
importpath = "github.com/BurntSushi/toml",
visibility = ["//visibility:public"],
)

View File

@ -1,3 +0,0 @@
Compatible with TOML version
[v0.4.0](https://github.com/toml-lang/toml/blob/v0.4.0/versions/en/toml-v0.4.0.md)

View File

@ -1,21 +0,0 @@
The MIT License (MIT)
Copyright (c) 2013 TOML authors
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

View File

@ -1,19 +0,0 @@
install:
go install ./...
test: install
go test -v
toml-test toml-test-decoder
toml-test -encoder toml-test-encoder
fmt:
gofmt -w *.go */*.go
colcheck *.go */*.go
tags:
find ./ -name '*.go' -print0 | xargs -0 gotags > TAGS
push:
git push origin master
git push github master

View File

@ -1,218 +0,0 @@
## TOML parser and encoder for Go with reflection
TOML stands for Tom's Obvious, Minimal Language. This Go package provides a
reflection interface similar to Go's standard library `json` and `xml`
packages. This package also supports the `encoding.TextUnmarshaler` and
`encoding.TextMarshaler` interfaces so that you can define custom data
representations. (There is an example of this below.)
Spec: https://github.com/toml-lang/toml
Compatible with TOML version
[v0.4.0](https://github.com/toml-lang/toml/blob/master/versions/en/toml-v0.4.0.md)
Documentation: https://godoc.org/github.com/BurntSushi/toml
Installation:
```bash
go get github.com/BurntSushi/toml
```
Try the toml validator:
```bash
go get github.com/BurntSushi/toml/cmd/tomlv
tomlv some-toml-file.toml
```
[![Build Status](https://travis-ci.org/BurntSushi/toml.svg?branch=master)](https://travis-ci.org/BurntSushi/toml) [![GoDoc](https://godoc.org/github.com/BurntSushi/toml?status.svg)](https://godoc.org/github.com/BurntSushi/toml)
### Testing
This package passes all tests in
[toml-test](https://github.com/BurntSushi/toml-test) for both the decoder
and the encoder.
### Examples
This package works similarly to how the Go standard library handles `XML`
and `JSON`. Namely, data is loaded into Go values via reflection.
For the simplest example, consider some TOML file as just a list of keys
and values:
```toml
Age = 25
Cats = [ "Cauchy", "Plato" ]
Pi = 3.14
Perfection = [ 6, 28, 496, 8128 ]
DOB = 1987-07-05T05:45:00Z
```
Which could be defined in Go as:
```go
type Config struct {
Age int
Cats []string
Pi float64
Perfection []int
DOB time.Time // requires `import time`
}
```
And then decoded with:
```go
var conf Config
if _, err := toml.Decode(tomlData, &conf); err != nil {
// handle error
}
```
You can also use struct tags if your struct field name doesn't map to a TOML
key value directly:
```toml
some_key_NAME = "wat"
```
```go
type TOML struct {
ObscureKey string `toml:"some_key_NAME"`
}
```
### Using the `encoding.TextUnmarshaler` interface
Here's an example that automatically parses duration strings into
`time.Duration` values:
```toml
[[song]]
name = "Thunder Road"
duration = "4m49s"
[[song]]
name = "Stairway to Heaven"
duration = "8m03s"
```
Which can be decoded with:
```go
type song struct {
Name string
Duration duration
}
type songs struct {
Song []song
}
var favorites songs
if _, err := toml.Decode(blob, &favorites); err != nil {
log.Fatal(err)
}
for _, s := range favorites.Song {
fmt.Printf("%s (%s)\n", s.Name, s.Duration)
}
```
And you'll also need a `duration` type that satisfies the
`encoding.TextUnmarshaler` interface:
```go
type duration struct {
time.Duration
}
func (d *duration) UnmarshalText(text []byte) error {
var err error
d.Duration, err = time.ParseDuration(string(text))
return err
}
```
### More complex usage
Here's an example of how to load the example from the official spec page:
```toml
# This is a TOML document. Boom.
title = "TOML Example"
[owner]
name = "Tom Preston-Werner"
organization = "GitHub"
bio = "GitHub Cofounder & CEO\nLikes tater tots and beer."
dob = 1979-05-27T07:32:00Z # First class dates? Why not?
[database]
server = "192.168.1.1"
ports = [ 8001, 8001, 8002 ]
connection_max = 5000
enabled = true
[servers]
# You can indent as you please. Tabs or spaces. TOML don't care.
[servers.alpha]
ip = "10.0.0.1"
dc = "eqdc10"
[servers.beta]
ip = "10.0.0.2"
dc = "eqdc10"
[clients]
data = [ ["gamma", "delta"], [1, 2] ] # just an update to make sure parsers support it
# Line breaks are OK when inside arrays
hosts = [
"alpha",
"omega"
]
```
And the corresponding Go types are:
```go
type tomlConfig struct {
Title string
Owner ownerInfo
DB database `toml:"database"`
Servers map[string]server
Clients clients
}
type ownerInfo struct {
Name string
Org string `toml:"organization"`
Bio string
DOB time.Time
}
type database struct {
Server string
Ports []int
ConnMax int `toml:"connection_max"`
Enabled bool
}
type server struct {
IP string
DC string
}
type clients struct {
Data [][]interface{}
Hosts []string
}
```
Note that a case insensitive match will be tried if an exact match can't be
found.
A working example of the above can be found in `_examples/example.{go,toml}`.

View File

@ -1,509 +0,0 @@
package toml
import (
"fmt"
"io"
"io/ioutil"
"math"
"reflect"
"strings"
"time"
)
func e(format string, args ...interface{}) error {
return fmt.Errorf("toml: "+format, args...)
}
// Unmarshaler is the interface implemented by objects that can unmarshal a
// TOML description of themselves.
type Unmarshaler interface {
UnmarshalTOML(interface{}) error
}
// Unmarshal decodes the contents of `p` in TOML format into a pointer `v`.
func Unmarshal(p []byte, v interface{}) error {
_, err := Decode(string(p), v)
return err
}
// Primitive is a TOML value that hasn't been decoded into a Go value.
// When using the various `Decode*` functions, the type `Primitive` may
// be given to any value, and its decoding will be delayed.
//
// A `Primitive` value can be decoded using the `PrimitiveDecode` function.
//
// The underlying representation of a `Primitive` value is subject to change.
// Do not rely on it.
//
// N.B. Primitive values are still parsed, so using them will only avoid
// the overhead of reflection. They can be useful when you don't know the
// exact type of TOML data until run time.
type Primitive struct {
undecoded interface{}
context Key
}
// DEPRECATED!
//
// Use MetaData.PrimitiveDecode instead.
func PrimitiveDecode(primValue Primitive, v interface{}) error {
md := MetaData{decoded: make(map[string]bool)}
return md.unify(primValue.undecoded, rvalue(v))
}
// PrimitiveDecode is just like the other `Decode*` functions, except it
// decodes a TOML value that has already been parsed. Valid primitive values
// can *only* be obtained from values filled by the decoder functions,
// including this method. (i.e., `v` may contain more `Primitive`
// values.)
//
// Meta data for primitive values is included in the meta data returned by
// the `Decode*` functions with one exception: keys returned by the Undecoded
// method will only reflect keys that were decoded. Namely, any keys hidden
// behind a Primitive will be considered undecoded. Executing this method will
// update the undecoded keys in the meta data. (See the example.)
func (md *MetaData) PrimitiveDecode(primValue Primitive, v interface{}) error {
md.context = primValue.context
defer func() { md.context = nil }()
return md.unify(primValue.undecoded, rvalue(v))
}
// Decode will decode the contents of `data` in TOML format into a pointer
// `v`.
//
// TOML hashes correspond to Go structs or maps. (Dealer's choice. They can be
// used interchangeably.)
//
// TOML arrays of tables correspond to either a slice of structs or a slice
// of maps.
//
// TOML datetimes correspond to Go `time.Time` values.
//
// All other TOML types (float, string, int, bool and array) correspond
// to the obvious Go types.
//
// An exception to the above rules is if a type implements the
// encoding.TextUnmarshaler interface. In this case, any primitive TOML value
// (floats, strings, integers, booleans and datetimes) will be converted to
// a byte string and given to the value's UnmarshalText method. See the
// Unmarshaler example for a demonstration with time duration strings.
//
// Key mapping
//
// TOML keys can map to either keys in a Go map or field names in a Go
// struct. The special `toml` struct tag may be used to map TOML keys to
// struct fields that don't match the key name exactly. (See the example.)
// A case insensitive match to struct names will be tried if an exact match
// can't be found.
//
// The mapping between TOML values and Go values is loose. That is, there
// may exist TOML values that cannot be placed into your representation, and
// there may be parts of your representation that do not correspond to
// TOML values. This loose mapping can be made stricter by using the IsDefined
// and/or Undecoded methods on the MetaData returned.
//
// This decoder will not handle cyclic types. If a cyclic type is passed,
// `Decode` will not terminate.
func Decode(data string, v interface{}) (MetaData, error) {
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr {
return MetaData{}, e("Decode of non-pointer %s", reflect.TypeOf(v))
}
if rv.IsNil() {
return MetaData{}, e("Decode of nil %s", reflect.TypeOf(v))
}
p, err := parse(data)
if err != nil {
return MetaData{}, err
}
md := MetaData{
p.mapping, p.types, p.ordered,
make(map[string]bool, len(p.ordered)), nil,
}
return md, md.unify(p.mapping, indirect(rv))
}
// DecodeFile is just like Decode, except it will automatically read the
// contents of the file at `fpath` and decode it for you.
func DecodeFile(fpath string, v interface{}) (MetaData, error) {
bs, err := ioutil.ReadFile(fpath)
if err != nil {
return MetaData{}, err
}
return Decode(string(bs), v)
}
// DecodeReader is just like Decode, except it will consume all bytes
// from the reader and decode it for you.
func DecodeReader(r io.Reader, v interface{}) (MetaData, error) {
bs, err := ioutil.ReadAll(r)
if err != nil {
return MetaData{}, err
}
return Decode(string(bs), v)
}
// unify performs a sort of type unification based on the structure of `rv`,
// which is the client representation.
//
// Any type mismatch produces an error. Finding a type that we don't know
// how to handle produces an unsupported type error.
func (md *MetaData) unify(data interface{}, rv reflect.Value) error {
// Special case. Look for a `Primitive` value.
if rv.Type() == reflect.TypeOf((*Primitive)(nil)).Elem() {
// Save the undecoded data and the key context into the primitive
// value.
context := make(Key, len(md.context))
copy(context, md.context)
rv.Set(reflect.ValueOf(Primitive{
undecoded: data,
context: context,
}))
return nil
}
// Special case. Unmarshaler Interface support.
if rv.CanAddr() {
if v, ok := rv.Addr().Interface().(Unmarshaler); ok {
return v.UnmarshalTOML(data)
}
}
// Special case. Handle time.Time values specifically.
// TODO: Remove this code when we decide to drop support for Go 1.1.
// This isn't necessary in Go 1.2 because time.Time satisfies the encoding
// interfaces.
if rv.Type().AssignableTo(rvalue(time.Time{}).Type()) {
return md.unifyDatetime(data, rv)
}
// Special case. Look for a value satisfying the TextUnmarshaler interface.
if v, ok := rv.Interface().(TextUnmarshaler); ok {
return md.unifyText(data, v)
}
// BUG(burntsushi)
// The behavior here is incorrect whenever a Go type satisfies the
// encoding.TextUnmarshaler interface but also corresponds to a TOML
// hash or array. In particular, the unmarshaler should only be applied
// to primitive TOML values. But at this point, it will be applied to
// all kinds of values and produce an incorrect error whenever those values
// are hashes or arrays (including arrays of tables).
k := rv.Kind()
// laziness
if k >= reflect.Int && k <= reflect.Uint64 {
return md.unifyInt(data, rv)
}
switch k {
case reflect.Ptr:
elem := reflect.New(rv.Type().Elem())
err := md.unify(data, reflect.Indirect(elem))
if err != nil {
return err
}
rv.Set(elem)
return nil
case reflect.Struct:
return md.unifyStruct(data, rv)
case reflect.Map:
return md.unifyMap(data, rv)
case reflect.Array:
return md.unifyArray(data, rv)
case reflect.Slice:
return md.unifySlice(data, rv)
case reflect.String:
return md.unifyString(data, rv)
case reflect.Bool:
return md.unifyBool(data, rv)
case reflect.Interface:
// we only support empty interfaces.
if rv.NumMethod() > 0 {
return e("unsupported type %s", rv.Type())
}
return md.unifyAnything(data, rv)
case reflect.Float32:
fallthrough
case reflect.Float64:
return md.unifyFloat64(data, rv)
}
return e("unsupported type %s", rv.Kind())
}
func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error {
tmap, ok := mapping.(map[string]interface{})
if !ok {
if mapping == nil {
return nil
}
return e("type mismatch for %s: expected table but found %T",
rv.Type().String(), mapping)
}
for key, datum := range tmap {
var f *field
fields := cachedTypeFields(rv.Type())
for i := range fields {
ff := &fields[i]
if ff.name == key {
f = ff
break
}
if f == nil && strings.EqualFold(ff.name, key) {
f = ff
}
}
if f != nil {
subv := rv
for _, i := range f.index {
subv = indirect(subv.Field(i))
}
if isUnifiable(subv) {
md.decoded[md.context.add(key).String()] = true
md.context = append(md.context, key)
if err := md.unify(datum, subv); err != nil {
return err
}
md.context = md.context[0 : len(md.context)-1]
} else if f.name != "" {
// Bad user! No soup for you!
return e("cannot write unexported field %s.%s",
rv.Type().String(), f.name)
}
}
}
return nil
}
func (md *MetaData) unifyMap(mapping interface{}, rv reflect.Value) error {
tmap, ok := mapping.(map[string]interface{})
if !ok {
if tmap == nil {
return nil
}
return badtype("map", mapping)
}
if rv.IsNil() {
rv.Set(reflect.MakeMap(rv.Type()))
}
for k, v := range tmap {
md.decoded[md.context.add(k).String()] = true
md.context = append(md.context, k)
rvkey := indirect(reflect.New(rv.Type().Key()))
rvval := reflect.Indirect(reflect.New(rv.Type().Elem()))
if err := md.unify(v, rvval); err != nil {
return err
}
md.context = md.context[0 : len(md.context)-1]
rvkey.SetString(k)
rv.SetMapIndex(rvkey, rvval)
}
return nil
}
func (md *MetaData) unifyArray(data interface{}, rv reflect.Value) error {
datav := reflect.ValueOf(data)
if datav.Kind() != reflect.Slice {
if !datav.IsValid() {
return nil
}
return badtype("slice", data)
}
sliceLen := datav.Len()
if sliceLen != rv.Len() {
return e("expected array length %d; got TOML array of length %d",
rv.Len(), sliceLen)
}
return md.unifySliceArray(datav, rv)
}
func (md *MetaData) unifySlice(data interface{}, rv reflect.Value) error {
datav := reflect.ValueOf(data)
if datav.Kind() != reflect.Slice {
if !datav.IsValid() {
return nil
}
return badtype("slice", data)
}
n := datav.Len()
if rv.IsNil() || rv.Cap() < n {
rv.Set(reflect.MakeSlice(rv.Type(), n, n))
}
rv.SetLen(n)
return md.unifySliceArray(datav, rv)
}
func (md *MetaData) unifySliceArray(data, rv reflect.Value) error {
sliceLen := data.Len()
for i := 0; i < sliceLen; i++ {
v := data.Index(i).Interface()
sliceval := indirect(rv.Index(i))
if err := md.unify(v, sliceval); err != nil {
return err
}
}
return nil
}
func (md *MetaData) unifyDatetime(data interface{}, rv reflect.Value) error {
if _, ok := data.(time.Time); ok {
rv.Set(reflect.ValueOf(data))
return nil
}
return badtype("time.Time", data)
}
func (md *MetaData) unifyString(data interface{}, rv reflect.Value) error {
if s, ok := data.(string); ok {
rv.SetString(s)
return nil
}
return badtype("string", data)
}
func (md *MetaData) unifyFloat64(data interface{}, rv reflect.Value) error {
if num, ok := data.(float64); ok {
switch rv.Kind() {
case reflect.Float32:
fallthrough
case reflect.Float64:
rv.SetFloat(num)
default:
panic("bug")
}
return nil
}
return badtype("float", data)
}
func (md *MetaData) unifyInt(data interface{}, rv reflect.Value) error {
if num, ok := data.(int64); ok {
if rv.Kind() >= reflect.Int && rv.Kind() <= reflect.Int64 {
switch rv.Kind() {
case reflect.Int, reflect.Int64:
// No bounds checking necessary.
case reflect.Int8:
if num < math.MinInt8 || num > math.MaxInt8 {
return e("value %d is out of range for int8", num)
}
case reflect.Int16:
if num < math.MinInt16 || num > math.MaxInt16 {
return e("value %d is out of range for int16", num)
}
case reflect.Int32:
if num < math.MinInt32 || num > math.MaxInt32 {
return e("value %d is out of range for int32", num)
}
}
rv.SetInt(num)
} else if rv.Kind() >= reflect.Uint && rv.Kind() <= reflect.Uint64 {
unum := uint64(num)
switch rv.Kind() {
case reflect.Uint, reflect.Uint64:
// No bounds checking necessary.
case reflect.Uint8:
if num < 0 || unum > math.MaxUint8 {
return e("value %d is out of range for uint8", num)
}
case reflect.Uint16:
if num < 0 || unum > math.MaxUint16 {
return e("value %d is out of range for uint16", num)
}
case reflect.Uint32:
if num < 0 || unum > math.MaxUint32 {
return e("value %d is out of range for uint32", num)
}
}
rv.SetUint(unum)
} else {
panic("unreachable")
}
return nil
}
return badtype("integer", data)
}
func (md *MetaData) unifyBool(data interface{}, rv reflect.Value) error {
if b, ok := data.(bool); ok {
rv.SetBool(b)
return nil
}
return badtype("boolean", data)
}
func (md *MetaData) unifyAnything(data interface{}, rv reflect.Value) error {
rv.Set(reflect.ValueOf(data))
return nil
}
func (md *MetaData) unifyText(data interface{}, v TextUnmarshaler) error {
var s string
switch sdata := data.(type) {
case TextMarshaler:
text, err := sdata.MarshalText()
if err != nil {
return err
}
s = string(text)
case fmt.Stringer:
s = sdata.String()
case string:
s = sdata
case bool:
s = fmt.Sprintf("%v", sdata)
case int64:
s = fmt.Sprintf("%d", sdata)
case float64:
s = fmt.Sprintf("%f", sdata)
default:
return badtype("primitive (string-like)", data)
}
if err := v.UnmarshalText([]byte(s)); err != nil {
return err
}
return nil
}
// rvalue returns a reflect.Value of `v`. All pointers are resolved.
func rvalue(v interface{}) reflect.Value {
return indirect(reflect.ValueOf(v))
}
// indirect returns the value pointed to by a pointer.
// Pointers are followed until the value is not a pointer.
// New values are allocated for each nil pointer.
//
// An exception to this rule is if the value satisfies an interface of
// interest to us (like encoding.TextUnmarshaler).
func indirect(v reflect.Value) reflect.Value {
if v.Kind() != reflect.Ptr {
if v.CanSet() {
pv := v.Addr()
if _, ok := pv.Interface().(TextUnmarshaler); ok {
return pv
}
}
return v
}
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
return indirect(reflect.Indirect(v))
}
func isUnifiable(rv reflect.Value) bool {
if rv.CanSet() {
return true
}
if _, ok := rv.Interface().(TextUnmarshaler); ok {
return true
}
return false
}
func badtype(expected string, data interface{}) error {
return e("cannot load TOML value of type %T into a Go %s", data, expected)
}

View File

@ -1,121 +0,0 @@
package toml
import "strings"
// MetaData allows access to meta information about TOML data that may not
// be inferrable via reflection. In particular, whether a key has been defined
// and the TOML type of a key.
type MetaData struct {
mapping map[string]interface{}
types map[string]tomlType
keys []Key
decoded map[string]bool
context Key // Used only during decoding.
}
// IsDefined returns true if the key given exists in the TOML data. The key
// should be specified hierarchially. e.g.,
//
// // access the TOML key 'a.b.c'
// IsDefined("a", "b", "c")
//
// IsDefined will return false if an empty key given. Keys are case sensitive.
func (md *MetaData) IsDefined(key ...string) bool {
if len(key) == 0 {
return false
}
var hash map[string]interface{}
var ok bool
var hashOrVal interface{} = md.mapping
for _, k := range key {
if hash, ok = hashOrVal.(map[string]interface{}); !ok {
return false
}
if hashOrVal, ok = hash[k]; !ok {
return false
}
}
return true
}
// Type returns a string representation of the type of the key specified.
//
// Type will return the empty string if given an empty key or a key that
// does not exist. Keys are case sensitive.
func (md *MetaData) Type(key ...string) string {
fullkey := strings.Join(key, ".")
if typ, ok := md.types[fullkey]; ok {
return typ.typeString()
}
return ""
}
// Key is the type of any TOML key, including key groups. Use (MetaData).Keys
// to get values of this type.
type Key []string
func (k Key) String() string {
return strings.Join(k, ".")
}
func (k Key) maybeQuotedAll() string {
var ss []string
for i := range k {
ss = append(ss, k.maybeQuoted(i))
}
return strings.Join(ss, ".")
}
func (k Key) maybeQuoted(i int) string {
quote := false
for _, c := range k[i] {
if !isBareKeyChar(c) {
quote = true
break
}
}
if quote {
return "\"" + strings.Replace(k[i], "\"", "\\\"", -1) + "\""
}
return k[i]
}
func (k Key) add(piece string) Key {
newKey := make(Key, len(k)+1)
copy(newKey, k)
newKey[len(k)] = piece
return newKey
}
// Keys returns a slice of every key in the TOML data, including key groups.
// Each key is itself a slice, where the first element is the top of the
// hierarchy and the last is the most specific.
//
// The list will have the same order as the keys appeared in the TOML data.
//
// All keys returned are non-empty.
func (md *MetaData) Keys() []Key {
return md.keys
}
// Undecoded returns all keys that have not been decoded in the order in which
// they appear in the original TOML document.
//
// This includes keys that haven't been decoded because of a Primitive value.
// Once the Primitive value is decoded, the keys will be considered decoded.
//
// Also note that decoding into an empty interface will result in no decoding,
// and so no keys will be considered decoded.
//
// In this sense, the Undecoded keys correspond to keys in the TOML document
// that do not have a concrete type in your representation.
func (md *MetaData) Undecoded() []Key {
undecoded := make([]Key, 0, len(md.keys))
for _, key := range md.keys {
if !md.decoded[key.String()] {
undecoded = append(undecoded, key)
}
}
return undecoded
}

View File

@ -1,27 +0,0 @@
/*
Package toml provides facilities for decoding and encoding TOML configuration
files via reflection. There is also support for delaying decoding with
the Primitive type, and querying the set of keys in a TOML document with the
MetaData type.
The specification implemented: https://github.com/toml-lang/toml
The sub-command github.com/BurntSushi/toml/cmd/tomlv can be used to verify
whether a file is a valid TOML document. It can also be used to print the
type of each key in a TOML document.
Testing
There are two important types of tests used for this package. The first is
contained inside '*_test.go' files and uses the standard Go unit testing
framework. These tests are primarily devoted to holistically testing the
decoder and encoder.
The second type of testing is used to verify the implementation's adherence
to the TOML specification. These tests have been factored into their own
project: https://github.com/BurntSushi/toml-test
The reason the tests are in a separate project is so that they can be used by
any implementation of TOML. Namely, it is language agnostic.
*/
package toml

View File

@ -1,568 +0,0 @@
package toml
import (
"bufio"
"errors"
"fmt"
"io"
"reflect"
"sort"
"strconv"
"strings"
"time"
)
type tomlEncodeError struct{ error }
var (
errArrayMixedElementTypes = errors.New(
"toml: cannot encode array with mixed element types")
errArrayNilElement = errors.New(
"toml: cannot encode array with nil element")
errNonString = errors.New(
"toml: cannot encode a map with non-string key type")
errAnonNonStruct = errors.New(
"toml: cannot encode an anonymous field that is not a struct")
errArrayNoTable = errors.New(
"toml: TOML array element cannot contain a table")
errNoKey = errors.New(
"toml: top-level values must be Go maps or structs")
errAnything = errors.New("") // used in testing
)
var quotedReplacer = strings.NewReplacer(
"\t", "\\t",
"\n", "\\n",
"\r", "\\r",
"\"", "\\\"",
"\\", "\\\\",
)
// Encoder controls the encoding of Go values to a TOML document to some
// io.Writer.
//
// The indentation level can be controlled with the Indent field.
type Encoder struct {
// A single indentation level. By default it is two spaces.
Indent string
// hasWritten is whether we have written any output to w yet.
hasWritten bool
w *bufio.Writer
}
// NewEncoder returns a TOML encoder that encodes Go values to the io.Writer
// given. By default, a single indentation level is 2 spaces.
func NewEncoder(w io.Writer) *Encoder {
return &Encoder{
w: bufio.NewWriter(w),
Indent: " ",
}
}
// Encode writes a TOML representation of the Go value to the underlying
// io.Writer. If the value given cannot be encoded to a valid TOML document,
// then an error is returned.
//
// The mapping between Go values and TOML values should be precisely the same
// as for the Decode* functions. Similarly, the TextMarshaler interface is
// supported by encoding the resulting bytes as strings. (If you want to write
// arbitrary binary data then you will need to use something like base64 since
// TOML does not have any binary types.)
//
// When encoding TOML hashes (i.e., Go maps or structs), keys without any
// sub-hashes are encoded first.
//
// If a Go map is encoded, then its keys are sorted alphabetically for
// deterministic output. More control over this behavior may be provided if
// there is demand for it.
//
// Encoding Go values without a corresponding TOML representation---like map
// types with non-string keys---will cause an error to be returned. Similarly
// for mixed arrays/slices, arrays/slices with nil elements, embedded
// non-struct types and nested slices containing maps or structs.
// (e.g., [][]map[string]string is not allowed but []map[string]string is OK
// and so is []map[string][]string.)
func (enc *Encoder) Encode(v interface{}) error {
rv := eindirect(reflect.ValueOf(v))
if err := enc.safeEncode(Key([]string{}), rv); err != nil {
return err
}
return enc.w.Flush()
}
func (enc *Encoder) safeEncode(key Key, rv reflect.Value) (err error) {
defer func() {
if r := recover(); r != nil {
if terr, ok := r.(tomlEncodeError); ok {
err = terr.error
return
}
panic(r)
}
}()
enc.encode(key, rv)
return nil
}
func (enc *Encoder) encode(key Key, rv reflect.Value) {
// Special case. Time needs to be in ISO8601 format.
// Special case. If we can marshal the type to text, then we used that.
// Basically, this prevents the encoder for handling these types as
// generic structs (or whatever the underlying type of a TextMarshaler is).
switch rv.Interface().(type) {
case time.Time, TextMarshaler:
enc.keyEqElement(key, rv)
return
}
k := rv.Kind()
switch k {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
reflect.Uint64,
reflect.Float32, reflect.Float64, reflect.String, reflect.Bool:
enc.keyEqElement(key, rv)
case reflect.Array, reflect.Slice:
if typeEqual(tomlArrayHash, tomlTypeOfGo(rv)) {
enc.eArrayOfTables(key, rv)
} else {
enc.keyEqElement(key, rv)
}
case reflect.Interface:
if rv.IsNil() {
return
}
enc.encode(key, rv.Elem())
case reflect.Map:
if rv.IsNil() {
return
}
enc.eTable(key, rv)
case reflect.Ptr:
if rv.IsNil() {
return
}
enc.encode(key, rv.Elem())
case reflect.Struct:
enc.eTable(key, rv)
default:
panic(e("unsupported type for key '%s': %s", key, k))
}
}
// eElement encodes any value that can be an array element (primitives and
// arrays).
func (enc *Encoder) eElement(rv reflect.Value) {
switch v := rv.Interface().(type) {
case time.Time:
// Special case time.Time as a primitive. Has to come before
// TextMarshaler below because time.Time implements
// encoding.TextMarshaler, but we need to always use UTC.
enc.wf(v.UTC().Format("2006-01-02T15:04:05Z"))
return
case TextMarshaler:
// Special case. Use text marshaler if it's available for this value.
if s, err := v.MarshalText(); err != nil {
encPanic(err)
} else {
enc.writeQuoted(string(s))
}
return
}
switch rv.Kind() {
case reflect.Bool:
enc.wf(strconv.FormatBool(rv.Bool()))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Int64:
enc.wf(strconv.FormatInt(rv.Int(), 10))
case reflect.Uint, reflect.Uint8, reflect.Uint16,
reflect.Uint32, reflect.Uint64:
enc.wf(strconv.FormatUint(rv.Uint(), 10))
case reflect.Float32:
enc.wf(floatAddDecimal(strconv.FormatFloat(rv.Float(), 'f', -1, 32)))
case reflect.Float64:
enc.wf(floatAddDecimal(strconv.FormatFloat(rv.Float(), 'f', -1, 64)))
case reflect.Array, reflect.Slice:
enc.eArrayOrSliceElement(rv)
case reflect.Interface:
enc.eElement(rv.Elem())
case reflect.String:
enc.writeQuoted(rv.String())
default:
panic(e("unexpected primitive type: %s", rv.Kind()))
}
}
// By the TOML spec, all floats must have a decimal with at least one
// number on either side.
func floatAddDecimal(fstr string) string {
if !strings.Contains(fstr, ".") {
return fstr + ".0"
}
return fstr
}
func (enc *Encoder) writeQuoted(s string) {
enc.wf("\"%s\"", quotedReplacer.Replace(s))
}
func (enc *Encoder) eArrayOrSliceElement(rv reflect.Value) {
length := rv.Len()
enc.wf("[")
for i := 0; i < length; i++ {
elem := rv.Index(i)
enc.eElement(elem)
if i != length-1 {
enc.wf(", ")
}
}
enc.wf("]")
}
func (enc *Encoder) eArrayOfTables(key Key, rv reflect.Value) {
if len(key) == 0 {
encPanic(errNoKey)
}
for i := 0; i < rv.Len(); i++ {
trv := rv.Index(i)
if isNil(trv) {
continue
}
panicIfInvalidKey(key)
enc.newline()
enc.wf("%s[[%s]]", enc.indentStr(key), key.maybeQuotedAll())
enc.newline()
enc.eMapOrStruct(key, trv)
}
}
func (enc *Encoder) eTable(key Key, rv reflect.Value) {
panicIfInvalidKey(key)
if len(key) == 1 {
// Output an extra newline between top-level tables.
// (The newline isn't written if nothing else has been written though.)
enc.newline()
}
if len(key) > 0 {
enc.wf("%s[%s]", enc.indentStr(key), key.maybeQuotedAll())
enc.newline()
}
enc.eMapOrStruct(key, rv)
}
func (enc *Encoder) eMapOrStruct(key Key, rv reflect.Value) {
switch rv := eindirect(rv); rv.Kind() {
case reflect.Map:
enc.eMap(key, rv)
case reflect.Struct:
enc.eStruct(key, rv)
default:
panic("eTable: unhandled reflect.Value Kind: " + rv.Kind().String())
}
}
func (enc *Encoder) eMap(key Key, rv reflect.Value) {
rt := rv.Type()
if rt.Key().Kind() != reflect.String {
encPanic(errNonString)
}
// Sort keys so that we have deterministic output. And write keys directly
// underneath this key first, before writing sub-structs or sub-maps.
var mapKeysDirect, mapKeysSub []string
for _, mapKey := range rv.MapKeys() {
k := mapKey.String()
if typeIsHash(tomlTypeOfGo(rv.MapIndex(mapKey))) {
mapKeysSub = append(mapKeysSub, k)
} else {
mapKeysDirect = append(mapKeysDirect, k)
}
}
var writeMapKeys = func(mapKeys []string) {
sort.Strings(mapKeys)
for _, mapKey := range mapKeys {
mrv := rv.MapIndex(reflect.ValueOf(mapKey))
if isNil(mrv) {
// Don't write anything for nil fields.
continue
}
enc.encode(key.add(mapKey), mrv)
}
}
writeMapKeys(mapKeysDirect)
writeMapKeys(mapKeysSub)
}
func (enc *Encoder) eStruct(key Key, rv reflect.Value) {
// Write keys for fields directly under this key first, because if we write
// a field that creates a new table, then all keys under it will be in that
// table (not the one we're writing here).
rt := rv.Type()
var fieldsDirect, fieldsSub [][]int
var addFields func(rt reflect.Type, rv reflect.Value, start []int)
addFields = func(rt reflect.Type, rv reflect.Value, start []int) {
for i := 0; i < rt.NumField(); i++ {
f := rt.Field(i)
// skip unexported fields
if f.PkgPath != "" && !f.Anonymous {
continue
}
frv := rv.Field(i)
if f.Anonymous {
t := f.Type
switch t.Kind() {
case reflect.Struct:
// Treat anonymous struct fields with
// tag names as though they are not
// anonymous, like encoding/json does.
if getOptions(f.Tag).name == "" {
addFields(t, frv, f.Index)
continue
}
case reflect.Ptr:
if t.Elem().Kind() == reflect.Struct &&
getOptions(f.Tag).name == "" {
if !frv.IsNil() {
addFields(t.Elem(), frv.Elem(), f.Index)
}
continue
}
// Fall through to the normal field encoding logic below
// for non-struct anonymous fields.
}
}
if typeIsHash(tomlTypeOfGo(frv)) {
fieldsSub = append(fieldsSub, append(start, f.Index...))
} else {
fieldsDirect = append(fieldsDirect, append(start, f.Index...))
}
}
}
addFields(rt, rv, nil)
var writeFields = func(fields [][]int) {
for _, fieldIndex := range fields {
sft := rt.FieldByIndex(fieldIndex)
sf := rv.FieldByIndex(fieldIndex)
if isNil(sf) {
// Don't write anything for nil fields.
continue
}
opts := getOptions(sft.Tag)
if opts.skip {
continue
}
keyName := sft.Name
if opts.name != "" {
keyName = opts.name
}
if opts.omitempty && isEmpty(sf) {
continue
}
if opts.omitzero && isZero(sf) {
continue
}
enc.encode(key.add(keyName), sf)
}
}
writeFields(fieldsDirect)
writeFields(fieldsSub)
}
// tomlTypeName returns the TOML type name of the Go value's type. It is
// used to determine whether the types of array elements are mixed (which is
// forbidden). If the Go value is nil, then it is illegal for it to be an array
// element, and valueIsNil is returned as true.
// Returns the TOML type of a Go value. The type may be `nil`, which means
// no concrete TOML type could be found.
func tomlTypeOfGo(rv reflect.Value) tomlType {
if isNil(rv) || !rv.IsValid() {
return nil
}
switch rv.Kind() {
case reflect.Bool:
return tomlBool
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
reflect.Uint64:
return tomlInteger
case reflect.Float32, reflect.Float64:
return tomlFloat
case reflect.Array, reflect.Slice:
if typeEqual(tomlHash, tomlArrayType(rv)) {
return tomlArrayHash
}
return tomlArray
case reflect.Ptr, reflect.Interface:
return tomlTypeOfGo(rv.Elem())
case reflect.String:
return tomlString
case reflect.Map:
return tomlHash
case reflect.Struct:
switch rv.Interface().(type) {
case time.Time:
return tomlDatetime
case TextMarshaler:
return tomlString
default:
return tomlHash
}
default:
panic("unexpected reflect.Kind: " + rv.Kind().String())
}
}
// tomlArrayType returns the element type of a TOML array. The type returned
// may be nil if it cannot be determined (e.g., a nil slice or a zero length
// slize). This function may also panic if it finds a type that cannot be
// expressed in TOML (such as nil elements, heterogeneous arrays or directly
// nested arrays of tables).
func tomlArrayType(rv reflect.Value) tomlType {
if isNil(rv) || !rv.IsValid() || rv.Len() == 0 {
return nil
}
firstType := tomlTypeOfGo(rv.Index(0))
if firstType == nil {
encPanic(errArrayNilElement)
}
rvlen := rv.Len()
for i := 1; i < rvlen; i++ {
elem := rv.Index(i)
switch elemType := tomlTypeOfGo(elem); {
case elemType == nil:
encPanic(errArrayNilElement)
case !typeEqual(firstType, elemType):
encPanic(errArrayMixedElementTypes)
}
}
// If we have a nested array, then we must make sure that the nested
// array contains ONLY primitives.
// This checks arbitrarily nested arrays.
if typeEqual(firstType, tomlArray) || typeEqual(firstType, tomlArrayHash) {
nest := tomlArrayType(eindirect(rv.Index(0)))
if typeEqual(nest, tomlHash) || typeEqual(nest, tomlArrayHash) {
encPanic(errArrayNoTable)
}
}
return firstType
}
type tagOptions struct {
skip bool // "-"
name string
omitempty bool
omitzero bool
}
func getOptions(tag reflect.StructTag) tagOptions {
t := tag.Get("toml")
if t == "-" {
return tagOptions{skip: true}
}
var opts tagOptions
parts := strings.Split(t, ",")
opts.name = parts[0]
for _, s := range parts[1:] {
switch s {
case "omitempty":
opts.omitempty = true
case "omitzero":
opts.omitzero = true
}
}
return opts
}
func isZero(rv reflect.Value) bool {
switch rv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return rv.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return rv.Uint() == 0
case reflect.Float32, reflect.Float64:
return rv.Float() == 0.0
}
return false
}
func isEmpty(rv reflect.Value) bool {
switch rv.Kind() {
case reflect.Array, reflect.Slice, reflect.Map, reflect.String:
return rv.Len() == 0
case reflect.Bool:
return !rv.Bool()
}
return false
}
func (enc *Encoder) newline() {
if enc.hasWritten {
enc.wf("\n")
}
}
func (enc *Encoder) keyEqElement(key Key, val reflect.Value) {
if len(key) == 0 {
encPanic(errNoKey)
}
panicIfInvalidKey(key)
enc.wf("%s%s = ", enc.indentStr(key), key.maybeQuoted(len(key)-1))
enc.eElement(val)
enc.newline()
}
func (enc *Encoder) wf(format string, v ...interface{}) {
if _, err := fmt.Fprintf(enc.w, format, v...); err != nil {
encPanic(err)
}
enc.hasWritten = true
}
func (enc *Encoder) indentStr(key Key) string {
return strings.Repeat(enc.Indent, len(key)-1)
}
func encPanic(err error) {
panic(tomlEncodeError{err})
}
func eindirect(v reflect.Value) reflect.Value {
switch v.Kind() {
case reflect.Ptr, reflect.Interface:
return eindirect(v.Elem())
default:
return v
}
}
func isNil(rv reflect.Value) bool {
switch rv.Kind() {
case reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
return rv.IsNil()
default:
return false
}
}
func panicIfInvalidKey(key Key) {
for _, k := range key {
if len(k) == 0 {
encPanic(e("Key '%s' is not a valid table name. Key names "+
"cannot be empty.", key.maybeQuotedAll()))
}
}
}
func isValidKeyName(s string) bool {
return len(s) != 0
}

View File

@ -1,19 +0,0 @@
// +build go1.2
package toml
// In order to support Go 1.1, we define our own TextMarshaler and
// TextUnmarshaler types. For Go 1.2+, we just alias them with the
// standard library interfaces.
import (
"encoding"
)
// TextMarshaler is a synonym for encoding.TextMarshaler. It is defined here
// so that Go 1.1 can be supported.
type TextMarshaler encoding.TextMarshaler
// TextUnmarshaler is a synonym for encoding.TextUnmarshaler. It is defined
// here so that Go 1.1 can be supported.
type TextUnmarshaler encoding.TextUnmarshaler

View File

@ -1,18 +0,0 @@
// +build !go1.2
package toml
// These interfaces were introduced in Go 1.2, so we add them manually when
// compiling for Go 1.1.
// TextMarshaler is a synonym for encoding.TextMarshaler. It is defined here
// so that Go 1.1 can be supported.
type TextMarshaler interface {
MarshalText() (text []byte, err error)
}
// TextUnmarshaler is a synonym for encoding.TextUnmarshaler. It is defined
// here so that Go 1.1 can be supported.
type TextUnmarshaler interface {
UnmarshalText(text []byte) error
}

View File

@ -1,953 +0,0 @@
package toml
import (
"fmt"
"strings"
"unicode"
"unicode/utf8"
)
type itemType int
const (
itemError itemType = iota
itemNIL // used in the parser to indicate no type
itemEOF
itemText
itemString
itemRawString
itemMultilineString
itemRawMultilineString
itemBool
itemInteger
itemFloat
itemDatetime
itemArray // the start of an array
itemArrayEnd
itemTableStart
itemTableEnd
itemArrayTableStart
itemArrayTableEnd
itemKeyStart
itemCommentStart
itemInlineTableStart
itemInlineTableEnd
)
const (
eof = 0
comma = ','
tableStart = '['
tableEnd = ']'
arrayTableStart = '['
arrayTableEnd = ']'
tableSep = '.'
keySep = '='
arrayStart = '['
arrayEnd = ']'
commentStart = '#'
stringStart = '"'
stringEnd = '"'
rawStringStart = '\''
rawStringEnd = '\''
inlineTableStart = '{'
inlineTableEnd = '}'
)
type stateFn func(lx *lexer) stateFn
type lexer struct {
input string
start int
pos int
line int
state stateFn
items chan item
// Allow for backing up up to three runes.
// This is necessary because TOML contains 3-rune tokens (""" and ''').
prevWidths [3]int
nprev int // how many of prevWidths are in use
// If we emit an eof, we can still back up, but it is not OK to call
// next again.
atEOF bool
// A stack of state functions used to maintain context.
// The idea is to reuse parts of the state machine in various places.
// For example, values can appear at the top level or within arbitrarily
// nested arrays. The last state on the stack is used after a value has
// been lexed. Similarly for comments.
stack []stateFn
}
type item struct {
typ itemType
val string
line int
}
func (lx *lexer) nextItem() item {
for {
select {
case item := <-lx.items:
return item
default:
lx.state = lx.state(lx)
}
}
}
func lex(input string) *lexer {
lx := &lexer{
input: input,
state: lexTop,
line: 1,
items: make(chan item, 10),
stack: make([]stateFn, 0, 10),
}
return lx
}
func (lx *lexer) push(state stateFn) {
lx.stack = append(lx.stack, state)
}
func (lx *lexer) pop() stateFn {
if len(lx.stack) == 0 {
return lx.errorf("BUG in lexer: no states to pop")
}
last := lx.stack[len(lx.stack)-1]
lx.stack = lx.stack[0 : len(lx.stack)-1]
return last
}
func (lx *lexer) current() string {
return lx.input[lx.start:lx.pos]
}
func (lx *lexer) emit(typ itemType) {
lx.items <- item{typ, lx.current(), lx.line}
lx.start = lx.pos
}
func (lx *lexer) emitTrim(typ itemType) {
lx.items <- item{typ, strings.TrimSpace(lx.current()), lx.line}
lx.start = lx.pos
}
func (lx *lexer) next() (r rune) {
if lx.atEOF {
panic("next called after EOF")
}
if lx.pos >= len(lx.input) {
lx.atEOF = true
return eof
}
if lx.input[lx.pos] == '\n' {
lx.line++
}
lx.prevWidths[2] = lx.prevWidths[1]
lx.prevWidths[1] = lx.prevWidths[0]
if lx.nprev < 3 {
lx.nprev++
}
r, w := utf8.DecodeRuneInString(lx.input[lx.pos:])
lx.prevWidths[0] = w
lx.pos += w
return r
}
// ignore skips over the pending input before this point.
func (lx *lexer) ignore() {
lx.start = lx.pos
}
// backup steps back one rune. Can be called only twice between calls to next.
func (lx *lexer) backup() {
if lx.atEOF {
lx.atEOF = false
return
}
if lx.nprev < 1 {
panic("backed up too far")
}
w := lx.prevWidths[0]
lx.prevWidths[0] = lx.prevWidths[1]
lx.prevWidths[1] = lx.prevWidths[2]
lx.nprev--
lx.pos -= w
if lx.pos < len(lx.input) && lx.input[lx.pos] == '\n' {
lx.line--
}
}
// accept consumes the next rune if it's equal to `valid`.
func (lx *lexer) accept(valid rune) bool {
if lx.next() == valid {
return true
}
lx.backup()
return false
}
// peek returns but does not consume the next rune in the input.
func (lx *lexer) peek() rune {
r := lx.next()
lx.backup()
return r
}
// skip ignores all input that matches the given predicate.
func (lx *lexer) skip(pred func(rune) bool) {
for {
r := lx.next()
if pred(r) {
continue
}
lx.backup()
lx.ignore()
return
}
}
// errorf stops all lexing by emitting an error and returning `nil`.
// Note that any value that is a character is escaped if it's a special
// character (newlines, tabs, etc.).
func (lx *lexer) errorf(format string, values ...interface{}) stateFn {
lx.items <- item{
itemError,
fmt.Sprintf(format, values...),
lx.line,
}
return nil
}
// lexTop consumes elements at the top level of TOML data.
func lexTop(lx *lexer) stateFn {
r := lx.next()
if isWhitespace(r) || isNL(r) {
return lexSkip(lx, lexTop)
}
switch r {
case commentStart:
lx.push(lexTop)
return lexCommentStart
case tableStart:
return lexTableStart
case eof:
if lx.pos > lx.start {
return lx.errorf("unexpected EOF")
}
lx.emit(itemEOF)
return nil
}
// At this point, the only valid item can be a key, so we back up
// and let the key lexer do the rest.
lx.backup()
lx.push(lexTopEnd)
return lexKeyStart
}
// lexTopEnd is entered whenever a top-level item has been consumed. (A value
// or a table.) It must see only whitespace, and will turn back to lexTop
// upon a newline. If it sees EOF, it will quit the lexer successfully.
func lexTopEnd(lx *lexer) stateFn {
r := lx.next()
switch {
case r == commentStart:
// a comment will read to a newline for us.
lx.push(lexTop)
return lexCommentStart
case isWhitespace(r):
return lexTopEnd
case isNL(r):
lx.ignore()
return lexTop
case r == eof:
lx.emit(itemEOF)
return nil
}
return lx.errorf("expected a top-level item to end with a newline, "+
"comment, or EOF, but got %q instead", r)
}
// lexTable lexes the beginning of a table. Namely, it makes sure that
// it starts with a character other than '.' and ']'.
// It assumes that '[' has already been consumed.
// It also handles the case that this is an item in an array of tables.
// e.g., '[[name]]'.
func lexTableStart(lx *lexer) stateFn {
if lx.peek() == arrayTableStart {
lx.next()
lx.emit(itemArrayTableStart)
lx.push(lexArrayTableEnd)
} else {
lx.emit(itemTableStart)
lx.push(lexTableEnd)
}
return lexTableNameStart
}
func lexTableEnd(lx *lexer) stateFn {
lx.emit(itemTableEnd)
return lexTopEnd
}
func lexArrayTableEnd(lx *lexer) stateFn {
if r := lx.next(); r != arrayTableEnd {
return lx.errorf("expected end of table array name delimiter %q, "+
"but got %q instead", arrayTableEnd, r)
}
lx.emit(itemArrayTableEnd)
return lexTopEnd
}
func lexTableNameStart(lx *lexer) stateFn {
lx.skip(isWhitespace)
switch r := lx.peek(); {
case r == tableEnd || r == eof:
return lx.errorf("unexpected end of table name " +
"(table names cannot be empty)")
case r == tableSep:
return lx.errorf("unexpected table separator " +
"(table names cannot be empty)")
case r == stringStart || r == rawStringStart:
lx.ignore()
lx.push(lexTableNameEnd)
return lexValue // reuse string lexing
default:
return lexBareTableName
}
}
// lexBareTableName lexes the name of a table. It assumes that at least one
// valid character for the table has already been read.
func lexBareTableName(lx *lexer) stateFn {
r := lx.next()
if isBareKeyChar(r) {
return lexBareTableName
}
lx.backup()
lx.emit(itemText)
return lexTableNameEnd
}
// lexTableNameEnd reads the end of a piece of a table name, optionally
// consuming whitespace.
func lexTableNameEnd(lx *lexer) stateFn {
lx.skip(isWhitespace)
switch r := lx.next(); {
case isWhitespace(r):
return lexTableNameEnd
case r == tableSep:
lx.ignore()
return lexTableNameStart
case r == tableEnd:
return lx.pop()
default:
return lx.errorf("expected '.' or ']' to end table name, "+
"but got %q instead", r)
}
}
// lexKeyStart consumes a key name up until the first non-whitespace character.
// lexKeyStart will ignore whitespace.
func lexKeyStart(lx *lexer) stateFn {
r := lx.peek()
switch {
case r == keySep:
return lx.errorf("unexpected key separator %q", keySep)
case isWhitespace(r) || isNL(r):
lx.next()
return lexSkip(lx, lexKeyStart)
case r == stringStart || r == rawStringStart:
lx.ignore()
lx.emit(itemKeyStart)
lx.push(lexKeyEnd)
return lexValue // reuse string lexing
default:
lx.ignore()
lx.emit(itemKeyStart)
return lexBareKey
}
}
// lexBareKey consumes the text of a bare key. Assumes that the first character
// (which is not whitespace) has not yet been consumed.
func lexBareKey(lx *lexer) stateFn {
switch r := lx.next(); {
case isBareKeyChar(r):
return lexBareKey
case isWhitespace(r):
lx.backup()
lx.emit(itemText)
return lexKeyEnd
case r == keySep:
lx.backup()
lx.emit(itemText)
return lexKeyEnd
default:
return lx.errorf("bare keys cannot contain %q", r)
}
}
// lexKeyEnd consumes the end of a key and trims whitespace (up to the key
// separator).
func lexKeyEnd(lx *lexer) stateFn {
switch r := lx.next(); {
case r == keySep:
return lexSkip(lx, lexValue)
case isWhitespace(r):
return lexSkip(lx, lexKeyEnd)
default:
return lx.errorf("expected key separator %q, but got %q instead",
keySep, r)
}
}
// lexValue starts the consumption of a value anywhere a value is expected.
// lexValue will ignore whitespace.
// After a value is lexed, the last state on the next is popped and returned.
func lexValue(lx *lexer) stateFn {
// We allow whitespace to precede a value, but NOT newlines.
// In array syntax, the array states are responsible for ignoring newlines.
r := lx.next()
switch {
case isWhitespace(r):
return lexSkip(lx, lexValue)
case isDigit(r):
lx.backup() // avoid an extra state and use the same as above
return lexNumberOrDateStart
}
switch r {
case arrayStart:
lx.ignore()
lx.emit(itemArray)
return lexArrayValue
case inlineTableStart:
lx.ignore()
lx.emit(itemInlineTableStart)
return lexInlineTableValue
case stringStart:
if lx.accept(stringStart) {
if lx.accept(stringStart) {
lx.ignore() // Ignore """
return lexMultilineString
}
lx.backup()
}
lx.ignore() // ignore the '"'
return lexString
case rawStringStart:
if lx.accept(rawStringStart) {
if lx.accept(rawStringStart) {
lx.ignore() // Ignore """
return lexMultilineRawString
}
lx.backup()
}
lx.ignore() // ignore the "'"
return lexRawString
case '+', '-':
return lexNumberStart
case '.': // special error case, be kind to users
return lx.errorf("floats must start with a digit, not '.'")
}
if unicode.IsLetter(r) {
// Be permissive here; lexBool will give a nice error if the
// user wrote something like
// x = foo
// (i.e. not 'true' or 'false' but is something else word-like.)
lx.backup()
return lexBool
}
return lx.errorf("expected value but found %q instead", r)
}
// lexArrayValue consumes one value in an array. It assumes that '[' or ','
// have already been consumed. All whitespace and newlines are ignored.
func lexArrayValue(lx *lexer) stateFn {
r := lx.next()
switch {
case isWhitespace(r) || isNL(r):
return lexSkip(lx, lexArrayValue)
case r == commentStart:
lx.push(lexArrayValue)
return lexCommentStart
case r == comma:
return lx.errorf("unexpected comma")
case r == arrayEnd:
// NOTE(caleb): The spec isn't clear about whether you can have
// a trailing comma or not, so we'll allow it.
return lexArrayEnd
}
lx.backup()
lx.push(lexArrayValueEnd)
return lexValue
}
// lexArrayValueEnd consumes everything between the end of an array value and
// the next value (or the end of the array): it ignores whitespace and newlines
// and expects either a ',' or a ']'.
func lexArrayValueEnd(lx *lexer) stateFn {
r := lx.next()
switch {
case isWhitespace(r) || isNL(r):
return lexSkip(lx, lexArrayValueEnd)
case r == commentStart:
lx.push(lexArrayValueEnd)
return lexCommentStart
case r == comma:
lx.ignore()
return lexArrayValue // move on to the next value
case r == arrayEnd:
return lexArrayEnd
}
return lx.errorf(
"expected a comma or array terminator %q, but got %q instead",
arrayEnd, r,
)
}
// lexArrayEnd finishes the lexing of an array.
// It assumes that a ']' has just been consumed.
func lexArrayEnd(lx *lexer) stateFn {
lx.ignore()
lx.emit(itemArrayEnd)
return lx.pop()
}
// lexInlineTableValue consumes one key/value pair in an inline table.
// It assumes that '{' or ',' have already been consumed. Whitespace is ignored.
func lexInlineTableValue(lx *lexer) stateFn {
r := lx.next()
switch {
case isWhitespace(r):
return lexSkip(lx, lexInlineTableValue)
case isNL(r):
return lx.errorf("newlines not allowed within inline tables")
case r == commentStart:
lx.push(lexInlineTableValue)
return lexCommentStart
case r == comma:
return lx.errorf("unexpected comma")
case r == inlineTableEnd:
return lexInlineTableEnd
}
lx.backup()
lx.push(lexInlineTableValueEnd)
return lexKeyStart
}
// lexInlineTableValueEnd consumes everything between the end of an inline table
// key/value pair and the next pair (or the end of the table):
// it ignores whitespace and expects either a ',' or a '}'.
func lexInlineTableValueEnd(lx *lexer) stateFn {
r := lx.next()
switch {
case isWhitespace(r):
return lexSkip(lx, lexInlineTableValueEnd)
case isNL(r):
return lx.errorf("newlines not allowed within inline tables")
case r == commentStart:
lx.push(lexInlineTableValueEnd)
return lexCommentStart
case r == comma:
lx.ignore()
return lexInlineTableValue
case r == inlineTableEnd:
return lexInlineTableEnd
}
return lx.errorf("expected a comma or an inline table terminator %q, "+
"but got %q instead", inlineTableEnd, r)
}
// lexInlineTableEnd finishes the lexing of an inline table.
// It assumes that a '}' has just been consumed.
func lexInlineTableEnd(lx *lexer) stateFn {
lx.ignore()
lx.emit(itemInlineTableEnd)
return lx.pop()
}
// lexString consumes the inner contents of a string. It assumes that the
// beginning '"' has already been consumed and ignored.
func lexString(lx *lexer) stateFn {
r := lx.next()
switch {
case r == eof:
return lx.errorf("unexpected EOF")
case isNL(r):
return lx.errorf("strings cannot contain newlines")
case r == '\\':
lx.push(lexString)
return lexStringEscape
case r == stringEnd:
lx.backup()
lx.emit(itemString)
lx.next()
lx.ignore()
return lx.pop()
}
return lexString
}
// lexMultilineString consumes the inner contents of a string. It assumes that
// the beginning '"""' has already been consumed and ignored.
func lexMultilineString(lx *lexer) stateFn {
switch lx.next() {
case eof:
return lx.errorf("unexpected EOF")
case '\\':
return lexMultilineStringEscape
case stringEnd:
if lx.accept(stringEnd) {
if lx.accept(stringEnd) {
lx.backup()
lx.backup()
lx.backup()
lx.emit(itemMultilineString)
lx.next()
lx.next()
lx.next()
lx.ignore()
return lx.pop()
}
lx.backup()
}
}
return lexMultilineString
}
// lexRawString consumes a raw string. Nothing can be escaped in such a string.
// It assumes that the beginning "'" has already been consumed and ignored.
func lexRawString(lx *lexer) stateFn {
r := lx.next()
switch {
case r == eof:
return lx.errorf("unexpected EOF")
case isNL(r):
return lx.errorf("strings cannot contain newlines")
case r == rawStringEnd:
lx.backup()
lx.emit(itemRawString)
lx.next()
lx.ignore()
return lx.pop()
}
return lexRawString
}
// lexMultilineRawString consumes a raw string. Nothing can be escaped in such
// a string. It assumes that the beginning "'''" has already been consumed and
// ignored.
func lexMultilineRawString(lx *lexer) stateFn {
switch lx.next() {
case eof:
return lx.errorf("unexpected EOF")
case rawStringEnd:
if lx.accept(rawStringEnd) {
if lx.accept(rawStringEnd) {
lx.backup()
lx.backup()
lx.backup()
lx.emit(itemRawMultilineString)
lx.next()
lx.next()
lx.next()
lx.ignore()
return lx.pop()
}
lx.backup()
}
}
return lexMultilineRawString
}
// lexMultilineStringEscape consumes an escaped character. It assumes that the
// preceding '\\' has already been consumed.
func lexMultilineStringEscape(lx *lexer) stateFn {
// Handle the special case first:
if isNL(lx.next()) {
return lexMultilineString
}
lx.backup()
lx.push(lexMultilineString)
return lexStringEscape(lx)
}
func lexStringEscape(lx *lexer) stateFn {
r := lx.next()
switch r {
case 'b':
fallthrough
case 't':
fallthrough
case 'n':
fallthrough
case 'f':
fallthrough
case 'r':
fallthrough
case '"':
fallthrough
case '\\':
return lx.pop()
case 'u':
return lexShortUnicodeEscape
case 'U':
return lexLongUnicodeEscape
}
return lx.errorf("invalid escape character %q; only the following "+
"escape characters are allowed: "+
`\b, \t, \n, \f, \r, \", \\, \uXXXX, and \UXXXXXXXX`, r)
}
func lexShortUnicodeEscape(lx *lexer) stateFn {
var r rune
for i := 0; i < 4; i++ {
r = lx.next()
if !isHexadecimal(r) {
return lx.errorf(`expected four hexadecimal digits after '\u', `+
"but got %q instead", lx.current())
}
}
return lx.pop()
}
func lexLongUnicodeEscape(lx *lexer) stateFn {
var r rune
for i := 0; i < 8; i++ {
r = lx.next()
if !isHexadecimal(r) {
return lx.errorf(`expected eight hexadecimal digits after '\U', `+
"but got %q instead", lx.current())
}
}
return lx.pop()
}
// lexNumberOrDateStart consumes either an integer, a float, or datetime.
func lexNumberOrDateStart(lx *lexer) stateFn {
r := lx.next()
if isDigit(r) {
return lexNumberOrDate
}
switch r {
case '_':
return lexNumber
case 'e', 'E':
return lexFloat
case '.':
return lx.errorf("floats must start with a digit, not '.'")
}
return lx.errorf("expected a digit but got %q", r)
}
// lexNumberOrDate consumes either an integer, float or datetime.
func lexNumberOrDate(lx *lexer) stateFn {
r := lx.next()
if isDigit(r) {
return lexNumberOrDate
}
switch r {
case '-':
return lexDatetime
case '_':
return lexNumber
case '.', 'e', 'E':
return lexFloat
}
lx.backup()
lx.emit(itemInteger)
return lx.pop()
}
// lexDatetime consumes a Datetime, to a first approximation.
// The parser validates that it matches one of the accepted formats.
func lexDatetime(lx *lexer) stateFn {
r := lx.next()
if isDigit(r) {
return lexDatetime
}
switch r {
case '-', 'T', ':', '.', 'Z', '+':
return lexDatetime
}
lx.backup()
lx.emit(itemDatetime)
return lx.pop()
}
// lexNumberStart consumes either an integer or a float. It assumes that a sign
// has already been read, but that *no* digits have been consumed.
// lexNumberStart will move to the appropriate integer or float states.
func lexNumberStart(lx *lexer) stateFn {
// We MUST see a digit. Even floats have to start with a digit.
r := lx.next()
if !isDigit(r) {
if r == '.' {
return lx.errorf("floats must start with a digit, not '.'")
}
return lx.errorf("expected a digit but got %q", r)
}
return lexNumber
}
// lexNumber consumes an integer or a float after seeing the first digit.
func lexNumber(lx *lexer) stateFn {
r := lx.next()
if isDigit(r) {
return lexNumber
}
switch r {
case '_':
return lexNumber
case '.', 'e', 'E':
return lexFloat
}
lx.backup()
lx.emit(itemInteger)
return lx.pop()
}
// lexFloat consumes the elements of a float. It allows any sequence of
// float-like characters, so floats emitted by the lexer are only a first
// approximation and must be validated by the parser.
func lexFloat(lx *lexer) stateFn {
r := lx.next()
if isDigit(r) {
return lexFloat
}
switch r {
case '_', '.', '-', '+', 'e', 'E':
return lexFloat
}
lx.backup()
lx.emit(itemFloat)
return lx.pop()
}
// lexBool consumes a bool string: 'true' or 'false.
func lexBool(lx *lexer) stateFn {
var rs []rune
for {
r := lx.next()
if !unicode.IsLetter(r) {
lx.backup()
break
}
rs = append(rs, r)
}
s := string(rs)
switch s {
case "true", "false":
lx.emit(itemBool)
return lx.pop()
}
return lx.errorf("expected value but found %q instead", s)
}
// lexCommentStart begins the lexing of a comment. It will emit
// itemCommentStart and consume no characters, passing control to lexComment.
func lexCommentStart(lx *lexer) stateFn {
lx.ignore()
lx.emit(itemCommentStart)
return lexComment
}
// lexComment lexes an entire comment. It assumes that '#' has been consumed.
// It will consume *up to* the first newline character, and pass control
// back to the last state on the stack.
func lexComment(lx *lexer) stateFn {
r := lx.peek()
if isNL(r) || r == eof {
lx.emit(itemText)
return lx.pop()
}
lx.next()
return lexComment
}
// lexSkip ignores all slurped input and moves on to the next state.
func lexSkip(lx *lexer, nextState stateFn) stateFn {
return func(lx *lexer) stateFn {
lx.ignore()
return nextState
}
}
// isWhitespace returns true if `r` is a whitespace character according
// to the spec.
func isWhitespace(r rune) bool {
return r == '\t' || r == ' '
}
func isNL(r rune) bool {
return r == '\n' || r == '\r'
}
func isDigit(r rune) bool {
return r >= '0' && r <= '9'
}
func isHexadecimal(r rune) bool {
return (r >= '0' && r <= '9') ||
(r >= 'a' && r <= 'f') ||
(r >= 'A' && r <= 'F')
}
func isBareKeyChar(r rune) bool {
return (r >= 'A' && r <= 'Z') ||
(r >= 'a' && r <= 'z') ||
(r >= '0' && r <= '9') ||
r == '_' ||
r == '-'
}
func (itype itemType) String() string {
switch itype {
case itemError:
return "Error"
case itemNIL:
return "NIL"
case itemEOF:
return "EOF"
case itemText:
return "Text"
case itemString, itemRawString, itemMultilineString, itemRawMultilineString:
return "String"
case itemBool:
return "Bool"
case itemInteger:
return "Integer"
case itemFloat:
return "Float"
case itemDatetime:
return "DateTime"
case itemTableStart:
return "TableStart"
case itemTableEnd:
return "TableEnd"
case itemKeyStart:
return "KeyStart"
case itemArray:
return "Array"
case itemArrayEnd:
return "ArrayEnd"
case itemCommentStart:
return "CommentStart"
}
panic(fmt.Sprintf("BUG: Unknown type '%d'.", int(itype)))
}
func (item item) String() string {
return fmt.Sprintf("(%s, %s)", item.typ.String(), item.val)
}

View File

@ -1,592 +0,0 @@
package toml
import (
"fmt"
"strconv"
"strings"
"time"
"unicode"
"unicode/utf8"
)
type parser struct {
mapping map[string]interface{}
types map[string]tomlType
lx *lexer
// A list of keys in the order that they appear in the TOML data.
ordered []Key
// the full key for the current hash in scope
context Key
// the base key name for everything except hashes
currentKey string
// rough approximation of line number
approxLine int
// A map of 'key.group.names' to whether they were created implicitly.
implicits map[string]bool
}
type parseError string
func (pe parseError) Error() string {
return string(pe)
}
func parse(data string) (p *parser, err error) {
defer func() {
if r := recover(); r != nil {
var ok bool
if err, ok = r.(parseError); ok {
return
}
panic(r)
}
}()
p = &parser{
mapping: make(map[string]interface{}),
types: make(map[string]tomlType),
lx: lex(data),
ordered: make([]Key, 0),
implicits: make(map[string]bool),
}
for {
item := p.next()
if item.typ == itemEOF {
break
}
p.topLevel(item)
}
return p, nil
}
func (p *parser) panicf(format string, v ...interface{}) {
msg := fmt.Sprintf("Near line %d (last key parsed '%s'): %s",
p.approxLine, p.current(), fmt.Sprintf(format, v...))
panic(parseError(msg))
}
func (p *parser) next() item {
it := p.lx.nextItem()
if it.typ == itemError {
p.panicf("%s", it.val)
}
return it
}
func (p *parser) bug(format string, v ...interface{}) {
panic(fmt.Sprintf("BUG: "+format+"\n\n", v...))
}
func (p *parser) expect(typ itemType) item {
it := p.next()
p.assertEqual(typ, it.typ)
return it
}
func (p *parser) assertEqual(expected, got itemType) {
if expected != got {
p.bug("Expected '%s' but got '%s'.", expected, got)
}
}
func (p *parser) topLevel(item item) {
switch item.typ {
case itemCommentStart:
p.approxLine = item.line
p.expect(itemText)
case itemTableStart:
kg := p.next()
p.approxLine = kg.line
var key Key
for ; kg.typ != itemTableEnd && kg.typ != itemEOF; kg = p.next() {
key = append(key, p.keyString(kg))
}
p.assertEqual(itemTableEnd, kg.typ)
p.establishContext(key, false)
p.setType("", tomlHash)
p.ordered = append(p.ordered, key)
case itemArrayTableStart:
kg := p.next()
p.approxLine = kg.line
var key Key
for ; kg.typ != itemArrayTableEnd && kg.typ != itemEOF; kg = p.next() {
key = append(key, p.keyString(kg))
}
p.assertEqual(itemArrayTableEnd, kg.typ)
p.establishContext(key, true)
p.setType("", tomlArrayHash)
p.ordered = append(p.ordered, key)
case itemKeyStart:
kname := p.next()
p.approxLine = kname.line
p.currentKey = p.keyString(kname)
val, typ := p.value(p.next())
p.setValue(p.currentKey, val)
p.setType(p.currentKey, typ)
p.ordered = append(p.ordered, p.context.add(p.currentKey))
p.currentKey = ""
default:
p.bug("Unexpected type at top level: %s", item.typ)
}
}
// Gets a string for a key (or part of a key in a table name).
func (p *parser) keyString(it item) string {
switch it.typ {
case itemText:
return it.val
case itemString, itemMultilineString,
itemRawString, itemRawMultilineString:
s, _ := p.value(it)
return s.(string)
default:
p.bug("Unexpected key type: %s", it.typ)
panic("unreachable")
}
}
// value translates an expected value from the lexer into a Go value wrapped
// as an empty interface.
func (p *parser) value(it item) (interface{}, tomlType) {
switch it.typ {
case itemString:
return p.replaceEscapes(it.val), p.typeOfPrimitive(it)
case itemMultilineString:
trimmed := stripFirstNewline(stripEscapedWhitespace(it.val))
return p.replaceEscapes(trimmed), p.typeOfPrimitive(it)
case itemRawString:
return it.val, p.typeOfPrimitive(it)
case itemRawMultilineString:
return stripFirstNewline(it.val), p.typeOfPrimitive(it)
case itemBool:
switch it.val {
case "true":
return true, p.typeOfPrimitive(it)
case "false":
return false, p.typeOfPrimitive(it)
}
p.bug("Expected boolean value, but got '%s'.", it.val)
case itemInteger:
if !numUnderscoresOK(it.val) {
p.panicf("Invalid integer %q: underscores must be surrounded by digits",
it.val)
}
val := strings.Replace(it.val, "_", "", -1)
num, err := strconv.ParseInt(val, 10, 64)
if err != nil {
// Distinguish integer values. Normally, it'd be a bug if the lexer
// provides an invalid integer, but it's possible that the number is
// out of range of valid values (which the lexer cannot determine).
// So mark the former as a bug but the latter as a legitimate user
// error.
if e, ok := err.(*strconv.NumError); ok &&
e.Err == strconv.ErrRange {
p.panicf("Integer '%s' is out of the range of 64-bit "+
"signed integers.", it.val)
} else {
p.bug("Expected integer value, but got '%s'.", it.val)
}
}
return num, p.typeOfPrimitive(it)
case itemFloat:
parts := strings.FieldsFunc(it.val, func(r rune) bool {
switch r {
case '.', 'e', 'E':
return true
}
return false
})
for _, part := range parts {
if !numUnderscoresOK(part) {
p.panicf("Invalid float %q: underscores must be "+
"surrounded by digits", it.val)
}
}
if !numPeriodsOK(it.val) {
// As a special case, numbers like '123.' or '1.e2',
// which are valid as far as Go/strconv are concerned,
// must be rejected because TOML says that a fractional
// part consists of '.' followed by 1+ digits.
p.panicf("Invalid float %q: '.' must be followed "+
"by one or more digits", it.val)
}
val := strings.Replace(it.val, "_", "", -1)
num, err := strconv.ParseFloat(val, 64)
if err != nil {
if e, ok := err.(*strconv.NumError); ok &&
e.Err == strconv.ErrRange {
p.panicf("Float '%s' is out of the range of 64-bit "+
"IEEE-754 floating-point numbers.", it.val)
} else {
p.panicf("Invalid float value: %q", it.val)
}
}
return num, p.typeOfPrimitive(it)
case itemDatetime:
var t time.Time
var ok bool
var err error
for _, format := range []string{
"2006-01-02T15:04:05Z07:00",
"2006-01-02T15:04:05",
"2006-01-02",
} {
t, err = time.ParseInLocation(format, it.val, time.Local)
if err == nil {
ok = true
break
}
}
if !ok {
p.panicf("Invalid TOML Datetime: %q.", it.val)
}
return t, p.typeOfPrimitive(it)
case itemArray:
array := make([]interface{}, 0)
types := make([]tomlType, 0)
for it = p.next(); it.typ != itemArrayEnd; it = p.next() {
if it.typ == itemCommentStart {
p.expect(itemText)
continue
}
val, typ := p.value(it)
array = append(array, val)
types = append(types, typ)
}
return array, p.typeOfArray(types)
case itemInlineTableStart:
var (
hash = make(map[string]interface{})
outerContext = p.context
outerKey = p.currentKey
)
p.context = append(p.context, p.currentKey)
p.currentKey = ""
for it := p.next(); it.typ != itemInlineTableEnd; it = p.next() {
if it.typ != itemKeyStart {
p.bug("Expected key start but instead found %q, around line %d",
it.val, p.approxLine)
}
if it.typ == itemCommentStart {
p.expect(itemText)
continue
}
// retrieve key
k := p.next()
p.approxLine = k.line
kname := p.keyString(k)
// retrieve value
p.currentKey = kname
val, typ := p.value(p.next())
// make sure we keep metadata up to date
p.setType(kname, typ)
p.ordered = append(p.ordered, p.context.add(p.currentKey))
hash[kname] = val
}
p.context = outerContext
p.currentKey = outerKey
return hash, tomlHash
}
p.bug("Unexpected value type: %s", it.typ)
panic("unreachable")
}
// numUnderscoresOK checks whether each underscore in s is surrounded by
// characters that are not underscores.
func numUnderscoresOK(s string) bool {
accept := false
for _, r := range s {
if r == '_' {
if !accept {
return false
}
accept = false
continue
}
accept = true
}
return accept
}
// numPeriodsOK checks whether every period in s is followed by a digit.
func numPeriodsOK(s string) bool {
period := false
for _, r := range s {
if period && !isDigit(r) {
return false
}
period = r == '.'
}
return !period
}
// establishContext sets the current context of the parser,
// where the context is either a hash or an array of hashes. Which one is
// set depends on the value of the `array` parameter.
//
// Establishing the context also makes sure that the key isn't a duplicate, and
// will create implicit hashes automatically.
func (p *parser) establishContext(key Key, array bool) {
var ok bool
// Always start at the top level and drill down for our context.
hashContext := p.mapping
keyContext := make(Key, 0)
// We only need implicit hashes for key[0:-1]
for _, k := range key[0 : len(key)-1] {
_, ok = hashContext[k]
keyContext = append(keyContext, k)
// No key? Make an implicit hash and move on.
if !ok {
p.addImplicit(keyContext)
hashContext[k] = make(map[string]interface{})
}
// If the hash context is actually an array of tables, then set
// the hash context to the last element in that array.
//
// Otherwise, it better be a table, since this MUST be a key group (by
// virtue of it not being the last element in a key).
switch t := hashContext[k].(type) {
case []map[string]interface{}:
hashContext = t[len(t)-1]
case map[string]interface{}:
hashContext = t
default:
p.panicf("Key '%s' was already created as a hash.", keyContext)
}
}
p.context = keyContext
if array {
// If this is the first element for this array, then allocate a new
// list of tables for it.
k := key[len(key)-1]
if _, ok := hashContext[k]; !ok {
hashContext[k] = make([]map[string]interface{}, 0, 5)
}
// Add a new table. But make sure the key hasn't already been used
// for something else.
if hash, ok := hashContext[k].([]map[string]interface{}); ok {
hashContext[k] = append(hash, make(map[string]interface{}))
} else {
p.panicf("Key '%s' was already created and cannot be used as "+
"an array.", keyContext)
}
} else {
p.setValue(key[len(key)-1], make(map[string]interface{}))
}
p.context = append(p.context, key[len(key)-1])
}
// setValue sets the given key to the given value in the current context.
// It will make sure that the key hasn't already been defined, account for
// implicit key groups.
func (p *parser) setValue(key string, value interface{}) {
var tmpHash interface{}
var ok bool
hash := p.mapping
keyContext := make(Key, 0)
for _, k := range p.context {
keyContext = append(keyContext, k)
if tmpHash, ok = hash[k]; !ok {
p.bug("Context for key '%s' has not been established.", keyContext)
}
switch t := tmpHash.(type) {
case []map[string]interface{}:
// The context is a table of hashes. Pick the most recent table
// defined as the current hash.
hash = t[len(t)-1]
case map[string]interface{}:
hash = t
default:
p.bug("Expected hash to have type 'map[string]interface{}', but "+
"it has '%T' instead.", tmpHash)
}
}
keyContext = append(keyContext, key)
if _, ok := hash[key]; ok {
// Typically, if the given key has already been set, then we have
// to raise an error since duplicate keys are disallowed. However,
// it's possible that a key was previously defined implicitly. In this
// case, it is allowed to be redefined concretely. (See the
// `tests/valid/implicit-and-explicit-after.toml` test in `toml-test`.)
//
// But we have to make sure to stop marking it as an implicit. (So that
// another redefinition provokes an error.)
//
// Note that since it has already been defined (as a hash), we don't
// want to overwrite it. So our business is done.
if p.isImplicit(keyContext) {
p.removeImplicit(keyContext)
return
}
// Otherwise, we have a concrete key trying to override a previous
// key, which is *always* wrong.
p.panicf("Key '%s' has already been defined.", keyContext)
}
hash[key] = value
}
// setType sets the type of a particular value at a given key.
// It should be called immediately AFTER setValue.
//
// Note that if `key` is empty, then the type given will be applied to the
// current context (which is either a table or an array of tables).
func (p *parser) setType(key string, typ tomlType) {
keyContext := make(Key, 0, len(p.context)+1)
for _, k := range p.context {
keyContext = append(keyContext, k)
}
if len(key) > 0 { // allow type setting for hashes
keyContext = append(keyContext, key)
}
p.types[keyContext.String()] = typ
}
// addImplicit sets the given Key as having been created implicitly.
func (p *parser) addImplicit(key Key) {
p.implicits[key.String()] = true
}
// removeImplicit stops tagging the given key as having been implicitly
// created.
func (p *parser) removeImplicit(key Key) {
p.implicits[key.String()] = false
}
// isImplicit returns true if the key group pointed to by the key was created
// implicitly.
func (p *parser) isImplicit(key Key) bool {
return p.implicits[key.String()]
}
// current returns the full key name of the current context.
func (p *parser) current() string {
if len(p.currentKey) == 0 {
return p.context.String()
}
if len(p.context) == 0 {
return p.currentKey
}
return fmt.Sprintf("%s.%s", p.context, p.currentKey)
}
func stripFirstNewline(s string) string {
if len(s) == 0 || s[0] != '\n' {
return s
}
return s[1:]
}
func stripEscapedWhitespace(s string) string {
esc := strings.Split(s, "\\\n")
if len(esc) > 1 {
for i := 1; i < len(esc); i++ {
esc[i] = strings.TrimLeftFunc(esc[i], unicode.IsSpace)
}
}
return strings.Join(esc, "")
}
func (p *parser) replaceEscapes(str string) string {
var replaced []rune
s := []byte(str)
r := 0
for r < len(s) {
if s[r] != '\\' {
c, size := utf8.DecodeRune(s[r:])
r += size
replaced = append(replaced, c)
continue
}
r += 1
if r >= len(s) {
p.bug("Escape sequence at end of string.")
return ""
}
switch s[r] {
default:
p.bug("Expected valid escape code after \\, but got %q.", s[r])
return ""
case 'b':
replaced = append(replaced, rune(0x0008))
r += 1
case 't':
replaced = append(replaced, rune(0x0009))
r += 1
case 'n':
replaced = append(replaced, rune(0x000A))
r += 1
case 'f':
replaced = append(replaced, rune(0x000C))
r += 1
case 'r':
replaced = append(replaced, rune(0x000D))
r += 1
case '"':
replaced = append(replaced, rune(0x0022))
r += 1
case '\\':
replaced = append(replaced, rune(0x005C))
r += 1
case 'u':
// At this point, we know we have a Unicode escape of the form
// `uXXXX` at [r, r+5). (Because the lexer guarantees this
// for us.)
escaped := p.asciiEscapeToUnicode(s[r+1 : r+5])
replaced = append(replaced, escaped)
r += 5
case 'U':
// At this point, we know we have a Unicode escape of the form
// `uXXXX` at [r, r+9). (Because the lexer guarantees this
// for us.)
escaped := p.asciiEscapeToUnicode(s[r+1 : r+9])
replaced = append(replaced, escaped)
r += 9
}
}
return string(replaced)
}
func (p *parser) asciiEscapeToUnicode(bs []byte) rune {
s := string(bs)
hex, err := strconv.ParseUint(strings.ToLower(s), 16, 32)
if err != nil {
p.bug("Could not parse '%s' as a hexadecimal number, but the "+
"lexer claims it's OK: %s", s, err)
}
if !utf8.ValidRune(rune(hex)) {
p.panicf("Escaped character '\\u%s' is not valid UTF-8.", s)
}
return rune(hex)
}
func isStringType(ty itemType) bool {
return ty == itemString || ty == itemMultilineString ||
ty == itemRawString || ty == itemRawMultilineString
}

View File

@ -1 +0,0 @@
au BufWritePost *.go silent!make tags > /dev/null 2>&1

View File

@ -1,91 +0,0 @@
package toml
// tomlType represents any Go type that corresponds to a TOML type.
// While the first draft of the TOML spec has a simplistic type system that
// probably doesn't need this level of sophistication, we seem to be militating
// toward adding real composite types.
type tomlType interface {
typeString() string
}
// typeEqual accepts any two types and returns true if they are equal.
func typeEqual(t1, t2 tomlType) bool {
if t1 == nil || t2 == nil {
return false
}
return t1.typeString() == t2.typeString()
}
func typeIsHash(t tomlType) bool {
return typeEqual(t, tomlHash) || typeEqual(t, tomlArrayHash)
}
type tomlBaseType string
func (btype tomlBaseType) typeString() string {
return string(btype)
}
func (btype tomlBaseType) String() string {
return btype.typeString()
}
var (
tomlInteger tomlBaseType = "Integer"
tomlFloat tomlBaseType = "Float"
tomlDatetime tomlBaseType = "Datetime"
tomlString tomlBaseType = "String"
tomlBool tomlBaseType = "Bool"
tomlArray tomlBaseType = "Array"
tomlHash tomlBaseType = "Hash"
tomlArrayHash tomlBaseType = "ArrayHash"
)
// typeOfPrimitive returns a tomlType of any primitive value in TOML.
// Primitive values are: Integer, Float, Datetime, String and Bool.
//
// Passing a lexer item other than the following will cause a BUG message
// to occur: itemString, itemBool, itemInteger, itemFloat, itemDatetime.
func (p *parser) typeOfPrimitive(lexItem item) tomlType {
switch lexItem.typ {
case itemInteger:
return tomlInteger
case itemFloat:
return tomlFloat
case itemDatetime:
return tomlDatetime
case itemString:
return tomlString
case itemMultilineString:
return tomlString
case itemRawString:
return tomlString
case itemRawMultilineString:
return tomlString
case itemBool:
return tomlBool
}
p.bug("Cannot infer primitive type of lex item '%s'.", lexItem)
panic("unreachable")
}
// typeOfArray returns a tomlType for an array given a list of types of its
// values.
//
// In the current spec, if an array is homogeneous, then its type is always
// "Array". If the array is not homogeneous, an error is generated.
func (p *parser) typeOfArray(types []tomlType) tomlType {
// Empty arrays are cool.
if len(types) == 0 {
return tomlArray
}
theType := types[0]
for _, t := range types[1:] {
if !typeEqual(theType, t) {
p.panicf("Array contains values of type '%s' and '%s', but "+
"arrays must be homogeneous.", theType, t)
}
}
return tomlArray
}

View File

@ -1,242 +0,0 @@
package toml
// Struct field handling is adapted from code in encoding/json:
//
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the Go distribution.
import (
"reflect"
"sort"
"sync"
)
// A field represents a single field found in a struct.
type field struct {
name string // the name of the field (`toml` tag included)
tag bool // whether field has a `toml` tag
index []int // represents the depth of an anonymous field
typ reflect.Type // the type of the field
}
// byName sorts field by name, breaking ties with depth,
// then breaking ties with "name came from toml tag", then
// breaking ties with index sequence.
type byName []field
func (x byName) Len() int { return len(x) }
func (x byName) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
func (x byName) Less(i, j int) bool {
if x[i].name != x[j].name {
return x[i].name < x[j].name
}
if len(x[i].index) != len(x[j].index) {
return len(x[i].index) < len(x[j].index)
}
if x[i].tag != x[j].tag {
return x[i].tag
}
return byIndex(x).Less(i, j)
}
// byIndex sorts field by index sequence.
type byIndex []field
func (x byIndex) Len() int { return len(x) }
func (x byIndex) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
func (x byIndex) Less(i, j int) bool {
for k, xik := range x[i].index {
if k >= len(x[j].index) {
return false
}
if xik != x[j].index[k] {
return xik < x[j].index[k]
}
}
return len(x[i].index) < len(x[j].index)
}
// typeFields returns a list of fields that TOML should recognize for the given
// type. The algorithm is breadth-first search over the set of structs to
// include - the top struct and then any reachable anonymous structs.
func typeFields(t reflect.Type) []field {
// Anonymous fields to explore at the current level and the next.
current := []field{}
next := []field{{typ: t}}
// Count of queued names for current level and the next.
count := map[reflect.Type]int{}
nextCount := map[reflect.Type]int{}
// Types already visited at an earlier level.
visited := map[reflect.Type]bool{}
// Fields found.
var fields []field
for len(next) > 0 {
current, next = next, current[:0]
count, nextCount = nextCount, map[reflect.Type]int{}
for _, f := range current {
if visited[f.typ] {
continue
}
visited[f.typ] = true
// Scan f.typ for fields to include.
for i := 0; i < f.typ.NumField(); i++ {
sf := f.typ.Field(i)
if sf.PkgPath != "" && !sf.Anonymous { // unexported
continue
}
opts := getOptions(sf.Tag)
if opts.skip {
continue
}
index := make([]int, len(f.index)+1)
copy(index, f.index)
index[len(f.index)] = i
ft := sf.Type
if ft.Name() == "" && ft.Kind() == reflect.Ptr {
// Follow pointer.
ft = ft.Elem()
}
// Record found field and index sequence.
if opts.name != "" || !sf.Anonymous || ft.Kind() != reflect.Struct {
tagged := opts.name != ""
name := opts.name
if name == "" {
name = sf.Name
}
fields = append(fields, field{name, tagged, index, ft})
if count[f.typ] > 1 {
// If there were multiple instances, add a second,
// so that the annihilation code will see a duplicate.
// It only cares about the distinction between 1 or 2,
// so don't bother generating any more copies.
fields = append(fields, fields[len(fields)-1])
}
continue
}
// Record new anonymous struct to explore in next round.
nextCount[ft]++
if nextCount[ft] == 1 {
f := field{name: ft.Name(), index: index, typ: ft}
next = append(next, f)
}
}
}
}
sort.Sort(byName(fields))
// Delete all fields that are hidden by the Go rules for embedded fields,
// except that fields with TOML tags are promoted.
// The fields are sorted in primary order of name, secondary order
// of field index length. Loop over names; for each name, delete
// hidden fields by choosing the one dominant field that survives.
out := fields[:0]
for advance, i := 0, 0; i < len(fields); i += advance {
// One iteration per name.
// Find the sequence of fields with the name of this first field.
fi := fields[i]
name := fi.name
for advance = 1; i+advance < len(fields); advance++ {
fj := fields[i+advance]
if fj.name != name {
break
}
}
if advance == 1 { // Only one field with this name
out = append(out, fi)
continue
}
dominant, ok := dominantField(fields[i : i+advance])
if ok {
out = append(out, dominant)
}
}
fields = out
sort.Sort(byIndex(fields))
return fields
}
// dominantField looks through the fields, all of which are known to
// have the same name, to find the single field that dominates the
// others using Go's embedding rules, modified by the presence of
// TOML tags. If there are multiple top-level fields, the boolean
// will be false: This condition is an error in Go and we skip all
// the fields.
func dominantField(fields []field) (field, bool) {
// The fields are sorted in increasing index-length order. The winner
// must therefore be one with the shortest index length. Drop all
// longer entries, which is easy: just truncate the slice.
length := len(fields[0].index)
tagged := -1 // Index of first tagged field.
for i, f := range fields {
if len(f.index) > length {
fields = fields[:i]
break
}
if f.tag {
if tagged >= 0 {
// Multiple tagged fields at the same level: conflict.
// Return no field.
return field{}, false
}
tagged = i
}
}
if tagged >= 0 {
return fields[tagged], true
}
// All remaining fields have the same length. If there's more than one,
// we have a conflict (two fields named "X" at the same level) and we
// return no field.
if len(fields) > 1 {
return field{}, false
}
return fields[0], true
}
var fieldCache struct {
sync.RWMutex
m map[reflect.Type][]field
}
// cachedTypeFields is like typeFields but uses a cache to avoid repeated work.
func cachedTypeFields(t reflect.Type) []field {
fieldCache.RLock()
f := fieldCache.m[t]
fieldCache.RUnlock()
if f != nil {
return f
}
// Compute fields without lock.
// Might duplicate effort but won't hold other computations back.
f = typeFields(t)
if f == nil {
f = []field{}
}
fieldCache.Lock()
if fieldCache.m == nil {
fieldCache.m = map[reflect.Type][]field{}
}
fieldCache.m[t] = f
fieldCache.Unlock()
return f
}

View File

@ -1,24 +0,0 @@
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
*.test
*.prof

View File

@ -1,7 +0,0 @@
language: go
go:
- 1.8
- 1.9
- "1.10"
- tip

View File

@ -1,13 +0,0 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library(
name = "go_default_library",
srcs = [
"ber.go",
"pkcs7.go",
"x509.go",
],
importmap = "k8s.io/kops/vendor/github.com/fullsailor/pkcs7",
importpath = "github.com/fullsailor/pkcs7",
visibility = ["//visibility:public"],
)

View File

@ -1,22 +0,0 @@
The MIT License (MIT)
Copyright (c) 2015 Andrew Smith
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,8 +0,0 @@
# pkcs7
[![GoDoc](https://godoc.org/github.com/fullsailor/pkcs7?status.svg)](https://godoc.org/github.com/fullsailor/pkcs7)
[![Build Status](https://travis-ci.org/fullsailor/pkcs7.svg?branch=master)](https://travis-ci.org/fullsailor/pkcs7)
pkcs7 implements parsing and creating signed and enveloped messages.
- Documentation on [GoDoc](http://godoc.org/github.com/fullsailor/pkcs7)

View File

@ -1,248 +0,0 @@
package pkcs7
import (
"bytes"
"errors"
)
var encodeIndent = 0
type asn1Object interface {
EncodeTo(writer *bytes.Buffer) error
}
type asn1Structured struct {
tagBytes []byte
content []asn1Object
}
func (s asn1Structured) EncodeTo(out *bytes.Buffer) error {
//fmt.Printf("%s--> tag: % X\n", strings.Repeat("| ", encodeIndent), s.tagBytes)
encodeIndent++
inner := new(bytes.Buffer)
for _, obj := range s.content {
err := obj.EncodeTo(inner)
if err != nil {
return err
}
}
encodeIndent--
out.Write(s.tagBytes)
encodeLength(out, inner.Len())
out.Write(inner.Bytes())
return nil
}
type asn1Primitive struct {
tagBytes []byte
length int
content []byte
}
func (p asn1Primitive) EncodeTo(out *bytes.Buffer) error {
_, err := out.Write(p.tagBytes)
if err != nil {
return err
}
if err = encodeLength(out, p.length); err != nil {
return err
}
//fmt.Printf("%s--> tag: % X length: %d\n", strings.Repeat("| ", encodeIndent), p.tagBytes, p.length)
//fmt.Printf("%s--> content length: %d\n", strings.Repeat("| ", encodeIndent), len(p.content))
out.Write(p.content)
return nil
}
func ber2der(ber []byte) ([]byte, error) {
if len(ber) == 0 {
return nil, errors.New("ber2der: input ber is empty")
}
//fmt.Printf("--> ber2der: Transcoding %d bytes\n", len(ber))
out := new(bytes.Buffer)
obj, _, err := readObject(ber, 0)
if err != nil {
return nil, err
}
obj.EncodeTo(out)
// if offset < len(ber) {
// return nil, fmt.Errorf("ber2der: Content longer than expected. Got %d, expected %d", offset, len(ber))
//}
return out.Bytes(), nil
}
// encodes lengths that are longer than 127 into string of bytes
func marshalLongLength(out *bytes.Buffer, i int) (err error) {
n := lengthLength(i)
for ; n > 0; n-- {
err = out.WriteByte(byte(i >> uint((n-1)*8)))
if err != nil {
return
}
}
return nil
}
// computes the byte length of an encoded length value
func lengthLength(i int) (numBytes int) {
numBytes = 1
for i > 255 {
numBytes++
i >>= 8
}
return
}
// encodes the length in DER format
// If the length fits in 7 bits, the value is encoded directly.
//
// Otherwise, the number of bytes to encode the length is first determined.
// This number is likely to be 4 or less for a 32bit length. This number is
// added to 0x80. The length is encoded in big endian encoding follow after
//
// Examples:
// length | byte 1 | bytes n
// 0 | 0x00 | -
// 120 | 0x78 | -
// 200 | 0x81 | 0xC8
// 500 | 0x82 | 0x01 0xF4
//
func encodeLength(out *bytes.Buffer, length int) (err error) {
if length >= 128 {
l := lengthLength(length)
err = out.WriteByte(0x80 | byte(l))
if err != nil {
return
}
err = marshalLongLength(out, length)
if err != nil {
return
}
} else {
err = out.WriteByte(byte(length))
if err != nil {
return
}
}
return
}
func readObject(ber []byte, offset int) (asn1Object, int, error) {
//fmt.Printf("\n====> Starting readObject at offset: %d\n\n", offset)
tagStart := offset
b := ber[offset]
offset++
tag := b & 0x1F // last 5 bits
if tag == 0x1F {
tag = 0
for ber[offset] >= 0x80 {
tag = tag*128 + ber[offset] - 0x80
offset++
}
tag = tag*128 + ber[offset] - 0x80
offset++
}
tagEnd := offset
kind := b & 0x20
/*
if kind == 0 {
fmt.Print("--> Primitive\n")
} else {
fmt.Print("--> Constructed\n")
}
*/
// read length
var length int
l := ber[offset]
offset++
indefinite := false
if l > 0x80 {
numberOfBytes := (int)(l & 0x7F)
if numberOfBytes > 4 { // int is only guaranteed to be 32bit
return nil, 0, errors.New("ber2der: BER tag length too long")
}
if numberOfBytes == 4 && (int)(ber[offset]) > 0x7F {
return nil, 0, errors.New("ber2der: BER tag length is negative")
}
if 0x0 == (int)(ber[offset]) {
return nil, 0, errors.New("ber2der: BER tag length has leading zero")
}
//fmt.Printf("--> (compute length) indicator byte: %x\n", l)
//fmt.Printf("--> (compute length) length bytes: % X\n", ber[offset:offset+numberOfBytes])
for i := 0; i < numberOfBytes; i++ {
length = length*256 + (int)(ber[offset])
offset++
}
} else if l == 0x80 {
indefinite = true
} else {
length = (int)(l)
}
//fmt.Printf("--> length : %d\n", length)
contentEnd := offset + length
if contentEnd > len(ber) {
return nil, 0, errors.New("ber2der: BER tag length is more than available data")
}
//fmt.Printf("--> content start : %d\n", offset)
//fmt.Printf("--> content end : %d\n", contentEnd)
//fmt.Printf("--> content : % X\n", ber[offset:contentEnd])
var obj asn1Object
if indefinite && kind == 0 {
return nil, 0, errors.New("ber2der: Indefinite form tag must have constructed encoding")
}
if kind == 0 {
obj = asn1Primitive{
tagBytes: ber[tagStart:tagEnd],
length: length,
content: ber[offset:contentEnd],
}
} else {
var subObjects []asn1Object
for (offset < contentEnd) || indefinite {
var subObj asn1Object
var err error
subObj, offset, err = readObject(ber, offset)
if err != nil {
return nil, 0, err
}
subObjects = append(subObjects, subObj)
if indefinite {
terminated, err := isIndefiniteTermination(ber, offset)
if err != nil {
return nil, 0, err
}
if terminated {
break
}
}
}
obj = asn1Structured{
tagBytes: ber[tagStart:tagEnd],
content: subObjects,
}
}
// Apply indefinite form length with 0x0000 terminator.
if indefinite {
contentEnd = offset + 2
}
return obj, contentEnd, nil
}
func isIndefiniteTermination(ber []byte, offset int) (bool, error) {
if len(ber) - offset < 2 {
return false, errors.New("ber2der: Invalid BER format")
}
return bytes.Index(ber[offset:], []byte{0x0, 0x0}) == 0, nil
}

View File

@ -1,960 +0,0 @@
// Package pkcs7 implements parsing and generation of some PKCS#7 structures.
package pkcs7
import (
"bytes"
"crypto"
"crypto/aes"
"crypto/cipher"
"crypto/des"
"crypto/hmac"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"errors"
"fmt"
"math/big"
"sort"
"time"
_ "crypto/sha1" // for crypto.SHA1
)
// PKCS7 Represents a PKCS7 structure
type PKCS7 struct {
Content []byte
Certificates []*x509.Certificate
CRLs []pkix.CertificateList
Signers []signerInfo
raw interface{}
}
type contentInfo struct {
ContentType asn1.ObjectIdentifier
Content asn1.RawValue `asn1:"explicit,optional,tag:0"`
}
// ErrUnsupportedContentType is returned when a PKCS7 content is not supported.
// Currently only Data (1.2.840.113549.1.7.1), Signed Data (1.2.840.113549.1.7.2),
// and Enveloped Data are supported (1.2.840.113549.1.7.3)
var ErrUnsupportedContentType = errors.New("pkcs7: cannot parse data: unimplemented content type")
type unsignedData []byte
var (
oidData = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 7, 1}
oidSignedData = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 7, 2}
oidEnvelopedData = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 7, 3}
oidSignedAndEnvelopedData = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 7, 4}
oidDigestedData = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 7, 5}
oidEncryptedData = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 7, 6}
oidAttributeContentType = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 9, 3}
oidAttributeMessageDigest = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 9, 4}
oidAttributeSigningTime = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 9, 5}
)
type signedData struct {
Version int `asn1:"default:1"`
DigestAlgorithmIdentifiers []pkix.AlgorithmIdentifier `asn1:"set"`
ContentInfo contentInfo
Certificates rawCertificates `asn1:"optional,tag:0"`
CRLs []pkix.CertificateList `asn1:"optional,tag:1"`
SignerInfos []signerInfo `asn1:"set"`
}
type rawCertificates struct {
Raw asn1.RawContent
}
type envelopedData struct {
Version int
RecipientInfos []recipientInfo `asn1:"set"`
EncryptedContentInfo encryptedContentInfo
}
type recipientInfo struct {
Version int
IssuerAndSerialNumber issuerAndSerial
KeyEncryptionAlgorithm pkix.AlgorithmIdentifier
EncryptedKey []byte
}
type encryptedContentInfo struct {
ContentType asn1.ObjectIdentifier
ContentEncryptionAlgorithm pkix.AlgorithmIdentifier
EncryptedContent asn1.RawValue `asn1:"tag:0,optional"`
}
type attribute struct {
Type asn1.ObjectIdentifier
Value asn1.RawValue `asn1:"set"`
}
type issuerAndSerial struct {
IssuerName asn1.RawValue
SerialNumber *big.Int
}
// MessageDigestMismatchError is returned when the signer data digest does not
// match the computed digest for the contained content
type MessageDigestMismatchError struct {
ExpectedDigest []byte
ActualDigest []byte
}
func (err *MessageDigestMismatchError) Error() string {
return fmt.Sprintf("pkcs7: Message digest mismatch\n\tExpected: %X\n\tActual : %X", err.ExpectedDigest, err.ActualDigest)
}
type signerInfo struct {
Version int `asn1:"default:1"`
IssuerAndSerialNumber issuerAndSerial
DigestAlgorithm pkix.AlgorithmIdentifier
AuthenticatedAttributes []attribute `asn1:"optional,tag:0"`
DigestEncryptionAlgorithm pkix.AlgorithmIdentifier
EncryptedDigest []byte
UnauthenticatedAttributes []attribute `asn1:"optional,tag:1"`
}
// Parse decodes a DER encoded PKCS7 package
func Parse(data []byte) (p7 *PKCS7, err error) {
if len(data) == 0 {
return nil, errors.New("pkcs7: input data is empty")
}
var info contentInfo
der, err := ber2der(data)
if err != nil {
return nil, err
}
rest, err := asn1.Unmarshal(der, &info)
if len(rest) > 0 {
err = asn1.SyntaxError{Msg: "trailing data"}
return
}
if err != nil {
return
}
// fmt.Printf("--> Content Type: %s", info.ContentType)
switch {
case info.ContentType.Equal(oidSignedData):
return parseSignedData(info.Content.Bytes)
case info.ContentType.Equal(oidEnvelopedData):
return parseEnvelopedData(info.Content.Bytes)
}
return nil, ErrUnsupportedContentType
}
func parseSignedData(data []byte) (*PKCS7, error) {
var sd signedData
asn1.Unmarshal(data, &sd)
certs, err := sd.Certificates.Parse()
if err != nil {
return nil, err
}
// fmt.Printf("--> Signed Data Version %d\n", sd.Version)
var compound asn1.RawValue
var content unsignedData
// The Content.Bytes maybe empty on PKI responses.
if len(sd.ContentInfo.Content.Bytes) > 0 {
if _, err := asn1.Unmarshal(sd.ContentInfo.Content.Bytes, &compound); err != nil {
return nil, err
}
}
// Compound octet string
if compound.IsCompound {
if _, err = asn1.Unmarshal(compound.Bytes, &content); err != nil {
return nil, err
}
} else {
// assuming this is tag 04
content = compound.Bytes
}
return &PKCS7{
Content: content,
Certificates: certs,
CRLs: sd.CRLs,
Signers: sd.SignerInfos,
raw: sd}, nil
}
func (raw rawCertificates) Parse() ([]*x509.Certificate, error) {
if len(raw.Raw) == 0 {
return nil, nil
}
var val asn1.RawValue
if _, err := asn1.Unmarshal(raw.Raw, &val); err != nil {
return nil, err
}
return x509.ParseCertificates(val.Bytes)
}
func parseEnvelopedData(data []byte) (*PKCS7, error) {
var ed envelopedData
if _, err := asn1.Unmarshal(data, &ed); err != nil {
return nil, err
}
return &PKCS7{
raw: ed,
}, nil
}
// Verify checks the signatures of a PKCS7 object
// WARNING: Verify does not check signing time or verify certificate chains at
// this time.
func (p7 *PKCS7) Verify() (err error) {
if len(p7.Signers) == 0 {
return errors.New("pkcs7: Message has no signers")
}
for _, signer := range p7.Signers {
if err := verifySignature(p7, signer); err != nil {
return err
}
}
return nil
}
func verifySignature(p7 *PKCS7, signer signerInfo) error {
signedData := p7.Content
hash, err := getHashForOID(signer.DigestAlgorithm.Algorithm)
if err != nil {
return err
}
if len(signer.AuthenticatedAttributes) > 0 {
// TODO(fullsailor): First check the content type match
var digest []byte
err := unmarshalAttribute(signer.AuthenticatedAttributes, oidAttributeMessageDigest, &digest)
if err != nil {
return err
}
h := hash.New()
h.Write(p7.Content)
computed := h.Sum(nil)
if !hmac.Equal(digest, computed) {
return &MessageDigestMismatchError{
ExpectedDigest: digest,
ActualDigest: computed,
}
}
// TODO(fullsailor): Optionally verify certificate chain
// TODO(fullsailor): Optionally verify signingTime against certificate NotAfter/NotBefore
signedData, err = marshalAttributes(signer.AuthenticatedAttributes)
if err != nil {
return err
}
}
cert := getCertFromCertsByIssuerAndSerial(p7.Certificates, signer.IssuerAndSerialNumber)
if cert == nil {
return errors.New("pkcs7: No certificate for signer")
}
algo := getSignatureAlgorithmFromAI(signer.DigestEncryptionAlgorithm)
if algo == x509.UnknownSignatureAlgorithm {
// I'm not sure what the spec here is, and the openssl sources were not
// helpful. But, this is what App Store receipts appear to do.
// The DigestEncryptionAlgorithm is just "rsaEncryption (PKCS #1)"
// But we're expecting a digest + encryption algorithm. So... we're going
// to determine an algorithm based on the DigestAlgorithm and this
// encryption algorithm.
if signer.DigestEncryptionAlgorithm.Algorithm.Equal(oidEncryptionAlgorithmRSA) {
algo = getRSASignatureAlgorithmForDigestAlgorithm(hash)
}
}
return cert.CheckSignature(algo, signedData, signer.EncryptedDigest)
}
func marshalAttributes(attrs []attribute) ([]byte, error) {
encodedAttributes, err := asn1.Marshal(struct {
A []attribute `asn1:"set"`
}{A: attrs})
if err != nil {
return nil, err
}
// Remove the leading sequence octets
var raw asn1.RawValue
asn1.Unmarshal(encodedAttributes, &raw)
return raw.Bytes, nil
}
var (
oidDigestAlgorithmSHA1 = asn1.ObjectIdentifier{1, 3, 14, 3, 2, 26}
oidEncryptionAlgorithmRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 1}
)
func getCertFromCertsByIssuerAndSerial(certs []*x509.Certificate, ias issuerAndSerial) *x509.Certificate {
for _, cert := range certs {
if isCertMatchForIssuerAndSerial(cert, ias) {
return cert
}
}
return nil
}
func getHashForOID(oid asn1.ObjectIdentifier) (crypto.Hash, error) {
switch {
case oid.Equal(oidDigestAlgorithmSHA1):
return crypto.SHA1, nil
}
return crypto.Hash(0), ErrUnsupportedAlgorithm
}
func getRSASignatureAlgorithmForDigestAlgorithm(hash crypto.Hash) x509.SignatureAlgorithm {
for _, details := range signatureAlgorithmDetails {
if details.pubKeyAlgo == x509.RSA && details.hash == hash {
return details.algo
}
}
return x509.UnknownSignatureAlgorithm
}
// GetOnlySigner returns an x509.Certificate for the first signer of the signed
// data payload. If there are more or less than one signer, nil is returned
func (p7 *PKCS7) GetOnlySigner() *x509.Certificate {
if len(p7.Signers) != 1 {
return nil
}
signer := p7.Signers[0]
return getCertFromCertsByIssuerAndSerial(p7.Certificates, signer.IssuerAndSerialNumber)
}
// ErrUnsupportedAlgorithm tells you when our quick dev assumptions have failed
var ErrUnsupportedAlgorithm = errors.New("pkcs7: cannot decrypt data: only RSA, DES, DES-EDE3, AES-256-CBC and AES-128-GCM supported")
// ErrNotEncryptedContent is returned when attempting to Decrypt data that is not encrypted data
var ErrNotEncryptedContent = errors.New("pkcs7: content data is a decryptable data type")
// Decrypt decrypts encrypted content info for recipient cert and private key
func (p7 *PKCS7) Decrypt(cert *x509.Certificate, pk crypto.PrivateKey) ([]byte, error) {
data, ok := p7.raw.(envelopedData)
if !ok {
return nil, ErrNotEncryptedContent
}
recipient := selectRecipientForCertificate(data.RecipientInfos, cert)
if recipient.EncryptedKey == nil {
return nil, errors.New("pkcs7: no enveloped recipient for provided certificate")
}
if priv := pk.(*rsa.PrivateKey); priv != nil {
var contentKey []byte
contentKey, err := rsa.DecryptPKCS1v15(rand.Reader, priv, recipient.EncryptedKey)
if err != nil {
return nil, err
}
return data.EncryptedContentInfo.decrypt(contentKey)
}
fmt.Printf("Unsupported Private Key: %v\n", pk)
return nil, ErrUnsupportedAlgorithm
}
var oidEncryptionAlgorithmDESCBC = asn1.ObjectIdentifier{1, 3, 14, 3, 2, 7}
var oidEncryptionAlgorithmDESEDE3CBC = asn1.ObjectIdentifier{1, 2, 840, 113549, 3, 7}
var oidEncryptionAlgorithmAES256CBC = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 1, 42}
var oidEncryptionAlgorithmAES128GCM = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 1, 6}
var oidEncryptionAlgorithmAES128CBC = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 1, 2}
func (eci encryptedContentInfo) decrypt(key []byte) ([]byte, error) {
alg := eci.ContentEncryptionAlgorithm.Algorithm
if !alg.Equal(oidEncryptionAlgorithmDESCBC) &&
!alg.Equal(oidEncryptionAlgorithmDESEDE3CBC) &&
!alg.Equal(oidEncryptionAlgorithmAES256CBC) &&
!alg.Equal(oidEncryptionAlgorithmAES128CBC) &&
!alg.Equal(oidEncryptionAlgorithmAES128GCM) {
fmt.Printf("Unsupported Content Encryption Algorithm: %s\n", alg)
return nil, ErrUnsupportedAlgorithm
}
// EncryptedContent can either be constructed of multple OCTET STRINGs
// or _be_ a tagged OCTET STRING
var cyphertext []byte
if eci.EncryptedContent.IsCompound {
// Complex case to concat all of the children OCTET STRINGs
var buf bytes.Buffer
cypherbytes := eci.EncryptedContent.Bytes
for {
var part []byte
cypherbytes, _ = asn1.Unmarshal(cypherbytes, &part)
buf.Write(part)
if cypherbytes == nil {
break
}
}
cyphertext = buf.Bytes()
} else {
// Simple case, the bytes _are_ the cyphertext
cyphertext = eci.EncryptedContent.Bytes
}
var block cipher.Block
var err error
switch {
case alg.Equal(oidEncryptionAlgorithmDESCBC):
block, err = des.NewCipher(key)
case alg.Equal(oidEncryptionAlgorithmDESEDE3CBC):
block, err = des.NewTripleDESCipher(key)
case alg.Equal(oidEncryptionAlgorithmAES256CBC):
fallthrough
case alg.Equal(oidEncryptionAlgorithmAES128GCM), alg.Equal(oidEncryptionAlgorithmAES128CBC):
block, err = aes.NewCipher(key)
}
if err != nil {
return nil, err
}
if alg.Equal(oidEncryptionAlgorithmAES128GCM) {
params := aesGCMParameters{}
paramBytes := eci.ContentEncryptionAlgorithm.Parameters.Bytes
_, err := asn1.Unmarshal(paramBytes, &params)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
if len(params.Nonce) != gcm.NonceSize() {
return nil, errors.New("pkcs7: encryption algorithm parameters are incorrect")
}
if params.ICVLen != gcm.Overhead() {
return nil, errors.New("pkcs7: encryption algorithm parameters are incorrect")
}
plaintext, err := gcm.Open(nil, params.Nonce, cyphertext, nil)
if err != nil {
return nil, err
}
return plaintext, nil
}
iv := eci.ContentEncryptionAlgorithm.Parameters.Bytes
if len(iv) != block.BlockSize() {
return nil, errors.New("pkcs7: encryption algorithm parameters are malformed")
}
mode := cipher.NewCBCDecrypter(block, iv)
plaintext := make([]byte, len(cyphertext))
mode.CryptBlocks(plaintext, cyphertext)
if plaintext, err = unpad(plaintext, mode.BlockSize()); err != nil {
return nil, err
}
return plaintext, nil
}
func selectRecipientForCertificate(recipients []recipientInfo, cert *x509.Certificate) recipientInfo {
for _, recp := range recipients {
if isCertMatchForIssuerAndSerial(cert, recp.IssuerAndSerialNumber) {
return recp
}
}
return recipientInfo{}
}
func isCertMatchForIssuerAndSerial(cert *x509.Certificate, ias issuerAndSerial) bool {
return cert.SerialNumber.Cmp(ias.SerialNumber) == 0 && bytes.Compare(cert.RawIssuer, ias.IssuerName.FullBytes) == 0
}
func pad(data []byte, blocklen int) ([]byte, error) {
if blocklen < 1 {
return nil, fmt.Errorf("invalid blocklen %d", blocklen)
}
padlen := blocklen - (len(data) % blocklen)
if padlen == 0 {
padlen = blocklen
}
pad := bytes.Repeat([]byte{byte(padlen)}, padlen)
return append(data, pad...), nil
}
func unpad(data []byte, blocklen int) ([]byte, error) {
if blocklen < 1 {
return nil, fmt.Errorf("invalid blocklen %d", blocklen)
}
if len(data)%blocklen != 0 || len(data) == 0 {
return nil, fmt.Errorf("invalid data len %d", len(data))
}
// the last byte is the length of padding
padlen := int(data[len(data)-1])
// check padding integrity, all bytes should be the same
pad := data[len(data)-padlen:]
for _, padbyte := range pad {
if padbyte != byte(padlen) {
return nil, errors.New("invalid padding")
}
}
return data[:len(data)-padlen], nil
}
func unmarshalAttribute(attrs []attribute, attributeType asn1.ObjectIdentifier, out interface{}) error {
for _, attr := range attrs {
if attr.Type.Equal(attributeType) {
_, err := asn1.Unmarshal(attr.Value.Bytes, out)
return err
}
}
return errors.New("pkcs7: attribute type not in attributes")
}
// UnmarshalSignedAttribute decodes a single attribute from the signer info
func (p7 *PKCS7) UnmarshalSignedAttribute(attributeType asn1.ObjectIdentifier, out interface{}) error {
sd, ok := p7.raw.(signedData)
if !ok {
return errors.New("pkcs7: payload is not signedData content")
}
if len(sd.SignerInfos) < 1 {
return errors.New("pkcs7: payload has no signers")
}
attributes := sd.SignerInfos[0].AuthenticatedAttributes
return unmarshalAttribute(attributes, attributeType, out)
}
// SignedData is an opaque data structure for creating signed data payloads
type SignedData struct {
sd signedData
certs []*x509.Certificate
messageDigest []byte
}
// Attribute represents a key value pair attribute. Value must be marshalable byte
// `encoding/asn1`
type Attribute struct {
Type asn1.ObjectIdentifier
Value interface{}
}
// SignerInfoConfig are optional values to include when adding a signer
type SignerInfoConfig struct {
ExtraSignedAttributes []Attribute
}
// NewSignedData initializes a SignedData with content
func NewSignedData(data []byte) (*SignedData, error) {
content, err := asn1.Marshal(data)
if err != nil {
return nil, err
}
ci := contentInfo{
ContentType: oidData,
Content: asn1.RawValue{Class: 2, Tag: 0, Bytes: content, IsCompound: true},
}
digAlg := pkix.AlgorithmIdentifier{
Algorithm: oidDigestAlgorithmSHA1,
}
h := crypto.SHA1.New()
h.Write(data)
md := h.Sum(nil)
sd := signedData{
ContentInfo: ci,
Version: 1,
DigestAlgorithmIdentifiers: []pkix.AlgorithmIdentifier{digAlg},
}
return &SignedData{sd: sd, messageDigest: md}, nil
}
type attributes struct {
types []asn1.ObjectIdentifier
values []interface{}
}
// Add adds the attribute, maintaining insertion order
func (attrs *attributes) Add(attrType asn1.ObjectIdentifier, value interface{}) {
attrs.types = append(attrs.types, attrType)
attrs.values = append(attrs.values, value)
}
type sortableAttribute struct {
SortKey []byte
Attribute attribute
}
type attributeSet []sortableAttribute
func (sa attributeSet) Len() int {
return len(sa)
}
func (sa attributeSet) Less(i, j int) bool {
return bytes.Compare(sa[i].SortKey, sa[j].SortKey) < 0
}
func (sa attributeSet) Swap(i, j int) {
sa[i], sa[j] = sa[j], sa[i]
}
func (sa attributeSet) Attributes() []attribute {
attrs := make([]attribute, len(sa))
for i, attr := range sa {
attrs[i] = attr.Attribute
}
return attrs
}
func (attrs *attributes) ForMarshaling() ([]attribute, error) {
sortables := make(attributeSet, len(attrs.types))
for i := range sortables {
attrType := attrs.types[i]
attrValue := attrs.values[i]
asn1Value, err := asn1.Marshal(attrValue)
if err != nil {
return nil, err
}
attr := attribute{
Type: attrType,
Value: asn1.RawValue{Tag: 17, IsCompound: true, Bytes: asn1Value}, // 17 == SET tag
}
encoded, err := asn1.Marshal(attr)
if err != nil {
return nil, err
}
sortables[i] = sortableAttribute{
SortKey: encoded,
Attribute: attr,
}
}
sort.Sort(sortables)
return sortables.Attributes(), nil
}
// AddSigner signs attributes about the content and adds certificate to payload
func (sd *SignedData) AddSigner(cert *x509.Certificate, pkey crypto.PrivateKey, config SignerInfoConfig) error {
attrs := &attributes{}
attrs.Add(oidAttributeContentType, sd.sd.ContentInfo.ContentType)
attrs.Add(oidAttributeMessageDigest, sd.messageDigest)
attrs.Add(oidAttributeSigningTime, time.Now())
for _, attr := range config.ExtraSignedAttributes {
attrs.Add(attr.Type, attr.Value)
}
finalAttrs, err := attrs.ForMarshaling()
if err != nil {
return err
}
signature, err := signAttributes(finalAttrs, pkey, crypto.SHA1)
if err != nil {
return err
}
ias, err := cert2issuerAndSerial(cert)
if err != nil {
return err
}
signer := signerInfo{
AuthenticatedAttributes: finalAttrs,
DigestAlgorithm: pkix.AlgorithmIdentifier{Algorithm: oidDigestAlgorithmSHA1},
DigestEncryptionAlgorithm: pkix.AlgorithmIdentifier{Algorithm: oidSignatureSHA1WithRSA},
IssuerAndSerialNumber: ias,
EncryptedDigest: signature,
Version: 1,
}
// create signature of signed attributes
sd.certs = append(sd.certs, cert)
sd.sd.SignerInfos = append(sd.sd.SignerInfos, signer)
return nil
}
// AddCertificate adds the certificate to the payload. Useful for parent certificates
func (sd *SignedData) AddCertificate(cert *x509.Certificate) {
sd.certs = append(sd.certs, cert)
}
// Detach removes content from the signed data struct to make it a detached signature.
// This must be called right before Finish()
func (sd *SignedData) Detach() {
sd.sd.ContentInfo = contentInfo{ContentType: oidData}
}
// Finish marshals the content and its signers
func (sd *SignedData) Finish() ([]byte, error) {
sd.sd.Certificates = marshalCertificates(sd.certs)
inner, err := asn1.Marshal(sd.sd)
if err != nil {
return nil, err
}
outer := contentInfo{
ContentType: oidSignedData,
Content: asn1.RawValue{Class: 2, Tag: 0, Bytes: inner, IsCompound: true},
}
return asn1.Marshal(outer)
}
func cert2issuerAndSerial(cert *x509.Certificate) (issuerAndSerial, error) {
var ias issuerAndSerial
// The issuer RDNSequence has to match exactly the sequence in the certificate
// We cannot use cert.Issuer.ToRDNSequence() here since it mangles the sequence
ias.IssuerName = asn1.RawValue{FullBytes: cert.RawIssuer}
ias.SerialNumber = cert.SerialNumber
return ias, nil
}
// signs the DER encoded form of the attributes with the private key
func signAttributes(attrs []attribute, pkey crypto.PrivateKey, hash crypto.Hash) ([]byte, error) {
attrBytes, err := marshalAttributes(attrs)
if err != nil {
return nil, err
}
h := hash.New()
h.Write(attrBytes)
hashed := h.Sum(nil)
switch priv := pkey.(type) {
case *rsa.PrivateKey:
return rsa.SignPKCS1v15(rand.Reader, priv, crypto.SHA1, hashed)
}
return nil, ErrUnsupportedAlgorithm
}
// concats and wraps the certificates in the RawValue structure
func marshalCertificates(certs []*x509.Certificate) rawCertificates {
var buf bytes.Buffer
for _, cert := range certs {
buf.Write(cert.Raw)
}
rawCerts, _ := marshalCertificateBytes(buf.Bytes())
return rawCerts
}
// Even though, the tag & length are stripped out during marshalling the
// RawContent, we have to encode it into the RawContent. If its missing,
// then `asn1.Marshal()` will strip out the certificate wrapper instead.
func marshalCertificateBytes(certs []byte) (rawCertificates, error) {
var val = asn1.RawValue{Bytes: certs, Class: 2, Tag: 0, IsCompound: true}
b, err := asn1.Marshal(val)
if err != nil {
return rawCertificates{}, err
}
return rawCertificates{Raw: b}, nil
}
// DegenerateCertificate creates a signed data structure containing only the
// provided certificate or certificate chain.
func DegenerateCertificate(cert []byte) ([]byte, error) {
rawCert, err := marshalCertificateBytes(cert)
if err != nil {
return nil, err
}
emptyContent := contentInfo{ContentType: oidData}
sd := signedData{
Version: 1,
ContentInfo: emptyContent,
Certificates: rawCert,
CRLs: []pkix.CertificateList{},
}
content, err := asn1.Marshal(sd)
if err != nil {
return nil, err
}
signedContent := contentInfo{
ContentType: oidSignedData,
Content: asn1.RawValue{Class: 2, Tag: 0, Bytes: content, IsCompound: true},
}
return asn1.Marshal(signedContent)
}
const (
EncryptionAlgorithmDESCBC = iota
EncryptionAlgorithmAES128GCM
)
// ContentEncryptionAlgorithm determines the algorithm used to encrypt the
// plaintext message. Change the value of this variable to change which
// algorithm is used in the Encrypt() function.
var ContentEncryptionAlgorithm = EncryptionAlgorithmDESCBC
// ErrUnsupportedEncryptionAlgorithm is returned when attempting to encrypt
// content with an unsupported algorithm.
var ErrUnsupportedEncryptionAlgorithm = errors.New("pkcs7: cannot encrypt content: only DES-CBC and AES-128-GCM supported")
const nonceSize = 12
type aesGCMParameters struct {
Nonce []byte `asn1:"tag:4"`
ICVLen int
}
func encryptAES128GCM(content []byte) ([]byte, *encryptedContentInfo, error) {
// Create AES key and nonce
key := make([]byte, 16)
nonce := make([]byte, nonceSize)
_, err := rand.Read(key)
if err != nil {
return nil, nil, err
}
_, err = rand.Read(nonce)
if err != nil {
return nil, nil, err
}
// Encrypt content
block, err := aes.NewCipher(key)
if err != nil {
return nil, nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, nil, err
}
ciphertext := gcm.Seal(nil, nonce, content, nil)
// Prepare ASN.1 Encrypted Content Info
paramSeq := aesGCMParameters{
Nonce: nonce,
ICVLen: gcm.Overhead(),
}
paramBytes, err := asn1.Marshal(paramSeq)
if err != nil {
return nil, nil, err
}
eci := encryptedContentInfo{
ContentType: oidData,
ContentEncryptionAlgorithm: pkix.AlgorithmIdentifier{
Algorithm: oidEncryptionAlgorithmAES128GCM,
Parameters: asn1.RawValue{
Tag: asn1.TagSequence,
Bytes: paramBytes,
},
},
EncryptedContent: marshalEncryptedContent(ciphertext),
}
return key, &eci, nil
}
func encryptDESCBC(content []byte) ([]byte, *encryptedContentInfo, error) {
// Create DES key & CBC IV
key := make([]byte, 8)
iv := make([]byte, des.BlockSize)
_, err := rand.Read(key)
if err != nil {
return nil, nil, err
}
_, err = rand.Read(iv)
if err != nil {
return nil, nil, err
}
// Encrypt padded content
block, err := des.NewCipher(key)
if err != nil {
return nil, nil, err
}
mode := cipher.NewCBCEncrypter(block, iv)
plaintext, err := pad(content, mode.BlockSize())
cyphertext := make([]byte, len(plaintext))
mode.CryptBlocks(cyphertext, plaintext)
// Prepare ASN.1 Encrypted Content Info
eci := encryptedContentInfo{
ContentType: oidData,
ContentEncryptionAlgorithm: pkix.AlgorithmIdentifier{
Algorithm: oidEncryptionAlgorithmDESCBC,
Parameters: asn1.RawValue{Tag: 4, Bytes: iv},
},
EncryptedContent: marshalEncryptedContent(cyphertext),
}
return key, &eci, nil
}
// Encrypt creates and returns an envelope data PKCS7 structure with encrypted
// recipient keys for each recipient public key.
//
// The algorithm used to perform encryption is determined by the current value
// of the global ContentEncryptionAlgorithm package variable. By default, the
// value is EncryptionAlgorithmDESCBC. To use a different algorithm, change the
// value before calling Encrypt(). For example:
//
// ContentEncryptionAlgorithm = EncryptionAlgorithmAES128GCM
//
// TODO(fullsailor): Add support for encrypting content with other algorithms
func Encrypt(content []byte, recipients []*x509.Certificate) ([]byte, error) {
var eci *encryptedContentInfo
var key []byte
var err error
// Apply chosen symmetric encryption method
switch ContentEncryptionAlgorithm {
case EncryptionAlgorithmDESCBC:
key, eci, err = encryptDESCBC(content)
case EncryptionAlgorithmAES128GCM:
key, eci, err = encryptAES128GCM(content)
default:
return nil, ErrUnsupportedEncryptionAlgorithm
}
if err != nil {
return nil, err
}
// Prepare each recipient's encrypted cipher key
recipientInfos := make([]recipientInfo, len(recipients))
for i, recipient := range recipients {
encrypted, err := encryptKey(key, recipient)
if err != nil {
return nil, err
}
ias, err := cert2issuerAndSerial(recipient)
if err != nil {
return nil, err
}
info := recipientInfo{
Version: 0,
IssuerAndSerialNumber: ias,
KeyEncryptionAlgorithm: pkix.AlgorithmIdentifier{
Algorithm: oidEncryptionAlgorithmRSA,
},
EncryptedKey: encrypted,
}
recipientInfos[i] = info
}
// Prepare envelope content
envelope := envelopedData{
EncryptedContentInfo: *eci,
Version: 0,
RecipientInfos: recipientInfos,
}
innerContent, err := asn1.Marshal(envelope)
if err != nil {
return nil, err
}
// Prepare outer payload structure
wrapper := contentInfo{
ContentType: oidEnvelopedData,
Content: asn1.RawValue{Class: 2, Tag: 0, IsCompound: true, Bytes: innerContent},
}
return asn1.Marshal(wrapper)
}
func marshalEncryptedContent(content []byte) asn1.RawValue {
asn1Content, _ := asn1.Marshal(content)
return asn1.RawValue{Tag: 0, Class: 2, Bytes: asn1Content, IsCompound: true}
}
func encryptKey(key []byte, recipient *x509.Certificate) ([]byte, error) {
if pub := recipient.PublicKey.(*rsa.PublicKey); pub != nil {
return rsa.EncryptPKCS1v15(rand.Reader, pub, key)
}
return nil, ErrUnsupportedAlgorithm
}

View File

@ -1,133 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the go/golang LICENSE file.
package pkcs7
// These are private constants and functions from the crypto/x509 package that
// are useful when dealing with signatures verified by x509 certificates
import (
"bytes"
"crypto"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
)
var (
oidSignatureMD2WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 2}
oidSignatureMD5WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 4}
oidSignatureSHA1WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 5}
oidSignatureSHA256WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 11}
oidSignatureSHA384WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 12}
oidSignatureSHA512WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 13}
oidSignatureRSAPSS = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 10}
oidSignatureDSAWithSHA1 = asn1.ObjectIdentifier{1, 2, 840, 10040, 4, 3}
oidSignatureDSAWithSHA256 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 3, 2}
oidSignatureECDSAWithSHA1 = asn1.ObjectIdentifier{1, 2, 840, 10045, 4, 1}
oidSignatureECDSAWithSHA256 = asn1.ObjectIdentifier{1, 2, 840, 10045, 4, 3, 2}
oidSignatureECDSAWithSHA384 = asn1.ObjectIdentifier{1, 2, 840, 10045, 4, 3, 3}
oidSignatureECDSAWithSHA512 = asn1.ObjectIdentifier{1, 2, 840, 10045, 4, 3, 4}
oidSHA256 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 2, 1}
oidSHA384 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 2, 2}
oidSHA512 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 2, 3}
oidMGF1 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 8}
// oidISOSignatureSHA1WithRSA means the same as oidSignatureSHA1WithRSA
// but it's specified by ISO. Microsoft's makecert.exe has been known
// to produce certificates with this OID.
oidISOSignatureSHA1WithRSA = asn1.ObjectIdentifier{1, 3, 14, 3, 2, 29}
)
var signatureAlgorithmDetails = []struct {
algo x509.SignatureAlgorithm
name string
oid asn1.ObjectIdentifier
pubKeyAlgo x509.PublicKeyAlgorithm
hash crypto.Hash
}{
{x509.MD2WithRSA, "MD2-RSA", oidSignatureMD2WithRSA, x509.RSA, crypto.Hash(0) /* no value for MD2 */},
{x509.MD5WithRSA, "MD5-RSA", oidSignatureMD5WithRSA, x509.RSA, crypto.MD5},
{x509.SHA1WithRSA, "SHA1-RSA", oidSignatureSHA1WithRSA, x509.RSA, crypto.SHA1},
{x509.SHA1WithRSA, "SHA1-RSA", oidISOSignatureSHA1WithRSA, x509.RSA, crypto.SHA1},
{x509.SHA256WithRSA, "SHA256-RSA", oidSignatureSHA256WithRSA, x509.RSA, crypto.SHA256},
{x509.SHA384WithRSA, "SHA384-RSA", oidSignatureSHA384WithRSA, x509.RSA, crypto.SHA384},
{x509.SHA512WithRSA, "SHA512-RSA", oidSignatureSHA512WithRSA, x509.RSA, crypto.SHA512},
{x509.SHA256WithRSAPSS, "SHA256-RSAPSS", oidSignatureRSAPSS, x509.RSA, crypto.SHA256},
{x509.SHA384WithRSAPSS, "SHA384-RSAPSS", oidSignatureRSAPSS, x509.RSA, crypto.SHA384},
{x509.SHA512WithRSAPSS, "SHA512-RSAPSS", oidSignatureRSAPSS, x509.RSA, crypto.SHA512},
{x509.DSAWithSHA1, "DSA-SHA1", oidSignatureDSAWithSHA1, x509.DSA, crypto.SHA1},
{x509.DSAWithSHA256, "DSA-SHA256", oidSignatureDSAWithSHA256, x509.DSA, crypto.SHA256},
{x509.ECDSAWithSHA1, "ECDSA-SHA1", oidSignatureECDSAWithSHA1, x509.ECDSA, crypto.SHA1},
{x509.ECDSAWithSHA256, "ECDSA-SHA256", oidSignatureECDSAWithSHA256, x509.ECDSA, crypto.SHA256},
{x509.ECDSAWithSHA384, "ECDSA-SHA384", oidSignatureECDSAWithSHA384, x509.ECDSA, crypto.SHA384},
{x509.ECDSAWithSHA512, "ECDSA-SHA512", oidSignatureECDSAWithSHA512, x509.ECDSA, crypto.SHA512},
}
// pssParameters reflects the parameters in an AlgorithmIdentifier that
// specifies RSA PSS. See https://tools.ietf.org/html/rfc3447#appendix-A.2.3
type pssParameters struct {
// The following three fields are not marked as
// optional because the default values specify SHA-1,
// which is no longer suitable for use in signatures.
Hash pkix.AlgorithmIdentifier `asn1:"explicit,tag:0"`
MGF pkix.AlgorithmIdentifier `asn1:"explicit,tag:1"`
SaltLength int `asn1:"explicit,tag:2"`
TrailerField int `asn1:"optional,explicit,tag:3,default:1"`
}
// asn1.NullBytes is not available prior to Go 1.9
var nullBytes = []byte{5, 0}
func getSignatureAlgorithmFromAI(ai pkix.AlgorithmIdentifier) x509.SignatureAlgorithm {
if !ai.Algorithm.Equal(oidSignatureRSAPSS) {
for _, details := range signatureAlgorithmDetails {
if ai.Algorithm.Equal(details.oid) {
return details.algo
}
}
return x509.UnknownSignatureAlgorithm
}
// RSA PSS is special because it encodes important parameters
// in the Parameters.
var params pssParameters
if _, err := asn1.Unmarshal(ai.Parameters.FullBytes, &params); err != nil {
return x509.UnknownSignatureAlgorithm
}
var mgf1HashFunc pkix.AlgorithmIdentifier
if _, err := asn1.Unmarshal(params.MGF.Parameters.FullBytes, &mgf1HashFunc); err != nil {
return x509.UnknownSignatureAlgorithm
}
// PSS is greatly overburdened with options. This code forces
// them into three buckets by requiring that the MGF1 hash
// function always match the message hash function (as
// recommended in
// https://tools.ietf.org/html/rfc3447#section-8.1), that the
// salt length matches the hash length, and that the trailer
// field has the default value.
if !bytes.Equal(params.Hash.Parameters.FullBytes, nullBytes) ||
!params.MGF.Algorithm.Equal(oidMGF1) ||
!mgf1HashFunc.Algorithm.Equal(params.Hash.Algorithm) ||
!bytes.Equal(mgf1HashFunc.Parameters.FullBytes, nullBytes) ||
params.TrailerField != 1 {
return x509.UnknownSignatureAlgorithm
}
switch {
case params.Hash.Algorithm.Equal(oidSHA256) && params.SaltLength == 32:
return x509.SHA256WithRSAPSS
case params.Hash.Algorithm.Equal(oidSHA384) && params.SaltLength == 48:
return x509.SHA384WithRSAPSS
case params.Hash.Algorithm.Equal(oidSHA512) && params.SaltLength == 64:
return x509.SHA512WithRSAPSS
}
return x509.UnknownSignatureAlgorithm
}

View File

@ -1,8 +0,0 @@
# This is the official list of gorilla/mux authors for copyright purposes.
#
# Please keep the list sorted.
Google LLC (https://opensource.google.com/)
Kamil Kisielk <kamil@kamilkisiel.net>
Matt Silverlock <matt@eatsleeprepeat.net>
Rodrigo Moraes (https://github.com/moraes)

View File

@ -1,17 +0,0 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library(
name = "go_default_library",
srcs = [
"context.go",
"doc.go",
"middleware.go",
"mux.go",
"regexp.go",
"route.go",
"test_helpers.go",
],
importmap = "k8s.io/kops/vendor/github.com/gorilla/mux",
importpath = "github.com/gorilla/mux",
visibility = ["//visibility:public"],
)

View File

@ -1,27 +0,0 @@
Copyright (c) 2012-2018 The Gorilla Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -1,718 +0,0 @@
# gorilla/mux
[![GoDoc](https://godoc.org/github.com/gorilla/mux?status.svg)](https://godoc.org/github.com/gorilla/mux)
[![Build Status](https://travis-ci.org/gorilla/mux.svg?branch=master)](https://travis-ci.org/gorilla/mux)
[![CircleCI](https://circleci.com/gh/gorilla/mux.svg?style=svg)](https://circleci.com/gh/gorilla/mux)
[![Sourcegraph](https://sourcegraph.com/github.com/gorilla/mux/-/badge.svg)](https://sourcegraph.com/github.com/gorilla/mux?badge)
![Gorilla Logo](http://www.gorillatoolkit.org/static/images/gorilla-icon-64.png)
https://www.gorillatoolkit.org/pkg/mux
Package `gorilla/mux` implements a request router and dispatcher for matching incoming requests to
their respective handler.
The name mux stands for "HTTP request multiplexer". Like the standard `http.ServeMux`, `mux.Router` matches incoming requests against a list of registered routes and calls a handler for the route that matches the URL or other conditions. The main features are:
* It implements the `http.Handler` interface so it is compatible with the standard `http.ServeMux`.
* Requests can be matched based on URL host, path, path prefix, schemes, header and query values, HTTP methods or using custom matchers.
* URL hosts, paths and query values can have variables with an optional regular expression.
* Registered URLs can be built, or "reversed", which helps maintaining references to resources.
* Routes can be used as subrouters: nested routes are only tested if the parent route matches. This is useful to define groups of routes that share common conditions like a host, a path prefix or other repeated attributes. As a bonus, this optimizes request matching.
---
* [Install](#install)
* [Examples](#examples)
* [Matching Routes](#matching-routes)
* [Static Files](#static-files)
* [Registered URLs](#registered-urls)
* [Walking Routes](#walking-routes)
* [Graceful Shutdown](#graceful-shutdown)
* [Middleware](#middleware)
* [Handling CORS Requests](#handling-cors-requests)
* [Testing Handlers](#testing-handlers)
* [Full Example](#full-example)
---
## Install
With a [correctly configured](https://golang.org/doc/install#testing) Go toolchain:
```sh
go get -u github.com/gorilla/mux
```
## Examples
Let's start registering a couple of URL paths and handlers:
```go
func main() {
r := mux.NewRouter()
r.HandleFunc("/", HomeHandler)
r.HandleFunc("/products", ProductsHandler)
r.HandleFunc("/articles", ArticlesHandler)
http.Handle("/", r)
}
```
Here we register three routes mapping URL paths to handlers. This is equivalent to how `http.HandleFunc()` works: if an incoming request URL matches one of the paths, the corresponding handler is called passing (`http.ResponseWriter`, `*http.Request`) as parameters.
Paths can have variables. They are defined using the format `{name}` or `{name:pattern}`. If a regular expression pattern is not defined, the matched variable will be anything until the next slash. For example:
```go
r := mux.NewRouter()
r.HandleFunc("/products/{key}", ProductHandler)
r.HandleFunc("/articles/{category}/", ArticlesCategoryHandler)
r.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler)
```
The names are used to create a map of route variables which can be retrieved calling `mux.Vars()`:
```go
func ArticlesCategoryHandler(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, "Category: %v\n", vars["category"])
}
```
And this is all you need to know about the basic usage. More advanced options are explained below.
### Matching Routes
Routes can also be restricted to a domain or subdomain. Just define a host pattern to be matched. They can also have variables:
```go
r := mux.NewRouter()
// Only matches if domain is "www.example.com".
r.Host("www.example.com")
// Matches a dynamic subdomain.
r.Host("{subdomain:[a-z]+}.example.com")
```
There are several other matchers that can be added. To match path prefixes:
```go
r.PathPrefix("/products/")
```
...or HTTP methods:
```go
r.Methods("GET", "POST")
```
...or URL schemes:
```go
r.Schemes("https")
```
...or header values:
```go
r.Headers("X-Requested-With", "XMLHttpRequest")
```
...or query values:
```go
r.Queries("key", "value")
```
...or to use a custom matcher function:
```go
r.MatcherFunc(func(r *http.Request, rm *RouteMatch) bool {
return r.ProtoMajor == 0
})
```
...and finally, it is possible to combine several matchers in a single route:
```go
r.HandleFunc("/products", ProductsHandler).
Host("www.example.com").
Methods("GET").
Schemes("http")
```
Routes are tested in the order they were added to the router. If two routes match, the first one wins:
```go
r := mux.NewRouter()
r.HandleFunc("/specific", specificHandler)
r.PathPrefix("/").Handler(catchAllHandler)
```
Setting the same matching conditions again and again can be boring, so we have a way to group several routes that share the same requirements. We call it "subrouting".
For example, let's say we have several URLs that should only match when the host is `www.example.com`. Create a route for that host and get a "subrouter" from it:
```go
r := mux.NewRouter()
s := r.Host("www.example.com").Subrouter()
```
Then register routes in the subrouter:
```go
s.HandleFunc("/products/", ProductsHandler)
s.HandleFunc("/products/{key}", ProductHandler)
s.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler)
```
The three URL paths we registered above will only be tested if the domain is `www.example.com`, because the subrouter is tested first. This is not only convenient, but also optimizes request matching. You can create subrouters combining any attribute matchers accepted by a route.
Subrouters can be used to create domain or path "namespaces": you define subrouters in a central place and then parts of the app can register its paths relatively to a given subrouter.
There's one more thing about subroutes. When a subrouter has a path prefix, the inner routes use it as base for their paths:
```go
r := mux.NewRouter()
s := r.PathPrefix("/products").Subrouter()
// "/products/"
s.HandleFunc("/", ProductsHandler)
// "/products/{key}/"
s.HandleFunc("/{key}/", ProductHandler)
// "/products/{key}/details"
s.HandleFunc("/{key}/details", ProductDetailsHandler)
```
### Static Files
Note that the path provided to `PathPrefix()` represents a "wildcard": calling
`PathPrefix("/static/").Handler(...)` means that the handler will be passed any
request that matches "/static/\*". This makes it easy to serve static files with mux:
```go
func main() {
var dir string
flag.StringVar(&dir, "dir", ".", "the directory to serve files from. Defaults to the current dir")
flag.Parse()
r := mux.NewRouter()
// This will serve files under http://localhost:8000/static/<filename>
r.PathPrefix("/static/").Handler(http.StripPrefix("/static/", http.FileServer(http.Dir(dir))))
srv := &http.Server{
Handler: r,
Addr: "127.0.0.1:8000",
// Good practice: enforce timeouts for servers you create!
WriteTimeout: 15 * time.Second,
ReadTimeout: 15 * time.Second,
}
log.Fatal(srv.ListenAndServe())
}
```
### Registered URLs
Now let's see how to build registered URLs.
Routes can be named. All routes that define a name can have their URLs built, or "reversed". We define a name calling `Name()` on a route. For example:
```go
r := mux.NewRouter()
r.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler).
Name("article")
```
To build a URL, get the route and call the `URL()` method, passing a sequence of key/value pairs for the route variables. For the previous route, we would do:
```go
url, err := r.Get("article").URL("category", "technology", "id", "42")
```
...and the result will be a `url.URL` with the following path:
```
"/articles/technology/42"
```
This also works for host and query value variables:
```go
r := mux.NewRouter()
r.Host("{subdomain}.example.com").
Path("/articles/{category}/{id:[0-9]+}").
Queries("filter", "{filter}").
HandlerFunc(ArticleHandler).
Name("article")
// url.String() will be "http://news.example.com/articles/technology/42?filter=gorilla"
url, err := r.Get("article").URL("subdomain", "news",
"category", "technology",
"id", "42",
"filter", "gorilla")
```
All variables defined in the route are required, and their values must conform to the corresponding patterns. These requirements guarantee that a generated URL will always match a registered route -- the only exception is for explicitly defined "build-only" routes which never match.
Regex support also exists for matching Headers within a route. For example, we could do:
```go
r.HeadersRegexp("Content-Type", "application/(text|json)")
```
...and the route will match both requests with a Content-Type of `application/json` as well as `application/text`
There's also a way to build only the URL host or path for a route: use the methods `URLHost()` or `URLPath()` instead. For the previous route, we would do:
```go
// "http://news.example.com/"
host, err := r.Get("article").URLHost("subdomain", "news")
// "/articles/technology/42"
path, err := r.Get("article").URLPath("category", "technology", "id", "42")
```
And if you use subrouters, host and path defined separately can be built as well:
```go
r := mux.NewRouter()
s := r.Host("{subdomain}.example.com").Subrouter()
s.Path("/articles/{category}/{id:[0-9]+}").
HandlerFunc(ArticleHandler).
Name("article")
// "http://news.example.com/articles/technology/42"
url, err := r.Get("article").URL("subdomain", "news",
"category", "technology",
"id", "42")
```
### Walking Routes
The `Walk` function on `mux.Router` can be used to visit all of the routes that are registered on a router. For example,
the following prints all of the registered routes:
```go
package main
import (
"fmt"
"net/http"
"strings"
"github.com/gorilla/mux"
)
func handler(w http.ResponseWriter, r *http.Request) {
return
}
func main() {
r := mux.NewRouter()
r.HandleFunc("/", handler)
r.HandleFunc("/products", handler).Methods("POST")
r.HandleFunc("/articles", handler).Methods("GET")
r.HandleFunc("/articles/{id}", handler).Methods("GET", "PUT")
r.HandleFunc("/authors", handler).Queries("surname", "{surname}")
err := r.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error {
pathTemplate, err := route.GetPathTemplate()
if err == nil {
fmt.Println("ROUTE:", pathTemplate)
}
pathRegexp, err := route.GetPathRegexp()
if err == nil {
fmt.Println("Path regexp:", pathRegexp)
}
queriesTemplates, err := route.GetQueriesTemplates()
if err == nil {
fmt.Println("Queries templates:", strings.Join(queriesTemplates, ","))
}
queriesRegexps, err := route.GetQueriesRegexp()
if err == nil {
fmt.Println("Queries regexps:", strings.Join(queriesRegexps, ","))
}
methods, err := route.GetMethods()
if err == nil {
fmt.Println("Methods:", strings.Join(methods, ","))
}
fmt.Println()
return nil
})
if err != nil {
fmt.Println(err)
}
http.Handle("/", r)
}
```
### Graceful Shutdown
Go 1.8 introduced the ability to [gracefully shutdown](https://golang.org/doc/go1.8#http_shutdown) a `*http.Server`. Here's how to do that alongside `mux`:
```go
package main
import (
"context"
"flag"
"log"
"net/http"
"os"
"os/signal"
"time"
"github.com/gorilla/mux"
)
func main() {
var wait time.Duration
flag.DurationVar(&wait, "graceful-timeout", time.Second * 15, "the duration for which the server gracefully wait for existing connections to finish - e.g. 15s or 1m")
flag.Parse()
r := mux.NewRouter()
// Add your routes as needed
srv := &http.Server{
Addr: "0.0.0.0:8080",
// Good practice to set timeouts to avoid Slowloris attacks.
WriteTimeout: time.Second * 15,
ReadTimeout: time.Second * 15,
IdleTimeout: time.Second * 60,
Handler: r, // Pass our instance of gorilla/mux in.
}
// Run our server in a goroutine so that it doesn't block.
go func() {
if err := srv.ListenAndServe(); err != nil {
log.Println(err)
}
}()
c := make(chan os.Signal, 1)
// We'll accept graceful shutdowns when quit via SIGINT (Ctrl+C)
// SIGKILL, SIGQUIT or SIGTERM (Ctrl+/) will not be caught.
signal.Notify(c, os.Interrupt)
// Block until we receive our signal.
<-c
// Create a deadline to wait for.
ctx, cancel := context.WithTimeout(context.Background(), wait)
defer cancel()
// Doesn't block if no connections, but will otherwise wait
// until the timeout deadline.
srv.Shutdown(ctx)
// Optionally, you could run srv.Shutdown in a goroutine and block on
// <-ctx.Done() if your application should wait for other services
// to finalize based on context cancellation.
log.Println("shutting down")
os.Exit(0)
}
```
### Middleware
Mux supports the addition of middlewares to a [Router](https://godoc.org/github.com/gorilla/mux#Router), which are executed in the order they are added if a match is found, including its subrouters.
Middlewares are (typically) small pieces of code which take one request, do something with it, and pass it down to another middleware or the final handler. Some common use cases for middleware are request logging, header manipulation, or `ResponseWriter` hijacking.
Mux middlewares are defined using the de facto standard type:
```go
type MiddlewareFunc func(http.Handler) http.Handler
```
Typically, the returned handler is a closure which does something with the http.ResponseWriter and http.Request passed to it, and then calls the handler passed as parameter to the MiddlewareFunc. This takes advantage of closures being able access variables from the context where they are created, while retaining the signature enforced by the receivers.
A very basic middleware which logs the URI of the request being handled could be written as:
```go
func loggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Do stuff here
log.Println(r.RequestURI)
// Call the next handler, which can be another middleware in the chain, or the final handler.
next.ServeHTTP(w, r)
})
}
```
Middlewares can be added to a router using `Router.Use()`:
```go
r := mux.NewRouter()
r.HandleFunc("/", handler)
r.Use(loggingMiddleware)
```
A more complex authentication middleware, which maps session token to users, could be written as:
```go
// Define our struct
type authenticationMiddleware struct {
tokenUsers map[string]string
}
// Initialize it somewhere
func (amw *authenticationMiddleware) Populate() {
amw.tokenUsers["00000000"] = "user0"
amw.tokenUsers["aaaaaaaa"] = "userA"
amw.tokenUsers["05f717e5"] = "randomUser"
amw.tokenUsers["deadbeef"] = "user0"
}
// Middleware function, which will be called for each request
func (amw *authenticationMiddleware) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get("X-Session-Token")
if user, found := amw.tokenUsers[token]; found {
// We found the token in our map
log.Printf("Authenticated user %s\n", user)
// Pass down the request to the next middleware (or final handler)
next.ServeHTTP(w, r)
} else {
// Write an error and stop the handler chain
http.Error(w, "Forbidden", http.StatusForbidden)
}
})
}
```
```go
r := mux.NewRouter()
r.HandleFunc("/", handler)
amw := authenticationMiddleware{}
amw.Populate()
r.Use(amw.Middleware)
```
Note: The handler chain will be stopped if your middleware doesn't call `next.ServeHTTP()` with the corresponding parameters. This can be used to abort a request if the middleware writer wants to. Middlewares _should_ write to `ResponseWriter` if they _are_ going to terminate the request, and they _should not_ write to `ResponseWriter` if they _are not_ going to terminate it.
### Handling CORS Requests
[CORSMethodMiddleware](https://godoc.org/github.com/gorilla/mux#CORSMethodMiddleware) intends to make it easier to strictly set the `Access-Control-Allow-Methods` response header.
* You will still need to use your own CORS handler to set the other CORS headers such as `Access-Control-Allow-Origin`
* The middleware will set the `Access-Control-Allow-Methods` header to all the method matchers (e.g. `r.Methods(http.MethodGet, http.MethodPut, http.MethodOptions)` -> `Access-Control-Allow-Methods: GET,PUT,OPTIONS`) on a route
* If you do not specify any methods, then:
> _Important_: there must be an `OPTIONS` method matcher for the middleware to set the headers.
Here is an example of using `CORSMethodMiddleware` along with a custom `OPTIONS` handler to set all the required CORS headers:
```go
package main
import (
"net/http"
"github.com/gorilla/mux"
)
func main() {
r := mux.NewRouter()
// IMPORTANT: you must specify an OPTIONS method matcher for the middleware to set CORS headers
r.HandleFunc("/foo", fooHandler).Methods(http.MethodGet, http.MethodPut, http.MethodPatch, http.MethodOptions)
r.Use(mux.CORSMethodMiddleware(r))
http.ListenAndServe(":8080", r)
}
func fooHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
if r.Method == http.MethodOptions {
return
}
w.Write([]byte("foo"))
}
```
And an request to `/foo` using something like:
```bash
curl localhost:8080/foo -v
```
Would look like:
```bash
* Trying ::1...
* TCP_NODELAY set
* Connected to localhost (::1) port 8080 (#0)
> GET /foo HTTP/1.1
> Host: localhost:8080
> User-Agent: curl/7.59.0
> Accept: */*
>
< HTTP/1.1 200 OK
< Access-Control-Allow-Methods: GET,PUT,PATCH,OPTIONS
< Access-Control-Allow-Origin: *
< Date: Fri, 28 Jun 2019 20:13:30 GMT
< Content-Length: 3
< Content-Type: text/plain; charset=utf-8
<
* Connection #0 to host localhost left intact
foo
```
### Testing Handlers
Testing handlers in a Go web application is straightforward, and _mux_ doesn't complicate this any further. Given two files: `endpoints.go` and `endpoints_test.go`, here's how we'd test an application using _mux_.
First, our simple HTTP handler:
```go
// endpoints.go
package main
func HealthCheckHandler(w http.ResponseWriter, r *http.Request) {
// A very simple health check.
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
// In the future we could report back on the status of our DB, or our cache
// (e.g. Redis) by performing a simple PING, and include them in the response.
io.WriteString(w, `{"alive": true}`)
}
func main() {
r := mux.NewRouter()
r.HandleFunc("/health", HealthCheckHandler)
log.Fatal(http.ListenAndServe("localhost:8080", r))
}
```
Our test code:
```go
// endpoints_test.go
package main
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestHealthCheckHandler(t *testing.T) {
// Create a request to pass to our handler. We don't have any query parameters for now, so we'll
// pass 'nil' as the third parameter.
req, err := http.NewRequest("GET", "/health", nil)
if err != nil {
t.Fatal(err)
}
// We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response.
rr := httptest.NewRecorder()
handler := http.HandlerFunc(HealthCheckHandler)
// Our handlers satisfy http.Handler, so we can call their ServeHTTP method
// directly and pass in our Request and ResponseRecorder.
handler.ServeHTTP(rr, req)
// Check the status code is what we expect.
if status := rr.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}
// Check the response body is what we expect.
expected := `{"alive": true}`
if rr.Body.String() != expected {
t.Errorf("handler returned unexpected body: got %v want %v",
rr.Body.String(), expected)
}
}
```
In the case that our routes have [variables](#examples), we can pass those in the request. We could write
[table-driven tests](https://dave.cheney.net/2013/06/09/writing-table-driven-tests-in-go) to test multiple
possible route variables as needed.
```go
// endpoints.go
func main() {
r := mux.NewRouter()
// A route with a route variable:
r.HandleFunc("/metrics/{type}", MetricsHandler)
log.Fatal(http.ListenAndServe("localhost:8080", r))
}
```
Our test file, with a table-driven test of `routeVariables`:
```go
// endpoints_test.go
func TestMetricsHandler(t *testing.T) {
tt := []struct{
routeVariable string
shouldPass bool
}{
{"goroutines", true},
{"heap", true},
{"counters", true},
{"queries", true},
{"adhadaeqm3k", false},
}
for _, tc := range tt {
path := fmt.Sprintf("/metrics/%s", tc.routeVariable)
req, err := http.NewRequest("GET", path, nil)
if err != nil {
t.Fatal(err)
}
rr := httptest.NewRecorder()
// Need to create a router that we can pass the request through so that the vars will be added to the context
router := mux.NewRouter()
router.HandleFunc("/metrics/{type}", MetricsHandler)
router.ServeHTTP(rr, req)
// In this case, our MetricsHandler returns a non-200 response
// for a route variable it doesn't know about.
if rr.Code == http.StatusOK && !tc.shouldPass {
t.Errorf("handler should have failed on routeVariable %s: got %v want %v",
tc.routeVariable, rr.Code, http.StatusOK)
}
}
}
```
## Full Example
Here's a complete, runnable example of a small `mux` based server:
```go
package main
import (
"net/http"
"log"
"github.com/gorilla/mux"
)
func YourHandler(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Gorilla!\n"))
}
func main() {
r := mux.NewRouter()
// Routes consist of a path and a handler function.
r.HandleFunc("/", YourHandler)
// Bind to a port and pass our router in
log.Fatal(http.ListenAndServe(":8000", r))
}
```
## License
BSD licensed. See the LICENSE file for details.

View File

@ -1,18 +0,0 @@
package mux
import (
"context"
"net/http"
)
func contextGet(r *http.Request, key interface{}) interface{} {
return r.Context().Value(key)
}
func contextSet(r *http.Request, key, val interface{}) *http.Request {
if val == nil {
return r
}
return r.WithContext(context.WithValue(r.Context(), key, val))
}

306
vendor/github.com/gorilla/mux/doc.go generated vendored
View File

@ -1,306 +0,0 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package mux implements a request router and dispatcher.
The name mux stands for "HTTP request multiplexer". Like the standard
http.ServeMux, mux.Router matches incoming requests against a list of
registered routes and calls a handler for the route that matches the URL
or other conditions. The main features are:
* Requests can be matched based on URL host, path, path prefix, schemes,
header and query values, HTTP methods or using custom matchers.
* URL hosts, paths and query values can have variables with an optional
regular expression.
* Registered URLs can be built, or "reversed", which helps maintaining
references to resources.
* Routes can be used as subrouters: nested routes are only tested if the
parent route matches. This is useful to define groups of routes that
share common conditions like a host, a path prefix or other repeated
attributes. As a bonus, this optimizes request matching.
* It implements the http.Handler interface so it is compatible with the
standard http.ServeMux.
Let's start registering a couple of URL paths and handlers:
func main() {
r := mux.NewRouter()
r.HandleFunc("/", HomeHandler)
r.HandleFunc("/products", ProductsHandler)
r.HandleFunc("/articles", ArticlesHandler)
http.Handle("/", r)
}
Here we register three routes mapping URL paths to handlers. This is
equivalent to how http.HandleFunc() works: if an incoming request URL matches
one of the paths, the corresponding handler is called passing
(http.ResponseWriter, *http.Request) as parameters.
Paths can have variables. They are defined using the format {name} or
{name:pattern}. If a regular expression pattern is not defined, the matched
variable will be anything until the next slash. For example:
r := mux.NewRouter()
r.HandleFunc("/products/{key}", ProductHandler)
r.HandleFunc("/articles/{category}/", ArticlesCategoryHandler)
r.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler)
Groups can be used inside patterns, as long as they are non-capturing (?:re). For example:
r.HandleFunc("/articles/{category}/{sort:(?:asc|desc|new)}", ArticlesCategoryHandler)
The names are used to create a map of route variables which can be retrieved
calling mux.Vars():
vars := mux.Vars(request)
category := vars["category"]
Note that if any capturing groups are present, mux will panic() during parsing. To prevent
this, convert any capturing groups to non-capturing, e.g. change "/{sort:(asc|desc)}" to
"/{sort:(?:asc|desc)}". This is a change from prior versions which behaved unpredictably
when capturing groups were present.
And this is all you need to know about the basic usage. More advanced options
are explained below.
Routes can also be restricted to a domain or subdomain. Just define a host
pattern to be matched. They can also have variables:
r := mux.NewRouter()
// Only matches if domain is "www.example.com".
r.Host("www.example.com")
// Matches a dynamic subdomain.
r.Host("{subdomain:[a-z]+}.domain.com")
There are several other matchers that can be added. To match path prefixes:
r.PathPrefix("/products/")
...or HTTP methods:
r.Methods("GET", "POST")
...or URL schemes:
r.Schemes("https")
...or header values:
r.Headers("X-Requested-With", "XMLHttpRequest")
...or query values:
r.Queries("key", "value")
...or to use a custom matcher function:
r.MatcherFunc(func(r *http.Request, rm *RouteMatch) bool {
return r.ProtoMajor == 0
})
...and finally, it is possible to combine several matchers in a single route:
r.HandleFunc("/products", ProductsHandler).
Host("www.example.com").
Methods("GET").
Schemes("http")
Setting the same matching conditions again and again can be boring, so we have
a way to group several routes that share the same requirements.
We call it "subrouting".
For example, let's say we have several URLs that should only match when the
host is "www.example.com". Create a route for that host and get a "subrouter"
from it:
r := mux.NewRouter()
s := r.Host("www.example.com").Subrouter()
Then register routes in the subrouter:
s.HandleFunc("/products/", ProductsHandler)
s.HandleFunc("/products/{key}", ProductHandler)
s.HandleFunc("/articles/{category}/{id:[0-9]+}"), ArticleHandler)
The three URL paths we registered above will only be tested if the domain is
"www.example.com", because the subrouter is tested first. This is not
only convenient, but also optimizes request matching. You can create
subrouters combining any attribute matchers accepted by a route.
Subrouters can be used to create domain or path "namespaces": you define
subrouters in a central place and then parts of the app can register its
paths relatively to a given subrouter.
There's one more thing about subroutes. When a subrouter has a path prefix,
the inner routes use it as base for their paths:
r := mux.NewRouter()
s := r.PathPrefix("/products").Subrouter()
// "/products/"
s.HandleFunc("/", ProductsHandler)
// "/products/{key}/"
s.HandleFunc("/{key}/", ProductHandler)
// "/products/{key}/details"
s.HandleFunc("/{key}/details", ProductDetailsHandler)
Note that the path provided to PathPrefix() represents a "wildcard": calling
PathPrefix("/static/").Handler(...) means that the handler will be passed any
request that matches "/static/*". This makes it easy to serve static files with mux:
func main() {
var dir string
flag.StringVar(&dir, "dir", ".", "the directory to serve files from. Defaults to the current dir")
flag.Parse()
r := mux.NewRouter()
// This will serve files under http://localhost:8000/static/<filename>
r.PathPrefix("/static/").Handler(http.StripPrefix("/static/", http.FileServer(http.Dir(dir))))
srv := &http.Server{
Handler: r,
Addr: "127.0.0.1:8000",
// Good practice: enforce timeouts for servers you create!
WriteTimeout: 15 * time.Second,
ReadTimeout: 15 * time.Second,
}
log.Fatal(srv.ListenAndServe())
}
Now let's see how to build registered URLs.
Routes can be named. All routes that define a name can have their URLs built,
or "reversed". We define a name calling Name() on a route. For example:
r := mux.NewRouter()
r.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler).
Name("article")
To build a URL, get the route and call the URL() method, passing a sequence of
key/value pairs for the route variables. For the previous route, we would do:
url, err := r.Get("article").URL("category", "technology", "id", "42")
...and the result will be a url.URL with the following path:
"/articles/technology/42"
This also works for host and query value variables:
r := mux.NewRouter()
r.Host("{subdomain}.domain.com").
Path("/articles/{category}/{id:[0-9]+}").
Queries("filter", "{filter}").
HandlerFunc(ArticleHandler).
Name("article")
// url.String() will be "http://news.domain.com/articles/technology/42?filter=gorilla"
url, err := r.Get("article").URL("subdomain", "news",
"category", "technology",
"id", "42",
"filter", "gorilla")
All variables defined in the route are required, and their values must
conform to the corresponding patterns. These requirements guarantee that a
generated URL will always match a registered route -- the only exception is
for explicitly defined "build-only" routes which never match.
Regex support also exists for matching Headers within a route. For example, we could do:
r.HeadersRegexp("Content-Type", "application/(text|json)")
...and the route will match both requests with a Content-Type of `application/json` as well as
`application/text`
There's also a way to build only the URL host or path for a route:
use the methods URLHost() or URLPath() instead. For the previous route,
we would do:
// "http://news.domain.com/"
host, err := r.Get("article").URLHost("subdomain", "news")
// "/articles/technology/42"
path, err := r.Get("article").URLPath("category", "technology", "id", "42")
And if you use subrouters, host and path defined separately can be built
as well:
r := mux.NewRouter()
s := r.Host("{subdomain}.domain.com").Subrouter()
s.Path("/articles/{category}/{id:[0-9]+}").
HandlerFunc(ArticleHandler).
Name("article")
// "http://news.domain.com/articles/technology/42"
url, err := r.Get("article").URL("subdomain", "news",
"category", "technology",
"id", "42")
Mux supports the addition of middlewares to a Router, which are executed in the order they are added if a match is found, including its subrouters. Middlewares are (typically) small pieces of code which take one request, do something with it, and pass it down to another middleware or the final handler. Some common use cases for middleware are request logging, header manipulation, or ResponseWriter hijacking.
type MiddlewareFunc func(http.Handler) http.Handler
Typically, the returned handler is a closure which does something with the http.ResponseWriter and http.Request passed to it, and then calls the handler passed as parameter to the MiddlewareFunc (closures can access variables from the context where they are created).
A very basic middleware which logs the URI of the request being handled could be written as:
func simpleMw(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Do stuff here
log.Println(r.RequestURI)
// Call the next handler, which can be another middleware in the chain, or the final handler.
next.ServeHTTP(w, r)
})
}
Middlewares can be added to a router using `Router.Use()`:
r := mux.NewRouter()
r.HandleFunc("/", handler)
r.Use(simpleMw)
A more complex authentication middleware, which maps session token to users, could be written as:
// Define our struct
type authenticationMiddleware struct {
tokenUsers map[string]string
}
// Initialize it somewhere
func (amw *authenticationMiddleware) Populate() {
amw.tokenUsers["00000000"] = "user0"
amw.tokenUsers["aaaaaaaa"] = "userA"
amw.tokenUsers["05f717e5"] = "randomUser"
amw.tokenUsers["deadbeef"] = "user0"
}
// Middleware function, which will be called for each request
func (amw *authenticationMiddleware) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get("X-Session-Token")
if user, found := amw.tokenUsers[token]; found {
// We found the token in our map
log.Printf("Authenticated user %s\n", user)
next.ServeHTTP(w, r)
} else {
http.Error(w, "Forbidden", http.StatusForbidden)
}
})
}
r := mux.NewRouter()
r.HandleFunc("/", handler)
amw := authenticationMiddleware{tokenUsers: make(map[string]string)}
amw.Populate()
r.Use(amw.Middleware)
Note: The handler chain will be stopped if your middleware doesn't call `next.ServeHTTP()` with the corresponding parameters. This can be used to abort a request if the middleware writer wants to.
*/
package mux

View File

@ -1 +0,0 @@
module github.com/gorilla/mux

View File

@ -1,79 +0,0 @@
package mux
import (
"net/http"
"strings"
)
// MiddlewareFunc is a function which receives an http.Handler and returns another http.Handler.
// Typically, the returned handler is a closure which does something with the http.ResponseWriter and http.Request passed
// to it, and then calls the handler passed as parameter to the MiddlewareFunc.
type MiddlewareFunc func(http.Handler) http.Handler
// middleware interface is anything which implements a MiddlewareFunc named Middleware.
type middleware interface {
Middleware(handler http.Handler) http.Handler
}
// Middleware allows MiddlewareFunc to implement the middleware interface.
func (mw MiddlewareFunc) Middleware(handler http.Handler) http.Handler {
return mw(handler)
}
// Use appends a MiddlewareFunc to the chain. Middleware can be used to intercept or otherwise modify requests and/or responses, and are executed in the order that they are applied to the Router.
func (r *Router) Use(mwf ...MiddlewareFunc) {
for _, fn := range mwf {
r.middlewares = append(r.middlewares, fn)
}
}
// useInterface appends a middleware to the chain. Middleware can be used to intercept or otherwise modify requests and/or responses, and are executed in the order that they are applied to the Router.
func (r *Router) useInterface(mw middleware) {
r.middlewares = append(r.middlewares, mw)
}
// CORSMethodMiddleware automatically sets the Access-Control-Allow-Methods response header
// on requests for routes that have an OPTIONS method matcher to all the method matchers on
// the route. Routes that do not explicitly handle OPTIONS requests will not be processed
// by the middleware. See examples for usage.
func CORSMethodMiddleware(r *Router) MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
allMethods, err := getAllMethodsForRoute(r, req)
if err == nil {
for _, v := range allMethods {
if v == http.MethodOptions {
w.Header().Set("Access-Control-Allow-Methods", strings.Join(allMethods, ","))
}
}
}
next.ServeHTTP(w, req)
})
}
}
// getAllMethodsForRoute returns all the methods from method matchers matching a given
// request.
func getAllMethodsForRoute(r *Router, req *http.Request) ([]string, error) {
var allMethods []string
err := r.Walk(func(route *Route, _ *Router, _ []*Route) error {
for _, m := range route.matchers {
if _, ok := m.(*routeRegexp); ok {
if m.Match(req, &RouteMatch{}) {
methods, err := route.GetMethods()
if err != nil {
return err
}
allMethods = append(allMethods, methods...)
}
break
}
}
return nil
})
return allMethods, err
}

607
vendor/github.com/gorilla/mux/mux.go generated vendored
View File

@ -1,607 +0,0 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mux
import (
"errors"
"fmt"
"net/http"
"path"
"regexp"
)
var (
// ErrMethodMismatch is returned when the method in the request does not match
// the method defined against the route.
ErrMethodMismatch = errors.New("method is not allowed")
// ErrNotFound is returned when no route match is found.
ErrNotFound = errors.New("no matching route was found")
)
// NewRouter returns a new router instance.
func NewRouter() *Router {
return &Router{namedRoutes: make(map[string]*Route)}
}
// Router registers routes to be matched and dispatches a handler.
//
// It implements the http.Handler interface, so it can be registered to serve
// requests:
//
// var router = mux.NewRouter()
//
// func main() {
// http.Handle("/", router)
// }
//
// Or, for Google App Engine, register it in a init() function:
//
// func init() {
// http.Handle("/", router)
// }
//
// This will send all incoming requests to the router.
type Router struct {
// Configurable Handler to be used when no route matches.
NotFoundHandler http.Handler
// Configurable Handler to be used when the request method does not match the route.
MethodNotAllowedHandler http.Handler
// Routes to be matched, in order.
routes []*Route
// Routes by name for URL building.
namedRoutes map[string]*Route
// If true, do not clear the request context after handling the request.
//
// Deprecated: No effect when go1.7+ is used, since the context is stored
// on the request itself.
KeepContext bool
// Slice of middlewares to be called after a match is found
middlewares []middleware
// configuration shared with `Route`
routeConf
}
// common route configuration shared between `Router` and `Route`
type routeConf struct {
// If true, "/path/foo%2Fbar/to" will match the path "/path/{var}/to"
useEncodedPath bool
// If true, when the path pattern is "/path/", accessing "/path" will
// redirect to the former and vice versa.
strictSlash bool
// If true, when the path pattern is "/path//to", accessing "/path//to"
// will not redirect
skipClean bool
// Manager for the variables from host and path.
regexp routeRegexpGroup
// List of matchers.
matchers []matcher
// The scheme used when building URLs.
buildScheme string
buildVarsFunc BuildVarsFunc
}
// returns an effective deep copy of `routeConf`
func copyRouteConf(r routeConf) routeConf {
c := r
if r.regexp.path != nil {
c.regexp.path = copyRouteRegexp(r.regexp.path)
}
if r.regexp.host != nil {
c.regexp.host = copyRouteRegexp(r.regexp.host)
}
c.regexp.queries = make([]*routeRegexp, 0, len(r.regexp.queries))
for _, q := range r.regexp.queries {
c.regexp.queries = append(c.regexp.queries, copyRouteRegexp(q))
}
c.matchers = make([]matcher, 0, len(r.matchers))
for _, m := range r.matchers {
c.matchers = append(c.matchers, m)
}
return c
}
func copyRouteRegexp(r *routeRegexp) *routeRegexp {
c := *r
return &c
}
// Match attempts to match the given request against the router's registered routes.
//
// If the request matches a route of this router or one of its subrouters the Route,
// Handler, and Vars fields of the the match argument are filled and this function
// returns true.
//
// If the request does not match any of this router's or its subrouters' routes
// then this function returns false. If available, a reason for the match failure
// will be filled in the match argument's MatchErr field. If the match failure type
// (eg: not found) has a registered handler, the handler is assigned to the Handler
// field of the match argument.
func (r *Router) Match(req *http.Request, match *RouteMatch) bool {
for _, route := range r.routes {
if route.Match(req, match) {
// Build middleware chain if no error was found
if match.MatchErr == nil {
for i := len(r.middlewares) - 1; i >= 0; i-- {
match.Handler = r.middlewares[i].Middleware(match.Handler)
}
}
return true
}
}
if match.MatchErr == ErrMethodMismatch {
if r.MethodNotAllowedHandler != nil {
match.Handler = r.MethodNotAllowedHandler
return true
}
return false
}
// Closest match for a router (includes sub-routers)
if r.NotFoundHandler != nil {
match.Handler = r.NotFoundHandler
match.MatchErr = ErrNotFound
return true
}
match.MatchErr = ErrNotFound
return false
}
// ServeHTTP dispatches the handler registered in the matched route.
//
// When there is a match, the route variables can be retrieved calling
// mux.Vars(request).
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if !r.skipClean {
path := req.URL.Path
if r.useEncodedPath {
path = req.URL.EscapedPath()
}
// Clean path to canonical form and redirect.
if p := cleanPath(path); p != path {
// Added 3 lines (Philip Schlump) - It was dropping the query string and #whatever from query.
// This matches with fix in go 1.2 r.c. 4 for same problem. Go Issue:
// http://code.google.com/p/go/issues/detail?id=5252
url := *req.URL
url.Path = p
p = url.String()
w.Header().Set("Location", p)
w.WriteHeader(http.StatusMovedPermanently)
return
}
}
var match RouteMatch
var handler http.Handler
if r.Match(req, &match) {
handler = match.Handler
req = setVars(req, match.Vars)
req = setCurrentRoute(req, match.Route)
}
if handler == nil && match.MatchErr == ErrMethodMismatch {
handler = methodNotAllowedHandler()
}
if handler == nil {
handler = http.NotFoundHandler()
}
handler.ServeHTTP(w, req)
}
// Get returns a route registered with the given name.
func (r *Router) Get(name string) *Route {
return r.namedRoutes[name]
}
// GetRoute returns a route registered with the given name. This method
// was renamed to Get() and remains here for backwards compatibility.
func (r *Router) GetRoute(name string) *Route {
return r.namedRoutes[name]
}
// StrictSlash defines the trailing slash behavior for new routes. The initial
// value is false.
//
// When true, if the route path is "/path/", accessing "/path" will perform a redirect
// to the former and vice versa. In other words, your application will always
// see the path as specified in the route.
//
// When false, if the route path is "/path", accessing "/path/" will not match
// this route and vice versa.
//
// The re-direct is a HTTP 301 (Moved Permanently). Note that when this is set for
// routes with a non-idempotent method (e.g. POST, PUT), the subsequent re-directed
// request will be made as a GET by most clients. Use middleware or client settings
// to modify this behaviour as needed.
//
// Special case: when a route sets a path prefix using the PathPrefix() method,
// strict slash is ignored for that route because the redirect behavior can't
// be determined from a prefix alone. However, any subrouters created from that
// route inherit the original StrictSlash setting.
func (r *Router) StrictSlash(value bool) *Router {
r.strictSlash = value
return r
}
// SkipClean defines the path cleaning behaviour for new routes. The initial
// value is false. Users should be careful about which routes are not cleaned
//
// When true, if the route path is "/path//to", it will remain with the double
// slash. This is helpful if you have a route like: /fetch/http://xkcd.com/534/
//
// When false, the path will be cleaned, so /fetch/http://xkcd.com/534/ will
// become /fetch/http/xkcd.com/534
func (r *Router) SkipClean(value bool) *Router {
r.skipClean = value
return r
}
// UseEncodedPath tells the router to match the encoded original path
// to the routes.
// For eg. "/path/foo%2Fbar/to" will match the path "/path/{var}/to".
//
// If not called, the router will match the unencoded path to the routes.
// For eg. "/path/foo%2Fbar/to" will match the path "/path/foo/bar/to"
func (r *Router) UseEncodedPath() *Router {
r.useEncodedPath = true
return r
}
// ----------------------------------------------------------------------------
// Route factories
// ----------------------------------------------------------------------------
// NewRoute registers an empty route.
func (r *Router) NewRoute() *Route {
// initialize a route with a copy of the parent router's configuration
route := &Route{routeConf: copyRouteConf(r.routeConf), namedRoutes: r.namedRoutes}
r.routes = append(r.routes, route)
return route
}
// Name registers a new route with a name.
// See Route.Name().
func (r *Router) Name(name string) *Route {
return r.NewRoute().Name(name)
}
// Handle registers a new route with a matcher for the URL path.
// See Route.Path() and Route.Handler().
func (r *Router) Handle(path string, handler http.Handler) *Route {
return r.NewRoute().Path(path).Handler(handler)
}
// HandleFunc registers a new route with a matcher for the URL path.
// See Route.Path() and Route.HandlerFunc().
func (r *Router) HandleFunc(path string, f func(http.ResponseWriter,
*http.Request)) *Route {
return r.NewRoute().Path(path).HandlerFunc(f)
}
// Headers registers a new route with a matcher for request header values.
// See Route.Headers().
func (r *Router) Headers(pairs ...string) *Route {
return r.NewRoute().Headers(pairs...)
}
// Host registers a new route with a matcher for the URL host.
// See Route.Host().
func (r *Router) Host(tpl string) *Route {
return r.NewRoute().Host(tpl)
}
// MatcherFunc registers a new route with a custom matcher function.
// See Route.MatcherFunc().
func (r *Router) MatcherFunc(f MatcherFunc) *Route {
return r.NewRoute().MatcherFunc(f)
}
// Methods registers a new route with a matcher for HTTP methods.
// See Route.Methods().
func (r *Router) Methods(methods ...string) *Route {
return r.NewRoute().Methods(methods...)
}
// Path registers a new route with a matcher for the URL path.
// See Route.Path().
func (r *Router) Path(tpl string) *Route {
return r.NewRoute().Path(tpl)
}
// PathPrefix registers a new route with a matcher for the URL path prefix.
// See Route.PathPrefix().
func (r *Router) PathPrefix(tpl string) *Route {
return r.NewRoute().PathPrefix(tpl)
}
// Queries registers a new route with a matcher for URL query values.
// See Route.Queries().
func (r *Router) Queries(pairs ...string) *Route {
return r.NewRoute().Queries(pairs...)
}
// Schemes registers a new route with a matcher for URL schemes.
// See Route.Schemes().
func (r *Router) Schemes(schemes ...string) *Route {
return r.NewRoute().Schemes(schemes...)
}
// BuildVarsFunc registers a new route with a custom function for modifying
// route variables before building a URL.
func (r *Router) BuildVarsFunc(f BuildVarsFunc) *Route {
return r.NewRoute().BuildVarsFunc(f)
}
// Walk walks the router and all its sub-routers, calling walkFn for each route
// in the tree. The routes are walked in the order they were added. Sub-routers
// are explored depth-first.
func (r *Router) Walk(walkFn WalkFunc) error {
return r.walk(walkFn, []*Route{})
}
// SkipRouter is used as a return value from WalkFuncs to indicate that the
// router that walk is about to descend down to should be skipped.
var SkipRouter = errors.New("skip this router")
// WalkFunc is the type of the function called for each route visited by Walk.
// At every invocation, it is given the current route, and the current router,
// and a list of ancestor routes that lead to the current route.
type WalkFunc func(route *Route, router *Router, ancestors []*Route) error
func (r *Router) walk(walkFn WalkFunc, ancestors []*Route) error {
for _, t := range r.routes {
err := walkFn(t, r, ancestors)
if err == SkipRouter {
continue
}
if err != nil {
return err
}
for _, sr := range t.matchers {
if h, ok := sr.(*Router); ok {
ancestors = append(ancestors, t)
err := h.walk(walkFn, ancestors)
if err != nil {
return err
}
ancestors = ancestors[:len(ancestors)-1]
}
}
if h, ok := t.handler.(*Router); ok {
ancestors = append(ancestors, t)
err := h.walk(walkFn, ancestors)
if err != nil {
return err
}
ancestors = ancestors[:len(ancestors)-1]
}
}
return nil
}
// ----------------------------------------------------------------------------
// Context
// ----------------------------------------------------------------------------
// RouteMatch stores information about a matched route.
type RouteMatch struct {
Route *Route
Handler http.Handler
Vars map[string]string
// MatchErr is set to appropriate matching error
// It is set to ErrMethodMismatch if there is a mismatch in
// the request method and route method
MatchErr error
}
type contextKey int
const (
varsKey contextKey = iota
routeKey
)
// Vars returns the route variables for the current request, if any.
func Vars(r *http.Request) map[string]string {
if rv := contextGet(r, varsKey); rv != nil {
return rv.(map[string]string)
}
return nil
}
// CurrentRoute returns the matched route for the current request, if any.
// This only works when called inside the handler of the matched route
// because the matched route is stored in the request context which is cleared
// after the handler returns, unless the KeepContext option is set on the
// Router.
func CurrentRoute(r *http.Request) *Route {
if rv := contextGet(r, routeKey); rv != nil {
return rv.(*Route)
}
return nil
}
func setVars(r *http.Request, val interface{}) *http.Request {
return contextSet(r, varsKey, val)
}
func setCurrentRoute(r *http.Request, val interface{}) *http.Request {
return contextSet(r, routeKey, val)
}
// ----------------------------------------------------------------------------
// Helpers
// ----------------------------------------------------------------------------
// cleanPath returns the canonical path for p, eliminating . and .. elements.
// Borrowed from the net/http package.
func cleanPath(p string) string {
if p == "" {
return "/"
}
if p[0] != '/' {
p = "/" + p
}
np := path.Clean(p)
// path.Clean removes trailing slash except for root;
// put the trailing slash back if necessary.
if p[len(p)-1] == '/' && np != "/" {
np += "/"
}
return np
}
// uniqueVars returns an error if two slices contain duplicated strings.
func uniqueVars(s1, s2 []string) error {
for _, v1 := range s1 {
for _, v2 := range s2 {
if v1 == v2 {
return fmt.Errorf("mux: duplicated route variable %q", v2)
}
}
}
return nil
}
// checkPairs returns the count of strings passed in, and an error if
// the count is not an even number.
func checkPairs(pairs ...string) (int, error) {
length := len(pairs)
if length%2 != 0 {
return length, fmt.Errorf(
"mux: number of parameters must be multiple of 2, got %v", pairs)
}
return length, nil
}
// mapFromPairsToString converts variadic string parameters to a
// string to string map.
func mapFromPairsToString(pairs ...string) (map[string]string, error) {
length, err := checkPairs(pairs...)
if err != nil {
return nil, err
}
m := make(map[string]string, length/2)
for i := 0; i < length; i += 2 {
m[pairs[i]] = pairs[i+1]
}
return m, nil
}
// mapFromPairsToRegex converts variadic string parameters to a
// string to regex map.
func mapFromPairsToRegex(pairs ...string) (map[string]*regexp.Regexp, error) {
length, err := checkPairs(pairs...)
if err != nil {
return nil, err
}
m := make(map[string]*regexp.Regexp, length/2)
for i := 0; i < length; i += 2 {
regex, err := regexp.Compile(pairs[i+1])
if err != nil {
return nil, err
}
m[pairs[i]] = regex
}
return m, nil
}
// matchInArray returns true if the given string value is in the array.
func matchInArray(arr []string, value string) bool {
for _, v := range arr {
if v == value {
return true
}
}
return false
}
// matchMapWithString returns true if the given key/value pairs exist in a given map.
func matchMapWithString(toCheck map[string]string, toMatch map[string][]string, canonicalKey bool) bool {
for k, v := range toCheck {
// Check if key exists.
if canonicalKey {
k = http.CanonicalHeaderKey(k)
}
if values := toMatch[k]; values == nil {
return false
} else if v != "" {
// If value was defined as an empty string we only check that the
// key exists. Otherwise we also check for equality.
valueExists := false
for _, value := range values {
if v == value {
valueExists = true
break
}
}
if !valueExists {
return false
}
}
}
return true
}
// matchMapWithRegex returns true if the given key/value pairs exist in a given map compiled against
// the given regex
func matchMapWithRegex(toCheck map[string]*regexp.Regexp, toMatch map[string][]string, canonicalKey bool) bool {
for k, v := range toCheck {
// Check if key exists.
if canonicalKey {
k = http.CanonicalHeaderKey(k)
}
if values := toMatch[k]; values == nil {
return false
} else if v != nil {
// If value was defined as an empty string we only check that the
// key exists. Otherwise we also check for equality.
valueExists := false
for _, value := range values {
if v.MatchString(value) {
valueExists = true
break
}
}
if !valueExists {
return false
}
}
}
return true
}
// methodNotAllowed replies to the request with an HTTP status code 405.
func methodNotAllowed(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusMethodNotAllowed)
}
// methodNotAllowedHandler returns a simple request handler
// that replies to each request with a status code 405.
func methodNotAllowedHandler() http.Handler { return http.HandlerFunc(methodNotAllowed) }

View File

@ -1,345 +0,0 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mux
import (
"bytes"
"fmt"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
)
type routeRegexpOptions struct {
strictSlash bool
useEncodedPath bool
}
type regexpType int
const (
regexpTypePath regexpType = 0
regexpTypeHost regexpType = 1
regexpTypePrefix regexpType = 2
regexpTypeQuery regexpType = 3
)
// newRouteRegexp parses a route template and returns a routeRegexp,
// used to match a host, a path or a query string.
//
// It will extract named variables, assemble a regexp to be matched, create
// a "reverse" template to build URLs and compile regexps to validate variable
// values used in URL building.
//
// Previously we accepted only Python-like identifiers for variable
// names ([a-zA-Z_][a-zA-Z0-9_]*), but currently the only restriction is that
// name and pattern can't be empty, and names can't contain a colon.
func newRouteRegexp(tpl string, typ regexpType, options routeRegexpOptions) (*routeRegexp, error) {
// Check if it is well-formed.
idxs, errBraces := braceIndices(tpl)
if errBraces != nil {
return nil, errBraces
}
// Backup the original.
template := tpl
// Now let's parse it.
defaultPattern := "[^/]+"
if typ == regexpTypeQuery {
defaultPattern = ".*"
} else if typ == regexpTypeHost {
defaultPattern = "[^.]+"
}
// Only match strict slash if not matching
if typ != regexpTypePath {
options.strictSlash = false
}
// Set a flag for strictSlash.
endSlash := false
if options.strictSlash && strings.HasSuffix(tpl, "/") {
tpl = tpl[:len(tpl)-1]
endSlash = true
}
varsN := make([]string, len(idxs)/2)
varsR := make([]*regexp.Regexp, len(idxs)/2)
pattern := bytes.NewBufferString("")
pattern.WriteByte('^')
reverse := bytes.NewBufferString("")
var end int
var err error
for i := 0; i < len(idxs); i += 2 {
// Set all values we are interested in.
raw := tpl[end:idxs[i]]
end = idxs[i+1]
parts := strings.SplitN(tpl[idxs[i]+1:end-1], ":", 2)
name := parts[0]
patt := defaultPattern
if len(parts) == 2 {
patt = parts[1]
}
// Name or pattern can't be empty.
if name == "" || patt == "" {
return nil, fmt.Errorf("mux: missing name or pattern in %q",
tpl[idxs[i]:end])
}
// Build the regexp pattern.
fmt.Fprintf(pattern, "%s(?P<%s>%s)", regexp.QuoteMeta(raw), varGroupName(i/2), patt)
// Build the reverse template.
fmt.Fprintf(reverse, "%s%%s", raw)
// Append variable name and compiled pattern.
varsN[i/2] = name
varsR[i/2], err = regexp.Compile(fmt.Sprintf("^%s$", patt))
if err != nil {
return nil, err
}
}
// Add the remaining.
raw := tpl[end:]
pattern.WriteString(regexp.QuoteMeta(raw))
if options.strictSlash {
pattern.WriteString("[/]?")
}
if typ == regexpTypeQuery {
// Add the default pattern if the query value is empty
if queryVal := strings.SplitN(template, "=", 2)[1]; queryVal == "" {
pattern.WriteString(defaultPattern)
}
}
if typ != regexpTypePrefix {
pattern.WriteByte('$')
}
var wildcardHostPort bool
if typ == regexpTypeHost {
if !strings.Contains(pattern.String(), ":") {
wildcardHostPort = true
}
}
reverse.WriteString(raw)
if endSlash {
reverse.WriteByte('/')
}
// Compile full regexp.
reg, errCompile := regexp.Compile(pattern.String())
if errCompile != nil {
return nil, errCompile
}
// Check for capturing groups which used to work in older versions
if reg.NumSubexp() != len(idxs)/2 {
panic(fmt.Sprintf("route %s contains capture groups in its regexp. ", template) +
"Only non-capturing groups are accepted: e.g. (?:pattern) instead of (pattern)")
}
// Done!
return &routeRegexp{
template: template,
regexpType: typ,
options: options,
regexp: reg,
reverse: reverse.String(),
varsN: varsN,
varsR: varsR,
wildcardHostPort: wildcardHostPort,
}, nil
}
// routeRegexp stores a regexp to match a host or path and information to
// collect and validate route variables.
type routeRegexp struct {
// The unmodified template.
template string
// The type of match
regexpType regexpType
// Options for matching
options routeRegexpOptions
// Expanded regexp.
regexp *regexp.Regexp
// Reverse template.
reverse string
// Variable names.
varsN []string
// Variable regexps (validators).
varsR []*regexp.Regexp
// Wildcard host-port (no strict port match in hostname)
wildcardHostPort bool
}
// Match matches the regexp against the URL host or path.
func (r *routeRegexp) Match(req *http.Request, match *RouteMatch) bool {
if r.regexpType == regexpTypeHost {
host := getHost(req)
if r.wildcardHostPort {
// Don't be strict on the port match
if i := strings.Index(host, ":"); i != -1 {
host = host[:i]
}
}
return r.regexp.MatchString(host)
} else {
if r.regexpType == regexpTypeQuery {
return r.matchQueryString(req)
}
path := req.URL.Path
if r.options.useEncodedPath {
path = req.URL.EscapedPath()
}
return r.regexp.MatchString(path)
}
}
// url builds a URL part using the given values.
func (r *routeRegexp) url(values map[string]string) (string, error) {
urlValues := make([]interface{}, len(r.varsN))
for k, v := range r.varsN {
value, ok := values[v]
if !ok {
return "", fmt.Errorf("mux: missing route variable %q", v)
}
if r.regexpType == regexpTypeQuery {
value = url.QueryEscape(value)
}
urlValues[k] = value
}
rv := fmt.Sprintf(r.reverse, urlValues...)
if !r.regexp.MatchString(rv) {
// The URL is checked against the full regexp, instead of checking
// individual variables. This is faster but to provide a good error
// message, we check individual regexps if the URL doesn't match.
for k, v := range r.varsN {
if !r.varsR[k].MatchString(values[v]) {
return "", fmt.Errorf(
"mux: variable %q doesn't match, expected %q", values[v],
r.varsR[k].String())
}
}
}
return rv, nil
}
// getURLQuery returns a single query parameter from a request URL.
// For a URL with foo=bar&baz=ding, we return only the relevant key
// value pair for the routeRegexp.
func (r *routeRegexp) getURLQuery(req *http.Request) string {
if r.regexpType != regexpTypeQuery {
return ""
}
templateKey := strings.SplitN(r.template, "=", 2)[0]
for key, vals := range req.URL.Query() {
if key == templateKey && len(vals) > 0 {
return key + "=" + vals[0]
}
}
return ""
}
func (r *routeRegexp) matchQueryString(req *http.Request) bool {
return r.regexp.MatchString(r.getURLQuery(req))
}
// braceIndices returns the first level curly brace indices from a string.
// It returns an error in case of unbalanced braces.
func braceIndices(s string) ([]int, error) {
var level, idx int
var idxs []int
for i := 0; i < len(s); i++ {
switch s[i] {
case '{':
if level++; level == 1 {
idx = i
}
case '}':
if level--; level == 0 {
idxs = append(idxs, idx, i+1)
} else if level < 0 {
return nil, fmt.Errorf("mux: unbalanced braces in %q", s)
}
}
}
if level != 0 {
return nil, fmt.Errorf("mux: unbalanced braces in %q", s)
}
return idxs, nil
}
// varGroupName builds a capturing group name for the indexed variable.
func varGroupName(idx int) string {
return "v" + strconv.Itoa(idx)
}
// ----------------------------------------------------------------------------
// routeRegexpGroup
// ----------------------------------------------------------------------------
// routeRegexpGroup groups the route matchers that carry variables.
type routeRegexpGroup struct {
host *routeRegexp
path *routeRegexp
queries []*routeRegexp
}
// setMatch extracts the variables from the URL once a route matches.
func (v routeRegexpGroup) setMatch(req *http.Request, m *RouteMatch, r *Route) {
// Store host variables.
if v.host != nil {
host := getHost(req)
matches := v.host.regexp.FindStringSubmatchIndex(host)
if len(matches) > 0 {
extractVars(host, matches, v.host.varsN, m.Vars)
}
}
path := req.URL.Path
if r.useEncodedPath {
path = req.URL.EscapedPath()
}
// Store path variables.
if v.path != nil {
matches := v.path.regexp.FindStringSubmatchIndex(path)
if len(matches) > 0 {
extractVars(path, matches, v.path.varsN, m.Vars)
// Check if we should redirect.
if v.path.options.strictSlash {
p1 := strings.HasSuffix(path, "/")
p2 := strings.HasSuffix(v.path.template, "/")
if p1 != p2 {
u, _ := url.Parse(req.URL.String())
if p1 {
u.Path = u.Path[:len(u.Path)-1]
} else {
u.Path += "/"
}
m.Handler = http.RedirectHandler(u.String(), http.StatusMovedPermanently)
}
}
}
}
// Store query string variables.
for _, q := range v.queries {
queryURL := q.getURLQuery(req)
matches := q.regexp.FindStringSubmatchIndex(queryURL)
if len(matches) > 0 {
extractVars(queryURL, matches, q.varsN, m.Vars)
}
}
}
// getHost tries its best to return the request host.
// According to section 14.23 of RFC 2616 the Host header
// can include the port number if the default value of 80 is not used.
func getHost(r *http.Request) string {
if r.URL.IsAbs() {
return r.URL.Host
}
return r.Host
}
func extractVars(input string, matches []int, names []string, output map[string]string) {
for i, name := range names {
output[name] = input[matches[2*i+2]:matches[2*i+3]]
}
}

View File

@ -1,710 +0,0 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mux
import (
"errors"
"fmt"
"net/http"
"net/url"
"regexp"
"strings"
)
// Route stores information to match a request and build URLs.
type Route struct {
// Request handler for the route.
handler http.Handler
// If true, this route never matches: it is only used to build URLs.
buildOnly bool
// The name used to build URLs.
name string
// Error resulted from building a route.
err error
// "global" reference to all named routes
namedRoutes map[string]*Route
// config possibly passed in from `Router`
routeConf
}
// SkipClean reports whether path cleaning is enabled for this route via
// Router.SkipClean.
func (r *Route) SkipClean() bool {
return r.skipClean
}
// Match matches the route against the request.
func (r *Route) Match(req *http.Request, match *RouteMatch) bool {
if r.buildOnly || r.err != nil {
return false
}
var matchErr error
// Match everything.
for _, m := range r.matchers {
if matched := m.Match(req, match); !matched {
if _, ok := m.(methodMatcher); ok {
matchErr = ErrMethodMismatch
continue
}
// Ignore ErrNotFound errors. These errors arise from match call
// to Subrouters.
//
// This prevents subsequent matching subrouters from failing to
// run middleware. If not ignored, the middleware would see a
// non-nil MatchErr and be skipped, even when there was a
// matching route.
if match.MatchErr == ErrNotFound {
match.MatchErr = nil
}
matchErr = nil
return false
}
}
if matchErr != nil {
match.MatchErr = matchErr
return false
}
if match.MatchErr == ErrMethodMismatch {
// We found a route which matches request method, clear MatchErr
match.MatchErr = nil
// Then override the mis-matched handler
match.Handler = r.handler
}
// Yay, we have a match. Let's collect some info about it.
if match.Route == nil {
match.Route = r
}
if match.Handler == nil {
match.Handler = r.handler
}
if match.Vars == nil {
match.Vars = make(map[string]string)
}
// Set variables.
r.regexp.setMatch(req, match, r)
return true
}
// ----------------------------------------------------------------------------
// Route attributes
// ----------------------------------------------------------------------------
// GetError returns an error resulted from building the route, if any.
func (r *Route) GetError() error {
return r.err
}
// BuildOnly sets the route to never match: it is only used to build URLs.
func (r *Route) BuildOnly() *Route {
r.buildOnly = true
return r
}
// Handler --------------------------------------------------------------------
// Handler sets a handler for the route.
func (r *Route) Handler(handler http.Handler) *Route {
if r.err == nil {
r.handler = handler
}
return r
}
// HandlerFunc sets a handler function for the route.
func (r *Route) HandlerFunc(f func(http.ResponseWriter, *http.Request)) *Route {
return r.Handler(http.HandlerFunc(f))
}
// GetHandler returns the handler for the route, if any.
func (r *Route) GetHandler() http.Handler {
return r.handler
}
// Name -----------------------------------------------------------------------
// Name sets the name for the route, used to build URLs.
// It is an error to call Name more than once on a route.
func (r *Route) Name(name string) *Route {
if r.name != "" {
r.err = fmt.Errorf("mux: route already has name %q, can't set %q",
r.name, name)
}
if r.err == nil {
r.name = name
r.namedRoutes[name] = r
}
return r
}
// GetName returns the name for the route, if any.
func (r *Route) GetName() string {
return r.name
}
// ----------------------------------------------------------------------------
// Matchers
// ----------------------------------------------------------------------------
// matcher types try to match a request.
type matcher interface {
Match(*http.Request, *RouteMatch) bool
}
// addMatcher adds a matcher to the route.
func (r *Route) addMatcher(m matcher) *Route {
if r.err == nil {
r.matchers = append(r.matchers, m)
}
return r
}
// addRegexpMatcher adds a host or path matcher and builder to a route.
func (r *Route) addRegexpMatcher(tpl string, typ regexpType) error {
if r.err != nil {
return r.err
}
if typ == regexpTypePath || typ == regexpTypePrefix {
if len(tpl) > 0 && tpl[0] != '/' {
return fmt.Errorf("mux: path must start with a slash, got %q", tpl)
}
if r.regexp.path != nil {
tpl = strings.TrimRight(r.regexp.path.template, "/") + tpl
}
}
rr, err := newRouteRegexp(tpl, typ, routeRegexpOptions{
strictSlash: r.strictSlash,
useEncodedPath: r.useEncodedPath,
})
if err != nil {
return err
}
for _, q := range r.regexp.queries {
if err = uniqueVars(rr.varsN, q.varsN); err != nil {
return err
}
}
if typ == regexpTypeHost {
if r.regexp.path != nil {
if err = uniqueVars(rr.varsN, r.regexp.path.varsN); err != nil {
return err
}
}
r.regexp.host = rr
} else {
if r.regexp.host != nil {
if err = uniqueVars(rr.varsN, r.regexp.host.varsN); err != nil {
return err
}
}
if typ == regexpTypeQuery {
r.regexp.queries = append(r.regexp.queries, rr)
} else {
r.regexp.path = rr
}
}
r.addMatcher(rr)
return nil
}
// Headers --------------------------------------------------------------------
// headerMatcher matches the request against header values.
type headerMatcher map[string]string
func (m headerMatcher) Match(r *http.Request, match *RouteMatch) bool {
return matchMapWithString(m, r.Header, true)
}
// Headers adds a matcher for request header values.
// It accepts a sequence of key/value pairs to be matched. For example:
//
// r := mux.NewRouter()
// r.Headers("Content-Type", "application/json",
// "X-Requested-With", "XMLHttpRequest")
//
// The above route will only match if both request header values match.
// If the value is an empty string, it will match any value if the key is set.
func (r *Route) Headers(pairs ...string) *Route {
if r.err == nil {
var headers map[string]string
headers, r.err = mapFromPairsToString(pairs...)
return r.addMatcher(headerMatcher(headers))
}
return r
}
// headerRegexMatcher matches the request against the route given a regex for the header
type headerRegexMatcher map[string]*regexp.Regexp
func (m headerRegexMatcher) Match(r *http.Request, match *RouteMatch) bool {
return matchMapWithRegex(m, r.Header, true)
}
// HeadersRegexp accepts a sequence of key/value pairs, where the value has regex
// support. For example:
//
// r := mux.NewRouter()
// r.HeadersRegexp("Content-Type", "application/(text|json)",
// "X-Requested-With", "XMLHttpRequest")
//
// The above route will only match if both the request header matches both regular expressions.
// If the value is an empty string, it will match any value if the key is set.
// Use the start and end of string anchors (^ and $) to match an exact value.
func (r *Route) HeadersRegexp(pairs ...string) *Route {
if r.err == nil {
var headers map[string]*regexp.Regexp
headers, r.err = mapFromPairsToRegex(pairs...)
return r.addMatcher(headerRegexMatcher(headers))
}
return r
}
// Host -----------------------------------------------------------------------
// Host adds a matcher for the URL host.
// It accepts a template with zero or more URL variables enclosed by {}.
// Variables can define an optional regexp pattern to be matched:
//
// - {name} matches anything until the next dot.
//
// - {name:pattern} matches the given regexp pattern.
//
// For example:
//
// r := mux.NewRouter()
// r.Host("www.example.com")
// r.Host("{subdomain}.domain.com")
// r.Host("{subdomain:[a-z]+}.domain.com")
//
// Variable names must be unique in a given route. They can be retrieved
// calling mux.Vars(request).
func (r *Route) Host(tpl string) *Route {
r.err = r.addRegexpMatcher(tpl, regexpTypeHost)
return r
}
// MatcherFunc ----------------------------------------------------------------
// MatcherFunc is the function signature used by custom matchers.
type MatcherFunc func(*http.Request, *RouteMatch) bool
// Match returns the match for a given request.
func (m MatcherFunc) Match(r *http.Request, match *RouteMatch) bool {
return m(r, match)
}
// MatcherFunc adds a custom function to be used as request matcher.
func (r *Route) MatcherFunc(f MatcherFunc) *Route {
return r.addMatcher(f)
}
// Methods --------------------------------------------------------------------
// methodMatcher matches the request against HTTP methods.
type methodMatcher []string
func (m methodMatcher) Match(r *http.Request, match *RouteMatch) bool {
return matchInArray(m, r.Method)
}
// Methods adds a matcher for HTTP methods.
// It accepts a sequence of one or more methods to be matched, e.g.:
// "GET", "POST", "PUT".
func (r *Route) Methods(methods ...string) *Route {
for k, v := range methods {
methods[k] = strings.ToUpper(v)
}
return r.addMatcher(methodMatcher(methods))
}
// Path -----------------------------------------------------------------------
// Path adds a matcher for the URL path.
// It accepts a template with zero or more URL variables enclosed by {}. The
// template must start with a "/".
// Variables can define an optional regexp pattern to be matched:
//
// - {name} matches anything until the next slash.
//
// - {name:pattern} matches the given regexp pattern.
//
// For example:
//
// r := mux.NewRouter()
// r.Path("/products/").Handler(ProductsHandler)
// r.Path("/products/{key}").Handler(ProductsHandler)
// r.Path("/articles/{category}/{id:[0-9]+}").
// Handler(ArticleHandler)
//
// Variable names must be unique in a given route. They can be retrieved
// calling mux.Vars(request).
func (r *Route) Path(tpl string) *Route {
r.err = r.addRegexpMatcher(tpl, regexpTypePath)
return r
}
// PathPrefix -----------------------------------------------------------------
// PathPrefix adds a matcher for the URL path prefix. This matches if the given
// template is a prefix of the full URL path. See Route.Path() for details on
// the tpl argument.
//
// Note that it does not treat slashes specially ("/foobar/" will be matched by
// the prefix "/foo") so you may want to use a trailing slash here.
//
// Also note that the setting of Router.StrictSlash() has no effect on routes
// with a PathPrefix matcher.
func (r *Route) PathPrefix(tpl string) *Route {
r.err = r.addRegexpMatcher(tpl, regexpTypePrefix)
return r
}
// Query ----------------------------------------------------------------------
// Queries adds a matcher for URL query values.
// It accepts a sequence of key/value pairs. Values may define variables.
// For example:
//
// r := mux.NewRouter()
// r.Queries("foo", "bar", "id", "{id:[0-9]+}")
//
// The above route will only match if the URL contains the defined queries
// values, e.g.: ?foo=bar&id=42.
//
// If the value is an empty string, it will match any value if the key is set.
//
// Variables can define an optional regexp pattern to be matched:
//
// - {name} matches anything until the next slash.
//
// - {name:pattern} matches the given regexp pattern.
func (r *Route) Queries(pairs ...string) *Route {
length := len(pairs)
if length%2 != 0 {
r.err = fmt.Errorf(
"mux: number of parameters must be multiple of 2, got %v", pairs)
return nil
}
for i := 0; i < length; i += 2 {
if r.err = r.addRegexpMatcher(pairs[i]+"="+pairs[i+1], regexpTypeQuery); r.err != nil {
return r
}
}
return r
}
// Schemes --------------------------------------------------------------------
// schemeMatcher matches the request against URL schemes.
type schemeMatcher []string
func (m schemeMatcher) Match(r *http.Request, match *RouteMatch) bool {
return matchInArray(m, r.URL.Scheme)
}
// Schemes adds a matcher for URL schemes.
// It accepts a sequence of schemes to be matched, e.g.: "http", "https".
func (r *Route) Schemes(schemes ...string) *Route {
for k, v := range schemes {
schemes[k] = strings.ToLower(v)
}
if len(schemes) > 0 {
r.buildScheme = schemes[0]
}
return r.addMatcher(schemeMatcher(schemes))
}
// BuildVarsFunc --------------------------------------------------------------
// BuildVarsFunc is the function signature used by custom build variable
// functions (which can modify route variables before a route's URL is built).
type BuildVarsFunc func(map[string]string) map[string]string
// BuildVarsFunc adds a custom function to be used to modify build variables
// before a route's URL is built.
func (r *Route) BuildVarsFunc(f BuildVarsFunc) *Route {
if r.buildVarsFunc != nil {
// compose the old and new functions
old := r.buildVarsFunc
r.buildVarsFunc = func(m map[string]string) map[string]string {
return f(old(m))
}
} else {
r.buildVarsFunc = f
}
return r
}
// Subrouter ------------------------------------------------------------------
// Subrouter creates a subrouter for the route.
//
// It will test the inner routes only if the parent route matched. For example:
//
// r := mux.NewRouter()
// s := r.Host("www.example.com").Subrouter()
// s.HandleFunc("/products/", ProductsHandler)
// s.HandleFunc("/products/{key}", ProductHandler)
// s.HandleFunc("/articles/{category}/{id:[0-9]+}"), ArticleHandler)
//
// Here, the routes registered in the subrouter won't be tested if the host
// doesn't match.
func (r *Route) Subrouter() *Router {
// initialize a subrouter with a copy of the parent route's configuration
router := &Router{routeConf: copyRouteConf(r.routeConf), namedRoutes: r.namedRoutes}
r.addMatcher(router)
return router
}
// ----------------------------------------------------------------------------
// URL building
// ----------------------------------------------------------------------------
// URL builds a URL for the route.
//
// It accepts a sequence of key/value pairs for the route variables. For
// example, given this route:
//
// r := mux.NewRouter()
// r.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler).
// Name("article")
//
// ...a URL for it can be built using:
//
// url, err := r.Get("article").URL("category", "technology", "id", "42")
//
// ...which will return an url.URL with the following path:
//
// "/articles/technology/42"
//
// This also works for host variables:
//
// r := mux.NewRouter()
// r.Host("{subdomain}.domain.com").
// HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler).
// Name("article")
//
// // url.String() will be "http://news.domain.com/articles/technology/42"
// url, err := r.Get("article").URL("subdomain", "news",
// "category", "technology",
// "id", "42")
//
// All variables defined in the route are required, and their values must
// conform to the corresponding patterns.
func (r *Route) URL(pairs ...string) (*url.URL, error) {
if r.err != nil {
return nil, r.err
}
values, err := r.prepareVars(pairs...)
if err != nil {
return nil, err
}
var scheme, host, path string
queries := make([]string, 0, len(r.regexp.queries))
if r.regexp.host != nil {
if host, err = r.regexp.host.url(values); err != nil {
return nil, err
}
scheme = "http"
if r.buildScheme != "" {
scheme = r.buildScheme
}
}
if r.regexp.path != nil {
if path, err = r.regexp.path.url(values); err != nil {
return nil, err
}
}
for _, q := range r.regexp.queries {
var query string
if query, err = q.url(values); err != nil {
return nil, err
}
queries = append(queries, query)
}
return &url.URL{
Scheme: scheme,
Host: host,
Path: path,
RawQuery: strings.Join(queries, "&"),
}, nil
}
// URLHost builds the host part of the URL for a route. See Route.URL().
//
// The route must have a host defined.
func (r *Route) URLHost(pairs ...string) (*url.URL, error) {
if r.err != nil {
return nil, r.err
}
if r.regexp.host == nil {
return nil, errors.New("mux: route doesn't have a host")
}
values, err := r.prepareVars(pairs...)
if err != nil {
return nil, err
}
host, err := r.regexp.host.url(values)
if err != nil {
return nil, err
}
u := &url.URL{
Scheme: "http",
Host: host,
}
if r.buildScheme != "" {
u.Scheme = r.buildScheme
}
return u, nil
}
// URLPath builds the path part of the URL for a route. See Route.URL().
//
// The route must have a path defined.
func (r *Route) URLPath(pairs ...string) (*url.URL, error) {
if r.err != nil {
return nil, r.err
}
if r.regexp.path == nil {
return nil, errors.New("mux: route doesn't have a path")
}
values, err := r.prepareVars(pairs...)
if err != nil {
return nil, err
}
path, err := r.regexp.path.url(values)
if err != nil {
return nil, err
}
return &url.URL{
Path: path,
}, nil
}
// GetPathTemplate returns the template used to build the
// route match.
// This is useful for building simple REST API documentation and for instrumentation
// against third-party services.
// An error will be returned if the route does not define a path.
func (r *Route) GetPathTemplate() (string, error) {
if r.err != nil {
return "", r.err
}
if r.regexp.path == nil {
return "", errors.New("mux: route doesn't have a path")
}
return r.regexp.path.template, nil
}
// GetPathRegexp returns the expanded regular expression used to match route path.
// This is useful for building simple REST API documentation and for instrumentation
// against third-party services.
// An error will be returned if the route does not define a path.
func (r *Route) GetPathRegexp() (string, error) {
if r.err != nil {
return "", r.err
}
if r.regexp.path == nil {
return "", errors.New("mux: route does not have a path")
}
return r.regexp.path.regexp.String(), nil
}
// GetQueriesRegexp returns the expanded regular expressions used to match the
// route queries.
// This is useful for building simple REST API documentation and for instrumentation
// against third-party services.
// An error will be returned if the route does not have queries.
func (r *Route) GetQueriesRegexp() ([]string, error) {
if r.err != nil {
return nil, r.err
}
if r.regexp.queries == nil {
return nil, errors.New("mux: route doesn't have queries")
}
var queries []string
for _, query := range r.regexp.queries {
queries = append(queries, query.regexp.String())
}
return queries, nil
}
// GetQueriesTemplates returns the templates used to build the
// query matching.
// This is useful for building simple REST API documentation and for instrumentation
// against third-party services.
// An error will be returned if the route does not define queries.
func (r *Route) GetQueriesTemplates() ([]string, error) {
if r.err != nil {
return nil, r.err
}
if r.regexp.queries == nil {
return nil, errors.New("mux: route doesn't have queries")
}
var queries []string
for _, query := range r.regexp.queries {
queries = append(queries, query.template)
}
return queries, nil
}
// GetMethods returns the methods the route matches against
// This is useful for building simple REST API documentation and for instrumentation
// against third-party services.
// An error will be returned if route does not have methods.
func (r *Route) GetMethods() ([]string, error) {
if r.err != nil {
return nil, r.err
}
for _, m := range r.matchers {
if methods, ok := m.(methodMatcher); ok {
return []string(methods), nil
}
}
return nil, errors.New("mux: route doesn't have methods")
}
// GetHostTemplate returns the template used to build the
// route match.
// This is useful for building simple REST API documentation and for instrumentation
// against third-party services.
// An error will be returned if the route does not define a host.
func (r *Route) GetHostTemplate() (string, error) {
if r.err != nil {
return "", r.err
}
if r.regexp.host == nil {
return "", errors.New("mux: route doesn't have a host")
}
return r.regexp.host.template, nil
}
// prepareVars converts the route variable pairs into a map. If the route has a
// BuildVarsFunc, it is invoked.
func (r *Route) prepareVars(pairs ...string) (map[string]string, error) {
m, err := mapFromPairsToString(pairs...)
if err != nil {
return nil, err
}
return r.buildVars(m), nil
}
func (r *Route) buildVars(m map[string]string) map[string]string {
if r.buildVarsFunc != nil {
m = r.buildVarsFunc(m)
}
return m
}

View File

@ -1,19 +0,0 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mux
import "net/http"
// SetURLVars sets the URL variables for the given request, to be accessed via
// mux.Vars for testing route behaviour. Arguments are not modified, a shallow
// copy is returned.
//
// This API should only be used for testing purposes; it provides a way to
// inject variables into the request context. Alternatively, URL variables
// can be set by making a route that captures the required variables,
// starting a server and sending the request to that server.
func SetURLVars(r *http.Request, val map[string]string) *http.Request {
return setVars(r, val)
}

View File

@ -1,9 +0,0 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library(
name = "go_default_library",
srcs = ["backoff.go"],
importmap = "k8s.io/kops/vendor/github.com/jpillora/backoff",
importpath = "github.com/jpillora/backoff",
visibility = ["//visibility:public"],
)

View File

@ -1,21 +0,0 @@
The MIT License (MIT)
Copyright (c) 2017 Jaime Pillora
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,119 +0,0 @@
# Backoff
A simple exponential backoff counter in Go (Golang)
[![GoDoc](https://godoc.org/github.com/jpillora/backoff?status.svg)](https://godoc.org/github.com/jpillora/backoff) [![Circle CI](https://circleci.com/gh/jpillora/backoff.svg?style=shield)](https://circleci.com/gh/jpillora/backoff)
### Install
```
$ go get -v github.com/jpillora/backoff
```
### Usage
Backoff is a `time.Duration` counter. It starts at `Min`. After every call to `Duration()` it is multiplied by `Factor`. It is capped at `Max`. It returns to `Min` on every call to `Reset()`. `Jitter` adds randomness ([see below](#example-using-jitter)). Used in conjunction with the `time` package.
---
#### Simple example
``` go
b := &backoff.Backoff{
//These are the defaults
Min: 100 * time.Millisecond,
Max: 10 * time.Second,
Factor: 2,
Jitter: false,
}
fmt.Printf("%s\n", b.Duration())
fmt.Printf("%s\n", b.Duration())
fmt.Printf("%s\n", b.Duration())
fmt.Printf("Reset!\n")
b.Reset()
fmt.Printf("%s\n", b.Duration())
```
```
100ms
200ms
400ms
Reset!
100ms
```
---
#### Example using `net` package
``` go
b := &backoff.Backoff{
Max: 5 * time.Minute,
}
for {
conn, err := net.Dial("tcp", "example.com:5309")
if err != nil {
d := b.Duration()
fmt.Printf("%s, reconnecting in %s", err, d)
time.Sleep(d)
continue
}
//connected
b.Reset()
conn.Write([]byte("hello world!"))
// ... Read ... Write ... etc
conn.Close()
//disconnected
}
```
---
#### Example using `Jitter`
Enabling `Jitter` adds some randomization to the backoff durations. [See Amazon's writeup of performance gains using jitter](http://www.awsarchitectureblog.com/2015/03/backoff.html). Seeding is not necessary but doing so gives repeatable results.
```go
import "math/rand"
b := &backoff.Backoff{
Jitter: true,
}
rand.Seed(42)
fmt.Printf("%s\n", b.Duration())
fmt.Printf("%s\n", b.Duration())
fmt.Printf("%s\n", b.Duration())
fmt.Printf("Reset!\n")
b.Reset()
fmt.Printf("%s\n", b.Duration())
fmt.Printf("%s\n", b.Duration())
fmt.Printf("%s\n", b.Duration())
```
```
100ms
106.600049ms
281.228155ms
Reset!
100ms
104.381845ms
214.957989ms
```
#### Documentation
https://godoc.org/github.com/jpillora/backoff
#### Credits
Forked from [some JavaScript](https://github.com/segmentio/backo) written by [@tj](https://github.com/tj)

View File

@ -1,88 +0,0 @@
// Package backoff provides an exponential-backoff implementation.
package backoff
import (
"math"
"math/rand"
"time"
)
// Backoff is a time.Duration counter, starting at Min. After every call to
// the Duration method the current timing is multiplied by Factor, but it
// never exceeds Max.
//
// Backoff is not generally concurrent-safe, but the ForAttempt method can
// be used concurrently.
type Backoff struct {
//Factor is the multiplying factor for each increment step
attempt, Factor float64
//Jitter eases contention by randomizing backoff steps
Jitter bool
//Min and Max are the minimum and maximum values of the counter
Min, Max time.Duration
}
// Duration returns the duration for the current attempt before incrementing
// the attempt counter. See ForAttempt.
func (b *Backoff) Duration() time.Duration {
d := b.ForAttempt(b.attempt)
b.attempt++
return d
}
const maxInt64 = float64(math.MaxInt64 - 512)
// ForAttempt returns the duration for a specific attempt. This is useful if
// you have a large number of independent Backoffs, but don't want use
// unnecessary memory storing the Backoff parameters per Backoff. The first
// attempt should be 0.
//
// ForAttempt is concurrent-safe.
func (b *Backoff) ForAttempt(attempt float64) time.Duration {
// Zero-values are nonsensical, so we use
// them to apply defaults
min := b.Min
if min <= 0 {
min = 100 * time.Millisecond
}
max := b.Max
if max <= 0 {
max = 10 * time.Second
}
if min >= max {
// short-circuit
return max
}
factor := b.Factor
if factor <= 0 {
factor = 2
}
//calculate this duration
minf := float64(min)
durf := minf * math.Pow(factor, attempt)
if b.Jitter {
durf = rand.Float64()*(durf-minf) + minf
}
//ensure float64 wont overflow int64
if durf > maxInt64 {
return max
}
dur := time.Duration(durf)
//keep within bounds
if dur < min {
return min
} else if dur > max {
return max
}
return dur
}
// Reset restarts the current attempt counter at zero.
func (b *Backoff) Reset() {
b.attempt = 0
}
// Attempt returns the current attempt counter value.
func (b *Backoff) Attempt() float64 {
return b.attempt
}

View File

@ -1,2 +0,0 @@
[flake8]
max-line-length = 120

View File

@ -1,3 +0,0 @@
*.coverprofile
node_modules/
vendor

View File

@ -1,35 +0,0 @@
language: go
sudo: false
dist: bionic
osx_image: xcode10
go:
- 1.11.x
- 1.12.x
- 1.13.x
os:
- linux
- osx
env:
GO111MODULE=on
GOPROXY=https://proxy.golang.org
cache:
directories:
- node_modules
before_script:
- go get github.com/urfave/gfmrun/cmd/gfmrun
- go get golang.org/x/tools/cmd/goimports
- npm install markdown-toc
- go mod tidy
script:
- go run build.go vet
- go run build.go test
- go run build.go gfmrun docs/v1/manual.md
- go run build.go toc docs/v1/manual.md
after_success:
- bash <(curl -s https://codecov.io/bash)

View File

@ -1,38 +0,0 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library(
name = "go_default_library",
srcs = [
"app.go",
"category.go",
"cli.go",
"command.go",
"context.go",
"docs.go",
"errors.go",
"fish.go",
"flag.go",
"flag_bool.go",
"flag_bool_t.go",
"flag_duration.go",
"flag_float64.go",
"flag_generic.go",
"flag_int.go",
"flag_int64.go",
"flag_int64_slice.go",
"flag_int_slice.go",
"flag_string.go",
"flag_string_slice.go",
"flag_uint.go",
"flag_uint64.go",
"funcs.go",
"help.go",
"parse.go",
"sort.go",
"template.go",
],
importmap = "k8s.io/kops/vendor/github.com/urfave/cli",
importpath = "github.com/urfave/cli",
visibility = ["//visibility:public"],
deps = ["//vendor/github.com/cpuguy83/go-md2man/v2/md2man:go_default_library"],
)

View File

@ -1,74 +0,0 @@
# Contributor Covenant Code of Conduct
## Our Pledge
In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to making participation in our project and
our community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, gender identity and expression, level of experience,
education, socio-economic status, nationality, personal appearance, race,
religion, or sexual identity and orientation.
## Our Standards
Examples of behavior that contributes to creating a positive environment
include:
* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic
address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Our Responsibilities
Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.
Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned to this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.
## Scope
This Code of Conduct applies both within project spaces and in public spaces
when an individual is representing the project or its community. Examples of
representing a project or community include using an official project e-mail
address, posting via an official social media account, or acting as an appointed
representative at an online or offline event. Representation of a project may be
further defined and clarified by project maintainers.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting Dan Buch at dan@meatballhat.com. All complaints will be
reviewed and investigated and will result in a response that is deemed necessary
and appropriate to the circumstances. The project team is obligated to maintain
confidentiality with regard to the reporter of an incident. Further details of
specific enforcement policies may be posted separately.
Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
[homepage]: https://www.contributor-covenant.org

21
vendor/github.com/urfave/cli/LICENSE generated vendored
View File

@ -1,21 +0,0 @@
MIT License
Copyright (c) 2016 Jeremy Saenz & Contributors
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,70 +0,0 @@
cli
===
[![Build Status](https://travis-ci.org/urfave/cli.svg?branch=master)](https://travis-ci.org/urfave/cli)
[![Windows Build Status](https://ci.appveyor.com/api/projects/status/rtgk5xufi932pb2v?svg=true)](https://ci.appveyor.com/project/urfave/cli)
[![GoDoc](https://godoc.org/github.com/urfave/cli?status.svg)](https://godoc.org/github.com/urfave/cli)
[![codebeat](https://codebeat.co/badges/0a8f30aa-f975-404b-b878-5fab3ae1cc5f)](https://codebeat.co/projects/github-com-urfave-cli)
[![Go Report Card](https://goreportcard.com/badge/urfave/cli)](https://goreportcard.com/report/urfave/cli)
[![codecov](https://codecov.io/gh/urfave/cli/branch/master/graph/badge.svg)](https://codecov.io/gh/urfave/cli)
cli is a simple, fast, and fun package for building command line apps in Go. The
goal is to enable developers to write fast and distributable command line
applications in an expressive way.
## Usage Documentation
Usage documentation exists for each major version
- `v1` - [./docs/v1/manual.md](./docs/v1/manual.md)
- `v2` - 🚧 documentation for `v2` is WIP 🚧
## Installation
Make sure you have a working Go environment. Go version 1.10+ is supported. [See
the install instructions for Go](http://golang.org/doc/install.html).
### GOPATH
Make sure your `PATH` includes the `$GOPATH/bin` directory so your commands can
be easily used:
```
export PATH=$PATH:$GOPATH/bin
```
### Supported platforms
cli is tested against multiple versions of Go on Linux, and against the latest
released version of Go on OS X and Windows. For full details, see
[`./.travis.yml`](./.travis.yml) and [`./appveyor.yml`](./appveyor.yml).
### Using `v1` releases
```
$ go get github.com/urfave/cli
```
```go
...
import (
"github.com/urfave/cli"
)
...
```
### Using `v2` releases
**Warning**: `v2` is in a pre-release state.
```
$ go get github.com/urfave/cli.v2
```
```go
...
import (
"github.com/urfave/cli.v2" // imports as package "cli"
)
...
```

530
vendor/github.com/urfave/cli/app.go generated vendored
View File

@ -1,530 +0,0 @@
package cli
import (
"flag"
"fmt"
"io"
"os"
"path/filepath"
"sort"
"time"
)
var (
changeLogURL = "https://github.com/urfave/cli/blob/master/CHANGELOG.md"
appActionDeprecationURL = fmt.Sprintf("%s#deprecated-cli-app-action-signature", changeLogURL)
// unused variable. commented for now. will remove in future if agreed upon by everyone
//runAndExitOnErrorDeprecationURL = fmt.Sprintf("%s#deprecated-cli-app-runandexitonerror", changeLogURL)
contactSysadmin = "This is an error in the application. Please contact the distributor of this application if this is not you."
errInvalidActionType = NewExitError("ERROR invalid Action type. "+
fmt.Sprintf("Must be `func(*Context`)` or `func(*Context) error). %s", contactSysadmin)+
fmt.Sprintf("See %s", appActionDeprecationURL), 2)
)
// App is the main structure of a cli application. It is recommended that
// an app be created with the cli.NewApp() function
type App struct {
// The name of the program. Defaults to path.Base(os.Args[0])
Name string
// Full name of command for help, defaults to Name
HelpName string
// Description of the program.
Usage string
// Text to override the USAGE section of help
UsageText string
// Description of the program argument format.
ArgsUsage string
// Version of the program
Version string
// Description of the program
Description string
// List of commands to execute
Commands []Command
// List of flags to parse
Flags []Flag
// Boolean to enable bash completion commands
EnableBashCompletion bool
// Boolean to hide built-in help command
HideHelp bool
// Boolean to hide built-in version flag and the VERSION section of help
HideVersion bool
// Populate on app startup, only gettable through method Categories()
categories CommandCategories
// An action to execute when the bash-completion flag is set
BashComplete BashCompleteFunc
// An action to execute before any subcommands are run, but after the context is ready
// If a non-nil error is returned, no subcommands are run
Before BeforeFunc
// An action to execute after any subcommands are run, but after the subcommand has finished
// It is run even if Action() panics
After AfterFunc
// The action to execute when no subcommands are specified
// Expects a `cli.ActionFunc` but will accept the *deprecated* signature of `func(*cli.Context) {}`
// *Note*: support for the deprecated `Action` signature will be removed in a future version
Action interface{}
// Execute this function if the proper command cannot be found
CommandNotFound CommandNotFoundFunc
// Execute this function if an usage error occurs
OnUsageError OnUsageErrorFunc
// Compilation date
Compiled time.Time
// List of all authors who contributed
Authors []Author
// Copyright of the binary if any
Copyright string
// Name of Author (Note: Use App.Authors, this is deprecated)
Author string
// Email of Author (Note: Use App.Authors, this is deprecated)
Email string
// Writer writer to write output to
Writer io.Writer
// ErrWriter writes error output
ErrWriter io.Writer
// Execute this function to handle ExitErrors. If not provided, HandleExitCoder is provided to
// function as a default, so this is optional.
ExitErrHandler ExitErrHandlerFunc
// Other custom info
Metadata map[string]interface{}
// Carries a function which returns app specific info.
ExtraInfo func() map[string]string
// CustomAppHelpTemplate the text template for app help topic.
// cli.go uses text/template to render templates. You can
// render custom help text by setting this variable.
CustomAppHelpTemplate string
// Boolean to enable short-option handling so user can combine several
// single-character bool arguements into one
// i.e. foobar -o -v -> foobar -ov
UseShortOptionHandling bool
didSetup bool
}
// Tries to find out when this binary was compiled.
// Returns the current time if it fails to find it.
func compileTime() time.Time {
info, err := os.Stat(os.Args[0])
if err != nil {
return time.Now()
}
return info.ModTime()
}
// NewApp creates a new cli Application with some reasonable defaults for Name,
// Usage, Version and Action.
func NewApp() *App {
return &App{
Name: filepath.Base(os.Args[0]),
HelpName: filepath.Base(os.Args[0]),
Usage: "A new cli application",
UsageText: "",
Version: "0.0.0",
BashComplete: DefaultAppComplete,
Action: helpCommand.Action,
Compiled: compileTime(),
Writer: os.Stdout,
}
}
// Setup runs initialization code to ensure all data structures are ready for
// `Run` or inspection prior to `Run`. It is internally called by `Run`, but
// will return early if setup has already happened.
func (a *App) Setup() {
if a.didSetup {
return
}
a.didSetup = true
if a.Author != "" || a.Email != "" {
a.Authors = append(a.Authors, Author{Name: a.Author, Email: a.Email})
}
var newCmds []Command
for _, c := range a.Commands {
if c.HelpName == "" {
c.HelpName = fmt.Sprintf("%s %s", a.HelpName, c.Name)
}
newCmds = append(newCmds, c)
}
a.Commands = newCmds
if a.Command(helpCommand.Name) == nil && !a.HideHelp {
a.Commands = append(a.Commands, helpCommand)
if (HelpFlag != BoolFlag{}) {
a.appendFlag(HelpFlag)
}
}
if !a.HideVersion {
a.appendFlag(VersionFlag)
}
a.categories = CommandCategories{}
for _, command := range a.Commands {
a.categories = a.categories.AddCommand(command.Category, command)
}
sort.Sort(a.categories)
if a.Metadata == nil {
a.Metadata = make(map[string]interface{})
}
if a.Writer == nil {
a.Writer = os.Stdout
}
}
func (a *App) newFlagSet() (*flag.FlagSet, error) {
return flagSet(a.Name, a.Flags)
}
func (a *App) useShortOptionHandling() bool {
return a.UseShortOptionHandling
}
// Run is the entry point to the cli app. Parses the arguments slice and routes
// to the proper flag/args combination
func (a *App) Run(arguments []string) (err error) {
a.Setup()
// handle the completion flag separately from the flagset since
// completion could be attempted after a flag, but before its value was put
// on the command line. this causes the flagset to interpret the completion
// flag name as the value of the flag before it which is undesirable
// note that we can only do this because the shell autocomplete function
// always appends the completion flag at the end of the command
shellComplete, arguments := checkShellCompleteFlag(a, arguments)
set, err := a.newFlagSet()
if err != nil {
return err
}
err = parseIter(set, a, arguments[1:])
nerr := normalizeFlags(a.Flags, set)
context := NewContext(a, set, nil)
if nerr != nil {
_, _ = fmt.Fprintln(a.Writer, nerr)
_ = ShowAppHelp(context)
return nerr
}
context.shellComplete = shellComplete
if checkCompletions(context) {
return nil
}
if err != nil {
if a.OnUsageError != nil {
err := a.OnUsageError(context, err, false)
a.handleExitCoder(context, err)
return err
}
_, _ = fmt.Fprintf(a.Writer, "%s %s\n\n", "Incorrect Usage.", err.Error())
_ = ShowAppHelp(context)
return err
}
if !a.HideHelp && checkHelp(context) {
_ = ShowAppHelp(context)
return nil
}
if !a.HideVersion && checkVersion(context) {
ShowVersion(context)
return nil
}
cerr := checkRequiredFlags(a.Flags, context)
if cerr != nil {
_ = ShowAppHelp(context)
return cerr
}
if a.After != nil {
defer func() {
if afterErr := a.After(context); afterErr != nil {
if err != nil {
err = NewMultiError(err, afterErr)
} else {
err = afterErr
}
}
}()
}
if a.Before != nil {
beforeErr := a.Before(context)
if beforeErr != nil {
_, _ = fmt.Fprintf(a.Writer, "%v\n\n", beforeErr)
_ = ShowAppHelp(context)
a.handleExitCoder(context, beforeErr)
err = beforeErr
return err
}
}
args := context.Args()
if args.Present() {
name := args.First()
c := a.Command(name)
if c != nil {
return c.Run(context)
}
}
if a.Action == nil {
a.Action = helpCommand.Action
}
// Run default Action
err = HandleAction(a.Action, context)
a.handleExitCoder(context, err)
return err
}
// RunAndExitOnError calls .Run() and exits non-zero if an error was returned
//
// Deprecated: instead you should return an error that fulfills cli.ExitCoder
// to cli.App.Run. This will cause the application to exit with the given eror
// code in the cli.ExitCoder
func (a *App) RunAndExitOnError() {
if err := a.Run(os.Args); err != nil {
_, _ = fmt.Fprintln(a.errWriter(), err)
OsExiter(1)
}
}
// RunAsSubcommand invokes the subcommand given the context, parses ctx.Args() to
// generate command-specific flags
func (a *App) RunAsSubcommand(ctx *Context) (err error) {
// append help to commands
if len(a.Commands) > 0 {
if a.Command(helpCommand.Name) == nil && !a.HideHelp {
a.Commands = append(a.Commands, helpCommand)
if (HelpFlag != BoolFlag{}) {
a.appendFlag(HelpFlag)
}
}
}
newCmds := []Command{}
for _, c := range a.Commands {
if c.HelpName == "" {
c.HelpName = fmt.Sprintf("%s %s", a.HelpName, c.Name)
}
newCmds = append(newCmds, c)
}
a.Commands = newCmds
set, err := a.newFlagSet()
if err != nil {
return err
}
err = parseIter(set, a, ctx.Args().Tail())
nerr := normalizeFlags(a.Flags, set)
context := NewContext(a, set, ctx)
if nerr != nil {
_, _ = fmt.Fprintln(a.Writer, nerr)
_, _ = fmt.Fprintln(a.Writer)
if len(a.Commands) > 0 {
_ = ShowSubcommandHelp(context)
} else {
_ = ShowCommandHelp(ctx, context.Args().First())
}
return nerr
}
if checkCompletions(context) {
return nil
}
if err != nil {
if a.OnUsageError != nil {
err = a.OnUsageError(context, err, true)
a.handleExitCoder(context, err)
return err
}
_, _ = fmt.Fprintf(a.Writer, "%s %s\n\n", "Incorrect Usage.", err.Error())
_ = ShowSubcommandHelp(context)
return err
}
if len(a.Commands) > 0 {
if checkSubcommandHelp(context) {
return nil
}
} else {
if checkCommandHelp(ctx, context.Args().First()) {
return nil
}
}
cerr := checkRequiredFlags(a.Flags, context)
if cerr != nil {
_ = ShowSubcommandHelp(context)
return cerr
}
if a.After != nil {
defer func() {
afterErr := a.After(context)
if afterErr != nil {
a.handleExitCoder(context, err)
if err != nil {
err = NewMultiError(err, afterErr)
} else {
err = afterErr
}
}
}()
}
if a.Before != nil {
beforeErr := a.Before(context)
if beforeErr != nil {
a.handleExitCoder(context, beforeErr)
err = beforeErr
return err
}
}
args := context.Args()
if args.Present() {
name := args.First()
c := a.Command(name)
if c != nil {
return c.Run(context)
}
}
// Run default Action
err = HandleAction(a.Action, context)
a.handleExitCoder(context, err)
return err
}
// Command returns the named command on App. Returns nil if the command does not exist
func (a *App) Command(name string) *Command {
for _, c := range a.Commands {
if c.HasName(name) {
return &c
}
}
return nil
}
// Categories returns a slice containing all the categories with the commands they contain
func (a *App) Categories() CommandCategories {
return a.categories
}
// VisibleCategories returns a slice of categories and commands that are
// Hidden=false
func (a *App) VisibleCategories() []*CommandCategory {
ret := []*CommandCategory{}
for _, category := range a.categories {
if visible := func() *CommandCategory {
for _, command := range category.Commands {
if !command.Hidden {
return category
}
}
return nil
}(); visible != nil {
ret = append(ret, visible)
}
}
return ret
}
// VisibleCommands returns a slice of the Commands with Hidden=false
func (a *App) VisibleCommands() []Command {
var ret []Command
for _, command := range a.Commands {
if !command.Hidden {
ret = append(ret, command)
}
}
return ret
}
// VisibleFlags returns a slice of the Flags with Hidden=false
func (a *App) VisibleFlags() []Flag {
return visibleFlags(a.Flags)
}
func (a *App) hasFlag(flag Flag) bool {
for _, f := range a.Flags {
if flag == f {
return true
}
}
return false
}
func (a *App) errWriter() io.Writer {
// When the app ErrWriter is nil use the package level one.
if a.ErrWriter == nil {
return ErrWriter
}
return a.ErrWriter
}
func (a *App) appendFlag(flag Flag) {
if !a.hasFlag(flag) {
a.Flags = append(a.Flags, flag)
}
}
func (a *App) handleExitCoder(context *Context, err error) {
if a.ExitErrHandler != nil {
a.ExitErrHandler(context, err)
} else {
HandleExitCoder(err)
}
}
// Author represents someone who has contributed to a cli project.
type Author struct {
Name string // The Authors name
Email string // The Authors email
}
// String makes Author comply to the Stringer interface, to allow an easy print in the templating process
func (a Author) String() string {
e := ""
if a.Email != "" {
e = " <" + a.Email + ">"
}
return fmt.Sprintf("%v%v", a.Name, e)
}
// HandleAction attempts to figure out which Action signature was used. If
// it's an ActionFunc or a func with the legacy signature for Action, the func
// is run!
func HandleAction(action interface{}, context *Context) (err error) {
switch a := action.(type) {
case ActionFunc:
return a(context)
case func(*Context) error:
return a(context)
case func(*Context): // deprecated function signature
a(context)
return nil
}
return errInvalidActionType
}

Some files were not shown because too many files have changed in this diff Show More