mirror of https://github.com/fluxcd/flagger.git
Merge pull request #90 from cloudang/ngrinder
Support delegation to external load testing tools
This commit is contained in:
commit
62f4a6cb96
|
|
@ -128,6 +128,14 @@
|
|||
pruneopts = "NUT"
|
||||
revision = "3befbb6ad0cc97d4c25d851e9528915809e1a22f"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:7313d6b9095eb86581402557bbf3871620cf82adf41853c5b9bee04b894290c7"
|
||||
name = "github.com/h2non/parth"
|
||||
packages = ["."]
|
||||
pruneopts = "NUT"
|
||||
revision = "b4df798d65426f8c8ab5ca5f9987aec5575d26c9"
|
||||
version = "v2.0.1"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:52094d0f8bdf831d1a2401e9b6fee5795fdc0b2a2d1f8bb1980834c289e79129"
|
||||
name = "github.com/hashicorp/golang-lru"
|
||||
|
|
@ -405,6 +413,14 @@
|
|||
revision = "e9657d882bb81064595ca3b56cbe2546bbabf7b1"
|
||||
version = "v1.4.0"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:fe9eb931d7b59027c4a3467f7edc16cc8552dac5328039bec05045143c18e1ce"
|
||||
name = "gopkg.in/h2non/gock.v1"
|
||||
packages = ["."]
|
||||
pruneopts = "NUT"
|
||||
revision = "ba88c4862a27596539531ce469478a91bc5a0511"
|
||||
version = "v1.0.14"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:2d1fbdc6777e5408cabeb02bf336305e724b925ff4546ded0fa8715a7267922a"
|
||||
name = "gopkg.in/inf.v0"
|
||||
|
|
@ -684,6 +700,7 @@
|
|||
"github.com/prometheus/client_golang/prometheus/promhttp",
|
||||
"go.uber.org/zap",
|
||||
"go.uber.org/zap/zapcore",
|
||||
"gopkg.in/h2non/gock.v1",
|
||||
"k8s.io/api/apps/v1",
|
||||
"k8s.io/api/autoscaling/v1",
|
||||
"k8s.io/api/autoscaling/v2beta1",
|
||||
|
|
|
|||
|
|
@ -11,6 +11,10 @@ required = [
|
|||
name = "go.uber.org/zap"
|
||||
version = "v1.9.1"
|
||||
|
||||
[[constraint]]
|
||||
name = "gopkg.in/h2non/gock.v1"
|
||||
version = "v1.0.14"
|
||||
|
||||
[[override]]
|
||||
name = "gopkg.in/yaml.v2"
|
||||
version = "v2.2.1"
|
||||
|
|
|
|||
|
|
@ -71,4 +71,6 @@ spec:
|
|||
url: http://flagger-loadtester.test/
|
||||
timeout: 5s
|
||||
metadata:
|
||||
type: cmd
|
||||
cmd: "hey -z 1m -q 10 -c 2 http://podinfo.test:9898/"
|
||||
logCmdOutput: "true"
|
||||
|
|
|
|||
|
|
@ -56,15 +56,13 @@ Parameter | Description | Default
|
|||
`nodeSelector` | node labels for pod assignment | `{}`
|
||||
`service.type` | type of service | `ClusterIP`
|
||||
`service.port` | ClusterIP port | `80`
|
||||
`cmd.logOutput` | Log the command output to stderr | `true`
|
||||
`cmd.timeout` | Command execution timeout | `1h`
|
||||
`logLevel` | Log level can be debug, info, warning, error or panic | `info`
|
||||
|
||||
Specify each parameter using the `--set key=value[,key=value]` argument to `helm install`. For example,
|
||||
|
||||
```console
|
||||
helm install flagger/loadtester --name flagger-loadtester \
|
||||
--set cmd.logOutput=false
|
||||
helm install flagger/loadtester --name flagger-loadtester
|
||||
```
|
||||
|
||||
Alternatively, a YAML file that specifies the values for the above parameters can be provided while installing the chart. For example,
|
||||
|
|
|
|||
|
|
@ -29,7 +29,6 @@ spec:
|
|||
- -port=8080
|
||||
- -log-level={{ .Values.logLevel }}
|
||||
- -timeout={{ .Values.cmd.timeout }}
|
||||
- -log-cmd-output={{ .Values.cmd.logOutput }}
|
||||
livenessProbe:
|
||||
exec:
|
||||
command:
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ image:
|
|||
|
||||
logLevel: info
|
||||
cmd:
|
||||
logOutput: true
|
||||
timeout: 1h
|
||||
|
||||
nameOverride: ""
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ var (
|
|||
logLevel string
|
||||
port string
|
||||
timeout time.Duration
|
||||
logCmdOutput bool
|
||||
zapReplaceGlobals bool
|
||||
zapEncoding string
|
||||
)
|
||||
|
|
@ -23,8 +22,7 @@ var (
|
|||
func init() {
|
||||
flag.StringVar(&logLevel, "log-level", "debug", "Log level can be: debug, info, warning, error.")
|
||||
flag.StringVar(&port, "port", "9090", "Port to listen on.")
|
||||
flag.DurationVar(&timeout, "timeout", time.Hour, "Command exec timeout.")
|
||||
flag.BoolVar(&logCmdOutput, "log-cmd-output", true, "Log command output to stderr")
|
||||
flag.DurationVar(&timeout, "timeout", time.Hour, "Load test exec timeout.")
|
||||
flag.BoolVar(&zapReplaceGlobals, "zap-replace-globals", false, "Whether to change the logging level of the global zap logger.")
|
||||
flag.StringVar(&zapEncoding, "zap-encoding", "json", "Zap logger encoding.")
|
||||
}
|
||||
|
|
@ -44,7 +42,7 @@ func main() {
|
|||
|
||||
stopCh := signals.SetupSignalHandler()
|
||||
|
||||
taskRunner := loadtester.NewTaskRunner(logger, timeout, logCmdOutput)
|
||||
taskRunner := loadtester.NewTaskRunner(logger, timeout)
|
||||
|
||||
go taskRunner.Start(100*time.Millisecond, stopCh)
|
||||
|
||||
|
|
|
|||
|
|
@ -562,7 +562,6 @@ helm repo add flagger https://flagger.app
|
|||
|
||||
helm upgrade -i flagger-loadtester flagger/loadtester \
|
||||
--namespace=test \
|
||||
--set cmd.logOutput=true \
|
||||
--set cmd.timeout=1h
|
||||
```
|
||||
|
||||
|
|
@ -576,11 +575,13 @@ webhooks:
|
|||
url: http://flagger-loadtester.test/
|
||||
timeout: 5s
|
||||
metadata:
|
||||
type: cmd
|
||||
cmd: "hey -z 1m -q 10 -c 2 http://podinfo.test:9898/"
|
||||
- name: load-test-post
|
||||
url: http://flagger-loadtester.test/
|
||||
timeout: 5s
|
||||
metadata:
|
||||
type: cmd
|
||||
cmd: "hey -z 1m -q 10 -c 2 -m POST -d '{test: 2}' http://podinfo.test:9898/echo"
|
||||
```
|
||||
|
||||
|
|
@ -597,6 +598,7 @@ webhooks:
|
|||
url: http://flagger-loadtester.test/
|
||||
timeout: 5s
|
||||
metadata:
|
||||
type: cmd
|
||||
cmd: "hey -z 1m -q 10 -c 2 -h2 https://podinfo.example.com/"
|
||||
```
|
||||
|
||||
|
|
@ -609,3 +611,32 @@ FROM quay.io/stefanprodan/flagger-loadtester:<VER>
|
|||
RUN curl -Lo /usr/local/bin/my-cli https://github.com/user/repo/releases/download/ver/my-cli \
|
||||
&& chmod +x /usr/local/bin/my-cli
|
||||
```
|
||||
|
||||
### Load Testing Delegation
|
||||
|
||||
The load tester can also forward testing tasks to external tools, by now [nGrinder](https://github.com/naver/ngrinder)
|
||||
is supported.
|
||||
|
||||
To use this feature, add a load test task of type 'ngrinder' to the canary analysis spec:
|
||||
|
||||
```yaml
|
||||
webhooks:
|
||||
- name: load-test-post
|
||||
url: http://flagger-loadtester.test/
|
||||
timeout: 5s
|
||||
metadata:
|
||||
# type of this load test task, cmd or ngrinder
|
||||
type: ngrinder
|
||||
# base url of your nGrinder controller server
|
||||
server: http://ngrinder-server:port
|
||||
# id of the test to clone from, the test must have been defined.
|
||||
clone: 100
|
||||
# user name and base64 encoded password to authenticate against the nGrinder server
|
||||
username: admin
|
||||
passwd: YWRtaW4=
|
||||
# the interval between between nGrinder test status polling, default to 1s
|
||||
pollInterval: 5s
|
||||
```
|
||||
When the canary analysis starts, the load tester will initiate a [clone_and_start request](https://github.com/naver/ngrinder/wiki/REST-API-PerfTest)
|
||||
to the nGrinder server and start a new performance test. the load tester will periodically poll the nGrinder server
|
||||
for the status of the test, and prevent duplicate requests from being sent in subsequent analysis loops.
|
||||
|
|
@ -126,7 +126,6 @@ Deploy the load test runner with Helm:
|
|||
```bash
|
||||
helm upgrade -i flagger-loadtester flagger/loadtester \
|
||||
--namespace=test \
|
||||
--set cmd.logOutput=true \
|
||||
--set cmd.timeout=1h
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -2,11 +2,7 @@ package loadtester
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"go.uber.org/zap"
|
||||
"hash/fnv"
|
||||
"os/exec"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
|
@ -21,24 +17,12 @@ type TaskRunner struct {
|
|||
logCmdOutput bool
|
||||
}
|
||||
|
||||
type Task struct {
|
||||
Canary string
|
||||
Command string
|
||||
}
|
||||
|
||||
func (t Task) Hash() string {
|
||||
fnvHash := fnv.New32()
|
||||
fnvBytes := fnvHash.Sum([]byte(t.Canary + t.Command))
|
||||
return hex.EncodeToString(fnvBytes[:])
|
||||
}
|
||||
|
||||
func NewTaskRunner(logger *zap.SugaredLogger, timeout time.Duration, logCmdOutput bool) *TaskRunner {
|
||||
func NewTaskRunner(logger *zap.SugaredLogger, timeout time.Duration) *TaskRunner {
|
||||
return &TaskRunner{
|
||||
logger: logger,
|
||||
todoTasks: new(sync.Map),
|
||||
runningTasks: new(sync.Map),
|
||||
timeout: timeout,
|
||||
logCmdOutput: logCmdOutput,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -69,24 +53,15 @@ func (tr *TaskRunner) runAll() {
|
|||
// increment the total exec counter
|
||||
atomic.AddUint64(&tr.totalExecs, 1)
|
||||
|
||||
tr.logger.With("canary", t.Canary).Infof("command starting %s", t.Command)
|
||||
cmd := exec.CommandContext(ctx, "sh", "-c", t.Command)
|
||||
tr.logger.With("canary", t.Canary()).Infof("task starting %s", t)
|
||||
|
||||
// execute task
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
tr.logger.With("canary", t.Canary).Errorf("command failed %s %v %s", t.Command, err, out)
|
||||
} else {
|
||||
if tr.logCmdOutput {
|
||||
fmt.Printf("%s\n", out)
|
||||
}
|
||||
tr.logger.With("canary", t.Canary).Infof("command finished %s", t.Command)
|
||||
}
|
||||
// run task with the timeout context
|
||||
t.Run(ctx)
|
||||
|
||||
// remove task from the running list
|
||||
tr.runningTasks.Delete(t.Hash())
|
||||
} else {
|
||||
tr.logger.With("canary", t.Canary).Infof("command skipped %s is already running", t.Command)
|
||||
tr.logger.With("canary", t.Canary()).Infof("command skipped %s is already running", t)
|
||||
}
|
||||
}(task)
|
||||
return true
|
||||
|
|
|
|||
|
|
@ -9,18 +9,13 @@ import (
|
|||
func TestTaskRunner_Start(t *testing.T) {
|
||||
stop := make(chan struct{})
|
||||
logger, _ := logging.NewLogger("debug")
|
||||
tr := NewTaskRunner(logger, time.Hour, false)
|
||||
tr := NewTaskRunner(logger, time.Hour)
|
||||
|
||||
go tr.Start(10*time.Millisecond, stop)
|
||||
|
||||
task1 := Task{
|
||||
Canary: "podinfo.default",
|
||||
Command: "sleep 0.6",
|
||||
}
|
||||
task2 := Task{
|
||||
Canary: "podinfo.default",
|
||||
Command: "sleep 0.7",
|
||||
}
|
||||
taskFactory, _ := GetTaskFactory(TaskTypeShell)
|
||||
task1, _ := taskFactory(map[string]string{"type": "cmd", "cmd": "sleep 0.6"}, "podinfo.default", logger)
|
||||
task2, _ := taskFactory(map[string]string{"cmd": "sleep 0.7", "logCmdOutput": "true"}, "podinfo.default", logger)
|
||||
|
||||
tr.Add(task1)
|
||||
tr.Add(task2)
|
||||
|
|
|
|||
|
|
@ -39,16 +39,25 @@ func ListenAndServe(port string, timeout time.Duration, logger *zap.SugaredLogge
|
|||
}
|
||||
|
||||
if len(payload.Metadata) > 0 {
|
||||
if cmd, ok := payload.Metadata["cmd"]; ok {
|
||||
taskRunner.Add(Task{
|
||||
Canary: fmt.Sprintf("%s.%s", payload.Name, payload.Namespace),
|
||||
Command: cmd,
|
||||
})
|
||||
} else {
|
||||
metadata := payload.Metadata
|
||||
var typ, ok = metadata["type"]
|
||||
if !ok {
|
||||
typ = TaskTypeShell
|
||||
}
|
||||
taskFactory, ok := GetTaskFactory(typ)
|
||||
if !ok {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("cmd not found in metadata"))
|
||||
w.Write([]byte(fmt.Sprintf("unknown task type %s", typ)))
|
||||
return
|
||||
}
|
||||
canary := fmt.Sprintf("%s.%s", payload.Name, payload.Namespace)
|
||||
task, err := taskFactory(metadata, canary, logger)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(err.Error()))
|
||||
return
|
||||
}
|
||||
taskRunner.Add(task)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("metadata not found in payload"))
|
||||
|
|
|
|||
|
|
@ -0,0 +1,40 @@
|
|||
package loadtester
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"go.uber.org/zap"
|
||||
"hash/fnv"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Modeling a loadtester task
|
||||
type Task interface {
|
||||
Hash() string
|
||||
Run(ctx context.Context) bool
|
||||
String() string
|
||||
Canary() string
|
||||
}
|
||||
|
||||
type TaskBase struct {
|
||||
canary string
|
||||
logger *zap.SugaredLogger
|
||||
}
|
||||
|
||||
func (task *TaskBase) Canary() string {
|
||||
return task.canary
|
||||
}
|
||||
func hash(str string) string {
|
||||
fnvHash := fnv.New32()
|
||||
fnvBytes := fnvHash.Sum([]byte(str))
|
||||
return hex.EncodeToString(fnvBytes[:])
|
||||
}
|
||||
|
||||
var taskFactories = new(sync.Map)
|
||||
|
||||
type TaskFactory = func(metadata map[string]string, canary string, logger *zap.SugaredLogger) (Task, error)
|
||||
|
||||
func GetTaskFactory(typ string) (TaskFactory, bool) {
|
||||
factory, ok := taskFactories.Load(typ)
|
||||
return factory.(TaskFactory), ok
|
||||
}
|
||||
|
|
@ -0,0 +1,158 @@
|
|||
package loadtester
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"go.uber.org/zap"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
const TaskTypeNGrinder = "ngrinder"
|
||||
|
||||
func init() {
|
||||
taskFactories.Store(TaskTypeNGrinder, func(metadata map[string]string, canary string, logger *zap.SugaredLogger) (Task, error) {
|
||||
server := metadata["server"]
|
||||
clone := metadata["clone"]
|
||||
username := metadata["username"]
|
||||
passwd := metadata["passwd"]
|
||||
pollInterval := metadata["pollInterval"]
|
||||
if server == "" || clone == "" || username == "" || passwd == "" {
|
||||
return nil, errors.New("server, clone, username and passwd are required metadata")
|
||||
}
|
||||
baseUrl, err := url.Parse(server)
|
||||
if err != nil {
|
||||
return nil, errors.New(fmt.Sprintf("invalid url: %s", server))
|
||||
}
|
||||
cloneId, err := strconv.Atoi(clone)
|
||||
if err != nil {
|
||||
return nil, errors.New("metadata clone must be integer")
|
||||
}
|
||||
|
||||
passwdDecoded, err := base64.StdEncoding.DecodeString(passwd)
|
||||
if err != nil {
|
||||
return nil, errors.New("metadata auth provided is invalid, base64 encoded username:password required")
|
||||
}
|
||||
interval, err := time.ParseDuration(pollInterval)
|
||||
if err != nil {
|
||||
interval = 1
|
||||
}
|
||||
|
||||
return &NGrinderTask{
|
||||
TaskBase{canary, logger},
|
||||
baseUrl, cloneId, username, string(passwdDecoded), -1, interval,
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
||||
type NGrinderTask struct {
|
||||
TaskBase
|
||||
// base url of ngrinder server, e.g. http://ngrinder:8080
|
||||
baseUrl *url.URL
|
||||
// template test to clone from
|
||||
cloneId int
|
||||
// http basic auth
|
||||
username string
|
||||
passwd string
|
||||
// current ngrinder test id
|
||||
testId int
|
||||
// task status polling interval
|
||||
pollInterval time.Duration
|
||||
}
|
||||
|
||||
func (task *NGrinderTask) Hash() string {
|
||||
return hash(task.canary + string(task.cloneId))
|
||||
}
|
||||
|
||||
// nGrinder REST endpoints
|
||||
func (task *NGrinderTask) CloneAndStartEndpoint() *url.URL {
|
||||
path, _ := url.Parse(fmt.Sprintf("perftest/api/%d/clone_and_start", task.cloneId))
|
||||
return task.baseUrl.ResolveReference(path)
|
||||
}
|
||||
func (task *NGrinderTask) StatusEndpoint() *url.URL {
|
||||
path, _ := url.Parse(fmt.Sprintf("perftest/api/%d/status", task.testId))
|
||||
return task.baseUrl.ResolveReference(path)
|
||||
}
|
||||
func (task *NGrinderTask) StopEndpoint() *url.URL {
|
||||
path, _ := url.Parse(fmt.Sprintf("perftest/api/%d?action=stop", task.testId))
|
||||
return task.baseUrl.ResolveReference(path)
|
||||
}
|
||||
|
||||
// initiate a clone_and_start request and get new test id from response
|
||||
func (task *NGrinderTask) Run(ctx context.Context) bool {
|
||||
url := task.CloneAndStartEndpoint().String()
|
||||
result, err := task.request("POST", url, ctx)
|
||||
if err != nil {
|
||||
task.logger.With("canary", task.canary).Errorf("failed to clone and start ngrinder test %s: %s", url, err.Error())
|
||||
return false
|
||||
}
|
||||
id := result["id"]
|
||||
task.testId = int(id.(float64))
|
||||
return task.PollStatus(ctx)
|
||||
}
|
||||
|
||||
func (task *NGrinderTask) String() string {
|
||||
return task.canary + task.CloneAndStartEndpoint().String()
|
||||
}
|
||||
|
||||
// polling execution status of the new test and check if finished
|
||||
func (task *NGrinderTask) PollStatus(ctx context.Context) bool {
|
||||
// wait until ngrinder test finished/canceled or timedout
|
||||
tickChan := time.NewTicker(time.Second * task.pollInterval).C
|
||||
for {
|
||||
select {
|
||||
case <-tickChan:
|
||||
result, err := task.request("GET", task.StatusEndpoint().String(), ctx)
|
||||
if err == nil {
|
||||
statusArray, ok := result["status"].([]interface{})
|
||||
if ok && len(statusArray) > 0 {
|
||||
status := statusArray[0].(map[string]interface{})
|
||||
statusId := status["status_id"]
|
||||
task.logger.Debugf("status of ngrinder task %d is %s", task.testId, statusId)
|
||||
if statusId == "FINISHED" {
|
||||
return true
|
||||
} else if statusId == "STOP_BY_ERROR" || statusId == "CANCELED" || statusId == "UNKNOWN" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
case <-ctx.Done():
|
||||
task.logger.Warnf("context timedout, top ngrinder task %d forcibly", task.testId)
|
||||
task.request("PUT", task.StopEndpoint().String(), nil)
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// send request, handle error, and eavl response json
|
||||
func (task *NGrinderTask) request(method, url string, ctx context.Context) (map[string]interface{}, error) {
|
||||
task.logger.Debugf("send %s request to %s", method, url)
|
||||
req, _ := http.NewRequest(method, url, nil)
|
||||
req.SetBasicAuth(task.username, task.passwd)
|
||||
if ctx != nil {
|
||||
req = req.WithContext(ctx)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if resp != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
if err != nil {
|
||||
task.logger.Errorf("bad request: %s", err.Error())
|
||||
return nil, err
|
||||
}
|
||||
respBytes, err := ioutil.ReadAll(resp.Body)
|
||||
res := make(map[string]interface{})
|
||||
err = json.Unmarshal(respBytes, &res)
|
||||
if err != nil {
|
||||
task.logger.Errorf("bad response, %s ,json expected:\n %s", err.Error(), string(respBytes))
|
||||
} else if success, ok := res["success"]; ok && success == false {
|
||||
err = errors.New(res["message"].(string))
|
||||
}
|
||||
return res, err
|
||||
}
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
package loadtester
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/stefanprodan/flagger/pkg/logging"
|
||||
"gopkg.in/h2non/gock.v1"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestTaskNGrinder(t *testing.T) {
|
||||
server := "http://ngrinder:8080"
|
||||
cloneId := "960"
|
||||
logger, _ := logging.NewLoggerWithEncoding("debug", "console")
|
||||
canary := "podinfo.default"
|
||||
taskFactory, ok := GetTaskFactory(TaskTypeNGrinder)
|
||||
if !ok {
|
||||
t.Errorf("Failed to get ngrinder task factory")
|
||||
}
|
||||
|
||||
defer gock.Off()
|
||||
gock.New(server).Post(fmt.Sprintf("perftest/api/%s/clone_and_start", cloneId)).
|
||||
Reply(200).BodyString(`{"status": "READY","id": 961}`)
|
||||
gock.New(server).Get("perftest/api/961/status").Reply(200).
|
||||
BodyString(`{"status": [{"status_id": "FINISHED"}]}`)
|
||||
gock.New(server).Put("perftest/api/961").MatchParam("action", "stop").Reply(200).
|
||||
BodyString(`{"success": true}`)
|
||||
|
||||
t.Run("NormalRequest", func(t *testing.T) {
|
||||
task, err := taskFactory(map[string]string{
|
||||
"server": server,
|
||||
"clone": cloneId,
|
||||
"username": "admin",
|
||||
"passwd": "YWRtaW4=",
|
||||
"pollInterval": "1s",
|
||||
}, canary, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create ngrinder task: %s", err.Error())
|
||||
return
|
||||
}
|
||||
ctx, _ := context.WithTimeout(context.Background(), time.Second*3)
|
||||
task.Run(ctx)
|
||||
<-ctx.Done()
|
||||
})
|
||||
}
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
package loadtester
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"go.uber.org/zap"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const TaskTypeShell = "cmd"
|
||||
|
||||
func init() {
|
||||
taskFactories.Store(TaskTypeShell, func(metadata map[string]string, canary string, logger *zap.SugaredLogger) (Task, error) {
|
||||
cmd, ok := metadata["cmd"]
|
||||
if !ok {
|
||||
return nil, errors.New("cmd not found in metadata")
|
||||
}
|
||||
logCmdOutput, _ := strconv.ParseBool(metadata["logCmdOutput"])
|
||||
return &CmdTask{TaskBase{canary, logger}, cmd, logCmdOutput}, nil
|
||||
})
|
||||
}
|
||||
|
||||
type CmdTask struct {
|
||||
TaskBase
|
||||
command string
|
||||
logCmdOutput bool
|
||||
}
|
||||
|
||||
func (task *CmdTask) Hash() string {
|
||||
return hash(task.canary + task.command)
|
||||
}
|
||||
|
||||
func (task *CmdTask) Run(ctx context.Context) bool {
|
||||
cmd := exec.CommandContext(ctx, "sh", "-c", task.command)
|
||||
out, err := cmd.CombinedOutput()
|
||||
|
||||
if err != nil {
|
||||
task.logger.With("canary", task.canary).Errorf("command failed %s %v %s", task.command, err, out)
|
||||
} else {
|
||||
if task.logCmdOutput {
|
||||
fmt.Printf("%s\n", out)
|
||||
}
|
||||
task.logger.With("canary", task.canary).Infof("command finished %s", task.command)
|
||||
}
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func (task *CmdTask) String() string {
|
||||
return task.command
|
||||
}
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2018 codemodus
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
|
|
@ -0,0 +1,349 @@
|
|||
// Package parth provides path parsing for segment unmarshaling and slicing. In
|
||||
// other words, parth provides simple and flexible access to (URL) path
|
||||
// parameters.
|
||||
//
|
||||
// Along with string, all basic non-alias types are supported. An interface is
|
||||
// available for implementation by user-defined types. When handling an int,
|
||||
// uint, or float of any size, the first valid value within the specified
|
||||
// segment will be used.
|
||||
package parth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
// Unmarshaler is the interface implemented by types that can unmarshal a path
|
||||
// segment representation of themselves. It is safe to assume that the segment
|
||||
// data will not include slashes.
|
||||
type Unmarshaler interface {
|
||||
UnmarshalSegment(string) error
|
||||
}
|
||||
|
||||
// Err{Name} values facilitate error identification.
|
||||
var (
|
||||
ErrUnknownType = errors.New("unknown type provided")
|
||||
|
||||
ErrFirstSegNotFound = errors.New("first segment not found by index")
|
||||
ErrLastSegNotFound = errors.New("last segment not found by index")
|
||||
ErrSegOrderReversed = errors.New("first segment must precede last segment")
|
||||
ErrKeySegNotFound = errors.New("segment not found by key")
|
||||
|
||||
ErrDataUnparsable = errors.New("data cannot be parsed")
|
||||
)
|
||||
|
||||
// Segment locates the path segment indicated by the index i and unmarshals it
|
||||
// into the provided type v. If the index is negative, the negative count
|
||||
// begins with the last segment. An error is returned if: 1. The type is not a
|
||||
// pointer to an instance of one of the basic non-alias types and does not
|
||||
// implement the Unmarshaler interface; 2. The index is out of range of the
|
||||
// path; 3. The located path segment data cannot be parsed as the provided type
|
||||
// or if an error is returned when using a provided Unmarshaler implementation.
|
||||
func Segment(path string, i int, v interface{}) error { //nolint
|
||||
var err error
|
||||
|
||||
switch v := v.(type) {
|
||||
case *bool:
|
||||
*v, err = segmentToBool(path, i)
|
||||
|
||||
case *float32:
|
||||
var f float64
|
||||
f, err = segmentToFloatN(path, i, 32)
|
||||
*v = float32(f)
|
||||
|
||||
case *float64:
|
||||
*v, err = segmentToFloatN(path, i, 64)
|
||||
|
||||
case *int:
|
||||
var n int64
|
||||
n, err = segmentToIntN(path, i, 0)
|
||||
*v = int(n)
|
||||
|
||||
case *int16:
|
||||
var n int64
|
||||
n, err = segmentToIntN(path, i, 16)
|
||||
*v = int16(n)
|
||||
|
||||
case *int32:
|
||||
var n int64
|
||||
n, err = segmentToIntN(path, i, 32)
|
||||
*v = int32(n)
|
||||
|
||||
case *int64:
|
||||
*v, err = segmentToIntN(path, i, 64)
|
||||
|
||||
case *int8:
|
||||
var n int64
|
||||
n, err = segmentToIntN(path, i, 8)
|
||||
*v = int8(n)
|
||||
|
||||
case *string:
|
||||
*v, err = segmentToString(path, i)
|
||||
|
||||
case *uint:
|
||||
var n uint64
|
||||
n, err = segmentToUintN(path, i, 0)
|
||||
*v = uint(n)
|
||||
|
||||
case *uint16:
|
||||
var n uint64
|
||||
n, err = segmentToUintN(path, i, 16)
|
||||
*v = uint16(n)
|
||||
|
||||
case *uint32:
|
||||
var n uint64
|
||||
n, err = segmentToUintN(path, i, 32)
|
||||
*v = uint32(n)
|
||||
|
||||
case *uint64:
|
||||
*v, err = segmentToUintN(path, i, 64)
|
||||
|
||||
case *uint8:
|
||||
var n uint64
|
||||
n, err = segmentToUintN(path, i, 8)
|
||||
*v = uint8(n)
|
||||
|
||||
case Unmarshaler:
|
||||
var s string
|
||||
s, err = segmentToString(path, i)
|
||||
if err == nil {
|
||||
err = v.UnmarshalSegment(s)
|
||||
}
|
||||
|
||||
default:
|
||||
err = ErrUnknownType
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Sequent is similar to Segment, but uses a key to locate a segment and then
|
||||
// unmarshal the subsequent segment. It is a simple wrapper over SubSeg with an
|
||||
// index of 0.
|
||||
func Sequent(path, key string, v interface{}) error {
|
||||
return SubSeg(path, key, 0, v)
|
||||
}
|
||||
|
||||
// Span returns the path segments between two segment indexes i and j including
|
||||
// the first segment. If an index is negative, the negative count begins with
|
||||
// the last segment. Providing a 0 for the last index j is a special case which
|
||||
// acts as an alias for the end of the path. If the first segment does not begin
|
||||
// with a slash and it is part of the requested span, no slash will be added. An
|
||||
// error is returned if: 1. Either index is out of range of the path; 2. The
|
||||
// first index i does not precede the last index j.
|
||||
func Span(path string, i, j int) (string, error) {
|
||||
var f, l int
|
||||
var ok bool
|
||||
|
||||
if i < 0 {
|
||||
f, ok = segStartIndexFromEnd(path, i)
|
||||
} else {
|
||||
f, ok = segStartIndexFromStart(path, i)
|
||||
}
|
||||
if !ok {
|
||||
return "", ErrFirstSegNotFound
|
||||
}
|
||||
|
||||
if j > 0 {
|
||||
l, ok = segEndIndexFromStart(path, j)
|
||||
} else {
|
||||
l, ok = segEndIndexFromEnd(path, j)
|
||||
}
|
||||
if !ok {
|
||||
return "", ErrLastSegNotFound
|
||||
}
|
||||
|
||||
if f == l {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
if f > l {
|
||||
return "", ErrSegOrderReversed
|
||||
}
|
||||
|
||||
return path[f:l], nil
|
||||
}
|
||||
|
||||
// SubSeg is similar to Segment, but only handles the portion of the path
|
||||
// subsequent to the provided key. For example, to access the segment
|
||||
// immediately after a key, an index of 0 should be provided (see Sequent). An
|
||||
// error is returned if the key cannot be found in the path.
|
||||
func SubSeg(path, key string, i int, v interface{}) error { //nolint
|
||||
var err error
|
||||
|
||||
switch v := v.(type) {
|
||||
case *bool:
|
||||
*v, err = subSegToBool(path, key, i)
|
||||
|
||||
case *float32:
|
||||
var f float64
|
||||
f, err = subSegToFloatN(path, key, i, 32)
|
||||
*v = float32(f)
|
||||
|
||||
case *float64:
|
||||
*v, err = subSegToFloatN(path, key, i, 64)
|
||||
|
||||
case *int:
|
||||
var n int64
|
||||
n, err = subSegToIntN(path, key, i, 0)
|
||||
*v = int(n)
|
||||
|
||||
case *int16:
|
||||
var n int64
|
||||
n, err = subSegToIntN(path, key, i, 16)
|
||||
*v = int16(n)
|
||||
|
||||
case *int32:
|
||||
var n int64
|
||||
n, err = subSegToIntN(path, key, i, 32)
|
||||
*v = int32(n)
|
||||
|
||||
case *int64:
|
||||
*v, err = subSegToIntN(path, key, i, 64)
|
||||
|
||||
case *int8:
|
||||
var n int64
|
||||
n, err = subSegToIntN(path, key, i, 8)
|
||||
*v = int8(n)
|
||||
|
||||
case *string:
|
||||
*v, err = subSegToString(path, key, i)
|
||||
|
||||
case *uint:
|
||||
var n uint64
|
||||
n, err = subSegToUintN(path, key, i, 0)
|
||||
*v = uint(n)
|
||||
|
||||
case *uint16:
|
||||
var n uint64
|
||||
n, err = subSegToUintN(path, key, i, 16)
|
||||
*v = uint16(n)
|
||||
|
||||
case *uint32:
|
||||
var n uint64
|
||||
n, err = subSegToUintN(path, key, i, 32)
|
||||
*v = uint32(n)
|
||||
|
||||
case *uint64:
|
||||
*v, err = subSegToUintN(path, key, i, 64)
|
||||
|
||||
case *uint8:
|
||||
var n uint64
|
||||
n, err = subSegToUintN(path, key, i, 8)
|
||||
*v = uint8(n)
|
||||
|
||||
case Unmarshaler:
|
||||
var s string
|
||||
s, err = subSegToString(path, key, i)
|
||||
if err == nil {
|
||||
err = v.UnmarshalSegment(s)
|
||||
}
|
||||
|
||||
default:
|
||||
err = ErrUnknownType
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// SubSpan is similar to Span, but only handles the portion of the path
|
||||
// subsequent to the provided key. An error is returned if the key cannot be
|
||||
// found in the path.
|
||||
func SubSpan(path, key string, i, j int) (string, error) {
|
||||
si, ok := segIndexByKey(path, key)
|
||||
if !ok {
|
||||
return "", ErrKeySegNotFound
|
||||
}
|
||||
|
||||
if i >= 0 {
|
||||
i++
|
||||
}
|
||||
if j > 0 {
|
||||
j++
|
||||
}
|
||||
|
||||
s, err := Span(path[si:], i, j)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Parth manages path and error data for processing a single path multiple
|
||||
// times while error checking only once. Only the first encountered error is
|
||||
// stored as all subsequent calls to Parth methods that can error are elided.
|
||||
type Parth struct {
|
||||
path string
|
||||
err error
|
||||
}
|
||||
|
||||
// New constructs a pointer to an instance of Parth around the provided path.
|
||||
func New(path string) *Parth {
|
||||
return &Parth{path: path}
|
||||
}
|
||||
|
||||
// NewBySpan constructs a pointer to an instance of Parth after preprocessing
|
||||
// the provided path with Span.
|
||||
func NewBySpan(path string, i, j int) *Parth {
|
||||
s, err := Span(path, i, j)
|
||||
return &Parth{s, err}
|
||||
}
|
||||
|
||||
// NewBySubSpan constructs a pointer to an instance of Parth after
|
||||
// preprocessing the provided path with SubSpan.
|
||||
func NewBySubSpan(path, key string, i, j int) *Parth {
|
||||
s, err := SubSpan(path, key, i, j)
|
||||
return &Parth{s, err}
|
||||
}
|
||||
|
||||
// Err returns the first error encountered by the *Parth receiver.
|
||||
func (p *Parth) Err() error {
|
||||
return p.err
|
||||
}
|
||||
|
||||
// Segment operates the same as the package-level function Segment.
|
||||
func (p *Parth) Segment(i int, v interface{}) {
|
||||
if p.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.err = Segment(p.path, i, v)
|
||||
}
|
||||
|
||||
// Sequent operates the same as the package-level function Sequent.
|
||||
func (p *Parth) Sequent(key string, v interface{}) {
|
||||
p.SubSeg(key, 0, v)
|
||||
}
|
||||
|
||||
// Span operates the same as the package-level function Span.
|
||||
func (p *Parth) Span(i, j int) string {
|
||||
if p.err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
s, err := Span(p.path, i, j)
|
||||
p.err = err
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// SubSeg operates the same as the package-level function SubSeg.
|
||||
func (p *Parth) SubSeg(key string, i int, v interface{}) {
|
||||
if p.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.err = SubSeg(p.path, key, i, v)
|
||||
}
|
||||
|
||||
// SubSpan operates the same as the package-level function SubSpan.
|
||||
func (p *Parth) SubSpan(key string, i, j int) string {
|
||||
if p.err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
s, err := SubSpan(p.path, key, i, j)
|
||||
p.err = err
|
||||
|
||||
return s
|
||||
}
|
||||
|
|
@ -0,0 +1,118 @@
|
|||
package parth
|
||||
|
||||
func segStartIndexFromStart(path string, seg int) (int, bool) {
|
||||
if seg < 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
for n, ct := 0, 0; n < len(path); n++ {
|
||||
if n > 0 && path[n] == '/' {
|
||||
ct++
|
||||
}
|
||||
|
||||
if ct == seg {
|
||||
return n, true
|
||||
}
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func segStartIndexFromEnd(path string, seg int) (int, bool) {
|
||||
if seg > -1 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
for n, ct := len(path)-1, 0; n >= 0; n-- {
|
||||
if path[n] == '/' || n == 0 {
|
||||
ct--
|
||||
}
|
||||
|
||||
if ct == seg {
|
||||
return n, true
|
||||
}
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func segEndIndexFromStart(path string, seg int) (int, bool) {
|
||||
if seg < 1 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
for n, ct := 0, 0; n < len(path); n++ {
|
||||
if path[n] == '/' && n > 0 {
|
||||
ct++
|
||||
}
|
||||
|
||||
if ct == seg {
|
||||
return n, true
|
||||
}
|
||||
|
||||
if n+1 == len(path) && ct+1 == seg {
|
||||
return n + 1, true
|
||||
}
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func segEndIndexFromEnd(path string, seg int) (int, bool) {
|
||||
if seg > 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
if seg == 0 {
|
||||
return len(path), true
|
||||
}
|
||||
|
||||
if len(path) == 1 && path[0] == '/' {
|
||||
return 0, true
|
||||
}
|
||||
|
||||
for n, ct := len(path)-1, 0; n >= 0; n-- {
|
||||
if n == 0 || path[n] == '/' {
|
||||
ct--
|
||||
}
|
||||
|
||||
if ct == seg {
|
||||
return n, true
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func segIndexByKey(path, key string) (int, bool) { //nolint
|
||||
if path == "" || key == "" {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
for n := 0; n < len(path); n++ {
|
||||
si, ok := segStartIndexFromStart(path, n)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
if len(path[si:]) == len(key)+1 {
|
||||
if path[si+1:] == key {
|
||||
return si, true
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
tmpEI, ok := segStartIndexFromStart(path[si:], 1)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
if path[si+1:tmpEI+si] == key || n == 0 && path[0] != '/' && path[si:tmpEI+si] == key {
|
||||
return si, true
|
||||
}
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
|
@ -0,0 +1,305 @@
|
|||
package parth
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
func segmentToBool(path string, i int) (bool, error) {
|
||||
s, err := segmentToString(path, i)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
v, err := strconv.ParseBool(s)
|
||||
if err != nil {
|
||||
return false, ErrDataUnparsable
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func segmentToFloatN(path string, i, size int) (float64, error) {
|
||||
ss, err := segmentToString(path, i)
|
||||
if err != nil {
|
||||
return 0.0, err
|
||||
}
|
||||
|
||||
s, ok := firstFloatFromString(ss)
|
||||
if !ok {
|
||||
return 0.0, err
|
||||
}
|
||||
|
||||
v, err := strconv.ParseFloat(s, size)
|
||||
if err != nil {
|
||||
return 0.0, ErrDataUnparsable
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func segmentToIntN(path string, i, size int) (int64, error) {
|
||||
ss, err := segmentToString(path, i)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
s, ok := firstIntFromString(ss)
|
||||
if !ok {
|
||||
return 0, ErrDataUnparsable
|
||||
}
|
||||
|
||||
v, err := strconv.ParseInt(s, 10, size)
|
||||
if err != nil {
|
||||
return 0, ErrDataUnparsable
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func segmentToString(path string, i int) (string, error) {
|
||||
j := i + 1
|
||||
if i < 0 {
|
||||
i--
|
||||
}
|
||||
|
||||
s, err := Span(path, i, j)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if s[0] == '/' {
|
||||
s = s[1:]
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func segmentToUintN(path string, i, size int) (uint64, error) {
|
||||
ss, err := segmentToString(path, i)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
s, ok := firstUintFromString(ss)
|
||||
if !ok {
|
||||
return 0, ErrDataUnparsable
|
||||
}
|
||||
|
||||
v, err := strconv.ParseUint(s, 10, size)
|
||||
if err != nil {
|
||||
return 0, ErrDataUnparsable
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func subSegToBool(path, key string, i int) (bool, error) {
|
||||
s, err := subSegToString(path, key, i)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
v, err := strconv.ParseBool(s)
|
||||
if err != nil {
|
||||
return false, ErrDataUnparsable
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func subSegToFloatN(path, key string, i, size int) (float64, error) {
|
||||
ss, err := subSegToString(path, key, i)
|
||||
if err != nil {
|
||||
return 0.0, err
|
||||
}
|
||||
|
||||
s, ok := firstFloatFromString(ss)
|
||||
if !ok {
|
||||
return 0.0, ErrDataUnparsable
|
||||
}
|
||||
|
||||
v, err := strconv.ParseFloat(s, size)
|
||||
if err != nil {
|
||||
return 0.0, ErrDataUnparsable
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func subSegToIntN(path, key string, i, size int) (int64, error) {
|
||||
ss, err := subSegToString(path, key, i)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
s, ok := firstIntFromString(ss)
|
||||
if !ok {
|
||||
return 0, ErrDataUnparsable
|
||||
}
|
||||
|
||||
v, err := strconv.ParseInt(s, 10, size)
|
||||
if err != nil {
|
||||
return 0, ErrDataUnparsable
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func subSegToString(path, key string, i int) (string, error) {
|
||||
ki, ok := segIndexByKey(path, key)
|
||||
if !ok {
|
||||
return "", ErrKeySegNotFound
|
||||
}
|
||||
|
||||
i++
|
||||
|
||||
s, err := segmentToString(path[ki:], i)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func subSegToUintN(path, key string, i, size int) (uint64, error) {
|
||||
ss, err := subSegToString(path, key, i)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
s, ok := firstUintFromString(ss)
|
||||
if !ok {
|
||||
return 0, ErrDataUnparsable
|
||||
}
|
||||
|
||||
v, err := strconv.ParseUint(s, 10, size)
|
||||
if err != nil {
|
||||
return 0, ErrDataUnparsable
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func firstUintFromString(s string) (string, bool) {
|
||||
ind, l := 0, 0
|
||||
|
||||
for n := 0; n < len(s); n++ {
|
||||
if unicode.IsDigit(rune(s[n])) {
|
||||
if l == 0 {
|
||||
ind = n
|
||||
}
|
||||
|
||||
l++
|
||||
} else {
|
||||
if l == 0 && s[n] == '.' {
|
||||
if n+1 < len(s) && unicode.IsDigit(rune(s[n+1])) {
|
||||
return "0", true
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
if l > 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if l == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return s[ind : ind+l], true
|
||||
}
|
||||
|
||||
func firstIntFromString(s string) (string, bool) { //nolint
|
||||
ind, l := 0, 0
|
||||
|
||||
for n := 0; n < len(s); n++ {
|
||||
if unicode.IsDigit(rune(s[n])) {
|
||||
if l == 0 {
|
||||
ind = n
|
||||
}
|
||||
|
||||
l++
|
||||
} else if s[n] == '-' {
|
||||
if l == 0 {
|
||||
ind = n
|
||||
l++
|
||||
} else {
|
||||
break
|
||||
}
|
||||
} else {
|
||||
if l == 0 && s[n] == '.' {
|
||||
if n+1 < len(s) && unicode.IsDigit(rune(s[n+1])) {
|
||||
return "0", true
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
if l > 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if l == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return s[ind : ind+l], true
|
||||
}
|
||||
|
||||
func firstFloatFromString(s string) (string, bool) { //nolint
|
||||
c, ind, l := 0, 0, 0
|
||||
|
||||
for n := 0; n < len(s); n++ {
|
||||
if unicode.IsDigit(rune(s[n])) {
|
||||
if l == 0 {
|
||||
ind = n
|
||||
}
|
||||
|
||||
l++
|
||||
} else if s[n] == '-' {
|
||||
if l == 0 {
|
||||
ind = n
|
||||
l++
|
||||
} else {
|
||||
break
|
||||
}
|
||||
} else if s[n] == '.' {
|
||||
if l == 0 {
|
||||
ind = n
|
||||
}
|
||||
|
||||
if c > 0 {
|
||||
break
|
||||
}
|
||||
|
||||
l++
|
||||
c++
|
||||
} else if s[n] == 'e' && l > 0 && n+1 < len(s) && s[n+1] == '+' {
|
||||
l++
|
||||
} else if s[n] == '+' && l > 0 && s[n-1] == 'e' {
|
||||
if n+1 < len(s) && unicode.IsDigit(rune(s[n+1])) {
|
||||
l++
|
||||
continue
|
||||
}
|
||||
|
||||
l--
|
||||
break
|
||||
} else {
|
||||
if l > 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if l == 0 || s[ind:ind+l] == "." {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return s[ind : ind+l], true
|
||||
}
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
The MIT License
|
||||
|
||||
Copyright (c) 2016-2018 Tomas Aparicio
|
||||
|
||||
Permission is hereby granted, free of charge, to any person
|
||||
obtaining a copy of this software and associated documentation
|
||||
files (the "Software"), to deal in the Software without
|
||||
restriction, including without limitation the rights to use,
|
||||
copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the
|
||||
Software is furnished to do so, subject to the following
|
||||
conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be
|
||||
included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
|
||||
OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
|
||||
HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
|
||||
WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
|
||||
OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
|
@ -0,0 +1,176 @@
|
|||
package gock
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// mutex is used interally for locking thread-sensitive functions.
|
||||
var mutex = &sync.Mutex{}
|
||||
|
||||
// config global singleton store.
|
||||
var config = struct {
|
||||
Networking bool
|
||||
NetworkingFilters []FilterRequestFunc
|
||||
Observer ObserverFunc
|
||||
}{}
|
||||
|
||||
// ObserverFunc is implemented by users to inspect the outgoing intercepted HTTP traffic
|
||||
type ObserverFunc func(*http.Request, Mock)
|
||||
|
||||
// DumpRequest is a default implementation of ObserverFunc that dumps
|
||||
// the HTTP/1.x wire representation of the http request
|
||||
var DumpRequest ObserverFunc = func(request *http.Request, mock Mock) {
|
||||
bytes, _ := httputil.DumpRequestOut(request, true)
|
||||
fmt.Println(string(bytes))
|
||||
fmt.Printf("\nMatches: %v\n---\n", mock != nil)
|
||||
}
|
||||
|
||||
// track unmatched requests so they can be tested for
|
||||
var unmatchedRequests = []*http.Request{}
|
||||
|
||||
// New creates and registers a new HTTP mock with
|
||||
// default settings and returns the Request DSL for HTTP mock
|
||||
// definition and set up.
|
||||
func New(uri string) *Request {
|
||||
Intercept()
|
||||
|
||||
res := NewResponse()
|
||||
req := NewRequest()
|
||||
req.URLStruct, res.Error = url.Parse(normalizeURI(uri))
|
||||
|
||||
// Create the new mock expectation
|
||||
exp := NewMock(req, res)
|
||||
Register(exp)
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
// Intercepting returns true if gock is currently able to intercept.
|
||||
func Intercepting() bool {
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
return http.DefaultTransport == DefaultTransport
|
||||
}
|
||||
|
||||
// Intercept enables HTTP traffic interception via http.DefaultTransport.
|
||||
// If you are using a custom HTTP transport, you have to use `gock.Transport()`
|
||||
func Intercept() {
|
||||
if !Intercepting() {
|
||||
http.DefaultTransport = DefaultTransport
|
||||
}
|
||||
}
|
||||
|
||||
// InterceptClient allows the developer to intercept HTTP traffic using
|
||||
// a custom http.Client who uses a non default http.Transport/http.RoundTripper implementation.
|
||||
func InterceptClient(cli *http.Client) {
|
||||
_, ok := cli.Transport.(*Transport)
|
||||
if ok {
|
||||
return // if transport already intercepted, just ignore it
|
||||
}
|
||||
trans := NewTransport()
|
||||
trans.Transport = cli.Transport
|
||||
cli.Transport = trans
|
||||
}
|
||||
|
||||
// RestoreClient allows the developer to disable and restore the
|
||||
// original transport in the given http.Client.
|
||||
func RestoreClient(cli *http.Client) {
|
||||
trans, ok := cli.Transport.(*Transport)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
cli.Transport = trans.Transport
|
||||
}
|
||||
|
||||
// Disable disables HTTP traffic interception by gock.
|
||||
func Disable() {
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
http.DefaultTransport = NativeTransport
|
||||
}
|
||||
|
||||
// Off disables the default HTTP interceptors and removes
|
||||
// all the registered mocks, even if they has not been intercepted yet.
|
||||
func Off() {
|
||||
Flush()
|
||||
Disable()
|
||||
}
|
||||
|
||||
// OffAll is like `Off()`, but it also removes the unmatched requests registry.
|
||||
func OffAll() {
|
||||
Flush()
|
||||
Disable()
|
||||
CleanUnmatchedRequest()
|
||||
}
|
||||
|
||||
// Observe provides a hook to support inspection of the request and matched mock
|
||||
func Observe(fn ObserverFunc) {
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
config.Observer = fn
|
||||
}
|
||||
|
||||
// EnableNetworking enables real HTTP networking
|
||||
func EnableNetworking() {
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
config.Networking = true
|
||||
}
|
||||
|
||||
// DisableNetworking disables real HTTP networking
|
||||
func DisableNetworking() {
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
config.Networking = false
|
||||
}
|
||||
|
||||
// NetworkingFilter determines if an http.Request should be triggered or not.
|
||||
func NetworkingFilter(fn FilterRequestFunc) {
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
config.NetworkingFilters = append(config.NetworkingFilters, fn)
|
||||
}
|
||||
|
||||
// DisableNetworkingFilters disables registered networking filters.
|
||||
func DisableNetworkingFilters() {
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
config.NetworkingFilters = []FilterRequestFunc{}
|
||||
}
|
||||
|
||||
// GetUnmatchedRequests returns all requests that have been received but haven't matched any mock
|
||||
func GetUnmatchedRequests() []*http.Request {
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
return unmatchedRequests
|
||||
}
|
||||
|
||||
// HasUnmatchedRequest returns true if gock has received any requests that didn't match a mock
|
||||
func HasUnmatchedRequest() bool {
|
||||
return len(GetUnmatchedRequests()) > 0
|
||||
}
|
||||
|
||||
// CleanUnmatchedRequest cleans the unmatched requests internal registry.
|
||||
func CleanUnmatchedRequest() {
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
unmatchedRequests = []*http.Request{}
|
||||
}
|
||||
|
||||
func trackUnmatchedRequest(req *http.Request) {
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
unmatchedRequests = append(unmatchedRequests, req)
|
||||
}
|
||||
|
||||
func normalizeURI(uri string) string {
|
||||
if ok, _ := regexp.MatchString("^http[s]?", uri); !ok {
|
||||
return "http://" + uri
|
||||
}
|
||||
return uri
|
||||
}
|
||||
|
|
@ -0,0 +1,118 @@
|
|||
package gock
|
||||
|
||||
import "net/http"
|
||||
|
||||
// MatchersHeader exposes an slice of HTTP header specific mock matchers.
|
||||
var MatchersHeader = []MatchFunc{
|
||||
MatchMethod,
|
||||
MatchScheme,
|
||||
MatchHost,
|
||||
MatchPath,
|
||||
MatchHeaders,
|
||||
MatchQueryParams,
|
||||
MatchPathParams,
|
||||
}
|
||||
|
||||
// MatchersBody exposes an slice of HTTP body specific built-in mock matchers.
|
||||
var MatchersBody = []MatchFunc{
|
||||
MatchBody,
|
||||
}
|
||||
|
||||
// Matchers stores all the built-in mock matchers.
|
||||
var Matchers = append(MatchersHeader, MatchersBody...)
|
||||
|
||||
// DefaultMatcher stores the default Matcher instance used to match mocks.
|
||||
var DefaultMatcher = NewMatcher()
|
||||
|
||||
// MatchFunc represents the required function
|
||||
// interface implemented by matchers.
|
||||
type MatchFunc func(*http.Request, *Request) (bool, error)
|
||||
|
||||
// Matcher represents the required interface implemented by mock matchers.
|
||||
type Matcher interface {
|
||||
// Get returns a slice of registered function matchers.
|
||||
Get() []MatchFunc
|
||||
|
||||
// Add adds a new matcher function.
|
||||
Add(MatchFunc)
|
||||
|
||||
// Set sets the matchers functions stack.
|
||||
Set([]MatchFunc)
|
||||
|
||||
// Flush flushes the current matchers function stack.
|
||||
Flush()
|
||||
|
||||
// Match matches the given http.Request with a mock Request.
|
||||
Match(*http.Request, *Request) (bool, error)
|
||||
}
|
||||
|
||||
// MockMatcher implements a mock matcher
|
||||
type MockMatcher struct {
|
||||
Matchers []MatchFunc
|
||||
}
|
||||
|
||||
// NewMatcher creates a new mock matcher
|
||||
// using the default matcher functions.
|
||||
func NewMatcher() *MockMatcher {
|
||||
return &MockMatcher{Matchers: Matchers}
|
||||
}
|
||||
|
||||
// NewBasicMatcher creates a new matcher with header only mock matchers.
|
||||
func NewBasicMatcher() *MockMatcher {
|
||||
return &MockMatcher{Matchers: MatchersHeader}
|
||||
}
|
||||
|
||||
// NewEmptyMatcher creates a new empty matcher with out default amtchers.
|
||||
func NewEmptyMatcher() *MockMatcher {
|
||||
return &MockMatcher{Matchers: []MatchFunc{}}
|
||||
}
|
||||
|
||||
// Get returns a slice of registered function matchers.
|
||||
func (m *MockMatcher) Get() []MatchFunc {
|
||||
return m.Matchers
|
||||
}
|
||||
|
||||
// Add adds a new function matcher.
|
||||
func (m *MockMatcher) Add(fn MatchFunc) {
|
||||
m.Matchers = append(m.Matchers, fn)
|
||||
}
|
||||
|
||||
// Set sets a new stack of matchers functions.
|
||||
func (m *MockMatcher) Set(stack []MatchFunc) {
|
||||
m.Matchers = stack
|
||||
}
|
||||
|
||||
// Flush flushes the current matcher
|
||||
func (m *MockMatcher) Flush() {
|
||||
m.Matchers = []MatchFunc{}
|
||||
}
|
||||
|
||||
// Match matches the given http.Request with a mock request
|
||||
// returning true in case that the request matches, otherwise false.
|
||||
func (m *MockMatcher) Match(req *http.Request, ereq *Request) (bool, error) {
|
||||
for _, matcher := range m.Matchers {
|
||||
matches, err := matcher(req, ereq)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if !matches {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// MatchMock is a helper function that matches the given http.Request
|
||||
// in the list of registered mocks, returning it if matches or error if it fails.
|
||||
func MatchMock(req *http.Request) (Mock, error) {
|
||||
for _, mock := range GetAll() {
|
||||
matches, err := mock.Match(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if matches {
|
||||
return mock, nil
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,253 @@
|
|||
package gock
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/h2non/parth"
|
||||
)
|
||||
|
||||
// EOL represents the end of line character.
|
||||
const EOL = 0xa
|
||||
|
||||
// BodyTypes stores the supported MIME body types for matching.
|
||||
// Currently only text-based types.
|
||||
var BodyTypes = []string{
|
||||
"text/html",
|
||||
"text/plain",
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"multipart/form-data",
|
||||
"application/x-www-form-urlencoded",
|
||||
}
|
||||
|
||||
// BodyTypeAliases stores a generic MIME type by alias.
|
||||
var BodyTypeAliases = map[string]string{
|
||||
"html": "text/html",
|
||||
"text": "text/plain",
|
||||
"json": "application/json",
|
||||
"xml": "application/xml",
|
||||
"form": "multipart/form-data",
|
||||
"url": "application/x-www-form-urlencoded",
|
||||
}
|
||||
|
||||
// CompressionSchemes stores the supported Content-Encoding types for decompression.
|
||||
var CompressionSchemes = []string{
|
||||
"gzip",
|
||||
}
|
||||
|
||||
// MatchMethod matches the HTTP method of the given request.
|
||||
func MatchMethod(req *http.Request, ereq *Request) (bool, error) {
|
||||
return ereq.Method == "" || req.Method == ereq.Method, nil
|
||||
}
|
||||
|
||||
// MatchScheme matches the request URL protocol scheme.
|
||||
func MatchScheme(req *http.Request, ereq *Request) (bool, error) {
|
||||
return ereq.URLStruct.Scheme == "" || req.URL.Scheme == "" || ereq.URLStruct.Scheme == req.URL.Scheme, nil
|
||||
}
|
||||
|
||||
// MatchHost matches the HTTP host header field of the given request.
|
||||
func MatchHost(req *http.Request, ereq *Request) (bool, error) {
|
||||
url := ereq.URLStruct
|
||||
if strings.EqualFold(url.Host, req.URL.Host) {
|
||||
return true, nil
|
||||
}
|
||||
return regexp.MatchString(url.Host, req.URL.Host)
|
||||
}
|
||||
|
||||
// MatchPath matches the HTTP URL path of the given request.
|
||||
func MatchPath(req *http.Request, ereq *Request) (bool, error) {
|
||||
if req.URL.Path == ereq.URLStruct.Path {
|
||||
return true, nil
|
||||
}
|
||||
return regexp.MatchString(ereq.URLStruct.Path, req.URL.Path)
|
||||
}
|
||||
|
||||
// MatchHeaders matches the headers fields of the given request.
|
||||
func MatchHeaders(req *http.Request, ereq *Request) (bool, error) {
|
||||
for key, value := range ereq.Header {
|
||||
var err error
|
||||
var match bool
|
||||
|
||||
for _, field := range req.Header[key] {
|
||||
match, err = regexp.MatchString(value[0], field)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if match {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !match {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// MatchQueryParams matches the URL query params fields of the given request.
|
||||
func MatchQueryParams(req *http.Request, ereq *Request) (bool, error) {
|
||||
for key, value := range ereq.URLStruct.Query() {
|
||||
var err error
|
||||
var match bool
|
||||
|
||||
for _, field := range req.URL.Query()[key] {
|
||||
match, err = regexp.MatchString(value[0], field)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if match {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !match {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// MatchPathParams matches the URL path parameters of the given request.
|
||||
func MatchPathParams(req *http.Request, ereq *Request) (bool, error) {
|
||||
for key, value := range ereq.PathParams {
|
||||
var s string
|
||||
|
||||
if err := parth.Sequent(req.URL.Path, key, &s); err != nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if s != value {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// MatchBody tries to match the request body.
|
||||
// TODO: not too smart now, needs several improvements.
|
||||
func MatchBody(req *http.Request, ereq *Request) (bool, error) {
|
||||
// If match body is empty, just continue
|
||||
if req.Method == "HEAD" || len(ereq.BodyBuffer) == 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Only can match certain MIME body types
|
||||
if !supportedType(req) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Can only match certain compression schemes
|
||||
if !supportedCompressionScheme(req) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Create a reader for the body depending on compression type
|
||||
bodyReader := req.Body
|
||||
if ereq.CompressionScheme != "" {
|
||||
if ereq.CompressionScheme != req.Header.Get("Content-Encoding") {
|
||||
return false, nil
|
||||
}
|
||||
compressedBodyReader, err := compressionReader(req.Body, ereq.CompressionScheme)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
bodyReader = compressedBodyReader
|
||||
}
|
||||
|
||||
// Read the whole request body
|
||||
body, err := ioutil.ReadAll(bodyReader)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Restore body reader stream
|
||||
req.Body = createReadCloser(body)
|
||||
|
||||
// If empty, ignore the match
|
||||
if len(body) == 0 && len(ereq.BodyBuffer) != 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Match body by atomic string comparison
|
||||
bodyStr := castToString(body)
|
||||
matchStr := castToString(ereq.BodyBuffer)
|
||||
if bodyStr == matchStr {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Match request body by regexp
|
||||
match, _ := regexp.MatchString(matchStr, bodyStr)
|
||||
if match == true {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// todo - add conditional do only perform the conversion of body bytes
|
||||
// representation of JSON to a map and then compare them for equality.
|
||||
|
||||
// Check if the key + value pairs match
|
||||
var bodyMap map[string]interface{}
|
||||
var matchMap map[string]interface{}
|
||||
|
||||
// Ensure that both byte bodies that that should be JSON can be converted to maps.
|
||||
umErr := json.Unmarshal(body, &bodyMap)
|
||||
umErr2 := json.Unmarshal(ereq.BodyBuffer, &matchMap)
|
||||
if umErr == nil && umErr2 == nil && reflect.DeepEqual(bodyMap, matchMap) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func supportedType(req *http.Request) bool {
|
||||
mime := req.Header.Get("Content-Type")
|
||||
if mime == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, kind := range BodyTypes {
|
||||
if match, _ := regexp.MatchString(kind, mime); match {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func supportedCompressionScheme(req *http.Request) bool {
|
||||
encoding := req.Header.Get("Content-Encoding")
|
||||
if encoding == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, kind := range CompressionSchemes {
|
||||
if match, _ := regexp.MatchString(kind, encoding); match {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func castToString(buf []byte) string {
|
||||
str := string(buf)
|
||||
tail := len(str) - 1
|
||||
if str[tail] == EOL {
|
||||
str = str[:tail]
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
func compressionReader(r io.ReadCloser, scheme string) (io.ReadCloser, error) {
|
||||
switch scheme {
|
||||
case "gzip":
|
||||
return gzip.NewReader(r)
|
||||
default:
|
||||
return r, nil
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,146 @@
|
|||
package gock
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Mock represents the required interface that must
|
||||
// be implemented by HTTP mock instances.
|
||||
type Mock interface {
|
||||
// Disable disables the current mock manually.
|
||||
Disable()
|
||||
|
||||
// Done returns true if the current mock is disabled.
|
||||
Done() bool
|
||||
|
||||
// Request returns the mock Request instance.
|
||||
Request() *Request
|
||||
|
||||
// Response returns the mock Response instance.
|
||||
Response() *Response
|
||||
|
||||
// Match matches the given http.Request with the current mock.
|
||||
Match(*http.Request) (bool, error)
|
||||
|
||||
// AddMatcher adds a new matcher function.
|
||||
AddMatcher(MatchFunc)
|
||||
|
||||
// SetMatcher uses a new matcher implementation.
|
||||
SetMatcher(Matcher)
|
||||
}
|
||||
|
||||
// Mocker implements a Mock capable interface providing
|
||||
// a default mock configuration used internally to store mocks.
|
||||
type Mocker struct {
|
||||
// disabled stores if the current mock is disabled.
|
||||
disabled bool
|
||||
|
||||
// mutex stores the mock mutex for thread safity.
|
||||
mutex sync.Mutex
|
||||
|
||||
// matcher stores a Matcher capable instance to match the given http.Request.
|
||||
matcher Matcher
|
||||
|
||||
// request stores the mock Request to match.
|
||||
request *Request
|
||||
|
||||
// response stores the mock Response to use in case of match.
|
||||
response *Response
|
||||
}
|
||||
|
||||
// NewMock creates a new HTTP mock based on the given request and response instances.
|
||||
// It's mostly used internally.
|
||||
func NewMock(req *Request, res *Response) *Mocker {
|
||||
mock := &Mocker{
|
||||
request: req,
|
||||
response: res,
|
||||
matcher: DefaultMatcher,
|
||||
}
|
||||
res.Mock = mock
|
||||
req.Mock = mock
|
||||
req.Response = res
|
||||
return mock
|
||||
}
|
||||
|
||||
// Disable disables the current mock manually.
|
||||
func (m *Mocker) Disable() {
|
||||
m.disabled = true
|
||||
}
|
||||
|
||||
// Done returns true in case that the current mock
|
||||
// instance is disabled and therefore must be removed.
|
||||
func (m *Mocker) Done() bool {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
return m.disabled || (!m.request.Persisted && m.request.Counter == 0)
|
||||
}
|
||||
|
||||
// Request returns the Request instance
|
||||
// configured for the current HTTP mock.
|
||||
func (m *Mocker) Request() *Request {
|
||||
return m.request
|
||||
}
|
||||
|
||||
// Response returns the Response instance
|
||||
// configured for the current HTTP mock.
|
||||
func (m *Mocker) Response() *Response {
|
||||
return m.response
|
||||
}
|
||||
|
||||
// Match matches the given http.Request with the current Request
|
||||
// mock expectation, returning true if matches.
|
||||
func (m *Mocker) Match(req *http.Request) (bool, error) {
|
||||
if m.disabled {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Filter
|
||||
for _, filter := range m.request.Filters {
|
||||
if !filter(req) {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Map
|
||||
for _, mapper := range m.request.Mappers {
|
||||
if treq := mapper(req); treq != nil {
|
||||
req = treq
|
||||
}
|
||||
}
|
||||
|
||||
// Match
|
||||
matches, err := m.matcher.Match(req, m.request)
|
||||
if matches {
|
||||
m.decrement()
|
||||
}
|
||||
|
||||
return matches, err
|
||||
}
|
||||
|
||||
// SetMatcher sets a new matcher implementation
|
||||
// for the current mock expectation.
|
||||
func (m *Mocker) SetMatcher(matcher Matcher) {
|
||||
m.matcher = matcher
|
||||
}
|
||||
|
||||
// AddMatcher adds a new matcher function
|
||||
// for the current mock expectation.
|
||||
func (m *Mocker) AddMatcher(fn MatchFunc) {
|
||||
m.matcher.Add(fn)
|
||||
}
|
||||
|
||||
// decrement decrements the current mock Request counter.
|
||||
func (m *Mocker) decrement() {
|
||||
if m.request.Persisted {
|
||||
return
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.request.Counter--
|
||||
if m.request.Counter == 0 {
|
||||
m.disabled = true
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,299 @@
|
|||
package gock
|
||||
|
||||
import (
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// MapRequestFunc represents the required function interface for request mappers.
|
||||
type MapRequestFunc func(*http.Request) *http.Request
|
||||
|
||||
// FilterRequestFunc represents the required function interface for request filters.
|
||||
type FilterRequestFunc func(*http.Request) bool
|
||||
|
||||
// Request represents the high-level HTTP request used to store
|
||||
// request fields used to match intercepted requests.
|
||||
type Request struct {
|
||||
// Mock stores the parent mock reference for the current request mock used for method delegation.
|
||||
Mock Mock
|
||||
|
||||
// Response stores the current Response instance for the current matches Request.
|
||||
Response *Response
|
||||
|
||||
// Error stores the latest mock request configuration error.
|
||||
Error error
|
||||
|
||||
// Counter stores the pending times that the current mock should be active.
|
||||
Counter int
|
||||
|
||||
// Persisted stores if the current mock should be always active.
|
||||
Persisted bool
|
||||
|
||||
// URLStruct stores the parsed URL as *url.URL struct.
|
||||
URLStruct *url.URL
|
||||
|
||||
// Method stores the Request HTTP method to match.
|
||||
Method string
|
||||
|
||||
// CompressionScheme stores the Request Compression scheme to match and use for decompression.
|
||||
CompressionScheme string
|
||||
|
||||
// Header stores the HTTP header fields to match.
|
||||
Header http.Header
|
||||
|
||||
// Cookies stores the Request HTTP cookies values to match.
|
||||
Cookies []*http.Cookie
|
||||
|
||||
// PathParams stores the path parameters to match.
|
||||
PathParams map[string]string
|
||||
|
||||
// BodyBuffer stores the body data to match.
|
||||
BodyBuffer []byte
|
||||
|
||||
// Mappers stores the request functions mappers used for matching.
|
||||
Mappers []MapRequestFunc
|
||||
|
||||
// Filters stores the request functions filters used for matching.
|
||||
Filters []FilterRequestFunc
|
||||
}
|
||||
|
||||
// NewRequest creates a new Request instance.
|
||||
func NewRequest() *Request {
|
||||
return &Request{
|
||||
Counter: 1,
|
||||
URLStruct: &url.URL{},
|
||||
Header: make(http.Header),
|
||||
PathParams: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// URL defines the mock URL to match.
|
||||
func (r *Request) URL(uri string) *Request {
|
||||
r.URLStruct, r.Error = url.Parse(uri)
|
||||
return r
|
||||
}
|
||||
|
||||
// SetURL defines the url.URL struct to be used for matching.
|
||||
func (r *Request) SetURL(u *url.URL) *Request {
|
||||
r.URLStruct = u
|
||||
return r
|
||||
}
|
||||
|
||||
// Path defines the mock URL path value to match.
|
||||
func (r *Request) Path(path string) *Request {
|
||||
r.URLStruct.Path = path
|
||||
return r
|
||||
}
|
||||
|
||||
// Get specifies the GET method and the given URL path to match.
|
||||
func (r *Request) Get(path string) *Request {
|
||||
return r.method("GET", path)
|
||||
}
|
||||
|
||||
// Post specifies the POST method and the given URL path to match.
|
||||
func (r *Request) Post(path string) *Request {
|
||||
return r.method("POST", path)
|
||||
}
|
||||
|
||||
// Put specifies the PUT method and the given URL path to match.
|
||||
func (r *Request) Put(path string) *Request {
|
||||
return r.method("PUT", path)
|
||||
}
|
||||
|
||||
// Delete specifies the DELETE method and the given URL path to match.
|
||||
func (r *Request) Delete(path string) *Request {
|
||||
return r.method("DELETE", path)
|
||||
}
|
||||
|
||||
// Patch specifies the PATCH method and the given URL path to match.
|
||||
func (r *Request) Patch(path string) *Request {
|
||||
return r.method("PATCH", path)
|
||||
}
|
||||
|
||||
// Head specifies the HEAD method and the given URL path to match.
|
||||
func (r *Request) Head(path string) *Request {
|
||||
return r.method("HEAD", path)
|
||||
}
|
||||
|
||||
// method is a DRY shortcut used to declare the expected HTTP method and URL path.
|
||||
func (r *Request) method(method, path string) *Request {
|
||||
if path != "/" {
|
||||
r.URLStruct.Path = path
|
||||
}
|
||||
r.Method = strings.ToUpper(method)
|
||||
return r
|
||||
}
|
||||
|
||||
// Body defines the body data to match based on a io.Reader interface.
|
||||
func (r *Request) Body(body io.Reader) *Request {
|
||||
r.BodyBuffer, r.Error = ioutil.ReadAll(body)
|
||||
return r
|
||||
}
|
||||
|
||||
// BodyString defines the body to match based on a given string.
|
||||
func (r *Request) BodyString(body string) *Request {
|
||||
r.BodyBuffer = []byte(body)
|
||||
return r
|
||||
}
|
||||
|
||||
// File defines the body to match based on the given file path string.
|
||||
func (r *Request) File(path string) *Request {
|
||||
r.BodyBuffer, r.Error = ioutil.ReadFile(path)
|
||||
return r
|
||||
}
|
||||
|
||||
// Compression defines the request compression scheme, and enables automatic body decompression.
|
||||
// Supports only the "gzip" scheme so far.
|
||||
func (r *Request) Compression(scheme string) *Request {
|
||||
r.Header.Set("Content-Encoding", scheme)
|
||||
r.CompressionScheme = scheme
|
||||
return r
|
||||
}
|
||||
|
||||
// JSON defines the JSON body to match based on a given structure.
|
||||
func (r *Request) JSON(data interface{}) *Request {
|
||||
if r.Header.Get("Content-Type") == "" {
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
r.BodyBuffer, r.Error = readAndDecode(data, "json")
|
||||
return r
|
||||
}
|
||||
|
||||
// XML defines the XML body to match based on a given structure.
|
||||
func (r *Request) XML(data interface{}) *Request {
|
||||
if r.Header.Get("Content-Type") == "" {
|
||||
r.Header.Set("Content-Type", "application/xml")
|
||||
}
|
||||
r.BodyBuffer, r.Error = readAndDecode(data, "xml")
|
||||
return r
|
||||
}
|
||||
|
||||
// MatchType defines the request Content-Type MIME header field.
|
||||
// Supports type alias. E.g: json, xml, form, text...
|
||||
func (r *Request) MatchType(kind string) *Request {
|
||||
mime := BodyTypeAliases[kind]
|
||||
if mime != "" {
|
||||
kind = mime
|
||||
}
|
||||
r.Header.Set("Content-Type", kind)
|
||||
return r
|
||||
}
|
||||
|
||||
// MatchHeader defines a new key and value header to match.
|
||||
func (r *Request) MatchHeader(key, value string) *Request {
|
||||
r.Header.Set(key, value)
|
||||
return r
|
||||
}
|
||||
|
||||
// HeaderPresent defines that a header field must be present in the request.
|
||||
func (r *Request) HeaderPresent(key string) *Request {
|
||||
r.Header.Set(key, ".*")
|
||||
return r
|
||||
}
|
||||
|
||||
// MatchHeaders defines a map of key-value headers to match.
|
||||
func (r *Request) MatchHeaders(headers map[string]string) *Request {
|
||||
for key, value := range headers {
|
||||
r.Header.Set(key, value)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// MatchParam defines a new key and value URL query param to match.
|
||||
func (r *Request) MatchParam(key, value string) *Request {
|
||||
query := r.URLStruct.Query()
|
||||
query.Set(key, value)
|
||||
r.URLStruct.RawQuery = query.Encode()
|
||||
return r
|
||||
}
|
||||
|
||||
// MatchParams defines a map of URL query param key-value to match.
|
||||
func (r *Request) MatchParams(params map[string]string) *Request {
|
||||
query := r.URLStruct.Query()
|
||||
for key, value := range params {
|
||||
query.Set(key, value)
|
||||
}
|
||||
r.URLStruct.RawQuery = query.Encode()
|
||||
return r
|
||||
}
|
||||
|
||||
// ParamPresent matches if the given query param key is present in the URL.
|
||||
func (r *Request) ParamPresent(key string) *Request {
|
||||
r.MatchParam(key, ".*")
|
||||
return r
|
||||
}
|
||||
|
||||
// PathParam matches if a given path parameter key is present in the URL.
|
||||
//
|
||||
// The value is representative of the restful resource the key defines, e.g.
|
||||
// // /users/123/name
|
||||
// r.PathParam("users", "123")
|
||||
// would match.
|
||||
func (r *Request) PathParam(key, val string) *Request {
|
||||
r.PathParams[key] = val
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// Persist defines the current HTTP mock as persistent and won't be removed after intercepting it.
|
||||
func (r *Request) Persist() *Request {
|
||||
r.Persisted = true
|
||||
return r
|
||||
}
|
||||
|
||||
// Times defines the number of times that the current HTTP mock should remain active.
|
||||
func (r *Request) Times(num int) *Request {
|
||||
r.Counter = num
|
||||
return r
|
||||
}
|
||||
|
||||
// AddMatcher adds a new matcher function to match the request.
|
||||
func (r *Request) AddMatcher(fn MatchFunc) *Request {
|
||||
r.Mock.AddMatcher(fn)
|
||||
return r
|
||||
}
|
||||
|
||||
// SetMatcher sets a new matcher function to match the request.
|
||||
func (r *Request) SetMatcher(matcher Matcher) *Request {
|
||||
r.Mock.SetMatcher(matcher)
|
||||
return r
|
||||
}
|
||||
|
||||
// Map adds a new request mapper function to map http.Request before the matching process.
|
||||
func (r *Request) Map(fn MapRequestFunc) *Request {
|
||||
r.Mappers = append(r.Mappers, fn)
|
||||
return r
|
||||
}
|
||||
|
||||
// Filter filters a new request filter function to filter http.Request before the matching process.
|
||||
func (r *Request) Filter(fn FilterRequestFunc) *Request {
|
||||
r.Filters = append(r.Filters, fn)
|
||||
return r
|
||||
}
|
||||
|
||||
// EnableNetworking enables the use real networking for the current mock.
|
||||
func (r *Request) EnableNetworking() *Request {
|
||||
if r.Response != nil {
|
||||
r.Response.UseNetwork = true
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// Reply defines the Response status code and returns the mock Response DSL.
|
||||
func (r *Request) Reply(status int) *Response {
|
||||
return r.Response.Status(status)
|
||||
}
|
||||
|
||||
// ReplyError defines the Response simulated error.
|
||||
func (r *Request) ReplyError(err error) *Response {
|
||||
return r.Response.SetError(err)
|
||||
}
|
||||
|
||||
// ReplyFunc allows the developer to define the mock response via a custom function.
|
||||
func (r *Request) ReplyFunc(replier func(*Response)) *Response {
|
||||
replier(r.Response)
|
||||
return r.Response
|
||||
}
|
||||
|
|
@ -0,0 +1,87 @@
|
|||
package gock
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Responder builds a mock http.Response based on the given Response mock.
|
||||
func Responder(req *http.Request, mock *Response, res *http.Response) (*http.Response, error) {
|
||||
// If error present, reply it
|
||||
err := mock.Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if res == nil {
|
||||
res = createResponse(req)
|
||||
}
|
||||
|
||||
// Apply response filter
|
||||
for _, filter := range mock.Filters {
|
||||
if !filter(res) {
|
||||
return res, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Define mock status code
|
||||
if mock.StatusCode != 0 {
|
||||
res.Status = strconv.Itoa(mock.StatusCode) + " " + http.StatusText(mock.StatusCode)
|
||||
res.StatusCode = mock.StatusCode
|
||||
}
|
||||
|
||||
// Define headers by merging fields
|
||||
res.Header = mergeHeaders(res, mock)
|
||||
|
||||
// Define mock body, if present
|
||||
if len(mock.BodyBuffer) > 0 {
|
||||
res.ContentLength = int64(len(mock.BodyBuffer))
|
||||
res.Body = createReadCloser(mock.BodyBuffer)
|
||||
}
|
||||
|
||||
// Apply response mappers
|
||||
for _, mapper := range mock.Mappers {
|
||||
if tres := mapper(res); tres != nil {
|
||||
res = tres
|
||||
}
|
||||
}
|
||||
|
||||
// Sleep to simulate delay, if necessary
|
||||
if mock.ResponseDelay > 0 {
|
||||
time.Sleep(mock.ResponseDelay)
|
||||
}
|
||||
|
||||
return res, err
|
||||
}
|
||||
|
||||
// createResponse creates a new http.Response with default fields.
|
||||
func createResponse(req *http.Request) *http.Response {
|
||||
return &http.Response{
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
Proto: "HTTP/1.1",
|
||||
Request: req,
|
||||
Header: make(http.Header),
|
||||
Body: createReadCloser([]byte{}),
|
||||
}
|
||||
}
|
||||
|
||||
// mergeHeaders copies the mock headers.
|
||||
func mergeHeaders(res *http.Response, mres *Response) http.Header {
|
||||
for key, values := range mres.Header {
|
||||
for _, value := range values {
|
||||
res.Header.Add(key, value)
|
||||
}
|
||||
}
|
||||
return res.Header
|
||||
}
|
||||
|
||||
// createReadCloser creates an io.ReadCloser from a byte slice that is suitable for use as an
|
||||
// http response body.
|
||||
func createReadCloser(body []byte) io.ReadCloser {
|
||||
return ioutil.NopCloser(bytes.NewReader(body))
|
||||
}
|
||||
|
|
@ -0,0 +1,186 @@
|
|||
package gock
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MapResponseFunc represents the required function interface impletemed by response mappers.
|
||||
type MapResponseFunc func(*http.Response) *http.Response
|
||||
|
||||
// FilterResponseFunc represents the required function interface impletemed by response filters.
|
||||
type FilterResponseFunc func(*http.Response) bool
|
||||
|
||||
// Response represents high-level HTTP fields to configure
|
||||
// and define HTTP responses intercepted by gock.
|
||||
type Response struct {
|
||||
// Mock stores the parent mock reference for the current response mock used for method delegation.
|
||||
Mock Mock
|
||||
|
||||
// Error stores the latest response configuration or injected error.
|
||||
Error error
|
||||
|
||||
// UseNetwork enables the use of real network for the current mock.
|
||||
UseNetwork bool
|
||||
|
||||
// StatusCode stores the response status code.
|
||||
StatusCode int
|
||||
|
||||
// Headers stores the response headers.
|
||||
Header http.Header
|
||||
|
||||
// Cookies stores the response cookie fields.
|
||||
Cookies []*http.Cookie
|
||||
|
||||
// BodyBuffer stores the array of bytes to use as body.
|
||||
BodyBuffer []byte
|
||||
|
||||
// ResponseDelay stores the simulated response delay.
|
||||
ResponseDelay time.Duration
|
||||
|
||||
// Mappers stores the request functions mappers used for matching.
|
||||
Mappers []MapResponseFunc
|
||||
|
||||
// Filters stores the request functions filters used for matching.
|
||||
Filters []FilterResponseFunc
|
||||
}
|
||||
|
||||
// NewResponse creates a new Response.
|
||||
func NewResponse() *Response {
|
||||
return &Response{Header: make(http.Header)}
|
||||
}
|
||||
|
||||
// Status defines the desired HTTP status code to reply in the current response.
|
||||
func (r *Response) Status(code int) *Response {
|
||||
r.StatusCode = code
|
||||
return r
|
||||
}
|
||||
|
||||
// Type defines the response Content-Type MIME header field.
|
||||
// Supports type alias. E.g: json, xml, form, text...
|
||||
func (r *Response) Type(kind string) *Response {
|
||||
mime := BodyTypeAliases[kind]
|
||||
if mime != "" {
|
||||
kind = mime
|
||||
}
|
||||
r.Header.Set("Content-Type", kind)
|
||||
return r
|
||||
}
|
||||
|
||||
// SetHeader sets a new header field in the mock response.
|
||||
func (r *Response) SetHeader(key, value string) *Response {
|
||||
r.Header.Set(key, value)
|
||||
return r
|
||||
}
|
||||
|
||||
// AddHeader adds a new header field in the mock response
|
||||
// with out removing an existent one.
|
||||
func (r *Response) AddHeader(key, value string) *Response {
|
||||
r.Header.Add(key, value)
|
||||
return r
|
||||
}
|
||||
|
||||
// SetHeaders sets a map of header fields in the mock response.
|
||||
func (r *Response) SetHeaders(headers map[string]string) *Response {
|
||||
for key, value := range headers {
|
||||
r.Header.Add(key, value)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// Body sets the HTTP response body to be used.
|
||||
func (r *Response) Body(body io.Reader) *Response {
|
||||
r.BodyBuffer, r.Error = ioutil.ReadAll(body)
|
||||
return r
|
||||
}
|
||||
|
||||
// BodyString defines the response body as string.
|
||||
func (r *Response) BodyString(body string) *Response {
|
||||
r.BodyBuffer = []byte(body)
|
||||
return r
|
||||
}
|
||||
|
||||
// File defines the response body reading the data
|
||||
// from disk based on the file path string.
|
||||
func (r *Response) File(path string) *Response {
|
||||
r.BodyBuffer, r.Error = ioutil.ReadFile(path)
|
||||
return r
|
||||
}
|
||||
|
||||
// JSON defines the response body based on a JSON based input.
|
||||
func (r *Response) JSON(data interface{}) *Response {
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
r.BodyBuffer, r.Error = readAndDecode(data, "json")
|
||||
return r
|
||||
}
|
||||
|
||||
// XML defines the response body based on a XML based input.
|
||||
func (r *Response) XML(data interface{}) *Response {
|
||||
r.Header.Set("Content-Type", "application/xml")
|
||||
r.BodyBuffer, r.Error = readAndDecode(data, "xml")
|
||||
return r
|
||||
}
|
||||
|
||||
// SetError defines the response simulated error.
|
||||
func (r *Response) SetError(err error) *Response {
|
||||
r.Error = err
|
||||
return r
|
||||
}
|
||||
|
||||
// Delay defines the response simulated delay.
|
||||
// This feature is still experimental and will be improved in the future.
|
||||
func (r *Response) Delay(delay time.Duration) *Response {
|
||||
r.ResponseDelay = delay
|
||||
return r
|
||||
}
|
||||
|
||||
// Map adds a new response mapper function to map http.Response before the matching process.
|
||||
func (r *Response) Map(fn MapResponseFunc) *Response {
|
||||
r.Mappers = append(r.Mappers, fn)
|
||||
return r
|
||||
}
|
||||
|
||||
// Filter filters a new request filter function to filter http.Request before the matching process.
|
||||
func (r *Response) Filter(fn FilterResponseFunc) *Response {
|
||||
r.Filters = append(r.Filters, fn)
|
||||
return r
|
||||
}
|
||||
|
||||
// EnableNetworking enables the use real networking for the current mock.
|
||||
func (r *Response) EnableNetworking() *Response {
|
||||
r.UseNetwork = true
|
||||
return r
|
||||
}
|
||||
|
||||
// Done returns true if the mock was done and disabled.
|
||||
func (r *Response) Done() bool {
|
||||
return r.Mock.Done()
|
||||
}
|
||||
|
||||
func readAndDecode(data interface{}, kind string) ([]byte, error) {
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
switch data.(type) {
|
||||
case string:
|
||||
buf.WriteString(data.(string))
|
||||
case []byte:
|
||||
buf.Write(data.([]byte))
|
||||
default:
|
||||
var err error
|
||||
if kind == "xml" {
|
||||
err = xml.NewEncoder(buf).Encode(data)
|
||||
} else {
|
||||
err = json.NewEncoder(buf).Encode(data)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return ioutil.ReadAll(buf)
|
||||
}
|
||||
|
|
@ -0,0 +1,100 @@
|
|||
package gock
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// storeMutex is used interally for store synchronization.
|
||||
var storeMutex = sync.RWMutex{}
|
||||
|
||||
// mocks is internally used to store registered mocks.
|
||||
var mocks = []Mock{}
|
||||
|
||||
// Register registers a new mock in the current mocks stack.
|
||||
func Register(mock Mock) {
|
||||
if Exists(mock) {
|
||||
return
|
||||
}
|
||||
|
||||
// Make ops thread safe
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
|
||||
// Expose mock in request/response for delegation
|
||||
mock.Request().Mock = mock
|
||||
mock.Response().Mock = mock
|
||||
|
||||
// Registers the mock in the global store
|
||||
mocks = append(mocks, mock)
|
||||
}
|
||||
|
||||
// GetAll returns the current stack of registed mocks.
|
||||
func GetAll() []Mock {
|
||||
storeMutex.RLock()
|
||||
defer storeMutex.RUnlock()
|
||||
return mocks
|
||||
}
|
||||
|
||||
// Exists checks if the given Mock is already registered.
|
||||
func Exists(m Mock) bool {
|
||||
storeMutex.RLock()
|
||||
defer storeMutex.RUnlock()
|
||||
for _, mock := range mocks {
|
||||
if mock == m {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Remove removes a registered mock by reference.
|
||||
func Remove(m Mock) {
|
||||
for i, mock := range mocks {
|
||||
if mock == m {
|
||||
storeMutex.Lock()
|
||||
mocks = append(mocks[:i], mocks[i+1:]...)
|
||||
storeMutex.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Flush flushes the current stack of registered mocks.
|
||||
func Flush() {
|
||||
storeMutex.Lock()
|
||||
defer storeMutex.Unlock()
|
||||
mocks = []Mock{}
|
||||
}
|
||||
|
||||
// Pending returns an slice of pending mocks.
|
||||
func Pending() []Mock {
|
||||
Clean()
|
||||
storeMutex.RLock()
|
||||
defer storeMutex.RUnlock()
|
||||
return mocks
|
||||
}
|
||||
|
||||
// IsDone returns true if all the registered mocks has been triggered successfully.
|
||||
func IsDone() bool {
|
||||
return !IsPending()
|
||||
}
|
||||
|
||||
// IsPending returns true if there are pending mocks.
|
||||
func IsPending() bool {
|
||||
return len(Pending()) > 0
|
||||
}
|
||||
|
||||
// Clean cleans the mocks store removing disabled or obsolete mocks.
|
||||
func Clean() {
|
||||
storeMutex.Lock()
|
||||
defer storeMutex.Unlock()
|
||||
|
||||
buf := []Mock{}
|
||||
for _, mock := range mocks {
|
||||
if mock.Done() {
|
||||
continue
|
||||
}
|
||||
buf = append(buf, mock)
|
||||
}
|
||||
|
||||
mocks = buf
|
||||
}
|
||||
|
|
@ -0,0 +1,112 @@
|
|||
package gock
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// var mutex *sync.Mutex = &sync.Mutex{}
|
||||
|
||||
var (
|
||||
// DefaultTransport stores the default mock transport used by gock.
|
||||
DefaultTransport = NewTransport()
|
||||
|
||||
// NativeTransport stores the native net/http default transport
|
||||
// in order to restore it when needed.
|
||||
NativeTransport = http.DefaultTransport
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrCannotMatch store the error returned in case of no matches.
|
||||
ErrCannotMatch = errors.New("gock: cannot match any request")
|
||||
)
|
||||
|
||||
// Transport implements http.RoundTripper, which fulfills single http requests issued by
|
||||
// an http.Client.
|
||||
//
|
||||
// gock's Transport encapsulates a given or default http.Transport for further
|
||||
// delegation, if needed.
|
||||
type Transport struct {
|
||||
// mutex is used to make transport thread-safe of concurrent uses across goroutines.
|
||||
mutex sync.Mutex
|
||||
|
||||
// Transport encapsulates the original http.RoundTripper transport interface for delegation.
|
||||
Transport http.RoundTripper
|
||||
}
|
||||
|
||||
// NewTransport creates a new *Transport with no responders.
|
||||
func NewTransport() *Transport {
|
||||
return &Transport{Transport: NativeTransport}
|
||||
}
|
||||
|
||||
// RoundTrip receives HTTP requests and routes them to the appropriate responder. It is required to
|
||||
// implement the http.RoundTripper interface. You will not interact with this directly, instead
|
||||
// the *http.Client you are using will call it for you.
|
||||
func (m *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Just act as a proxy if not intercepting
|
||||
if !Intercepting() {
|
||||
return m.Transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
defer Clean()
|
||||
|
||||
var err error
|
||||
var res *http.Response
|
||||
|
||||
// Match mock for the incoming http.Request
|
||||
mock, err := MatchMock(req)
|
||||
if err != nil {
|
||||
m.mutex.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Invoke the observer with the intercepted http.Request and matched mock
|
||||
if config.Observer != nil {
|
||||
config.Observer(req, mock)
|
||||
}
|
||||
|
||||
// Verify if should use real networking
|
||||
networking := shouldUseNetwork(req, mock)
|
||||
if !networking && mock == nil {
|
||||
m.mutex.Unlock()
|
||||
trackUnmatchedRequest(req)
|
||||
return nil, ErrCannotMatch
|
||||
}
|
||||
|
||||
// Ensure me unlock the mutex before building the response
|
||||
m.mutex.Unlock()
|
||||
|
||||
// Perform real networking via original transport
|
||||
if networking {
|
||||
res, err = m.Transport.RoundTrip(req)
|
||||
// In no mock matched, continue with the response
|
||||
if err != nil || mock == nil {
|
||||
return res, err
|
||||
}
|
||||
}
|
||||
|
||||
return Responder(req, mock.Response(), res)
|
||||
}
|
||||
|
||||
// CancelRequest is a no-op function.
|
||||
func (m *Transport) CancelRequest(req *http.Request) {}
|
||||
|
||||
func shouldUseNetwork(req *http.Request, mock Mock) bool {
|
||||
if mock != nil && mock.Response().UseNetwork {
|
||||
return true
|
||||
}
|
||||
if !config.Networking {
|
||||
return false
|
||||
}
|
||||
if len(config.NetworkingFilters) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, filter := range config.NetworkingFilters {
|
||||
if !filter(req) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
package gock
|
||||
|
||||
// Version defines the current package semantic version.
|
||||
const Version = "1.0.14"
|
||||
Loading…
Reference in New Issue