Compare commits

..

6 Commits

Author SHA1 Message Date
Zach Reyes 5b67e5ea44
Update version.go to v1.56.1 (#6386) 2023-06-21 15:17:20 -04:00
Zach Reyes d0f5150384
client: handle empty address lists correctly in addrConn.updateAddrs (#6354) (#6385)
Co-authored-by: Doug Fawley <dfawley@google.com>
2023-06-21 14:56:36 -04:00
Doug Fawley 997c1ea101
Change version to 1.56.1-dev (#6345)
Co-authored-by: Arvind Bright <arvind.bright100@gmail.com>
2023-06-16 09:04:45 -05:00
Doug Fawley 2b6ff72f08
Change version to 1.56.0 (#6344)
Co-authored-by: Arvind Bright <arvind.bright100@gmail.com>
2023-06-15 14:30:56 -05:00
Zach Reyes 799642536e
xds/outlierdetection: fix config handling (#6361) (#6367) 2023-06-09 20:00:03 -04:00
Doug Fawley a5ae5c6408
weightedroundrobin: cherry-pick 2 commits from master (#6360) 2023-06-07 15:52:01 -07:00
1153 changed files with 62709 additions and 126290 deletions

25
.github/codecov.yml vendored
View File

@ -1,25 +0,0 @@
coverage:
status:
project:
default:
informational: true
patch:
default:
informational: true
ignore:
# All 'pb.go's.
- "**/*.pb.go"
# Tests and test related files.
- "**/test"
- "**/testdata"
- "**/testutils"
- "benchmark"
- "interop"
# Other submodules.
- "cmd"
- "examples"
- "gcp"
- "security"
- "stats/opencensus"
comment:
layout: "header, diff, files"

21
.github/mergeable.yml vendored Normal file
View File

@ -0,0 +1,21 @@
version: 2
mergeable:
- when: pull_request.*
validate:
- do: label
must_include:
regex: '^Type:'
- do: description
must_include:
# Allow:
# RELEASE NOTES: none (case insensitive)
#
# RELEASE NOTES: N/A (case insensitive)
#
# RELEASE NOTES:
# * <text>
regex: '^RELEASE NOTES:\s*([Nn][Oo][Nn][Ee]|[Nn]/[Aa]|\n(\*|-)\s*.+)$'
regex_flag: 'm'
- do: milestone
must_include:
regex: 'Release$'

View File

@ -1,4 +0,0 @@
Thank you for your PR. Please read and follow
https://github.com/grpc/grpc-go/blob/master/CONTRIBUTING.md, especially the
"Guidelines for Pull Requests" section, and then delete this text before
entering your PR description.

View File

@ -8,6 +8,9 @@ on:
permissions:
contents: read
security-events: write
pull-requests: read
actions: read
jobs:
analyze:
@ -15,17 +18,12 @@ jobs:
runs-on: ubuntu-latest
timeout-minutes: 30
permissions:
security-events: write
pull-requests: read
actions: read
strategy:
fail-fast: false
steps:
- name: Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@v2
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL

View File

@ -1,29 +0,0 @@
name: codecov
on: [push, pull_request]
permissions:
contents: read
jobs:
upload:
runs-on: ubuntu-latest
steps:
- name: Install checkout
uses: actions/checkout@v4
- name: Install checkout
uses: actions/setup-go@v5
with:
go-version: "stable"
- name: Run coverage
run: go test -coverprofile=coverage.out -coverpkg=./... ./...
- name: Run coverage with old pickfirst
run: GRPC_EXPERIMENTAL_ENABLE_NEW_PICK_FIRST=false go test -coverprofile=coverage_old_pickfirst.out -coverpkg=./... ./...
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
fail_ci_if_error: true

View File

@ -1,60 +0,0 @@
name: Dependency Changes
# Trigger on PRs.
on:
pull_request:
permissions:
contents: read
jobs:
# Compare dependencies before and after this PR.
dependencies:
runs-on: ubuntu-latest
timeout-minutes: 10
strategy:
fail-fast: true
steps:
- name: Checkout repo
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: stable
cache-dependency-path: "**/*go.sum"
# Run the commands to generate dependencies before and after and compare.
- name: Compare dependencies
run: |
set -eu
TEMP_DIR="$(mktemp -d)"
# GITHUB_BASE_REF is set when the job is triggered by a PR.
TARGET_REF="${GITHUB_BASE_REF:-master}"
mkdir "${TEMP_DIR}/after"
scripts/gen-deps.sh "${TEMP_DIR}/after"
git checkout "origin/${TARGET_REF}"
mkdir "${TEMP_DIR}/before"
scripts/gen-deps.sh "${TEMP_DIR}/before"
echo -e " \nComparing dependencies..."
cd "${TEMP_DIR}"
# Run grep in a sub-shell since bash does not support ! in the middle of a pipe.
if diff -u0 -r "before" "after" | bash -c '! grep -v "@@"'; then
echo "No changes detected."
exit 0
fi
# Print packages in `after` but not `before`.
for x in $(ls -1 after | grep -vF "$(ls -1 before)"); do
echo -e " \nDependencies of new package $x:"
cat "after/$x"
done
echo -e " \nChanges detected; exiting with error."
exit 1

View File

@ -6,17 +6,15 @@ on:
- cron: '22 1 * * *'
permissions:
contents: read
issues: write
pull-requests: write
jobs:
lock:
runs-on: ubuntu-latest
permissions:
issues: write
pull-requests: write
steps:
- uses: dessant/lock-threads@v5
- uses: dessant/lock-threads@v2
with:
github-token: ${{ github.token }}
issue-inactive-days: 180
pr-inactive-days: 180
issue-lock-inactive-days: 180
pr-lock-inactive-days: 180

View File

@ -1,55 +0,0 @@
name: PR Validation
on:
pull_request:
types: [opened, edited, synchronize, labeled, unlabeled, milestoned, demilestoned]
permissions:
contents: read
jobs:
validate:
name: Validate PR
runs-on: ubuntu-latest
steps:
- name: Validate Label
uses: actions/github-script@v6
with:
script: |
const labels = context.payload.pull_request.labels.map(label => label.name);
const requiredRegex = new RegExp('^Type:');
const hasRequiredLabel = labels.some(label => requiredRegex.test(label));
if (!hasRequiredLabel) {
core.setFailed("This PR must have a label starting with 'Type:'.");
}
- name: Validate Description
uses: actions/github-script@v6
with:
script: |
const body = context.payload.pull_request.body;
const requiredRegex = new RegExp('^RELEASE NOTES:\\s*([Nn][Oo][Nn][Ee]|[Nn]/[Aa]|\\n(\\*|-)\\s*.+)$', 'm');
if (!requiredRegex.test(body)) {
core.setFailed(`
The PR description must include a RELEASE NOTES section.
It should be in one of the following formats:
- "RELEASE NOTES: none" (case-insensitive)
- "RELEASE NOTES: N/A" (case-insensitive)
- A bulleted list under "RELEASE NOTES:", for example:
RELEASE NOTES:
* my_package: Fix bug causing crash...
`);
}
- name: Validate Milestone
uses: actions/github-script@v6
with:
script: |
const milestone = context.payload.pull_request.milestone;
if (!milestone) {
core.setFailed("This PR must be associated with a milestone.");
} else {
const requiredRegex = new RegExp('Release$');
if (!requiredRegex.test(milestone.title)) {
core.setFailed("The milestone for this PR must end with 'Release'.");
}
}

View File

@ -4,9 +4,6 @@ on:
release:
types: [published]
permissions:
contents: read
jobs:
release:
permissions:
@ -25,10 +22,10 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v2
- name: Set up Go
uses: actions/setup-go@v5
uses: actions/setup-go@v2
- name: Download dependencies
run: |
@ -54,10 +51,14 @@ jobs:
run: |
PACKAGE_NAME=protoc-gen-go-grpc.${GITHUB_REF#refs/tags/cmd/protoc-gen-go-grpc/}.${{ matrix.goos }}.${{ matrix.goarch }}.tar.gz
tar -czvf $PACKAGE_NAME -C build .
echo "name=${PACKAGE_NAME}" >> $GITHUB_OUTPUT
echo ::set-output name=name::${PACKAGE_NAME}
- name: Upload asset
run: |
gh release upload ${{ github.event.release.tag_name }} ./${{ steps.package.outputs.name }}
uses: actions/upload-release-asset@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ github.event.release.upload_url }}
asset_path: ./${{ steps.package.outputs.name }}
asset_name: ${{ steps.package.outputs.name }}
asset_content_type: application/gzip

View File

@ -5,9 +5,6 @@ on:
schedule:
- cron: "44 */2 * * *"
permissions:
contents: read
jobs:
stale:
runs-on: ubuntu-latest
@ -16,7 +13,7 @@ jobs:
pull-requests: write
steps:
- uses: actions/stale@v8
- uses: actions/stale@v4
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
days-before-stale: 6

View File

@ -19,60 +19,58 @@ jobs:
vet-proto:
runs-on: ubuntu-latest
timeout-minutes: 20
env:
VET_ONLY_PROTO: 1
steps:
- name: Checkout repo
uses: actions/checkout@v4
# Setup the environment.
- name: Setup Go
uses: actions/setup-go@v5
uses: actions/setup-go@v3
with:
go-version: '1.25'
cache-dependency-path: "**/go.sum"
go-version: '1.20'
- name: Checkout repo
uses: actions/checkout@v3
# Run the vet-proto checks.
- name: vet-proto
run: ./scripts/vet-proto.sh -install && ./scripts/vet-proto.sh
# Run the vet checks.
- name: vet
run: ./vet.sh -install && ./vet.sh
# Run the main gRPC-Go tests.
tests:
# Use the matrix variable to set the runner, with 'ubuntu-latest' as the
# default.
runs-on: ${{ matrix.runner || 'ubuntu-latest' }}
# Proto checks are run in the above job.
env:
VET_SKIP_PROTO: 1
runs-on: ubuntu-latest
timeout-minutes: 20
strategy:
fail-fast: false
matrix:
include:
- type: vet
goversion: '1.24'
- type: extras
goversion: '1.25'
goversion: '1.20'
- type: tests
goversion: '1.25'
goversion: '1.20'
- type: tests
goversion: '1.25'
goversion: '1.20'
testflags: -race
- type: tests
goversion: '1.25'
goversion: '1.20'
goarch: 386
- type: tests
goversion: '1.25'
goversion: '1.20'
goarch: arm64
runner: ubuntu-24.04-arm
- type: tests
goversion: '1.24'
goversion: '1.19'
- type: tests
goversion: '1.25'
testflags: -race
grpcenv: 'GRPC_EXPERIMENTAL_ENABLE_NEW_PICK_FIRST=false'
goversion: '1.18'
- type: extras
goversion: '1.20'
steps:
# Setup the environment.
@ -80,23 +78,28 @@ jobs:
if: matrix.goarch != ''
run: echo "GOARCH=${{ matrix.goarch }}" >> $GITHUB_ENV
- name: Setup qemu emulator
if: matrix.goarch == 'arm64'
# setup qemu-user-static emulator and register it with binfmt_misc so that aarch64 binaries
# are automatically executed using qemu.
run: docker run --rm --privileged multiarch/qemu-user-static:5.2.0-2 --reset --credential yes --persistent yes
- name: Setup GRPC environment
if: matrix.grpcenv != ''
run: echo "${{ matrix.grpcenv }}" >> $GITHUB_ENV
- name: Checkout repo
uses: actions/checkout@v4
- name: Setup Go
uses: actions/setup-go@v5
uses: actions/setup-go@v3
with:
go-version: ${{ matrix.goversion }}
cache-dependency-path: "**/*go.sum"
- name: Checkout repo
uses: actions/checkout@v3
# Only run vet for 'vet' runs.
- name: Run vet.sh
if: matrix.type == 'vet'
run: ./scripts/vet.sh -install && ./scripts/vet.sh
run: ./vet.sh -install && ./vet.sh
# Main tests run for everything except when testing "extras"
# (where we run a reduced set of tests).
@ -104,7 +107,7 @@ jobs:
if: matrix.type == 'tests'
run: |
go version
go test ${{ matrix.testflags }} -cpu 1,4 -timeout 7m ./...
go test ${{ matrix.testflags }} -cpu 1,4 -timeout 7m google.golang.org/grpc/...
cd "${GITHUB_WORKSPACE}"
for MOD_FILE in $(find . -name 'go.mod' | grep -Ev '^\./go\.mod'); do
pushd "$(dirname ${MOD_FILE})"
@ -125,7 +128,4 @@ jobs:
echo -e "\n-- Running Interop Test --"
interop/interop_test.sh
echo -e "\n-- Running xDS E2E Test --"
internal/xds/test/e2e/run.sh
echo -e "\n-- Running protoc-gen-go-grpc test --"
./scripts/vet-proto.sh -install
cmd/protoc-gen-go-grpc/protoc-gen-go-grpc_test.sh
xds/internal/test/e2e/run.sh

View File

@ -1,159 +1,73 @@
# How to contribute
We welcome your patches and contributions to gRPC! Please read the gRPC
organization's [governance
rules](https://github.com/grpc/grpc-community/blob/master/governance.md) before
proceeding.
We definitely welcome your patches and contributions to gRPC! Please read the gRPC
organization's [governance rules](https://github.com/grpc/grpc-community/blob/master/governance.md)
and [contribution guidelines](https://github.com/grpc/grpc-community/blob/master/CONTRIBUTING.md) before proceeding.
If you are new to GitHub, please start by reading [Pull Request howto](https://help.github.com/articles/about-pull-requests/)
If you are new to github, please start by reading [Pull Request howto](https://help.github.com/articles/about-pull-requests/)
## Legal requirements
In order to protect both you and ourselves, you will need to sign the
[Contributor License
Agreement](https://identity.linuxfoundation.org/projects/cncf). When you create
your first PR, a link will be added as a comment that contains the steps needed
to complete this process.
## Getting Started
A great way to start is by searching through our open issues. [Unassigned issues
labeled as "help
wanted"](https://github.com/grpc/grpc-go/issues?q=sort%3Aupdated-desc%20is%3Aissue%20is%3Aopen%20label%3A%22Status%3A%20Help%20Wanted%22%20no%3Aassignee)
are especially nice for first-time contributors, as they should be well-defined
problems that already have agreed-upon solutions.
## Code Style
We follow [Google's published Go style
guide](https://google.github.io/styleguide/go/). Note that there are three
primary documents that make up this style guide; please follow them as closely
as possible. If a reviewer recommends something that contradicts those
guidelines, there may be valid reasons to do so, but it should be rare.
[Contributor License Agreement](https://identity.linuxfoundation.org/projects/cncf).
## Guidelines for Pull Requests
Please read the following carefully to ensure your contributions can be merged
smoothly and quickly.
### PR Contents
How to get your contributions merged smoothly and quickly.
- Create **small PRs** that are narrowly focused on **addressing a single
concern**. We often receive PRs that attempt to fix several things at the same
time, and if one part of the PR has a problem, that will hold up the entire
PR.
concern**. We often times receive PRs that are trying to fix several things at
a time, but only one fix is considered acceptable, nothing gets merged and
both author's & review's time is wasted. Create more PRs to address different
concerns and everyone will be happy.
- If your change does not address an **open issue** with an **agreed
resolution**, consider opening an issue and discussing it first. If you are
suggesting a behavioral or API change, consider starting with a [gRFC
proposal](https://github.com/grpc/proposal). Many new features that are not
bug fixes will require cross-language agreement.
- If you are searching for features to work on, issues labeled [Status: Help
Wanted](https://github.com/grpc/grpc-go/issues?q=is%3Aissue+is%3Aopen+sort%3Aupdated-desc+label%3A%22Status%3A+Help+Wanted%22)
is a great place to start. These issues are well-documented and usually can be
resolved with a single pull request.
- If you want to fix **formatting or style**, consider whether your changes are
an obvious improvement or might be considered a personal preference. If a
style change is based on preference, it likely will not be accepted. If it
corrects widely agreed-upon anti-patterns, then please do create a PR and
explain the benefits of the change.
- For correcting **misspellings**, please be aware that we use some terms that
are sometimes flagged by spell checkers. As an example, "if an only if" is
often written as "iff". Please do not make spelling correction changes unless
you are certain they are misspellings.
- **All tests need to be passing** before your change can be merged. We
recommend you run tests locally before creating your PR to catch breakages
early on:
- `./scripts/vet.sh` to catch vet errors.
- `go test -cpu 1,4 -timeout 7m ./...` to run the tests.
- `go test -race -cpu 1,4 -timeout 7m ./...` to run tests in race mode.
Note that we have a multi-module repo, so `go test` commands may need to be
run from the root of each module in order to cause all tests to run.
*Alternatively*, you may find it easier to push your changes to your fork on
GitHub, which will trigger a GitHub Actions run that you can use to verify
everything is passing.
- Note that there are two github actions checks that need not be green:
1. We test the freshness of the generated proto code we maintain via the
`vet-proto` check. If the source proto files are updated, but our repo is
not updated, an optional checker will fail. This will be fixed by our team
in a separate PR and will not prevent the merge of your PR.
2. We run a checker that will fail if there is any change in dependencies of
an exported package via the `dependencies` check. If new dependencies are
added that are not appropriate, we may not accept your PR (see below).
- If you are adding a **new file**, make sure it has the **copyright message**
template at the top as a comment. You can copy the message from an existing
file and update the year.
- If you are adding a new file, make sure it has the copyright message template
at the top as a comment. You can copy over the message from an existing file
and update the year.
- The grpc package should only depend on standard Go packages and a small number
of exceptions. **If your contribution introduces new dependencies**, you will
need a discussion with gRPC-Go maintainers.
of exceptions. If your contribution introduces new dependencies which are NOT
in the [list](https://godoc.org/google.golang.org/grpc?imports), you need a
discussion with gRPC-Go authors and consultants.
### PR Descriptions
- For speculative changes, consider opening an issue and discussing it first. If
you are suggesting a behavioral or API change, consider starting with a [gRFC
proposal](https://github.com/grpc/proposal).
- **PR titles** should start with the name of the component being addressed, or
the type of change. Examples: transport, client, server, round_robin, xds,
cleanup, deps.
- Provide a good **PR description** as a record of **what** change is being made
and **why** it was made. Link to a github issue if it exists.
- Read and follow the **guidelines for PR titles and descriptions** here:
https://google.github.io/eng-practices/review/developer/cl-descriptions.html
- If you want to fix formatting or style, consider whether your changes are an
obvious improvement or might be considered a personal preference. If a style
change is based on preference, it likely will not be accepted. If it corrects
widely agreed-upon anti-patterns, then please do create a PR and explain the
benefits of the change.
*particularly* the sections "First Line" and "Body is Informative".
Note: your PR description will be used as the git commit message in a
squash-and-merge if your PR is approved. We may make changes to this as
necessary.
- **Does this PR relate to an open issue?** On the first line, please use the
tag `Fixes #<issue>` to ensure the issue is closed when the PR is merged. Or
use `Updates #<issue>` if the PR is related to an open issue, but does not fix
it. Consider filing an issue if one does not already exist.
- PR descriptions *must* conclude with **release notes** as follows:
```
RELEASE NOTES:
* <component>: <summary>
```
This need not match the PR title.
The summary must:
* be something that gRPC users will understand.
* clearly explain the feature being added, the issue being fixed, or the
behavior being changed, etc. If fixing a bug, be clear about how the bug
can be triggered by an end-user.
* begin with a capital letter and use complete sentences.
* be as short as possible to describe the change being made.
If a PR is *not* end-user visible -- e.g. a cleanup, testing change, or
github-related, use `RELEASE NOTES: n/a`.
### PR Process
- Please **self-review** your code changes before sending your PR. This will
prevent simple, obvious errors from causing delays.
- Maintain a **clean commit history** and use **meaningful commit messages**.
PRs with messy commit histories are difficult to review and won't be merged.
Before sending your PR, ensure your changes are based on top of the latest
`upstream/master` commits, and avoid rebasing in the middle of a code review.
You should **never use `git push -f`** unless absolutely necessary during a
review, as it can interfere with GitHub's tracking of comments.
- Unless your PR is trivial, you should **expect reviewer comments** that you
will need to address before merging. We'll label the PR as `Status: Requires
- Unless your PR is trivial, you should expect there will be reviewer comments
that you'll need to address before merging. We'll mark it as `Status: Requires
Reporter Clarification` if we expect you to respond to these comments in a
timely manner. If the PR remains inactive for 6 days, it will be marked as
`stale`, and we will automatically close it after 7 days if we don't hear back
from you. Please feel free to ping issues or bugs if you do not get a response
within a week.
`stale` and automatically close 7 days after that if we don't hear back from
you.
- Maintain **clean commit history** and use **meaningful commit messages**. PRs
with messy commit history are difficult to review and won't be merged. Use
`rebase -i upstream/master` to curate your commit history and/or to bring in
latest changes from master (but avoid rebasing in the middle of a code
review).
- Keep your PR up to date with upstream/master (if there are merge conflicts, we
can't really merge your change).
- **All tests need to be passing** before your change can be merged. We
recommend you **run tests locally** before creating your PR to catch breakages
early on.
- `VET_SKIP_PROTO=1 ./vet.sh` to catch vet errors
- `go test -cpu 1,4 -timeout 7m ./...` to run the tests
- `go test -race -cpu 1,4 -timeout 7m ./...` to run tests in race mode
- Exceptions to the rules can be made if there's a compelling reason for doing so.

View File

@ -1,97 +1,103 @@
## Anti-Patterns of Client creation
## Anti-Patterns
### How to properly create a `ClientConn`: `grpc.NewClient`
[`grpc.NewClient`](https://pkg.go.dev/google.golang.org/grpc#NewClient) is the
function in the gRPC library that creates a virtual connection from a client
application to a gRPC server. It takes a target URI (which represents the name
of a logical backend service and resolves to one or more physical addresses) and
a list of options, and returns a
### Dialing in gRPC
[`grpc.Dial`](https://pkg.go.dev/google.golang.org/grpc#Dial) is a function in
the gRPC library that creates a virtual connection from the gRPC client to the
gRPC server. It takes a target URI (which can represent the name of a logical
backend service and could resolve to multiple actual addresses) and a list of
options, and returns a
[`ClientConn`](https://pkg.go.dev/google.golang.org/grpc#ClientConn) object that
represents the virtual connection to the server. The `ClientConn` contains one
or more actual connections to real servers and attempts to maintain these
connections by automatically reconnecting to them when they break. `NewClient`
was introduced in gRPC-Go v1.63.
represents the connection to the server. The `ClientConn` contains one or more
actual connections to real server backends and attempts to keep these
connections healthy by automatically reconnecting to them when they break.
### The wrong way: `grpc.Dial`
The `Dial` function can also be configured with various options to customize the
behavior of the client connection. For example, developers could use options
such a
[`WithTransportCredentials`](https://pkg.go.dev/google.golang.org/grpc#WithTransportCredentials)
to configure the transport credentials to use.
[`grpc.Dial`](https://pkg.go.dev/google.golang.org/grpc#Dial) is a deprecated
function that also creates the same virtual connection pool as `grpc.NewClient`.
However, unlike `grpc.NewClient`, it immediately starts connecting and supports
a few additional `DialOption`s that control this initial connection attempt.
These are: `WithBlock`, `WithTimeout`, `WithReturnConnectionError`, and
`FailOnNonTempDialError`.
While `Dial` is commonly referred to as a "dialing" function, it doesn't
actually perform the low-level network dialing operation like
[`net.Dial`](https://pkg.go.dev/net#Dial) would. Instead, it creates a virtual
connection from the gRPC client to the gRPC server.
That `grpc.Dial` creates connections immediately is not a problem in and of
itself, but this behavior differs from how gRPC works in all other languages,
and it can be convenient to have a constructor that does not perform I/O. It
can also be confusing to users, as most people expect a function called `Dial`
to create _a_ connection which may need to be recreated if it is lost.
`Dial` does initiate the process of connecting to the server, but it uses the
ClientConn object to manage and maintain that connection over time. This is why
errors encountered during the initial connection are no different from those
that occur later on, and why it's important to handle errors from RPCs rather
than relying on options like
[`FailOnNonTempDialError`](https://pkg.go.dev/google.golang.org/grpc#FailOnNonTempDialError),
[`WithBlock`](https://pkg.go.dev/google.golang.org/grpc#WithBlock), and
[`WithReturnConnectionError`](https://pkg.go.dev/google.golang.org/grpc#WithReturnConnectionError).
In fact, `Dial` does not always establish a connection to servers by default.
The connection behavior is determined by the load balancing policy being used.
For instance, an "active" load balancing policy such as Round Robin attempts to
maintain a constant connection, while the default "pick first" policy delays
connection until an RPC is executed. Instead of using the WithBlock option, which
may not be recommended in some cases, you can call the
[`ClientConn.Connect`](https://pkg.go.dev/google.golang.org/grpc#ClientConn.Connect)
method to explicitly initiate a connection.
`grpc.Dial` uses "passthrough" as the default name resolver for backward
compatibility while `grpc.NewClient` uses "dns" as its default name resolver.
This subtle difference is important to legacy systems that also specified a
custom dialer and expected it to receive the target string directly.
### Using `FailOnNonTempDialError`, `WithBlock`, and `WithReturnConnectionError`
For these reasons, using `grpc.Dial` is discouraged. Even though it is marked
as deprecated, we will continue to support it until a v2 is released (and no
plans for a v2 exist at the time this was written).
The gRPC API provides several options that can be used to configure the behavior
of dialing and connecting to a gRPC server. Some of these options, such as
`FailOnNonTempDialError`, `WithBlock`, and `WithReturnConnectionError`, rely on
failures at dial time. However, we strongly discourage developers from using
these options, as they can introduce race conditions and result in unreliable
and difficult-to-debug code.
### Especially bad: using deprecated `DialOptions`
One of the most important reasons for avoiding these options, which is often
overlooked, is that connections can fail at any point in time. This means that
you need to handle RPC failures caused by connection issues, regardless of
whether a connection was never established in the first place, or if it was
created and then immediately lost. Implementing proper error handling for RPCs
is crucial for maintaining the reliability and stability of your gRPC
communication.
`FailOnNonTempDialError`, `WithBlock`, and `WithReturnConnectionError` are three
`DialOption`s that are only supported by `Dial` because they only affect the
behavior of `Dial` itself. `WithBlock` causes `Dial` to wait until the
`ClientConn` reports its `State` as `connectivity.Connected`. The other two deal
with returning connection errors before the timeout (`WithTimeout` or on the
context when using `DialContext`).
### Why we discourage using `FailOnNonTempDialError`, `WithBlock`, and `WithReturnConnectionError`
The reason these options can be a problem is that connections with a
`ClientConn` are dynamic -- they may come and go over time. If your client
successfully connects, the server could go down 1 second later, and your RPCs
will fail. "Knowing you are connected" does not tell you much in this regard.
When a client attempts to connect to a gRPC server, it can encounter a variety
of errors, including network connectivity issues, server-side errors, and
incorrect usage of the gRPC API. The options `FailOnNonTempDialError`,
`WithBlock`, and `WithReturnConnectionError` are designed to handle some of
these errors, but they do so by relying on failures at dial time. This means
that they may not provide reliable or accurate information about the status of
the connection.
Additionally, _all_ RPCs created on an "idle" or a "connecting" `ClientConn`
will wait until their deadline or until a connection is established before
failing. This means that you don't need to check that a `ClientConn` is "ready"
before starting your RPCs. By default, RPCs will fail if the `ClientConn`
enters the "transient failure" state, but setting `WaitForReady(true)` on a
call will cause it to queue even in the "transient failure" state, and it will
only ever fail due to a deadline, a server response, or a connection loss after
the RPC was sent to a server.
Some users of `Dial` use it as a way to validate the configuration of their
system. If you wish to maintain this behavior but migrate to `NewClient`, you
can call `GetState`, then `Connect` if the state is `Idle` and
`WaitForStateChange` until the channel is connected. However, if this fails,
it does not mean that your configuration was bad - it could also mean the
service is not reachable by the client due to connectivity reasons.
For example, if a client uses `WithBlock` to wait for a connection to be
established, it may end up waiting indefinitely if the server is not responding.
Similarly, if a client uses `WithReturnConnectionError` to return a connection
error if dialing fails, it may miss opportunities to recover from transient
network issues that are resolved shortly after the initial dial attempt.
## Best practices for error handling in gRPC
Instead of relying on failures at dial time, we strongly encourage developers to
rely on errors from RPCs. When a client makes an RPC, it can receive an error
response from the server. These errors can provide valuable information about
rely on errors from RPCs. When a client makes an RPC, it can receive an error
response from the server. These errors can provide valuable information about
what went wrong, including information about network issues, server-side errors,
and incorrect usage of the gRPC API.
By handling errors from RPCs correctly, developers can write more reliable and
robust gRPC applications. Here are some best practices for error handling in
robust gRPC applications. Here are some best practices for error handling in
gRPC:
- Always check for error responses from RPCs and handle them appropriately.
- Use the `status` field of the error response to determine the type of error
that occurred.
- Always check for error responses from RPCs and handle them appropriately.
- Use the `status` field of the error response to determine the type of error that
occurred.
- When retrying failed RPCs, consider using the built-in retry mechanism
provided by gRPC-Go, if available, instead of manually implementing retries.
Refer to the [gRPC-Go retry example
documentation](https://github.com/grpc/grpc-go/blob/master/examples/features/retry/README.md)
for more information. Note that this is not a substitute for client-side
retries as errors that occur after an RPC starts on a server cannot be
retried through gRPC's built-in mechanism.
- If making an outgoing RPC from a server handler, be sure to translate the
status code before returning the error from your method handler. For example,
if the error is an `INVALID_ARGUMENT` status code, that probably means
for more information.
- Avoid using `FailOnNonTempDialError`, `WithBlock`, and
`WithReturnConnectionError`, as these options can introduce race conditions and
result in unreliable and difficult-to-debug code.
- If making the outgoing RPC in order to handle an incoming RPC, be sure to
translate the status code before returning the error from your method handler.
For example, if the error is an `INVALID_ARGUMENT` error, that probably means
your service has a bug (otherwise it shouldn't have triggered this error), in
which case `INTERNAL` is more appropriate to return back to your users.
@ -100,7 +106,7 @@ gRPC:
The following code snippet demonstrates how to handle errors from an RPC in
gRPC:
```go
```go
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
@ -112,72 +118,89 @@ if err != nil {
return nil, err
}
// Use the response as appropriate
// Use the response as appropriate
log.Printf("MyRPC response: %v", res)
```
To determine the type of error that occurred, you can use the status field of
the error response:
```go
resp, err := client.MakeRPC(context.TODO(), request)
resp, err := client.MakeRPC(context.Background(), request)
if err != nil {
if status, ok := status.FromError(err); ok {
// Handle the error based on its status code
status, ok := status.FromError(err)
if ok {
// Handle the error based on its status code
if status.Code() == codes.NotFound {
log.Println("Requested resource not found")
} else {
log.Printf("RPC error: %v", status.Message())
}
} else {
// Handle non-RPC errors
//Handle non-RPC errors
log.Printf("Non-RPC error: %v", err)
}
return
}
}
// Use the response as needed
log.Printf("Response received: %v", resp)
// Use the response as needed
log.Printf("Response received: %v", resp)
```
### Example: Using a backoff strategy
When retrying failed RPCs, use a backoff strategy to avoid overwhelming the
server or exacerbating network issues:
```go
```go
var res *MyResponse
var err error
retryableStatusCodes := map[codes.Code]bool{
codes.Unavailable: true, // etc
}
// If the user doesn't have a context with a deadline, create one
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
// Retry the RPC a maximum number of times.
// Retry the RPC call a maximum number of times
for i := 0; i < maxRetries; i++ {
// Make the RPC.
res, err = client.MyRPC(context.TODO(), &MyRequest{})
// Check if the RPC was successful.
if !retryableStatusCodes[status.Code(err)] {
// The RPC was successful or errored in a non-retryable way;
// do not retry.
// Make the RPC call
res, err = client.MyRPC(ctx, &MyRequest{})
// Check if the RPC call was successful
if err == nil {
// The RPC was successful, so break out of the loop
break
}
// The RPC is retryable; wait for a backoff period before retrying.
backoff := time.Duration(i+1) * time.Second
// The RPC failed, so wait for a backoff period before retrying
backoff := time.Duration(i) * time.Second
log.Printf("Error calling MyRPC: %v; retrying in %v", err, backoff)
time.Sleep(backoff)
}
// Check if the RPC was successful after all retries.
// Check if the RPC call was successful after all retries
if err != nil {
// All retries failed, so handle the error appropriately
log.Printf("Error calling MyRPC: %v", err)
return nil, err
}
// Use the response as appropriate.
// Use the response as appropriate
log.Printf("MyRPC response: %v", res)
```
## Conclusion
The
[`FailOnNonTempDialError`](https://pkg.go.dev/google.golang.org/grpc#FailOnNonTempDialError),
[`WithBlock`](https://pkg.go.dev/google.golang.org/grpc#WithBlock), and
[`WithReturnConnectionError`](https://pkg.go.dev/google.golang.org/grpc#WithReturnConnectionError)
options are designed to handle errors at dial time, but they can introduce race
conditions and result in unreliable and difficult-to-debug code. Instead of
relying on these options, we strongly encourage developers to rely on errors
from RPCs for error handling. By following best practices for error handling in
gRPC, developers can write more reliable and robust gRPC applications.

View File

@ -13,9 +13,9 @@ simulate your application:
```bash
$ go run google.golang.org/grpc/benchmark/benchmain/main.go \
-workloads=streaming \
-reqSizeBytes=1024 \
-respSizeBytes=1024 \
-compression=gzip
-reqSizeBytes=1024 \
-respSizeBytes=1024 \
-compression=gzip
```
Pass the `-h` flag to the `benchmain` utility to see other flags and workloads
@ -45,8 +45,8 @@ Assume that `benchmain` is invoked like so:
```bash
$ go run google.golang.org/grpc/benchmark/benchmain/main.go \
-workloads=unary \
-reqPayloadCurveFiles=/path/to/csv \
-respPayloadCurveFiles=/path/to/csv
-reqPayloadCurveFiles=/path/to/csv \
-respPayloadCurveFiles=/path/to/csv
```
This tells the `benchmain` utility to generate unary RPC requests with a 25%
@ -61,8 +61,8 @@ following command will execute four benchmarks:
```bash
$ go run google.golang.org/grpc/benchmark/benchmain/main.go \
-workloads=unary \
-reqPayloadCurveFiles=/path/to/csv1,/path/to/csv2 \
-respPayloadCurveFiles=/path/to/csv3,/path/to/csv4
-reqPayloadCurveFiles=/path/to/csv1,/path/to/csv2 \
-respPayloadCurveFiles=/path/to/csv3,/path/to/csv4
```
You may also combine `PayloadCurveFiles` with `SizeBytes` options. For example:
@ -70,6 +70,6 @@ You may also combine `PayloadCurveFiles` with `SizeBytes` options. For example:
```
$ go run google.golang.org/grpc/benchmark/benchmain/main.go \
-workloads=unary \
-reqPayloadCurveFiles=/path/to/csv \
-respSizeBytes=1
-reqPayloadCurveFiles=/path/to/csv \
-respSizeBytes=1
```

View File

@ -22,7 +22,7 @@ package proto
import "google.golang.org/grpc/encoding"
func init() {
encoding.RegisterCodec(protoCodec{})
encoding.RegisterCodec(protoCodec{})
}
// ... implementation of protoCodec ...
@ -50,14 +50,14 @@ On the client-side, to specify a `Codec` to use for message transmission, the
`CallOption` `CallContentSubtype` should be used as follows:
```go
response, err := myclient.MyCall(ctx, request, grpc.CallContentSubtype("mycodec"))
response, err := myclient.MyCall(ctx, request, grpc.CallContentSubtype("mycodec"))
```
As a reminder, all `CallOption`s may be converted into `DialOption`s that become
the default for all RPCs sent through a client using `grpc.WithDefaultCallOptions`:
```go
myclient := grpc.NewClient(target, grpc.WithDefaultCallOptions(grpc.CallContentSubtype("mycodec")))
myclient := grpc.Dial(ctx, target, grpc.WithDefaultCallOptions(grpc.CallContentSubtype("mycodec")))
```
When specified in either of these ways, messages will be encoded using this
@ -83,7 +83,7 @@ performing compression and decompression.
A `Compressor` contains code to compress and decompress by wrapping `io.Writer`s
and `io.Reader`s, respectively. (The form of `Compress` and `Decompress` were
chosen to most closely match Go's standard package
[implementations](https://golang.org/pkg/compress/) of compressors). Like
[implementations](https://golang.org/pkg/compress/) of compressors. Like
`Codec`s, `Compressor`s are registered by name into a global registry maintained
in the `encoding` package.
@ -98,7 +98,7 @@ package gzip
import "google.golang.org/grpc/encoding"
func init() {
encoding.RegisterCompressor(compressor{})
encoding.RegisterCompressor(compressor{})
}
// ... implementation of compressor ...
@ -125,14 +125,14 @@ On the client-side, to specify a `Compressor` to use for message transmission,
the `CallOption` `UseCompressor` should be used as follows:
```go
response, err := myclient.MyCall(ctx, request, grpc.UseCompressor("gzip"))
response, err := myclient.MyCall(ctx, request, grpc.UseCompressor("gzip"))
```
As a reminder, all `CallOption`s may be converted into `DialOption`s that become
the default for all RPCs sent through a client using `grpc.WithDefaultCallOptions`:
```go
myclient := grpc.NewClient(target, grpc.WithDefaultCallOptions(grpc.UseCompressor("gzip")))
myclient := grpc.Dial(ctx, target, grpc.WithDefaultCallOptions(grpc.UseCompressor("gzip")))
```
When specified in either of these ways, messages will be compressed using this

View File

@ -1,11 +1,11 @@
# Authentication
As outlined in the [gRPC authentication guide](https://grpc.io/docs/guides/auth.html) there are a number of different mechanisms for asserting identity between a client and server. We'll present some code-samples here demonstrating how to provide TLS support encryption and identity assertions as well as passing OAuth2 tokens to services that support it.
As outlined in the [gRPC authentication guide](https://grpc.io/docs/guides/auth.html) there are a number of different mechanisms for asserting identity between an client and server. We'll present some code-samples here demonstrating how to provide TLS support encryption and identity assertions as well as passing OAuth2 tokens to services that support it.
# Enabling TLS on a gRPC client
```Go
conn, err := grpc.NewClient(serverAddr, grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(nil, "")))
conn, err := grpc.Dial(serverAddr, grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(nil, "")))
```
# Enabling TLS on a gRPC server
@ -63,7 +63,7 @@ to prevent any insecure transmission of tokens.
## Google Compute Engine (GCE)
```Go
conn, err := grpc.NewClient(serverAddr, grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(nil, "")), grpc.WithPerRPCCredentials(oauth.NewComputeEngine()))
conn, err := grpc.Dial(serverAddr, grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(nil, "")), grpc.WithPerRPCCredentials(oauth.NewComputeEngine()))
```
## JWT
@ -73,6 +73,6 @@ jwtCreds, err := oauth.NewServiceAccountFromFile(*serviceAccountKeyFile, *oauthS
if err != nil {
log.Fatalf("Failed to create JWT credentials: %v", err)
}
conn, err := grpc.NewClient(serverAddr, grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(nil, "")), grpc.WithPerRPCCredentials(jwtCreds))
conn, err := grpc.Dial(serverAddr, grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(nil, "")), grpc.WithPerRPCCredentials(jwtCreds))
```

View File

@ -12,11 +12,11 @@ Four kinds of service method:
- [Client streaming RPC](https://grpc.io/docs/guides/concepts.html#client-streaming-rpc)
- [Bidirectional streaming RPC](https://grpc.io/docs/guides/concepts.html#bidirectional-streaming-rpc)
And concept of [metadata].
And concept of [metadata](https://grpc.io/docs/guides/concepts.html#metadata).
## Constructing metadata
A metadata can be created using package [metadata].
A metadata can be created using package [metadata](https://godoc.org/google.golang.org/grpc/metadata).
The type MD is actually a map from string to a list of strings:
```go
@ -64,10 +64,20 @@ md := metadata.Pairs(
)
```
## Retrieving metadata from context
Metadata can be retrieved from context using `FromIncomingContext`:
```go
func (s *server) SomeRPC(ctx context.Context, in *pb.SomeRequest) (*pb.SomeResponse, err) {
md, ok := metadata.FromIncomingContext(ctx)
// do something with metadata
}
```
## Sending and receiving metadata - client side
Client side metadata sending and receiving examples are available
[here](../examples/features/metadata/client/main.go).
Client side metadata sending and receiving examples are available [here](../examples/features/metadata/client/main.go).
### Sending metadata
@ -117,8 +127,7 @@ Metadata that a client can receive includes header and trailer.
#### Unary call
Header and trailer sent along with a unary call can be retrieved using function
[Header] and [Trailer] in [CallOption]:
Header and trailer sent along with a unary call can be retrieved using function [Header](https://godoc.org/google.golang.org/grpc#Header) and [Trailer](https://godoc.org/google.golang.org/grpc#Trailer) in [CallOption](https://godoc.org/google.golang.org/grpc#CallOption):
```go
var header, trailer metadata.MD // variable to store header and trailer
@ -140,8 +149,7 @@ For streaming calls including:
- Client streaming RPC
- Bidirectional streaming RPC
Header and trailer can be retrieved from the returned stream using function
`Header` and `Trailer` in interface [ClientStream]:
Header and trailer can be retrieved from the returned stream using function `Header` and `Trailer` in interface [ClientStream](https://godoc.org/google.golang.org/grpc#ClientStream):
```go
stream, err := client.SomeStreamingRPC(ctx)
@ -156,13 +164,11 @@ trailer := stream.Trailer()
## Sending and receiving metadata - server side
Server side metadata sending and receiving examples are available
[here](../examples/features/metadata/server/main.go).
Server side metadata sending and receiving examples are available [here](../examples/features/metadata/server/main.go).
### Receiving metadata
To read metadata sent by the client, the server needs to retrieve it from RPC
context using [FromIncomingContext].
To read metadata sent by the client, the server needs to retrieve it from RPC context.
If it is a unary call, the RPC handler's context can be used.
For streaming calls, the server needs to get context from the stream.
@ -188,16 +194,15 @@ func (s *server) SomeStreamingRPC(stream pb.Service_SomeStreamingRPCServer) erro
#### Unary call
To send header and trailer to client in unary call, the server can call
[SetHeader] and [SetTrailer] functions in module [grpc].
To send header and trailer to client in unary call, the server can call [SendHeader](https://godoc.org/google.golang.org/grpc#SendHeader) and [SetTrailer](https://godoc.org/google.golang.org/grpc#SetTrailer) functions in module [grpc](https://godoc.org/google.golang.org/grpc).
These two functions take a context as the first parameter.
It should be the RPC handler's context or one derived from it:
```go
func (s *server) SomeRPC(ctx context.Context, in *pb.someRequest) (*pb.someResponse, error) {
// create and set header
// create and send header
header := metadata.Pairs("header-key", "val")
grpc.SetHeader(ctx, header)
grpc.SendHeader(ctx, header)
// create and set trailer
trailer := metadata.Pairs("trailer-key", "val")
grpc.SetTrailer(ctx, trailer)
@ -206,39 +211,20 @@ func (s *server) SomeRPC(ctx context.Context, in *pb.someRequest) (*pb.someRespo
#### Streaming call
For streaming calls, header and trailer can be sent using function
[SetHeader] and [SetTrailer] in interface [ServerStream]:
For streaming calls, header and trailer can be sent using function `SendHeader` and `SetTrailer` in interface [ServerStream](https://godoc.org/google.golang.org/grpc#ServerStream):
```go
func (s *server) SomeStreamingRPC(stream pb.Service_SomeStreamingRPCServer) error {
// create and set header
// create and send header
header := metadata.Pairs("header-key", "val")
stream.SetHeader(header)
stream.SendHeader(header)
// create and set trailer
trailer := metadata.Pairs("trailer-key", "val")
stream.SetTrailer(trailer)
}
```
**Important**
Do not use
[FromOutgoingContext] on the server to write metadata to be sent to the client.
[FromOutgoingContext] is for client-side use only.
## Updating metadata from a server interceptor
An example for updating metadata from a server interceptor is
available [here](../examples/features/metadata_interceptor/server/main.go).
[FromIncomingContext]: <https://pkg.go.dev/google.golang.org/grpc/metadata#FromIncomingContext>
[SetHeader]: <https://godoc.org/google.golang.org/grpc#SetHeader>
[SetTrailer]: https://godoc.org/google.golang.org/grpc#SetTrailer
[FromOutgoingContext]: https://pkg.go.dev/google.golang.org/grpc/metadata#FromOutgoingContext
[ServerStream]: https://godoc.org/google.golang.org/grpc#ServerStream
[grpc]: https://godoc.org/google.golang.org/grpc
[ClientStream]: https://godoc.org/google.golang.org/grpc#ClientStream
[Header]: https://godoc.org/google.golang.org/grpc#Header
[Trailer]: https://godoc.org/google.golang.org/grpc#Trailer
[CallOption]: https://godoc.org/google.golang.org/grpc#CallOption
[metadata]: https://godoc.org/google.golang.org/grpc/metadata

View File

@ -1,8 +1,9 @@
# Proxy
HTTP CONNECT proxies are supported by default in gRPC. The proxy address can be
specified by the environment variables `HTTPS_PROXY` and `NO_PROXY`. (Note that
these environment variables are case insensitive.)
specified by the environment variables `HTTPS_PROXY` and `NO_PROXY`. Before Go
1.16, if the `HTTPS_PROXY` environment variable is unset, `HTTP_PROXY` will be
used instead. (Note that these environment variables are case insensitive.)
## Custom proxy
@ -12,4 +13,4 @@ connection before giving it to gRPC.
If the default proxy doesn't work for you, replace the default dialer with your
custom proxy dialer. This can be done using
[`WithContextDialer`](https://pkg.go.dev/google.golang.org/grpc#WithContextDialer).
[`WithDialer`](https://godoc.org/google.golang.org/grpc#WithDialer).

View File

@ -65,4 +65,4 @@ exit status 1
[details]: https://godoc.org/google.golang.org/grpc/internal/status#Status.Details
[status-err]: https://godoc.org/google.golang.org/grpc/internal/status#Status.Err
[status-error]: https://godoc.org/google.golang.org/grpc/status#Error
[example]: https://github.com/grpc/grpc-go/tree/master/examples/features/error_details
[example]: https://github.com/grpc/grpc-go/tree/master/examples/features/errors

View File

@ -103,7 +103,7 @@ The `list` command lists services exposed at a given port:
- Describe all services
The `describe` command inspects a service given its full name (in the format
of \<package\>.\<service\>).
of \<package\>.\<service\>).
```sh
$ grpcurl -plaintext localhost:50051 describe helloworld.Greeter

View File

@ -9,28 +9,20 @@ for general contribution guidelines.
## Maintainers (in alphabetical order)
- [arjan-bal](https://github.com/arjan-bal), Google LLC
- [arvindbr8](https://github.com/arvindbr8), Google LLC
- [atollena](https://github.com/atollena), Datadog, Inc.
- [cesarghali](https://github.com/cesarghali), Google LLC
- [dfawley](https://github.com/dfawley), Google LLC
- [easwars](https://github.com/easwars), Google LLC
- [gtcooke94](https://github.com/gtcooke94), Google LLC
- [menghanl](https://github.com/menghanl), Google LLC
- [srini100](https://github.com/srini100), Google LLC
## Emeritus Maintainers (in alphabetical order)
- [adelez](https://github.com/adelez)
- [aranjans](https://github.com/aranjans)
- [canguler](https://github.com/canguler)
- [cesarghali](https://github.com/cesarghali)
- [erm-g](https://github.com/erm-g)
- [iamqizhao](https://github.com/iamqizhao)
- [jeanbza](https://github.com/jeanbza)
- [jtattermusch](https://github.com/jtattermusch)
- [lyuxuan](https://github.com/lyuxuan)
- [makmukhi](https://github.com/makmukhi)
- [matt-kwong](https://github.com/matt-kwong)
- [menghanl](https://github.com/menghanl)
- [nicolasnoble](https://github.com/nicolasnoble)
- [purnesh42h](https://github.com/purnesh42h)
- [srini100](https://github.com/srini100)
- [yongni](https://github.com/yongni)
- [zasweq](https://github.com/zasweq)
- [adelez](https://github.com/adelez), Google LLC
- [canguler](https://github.com/canguler), Google LLC
- [iamqizhao](https://github.com/iamqizhao), Google LLC
- [jadekler](https://github.com/jadekler), Google LLC
- [jtattermusch](https://github.com/jtattermusch), Google LLC
- [lyuxuan](https://github.com/lyuxuan), Google LLC
- [makmukhi](https://github.com/makmukhi), Google LLC
- [matt-kwong](https://github.com/matt-kwong), Google LLC
- [nicolasnoble](https://github.com/nicolasnoble), Google LLC
- [yongni](https://github.com/yongni), Google LLC

View File

@ -30,20 +30,17 @@ testdeps:
GO111MODULE=on go get -d -v -t google.golang.org/grpc/...
vet: vetdeps
./scripts/vet.sh
./vet.sh
vetdeps:
./scripts/vet.sh -install
./vet.sh -install
.PHONY: \
all \
build \
clean \
deps \
proto \
test \
testsubmodule \
testrace \
testdeps \
vet \
vetdeps

View File

@ -1,8 +1,8 @@
# gRPC-Go
[![Build Status](https://travis-ci.org/grpc/grpc-go.svg)](https://travis-ci.org/grpc/grpc-go)
[![GoDoc](https://pkg.go.dev/badge/google.golang.org/grpc)][API]
[![GoReportCard](https://goreportcard.com/badge/grpc/grpc-go)](https://goreportcard.com/report/github.com/grpc/grpc-go)
[![codecov](https://codecov.io/gh/grpc/grpc-go/graph/badge.svg)](https://codecov.io/gh/grpc/grpc-go)
The [Go][] implementation of [gRPC][]: A high performance, open source, general
RPC framework that puts mobile and HTTP/2 first. For more information see the
@ -10,18 +10,25 @@ RPC framework that puts mobile and HTTP/2 first. For more information see the
## Prerequisites
- **[Go][]**: any one of the **two latest major** [releases][go-releases].
- **[Go][]**: any one of the **three latest major** [releases][go-releases].
## Installation
Simply add the following import to your code, and then `go [build|run|test]`
will automatically fetch the necessary dependencies:
With [Go module][] support (Go 1.11+), simply add the following import
```go
import "google.golang.org/grpc"
```
to your code, and then `go [build|run|test]` will automatically fetch the
necessary dependencies.
Otherwise, to install the `grpc-go` package, run the following command:
```console
$ go get -u google.golang.org/grpc
```
> **Note:** If you are trying to access `grpc-go` from **China**, see the
> [FAQ](#FAQ) below.
@ -32,7 +39,6 @@ import "google.golang.org/grpc"
- [Low-level technical docs](Documentation) from this repository
- [Performance benchmark][]
- [Examples](examples)
- [Contribution guidelines](CONTRIBUTING.md)
## FAQ
@ -50,6 +56,15 @@ To build Go code, there are several options:
- Set up a VPN and access google.golang.org through that.
- Without Go module support: `git clone` the repo manually:
```sh
git clone https://github.com/grpc/grpc-go.git $GOPATH/src/google.golang.org/grpc
```
You will need to do the same for all of grpc's dependencies in `golang.org`,
e.g. `golang.org/x/net`.
- With Go module support: it is possible to use the `replace` feature of `go
mod` to create aliases for golang.org packages. In your project's directory:
@ -61,13 +76,33 @@ To build Go code, there are several options:
```
Again, this will need to be done for all transitive dependencies hosted on
golang.org as well. For details, refer to [golang/go issue
#28652](https://github.com/golang/go/issues/28652).
golang.org as well. For details, refer to [golang/go issue #28652](https://github.com/golang/go/issues/28652).
### Compiling error, undefined: grpc.SupportPackageIsVersion
Please update to the latest version of gRPC-Go using
`go get google.golang.org/grpc`.
#### If you are using Go modules:
Ensure your gRPC-Go version is `require`d at the appropriate version in
the same module containing the generated `.pb.go` files. For example,
`SupportPackageIsVersion6` needs `v1.27.0`, so in your `go.mod` file:
```go
module <your module name>
require (
google.golang.org/grpc v1.27.0
)
```
#### If you are *not* using Go modules:
Update the `proto` package, gRPC package, and rebuild the `.proto` files:
```sh
go get -u github.com/golang/protobuf/{proto,protoc-gen-go}
go get -u google.golang.org/grpc
protoc --go_out=plugins=grpc:. *.proto
```
### How to turn on logging
@ -86,11 +121,9 @@ possible reasons, including:
1. mis-configured transport credentials, connection failed on handshaking
1. bytes disrupted, possibly by a proxy in between
1. server shutdown
1. Keepalive parameters caused connection shutdown, for example if you have
configured your server to terminate connections regularly to [trigger DNS
lookups](https://github.com/grpc/grpc-go/issues/3170#issuecomment-552517779).
If this is the case, you may want to increase your
[MaxConnectionAgeGrace](https://pkg.go.dev/google.golang.org/grpc/keepalive?tab=doc#ServerParameters),
1. Keepalive parameters caused connection shutdown, for example if you have configured
your server to terminate connections regularly to [trigger DNS lookups](https://github.com/grpc/grpc-go/issues/3170#issuecomment-552517779).
If this is the case, you may want to increase your [MaxConnectionAgeGrace](https://pkg.go.dev/google.golang.org/grpc/keepalive?tab=doc#ServerParameters),
to allow longer RPC calls to finish.
It can be tricky to debug this because the error happens on the client side but

View File

@ -1,3 +1,3 @@
# Security Policy
For information on gRPC Security Policy and reporting potential security issues, please see [gRPC CVE Process](https://github.com/grpc/proposal/blob/master/P4-grpc-cve-process.md).
For information on gRPC Security Policy and reporting potentional security issues, please see [gRPC CVE Process](https://github.com/grpc/proposal/blob/master/P4-grpc-cve-process.md).

View File

@ -26,10 +26,12 @@ import (
"testing"
"time"
"github.com/google/uuid"
"google.golang.org/grpc"
"google.golang.org/grpc/admin"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/internal/testutils/xds/bootstrap"
"google.golang.org/grpc/status"
v3statusgrpc "github.com/envoyproxy/go-control-plane/envoy/service/status/v3"
@ -52,6 +54,16 @@ type ExpectedStatusCodes struct {
// RunRegisterTests makes a client, runs the RPCs, and compares the status
// codes.
func RunRegisterTests(t *testing.T, ec ExpectedStatusCodes) {
nodeID := uuid.New().String()
bootstrapCleanup, err := bootstrap.CreateFile(bootstrap.Options{
NodeID: nodeID,
ServerURI: "no.need.for.a.server",
})
if err != nil {
t.Fatal(err)
}
defer bootstrapCleanup()
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("cannot create listener: %v", err)
@ -68,9 +80,9 @@ func RunRegisterTests(t *testing.T, ec ExpectedStatusCodes) {
server.Serve(lis)
}()
conn, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()))
conn, err := grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("grpc.NewClient(%q) = %v", lis.Addr().String(), err)
t.Fatalf("cannot connect to server: %v", err)
}
t.Run("channelz", func(t *testing.T) {

View File

@ -34,26 +34,26 @@ import (
// key/value pairs. Keys must be hashable, and users should define their own
// types for keys. Values should not be modified after they are added to an
// Attributes or if they were received from one. If values implement 'Equal(o
// any) bool', it will be called by (*Attributes).Equal to determine whether
// two values with the same key should be considered equal.
// interface{}) bool', it will be called by (*Attributes).Equal to determine
// whether two values with the same key should be considered equal.
type Attributes struct {
m map[any]any
m map[interface{}]interface{}
}
// New returns a new Attributes containing the key/value pair.
func New(key, value any) *Attributes {
return &Attributes{m: map[any]any{key: value}}
func New(key, value interface{}) *Attributes {
return &Attributes{m: map[interface{}]interface{}{key: value}}
}
// WithValue returns a new Attributes containing the previous keys and values
// and the new key/value pair. If the same key appears multiple times, the
// last value overwrites all previous values for that key. To remove an
// existing key, use a nil value. value should not be modified later.
func (a *Attributes) WithValue(key, value any) *Attributes {
func (a *Attributes) WithValue(key, value interface{}) *Attributes {
if a == nil {
return New(key, value)
}
n := &Attributes{m: make(map[any]any, len(a.m)+1)}
n := &Attributes{m: make(map[interface{}]interface{}, len(a.m)+1)}
for k, v := range a.m {
n.m[k] = v
}
@ -63,19 +63,20 @@ func (a *Attributes) WithValue(key, value any) *Attributes {
// Value returns the value associated with these attributes for key, or nil if
// no value is associated with key. The returned value should not be modified.
func (a *Attributes) Value(key any) any {
func (a *Attributes) Value(key interface{}) interface{} {
if a == nil {
return nil
}
return a.m[key]
}
// Equal returns whether a and o are equivalent. If 'Equal(o any) bool' is
// implemented for a value in the attributes, it is called to determine if the
// value matches the one stored in the other attributes. If Equal is not
// implemented, standard equality is used to determine if the two values are
// equal. Note that some types (e.g. maps) aren't comparable by default, so
// they must be wrapped in a struct, or in an alias type, with Equal defined.
// Equal returns whether a and o are equivalent. If 'Equal(o interface{})
// bool' is implemented for a value in the attributes, it is called to
// determine if the value matches the one stored in the other attributes. If
// Equal is not implemented, standard equality is used to determine if the two
// values are equal. Note that some types (e.g. maps) aren't comparable by
// default, so they must be wrapped in a struct, or in an alias type, with Equal
// defined.
func (a *Attributes) Equal(o *Attributes) bool {
if a == nil && o == nil {
return true
@ -92,7 +93,7 @@ func (a *Attributes) Equal(o *Attributes) bool {
// o missing element of a
return false
}
if eq, ok := v.(interface{ Equal(o any) bool }); ok {
if eq, ok := v.(interface{ Equal(o interface{}) bool }); ok {
if !eq.Equal(ov) {
return false
}
@ -111,31 +112,19 @@ func (a *Attributes) String() string {
sb.WriteString("{")
first := true
for k, v := range a.m {
var key, val string
if str, ok := k.(interface{ String() string }); ok {
key = str.String()
}
if str, ok := v.(interface{ String() string }); ok {
val = str.String()
}
if !first {
sb.WriteString(", ")
}
sb.WriteString(fmt.Sprintf("%q: %q ", str(k), str(v)))
sb.WriteString(fmt.Sprintf("%q: %q, ", key, val))
first = false
}
sb.WriteString("}")
return sb.String()
}
func str(x any) (s string) {
if v, ok := x.(fmt.Stringer); ok {
return fmt.Sprint(v)
} else if v, ok := x.(string); ok {
return v
}
return fmt.Sprintf("<%p>", x)
}
// MarshalJSON helps implement the json.Marshaler interface, thereby rendering
// the Attributes correctly when printing (via pretty.JSON) structs containing
// Attributes as fields.
//
// Is it impossible to unmarshal attributes from a JSON representation and this
// method is meant only for debugging purposes.
func (a *Attributes) MarshalJSON() ([]byte, error) {
return []byte(a.String()), nil
}

View File

@ -29,19 +29,11 @@ type stringVal struct {
s string
}
func (s stringVal) Equal(o any) bool {
func (s stringVal) Equal(o interface{}) bool {
os, ok := o.(stringVal)
return ok && s.s == os.s
}
type stringerVal struct {
s string
}
func (s stringerVal) String() string {
return s.s
}
func ExampleAttributes() {
type keyOne struct{}
type keyTwo struct{}
@ -65,36 +57,6 @@ func ExampleAttributes_WithValue() {
// Key two: {two}
}
func ExampleAttributes_String() {
type key struct{}
var typedNil *stringerVal
a1 := attributes.New(key{}, typedNil) // typed nil implements [fmt.Stringer]
a2 := attributes.New(key{}, (*stringerVal)(nil)) // typed nil implements [fmt.Stringer]
a3 := attributes.New(key{}, (*stringVal)(nil)) // typed nil not implements [fmt.Stringer]
a4 := attributes.New(key{}, nil) // untyped nil
a5 := attributes.New(key{}, 1)
a6 := attributes.New(key{}, stringerVal{s: "two"})
a7 := attributes.New(key{}, stringVal{s: "two"})
a8 := attributes.New(1, true)
fmt.Println("a1:", a1.String())
fmt.Println("a2:", a2.String())
fmt.Println("a3:", a3.String())
fmt.Println("a4:", a4.String())
fmt.Println("a5:", a5.String())
fmt.Println("a6:", a6.String())
fmt.Println("a7:", a7.String())
fmt.Println("a8:", a8.String())
// Output:
// a1: {"<%!p(attributes_test.key={})>": "<nil>" }
// a2: {"<%!p(attributes_test.key={})>": "<nil>" }
// a3: {"<%!p(attributes_test.key={})>": "<0x0>" }
// a4: {"<%!p(attributes_test.key={})>": "<%!p(<nil>)>" }
// a5: {"<%!p(attributes_test.key={})>": "<%!p(int=1)>" }
// a6: {"<%!p(attributes_test.key={})>": "two" }
// a7: {"<%!p(attributes_test.key={})>": "<%!p(attributes_test.stringVal={two})>" }
// a8: {"<%!p(int=1)>": "<%!p(bool=true)>" }
}
// Test that two attributes with the same content are Equal.
func TestEqual(t *testing.T) {
type keyOne struct{}

View File

@ -89,9 +89,9 @@ type LoggerConfig interface {
// decision meets the condition for audit, all the configured audit loggers'
// Log() method will be invoked to log that event.
//
// Please refer to
// https://github.com/grpc/proposal/blob/master/A59-audit-logging.md for more
// details about audit logging.
// TODO(lwge): Change the link to the merged gRFC once it's ready.
// Please refer to https://github.com/grpc/proposal/pull/346 for more details
// about audit logging.
type Logger interface {
// Log performs audit logging for the provided audit event.
//
@ -107,9 +107,9 @@ type Logger interface {
// implement this interface, along with the Logger interface, and register
// it by calling RegisterLoggerBuilder() at init time.
//
// Please refer to
// https://github.com/grpc/proposal/blob/master/A59-audit-logging.md for more
// details about audit logging.
// TODO(lwge): Change the link to the merged gRFC once it's ready.
// Please refer to https://github.com/grpc/proposal/pull/346 for more details
// about audit logging.
type LoggerBuilder interface {
// ParseLoggerConfig parses the given JSON bytes into a structured
// logger config this builder can use to build an audit logger.

View File

@ -24,6 +24,7 @@ import (
"crypto/x509"
"encoding/json"
"io"
"net"
"os"
"testing"
"time"
@ -78,7 +79,7 @@ func (lb *loggerBuilder) Build(audit.LoggerConfig) audit.Logger {
}
}
func (*loggerBuilder) ParseLoggerConfig(json.RawMessage) (audit.LoggerConfig, error) {
func (*loggerBuilder) ParseLoggerConfig(config json.RawMessage) (audit.LoggerConfig, error) {
return nil, nil
}
@ -239,24 +240,23 @@ func (s) TestAuditLogger(t *testing.T) {
wantStreamingCallCode: codes.PermissionDenied,
},
}
// Construct the credentials for the tests and the stub server
serverCreds := loadServerCreds(t)
clientCreds := loadClientCreds(t)
ss := &stubserver.StubServer{
UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
return &testpb.SimpleResponse{}, nil
},
FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
_, err := stream.Recv()
if err != io.EOF {
return err
}
return nil
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Construct the credentials for the tests and the stub server
serverCreds := loadServerCreds(t)
clientCreds := loadClientCreds(t)
ss := &stubserver.StubServer{
UnaryCallF: func(context.Context, *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
return &testpb.SimpleResponse{}, nil
},
FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
_, err := stream.Recv()
if err != io.EOF {
return err
}
return nil
},
}
// Setup test statAuditLogger, gRPC test server with authzPolicy, unary
// and stream interceptors.
lb := &loggerBuilder{
@ -266,18 +266,25 @@ func (s) TestAuditLogger(t *testing.T) {
audit.RegisterLoggerBuilder(lb)
i, _ := authz.NewStatic(test.authzPolicy)
s := grpc.NewServer(grpc.Creds(serverCreds), grpc.ChainUnaryInterceptor(i.UnaryInterceptor), grpc.ChainStreamInterceptor(i.StreamInterceptor))
s := grpc.NewServer(
grpc.Creds(serverCreds),
grpc.ChainUnaryInterceptor(i.UnaryInterceptor),
grpc.ChainStreamInterceptor(i.StreamInterceptor))
defer s.Stop()
ss.S = s
stubserver.StartTestService(t, ss)
testgrpc.RegisterTestServiceServer(s, ss)
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Error listening: %v", err)
}
go s.Serve(lis)
// Setup gRPC test client with certificates containing a SPIFFE Id.
cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(clientCreds))
clientConn, err := grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(clientCreds))
if err != nil {
t.Fatalf("grpc.NewClient(%v) failed: %v", ss.Address, err)
t.Fatalf("grpc.Dial(%v) failed: %v", lis.Addr().String(), err)
}
defer cc.Close()
client := testgrpc.NewTestServiceClient(cc)
defer clientConn.Close()
client := testgrpc.NewTestServiceClient(clientConn)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
@ -289,7 +296,7 @@ func (s) TestAuditLogger(t *testing.T) {
}
stream, err := client.StreamingInputCall(ctx)
if err != nil {
t.Fatalf("StreamingInputCall failed: %v", err)
t.Fatalf("StreamingInputCall failed:%v", err)
}
req := &testpb.StreamingInputCallRequest{
Payload: &testpb.Payload{
@ -297,7 +304,7 @@ func (s) TestAuditLogger(t *testing.T) {
},
}
if err := stream.Send(req); err != nil && err != io.EOF {
t.Fatalf("stream.Send failed: %v", err)
t.Fatalf("stream.Send failed:%v", err)
}
if _, err := stream.CloseAndRecv(); status.Code(err) != test.wantStreamingCallCode {
t.Errorf("Unexpected stream.CloseAndRecv fail: got %v want %v", err, test.wantStreamingCallCode)

View File

@ -56,7 +56,7 @@ type logger struct {
// Log marshals the audit.Event to json and prints it to standard output.
func (l *logger) Log(event *audit.Event) {
jsonContainer := map[string]any{
jsonContainer := map[string]interface{}{
"grpc_audit_log": convertEvent(event),
}
jsonBytes, err := json.Marshal(jsonContainer)

View File

@ -72,11 +72,11 @@ func (s) TestStdoutLogger_Log(t *testing.T) {
auditLogger.Log(test.event)
var container map[string]any
var container map[string]interface{}
if err := json.Unmarshal(buf.Bytes(), &container); err != nil {
t.Fatalf("Failed to unmarshal audit log event: %v", err)
}
innerEvent := extractEvent(container["grpc_audit_log"].(map[string]any))
innerEvent := extractEvent(container["grpc_audit_log"].(map[string]interface{}))
if innerEvent.Timestamp == "" {
t.Fatalf("Resulted event has no timestamp: %v", innerEvent)
}
@ -116,7 +116,7 @@ func (s) TestStdoutLoggerBuilder_Registration(t *testing.T) {
// extractEvent extracts an stdout.event from a map
// unmarshalled from a logged json message.
func extractEvent(container map[string]any) event {
func extractEvent(container map[string]interface{}) event {
return event{
FullMethodName: container["rpc_method"].(string),
Principal: container["principal"].(string),

View File

@ -23,6 +23,7 @@ import (
"crypto/tls"
"crypto/x509"
"io"
"net"
"os"
"testing"
"time"
@ -33,7 +34,6 @@ import (
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/grpc/testdata"
@ -42,6 +42,26 @@ import (
testpb "google.golang.org/grpc/interop/grpc_testing"
)
type testServer struct {
testgrpc.UnimplementedTestServiceServer
}
func (s *testServer) UnaryCall(ctx context.Context, req *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
return &testpb.SimpleResponse{}, nil
}
func (s *testServer) StreamingInputCall(stream testgrpc.TestService_StreamingInputCallServer) error {
for {
_, err := stream.Recv()
if err == io.EOF {
return stream.SendAndClose(&testpb.StreamingInputCallResponse{})
}
if err != nil {
return err
}
}
}
type s struct {
grpctest.Tester
}
@ -58,7 +78,7 @@ var authzTests = map[string]struct {
"DeniesRPCMatchInDenyNoMatchInAllow": {
authzPolicy: `{
"name": "authz",
"allow_rules":
"allow_rules":
[
{
"name": "allow_StreamingOutputCall",
@ -146,11 +166,11 @@ var authzTests = map[string]struct {
"/grpc.testing.TestService/UnaryCall",
"/grpc.testing.TestService/StreamingInputCall"
],
"headers":
"headers":
[
{
"key": "key-abc",
"values":
"values":
[
"val-abc",
"val-def"
@ -230,7 +250,7 @@ var authzTests = map[string]struct {
[
{
"name": "allow_StreamingOutputCall",
"request":
"request":
{
"paths":
[
@ -293,34 +313,25 @@ func (s) TestStaticPolicyEnd2End(t *testing.T) {
t.Run(name, func(t *testing.T) {
// Start a gRPC server with gRPC authz unary and stream server interceptors.
i, _ := authz.NewStatic(test.authzPolicy)
s := grpc.NewServer(
grpc.ChainUnaryInterceptor(i.UnaryInterceptor),
grpc.ChainStreamInterceptor(i.StreamInterceptor))
defer s.Stop()
testgrpc.RegisterTestServiceServer(s, &testServer{})
stub := &stubserver.StubServer{
UnaryCallF: func(context.Context, *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
return &testpb.SimpleResponse{}, nil
},
StreamingInputCallF: func(stream testgrpc.TestService_StreamingInputCallServer) error {
for {
_, err := stream.Recv()
if err == io.EOF {
return stream.SendAndClose(&testpb.StreamingInputCallResponse{})
}
if err != nil {
return err
}
}
},
S: grpc.NewServer(grpc.ChainUnaryInterceptor(i.UnaryInterceptor), grpc.ChainStreamInterceptor(i.StreamInterceptor)),
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("error listening: %v", err)
}
stubserver.StartTestService(t, stub)
defer stub.Stop()
go s.Serve(lis)
// Establish a connection to the server.
cc, err := grpc.NewClient(stub.Address, grpc.WithTransportCredentials(insecure.NewCredentials()))
clientConn, err := grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("grpc.NewClient(%v) failed: %v", stub.Address, err)
t.Fatalf("grpc.Dial(%v) failed: %v", lis.Addr().String(), err)
}
defer cc.Close()
client := testgrpc.NewTestServiceClient(cc)
defer clientConn.Close()
client := testgrpc.NewTestServiceClient(clientConn)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
@ -372,27 +383,29 @@ func (s) TestAllowsRPCRequestWithPrincipalsFieldOnTLSAuthenticatedConnection(t *
if err != nil {
t.Fatalf("failed to generate credentials: %v", err)
}
s := grpc.NewServer(
grpc.Creds(creds),
grpc.ChainUnaryInterceptor(i.UnaryInterceptor))
defer s.Stop()
testgrpc.RegisterTestServiceServer(s, &testServer{})
stub := &stubserver.StubServer{
UnaryCallF: func(context.Context, *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
return &testpb.SimpleResponse{}, nil
},
S: grpc.NewServer(grpc.Creds(creds), grpc.ChainUnaryInterceptor(i.UnaryInterceptor)),
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("error listening: %v", err)
}
stubserver.StartTestService(t, stub)
defer stub.S.Stop()
go s.Serve(lis)
// Establish a connection to the server.
creds, err = credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), "x.test.example.com")
if err != nil {
t.Fatalf("failed to load credentials: %v", err)
}
cc, err := grpc.NewClient(stub.Address, grpc.WithTransportCredentials(creds))
clientConn, err := grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(creds))
if err != nil {
t.Fatalf("grpc.NewClient(%v) failed: %v", stub.Address, err)
t.Fatalf("grpc.Dial(%v) failed: %v", lis.Addr().String(), err)
}
defer cc.Close()
client := testgrpc.NewTestServiceClient(cc)
defer clientConn.Close()
client := testgrpc.NewTestServiceClient(clientConn)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
@ -435,14 +448,17 @@ func (s) TestAllowsRPCRequestWithPrincipalsFieldOnMTLSAuthenticatedConnection(t
Certificates: []tls.Certificate{cert},
ClientCAs: certPool,
})
stub := &stubserver.StubServer{
UnaryCallF: func(context.Context, *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
return &testpb.SimpleResponse{}, nil
},
S: grpc.NewServer(grpc.Creds(creds), grpc.ChainUnaryInterceptor(i.UnaryInterceptor)),
s := grpc.NewServer(
grpc.Creds(creds),
grpc.ChainUnaryInterceptor(i.UnaryInterceptor))
defer s.Stop()
testgrpc.RegisterTestServiceServer(s, &testServer{})
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("error listening: %v", err)
}
stubserver.StartTestService(t, stub)
defer stub.Stop()
go s.Serve(lis)
// Establish a connection to the server.
cert, err = tls.LoadX509KeyPair(testdata.Path("x509/client1_cert.pem"), testdata.Path("x509/client1_key.pem"))
@ -462,12 +478,12 @@ func (s) TestAllowsRPCRequestWithPrincipalsFieldOnMTLSAuthenticatedConnection(t
RootCAs: roots,
ServerName: "x.test.example.com",
})
cc, err := grpc.NewClient(stub.Address, grpc.WithTransportCredentials(creds))
clientConn, err := grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(creds))
if err != nil {
t.Fatalf("grpc.NewClient(%v) failed: %v", stub.Address, err)
t.Fatalf("grpc.Dial(%v) failed: %v", lis.Addr().String(), err)
}
defer cc.Close()
client := testgrpc.NewTestServiceClient(cc)
defer clientConn.Close()
client := testgrpc.NewTestServiceClient(clientConn)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
@ -485,34 +501,27 @@ func (s) TestFileWatcherEnd2End(t *testing.T) {
i, _ := authz.NewFileWatcher(file, 1*time.Second)
defer i.Close()
stub := &stubserver.StubServer{
UnaryCallF: func(context.Context, *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
return &testpb.SimpleResponse{}, nil
},
StreamingInputCallF: func(stream testgrpc.TestService_StreamingInputCallServer) error {
for {
_, err := stream.Recv()
if err == io.EOF {
return stream.SendAndClose(&testpb.StreamingInputCallResponse{})
}
if err != nil {
return err
}
}
},
// Start a gRPC server with gRPC authz unary and stream server interceptors.
S: grpc.NewServer(grpc.ChainUnaryInterceptor(i.UnaryInterceptor), grpc.ChainStreamInterceptor(i.StreamInterceptor)),
// Start a gRPC server with gRPC authz unary and stream server interceptors.
s := grpc.NewServer(
grpc.ChainUnaryInterceptor(i.UnaryInterceptor),
grpc.ChainStreamInterceptor(i.StreamInterceptor))
defer s.Stop()
testgrpc.RegisterTestServiceServer(s, &testServer{})
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("error listening: %v", err)
}
stubserver.StartTestService(t, stub)
defer stub.Stop()
defer lis.Close()
go s.Serve(lis)
// Establish a connection to the server.
cc, err := grpc.NewClient(stub.Address, grpc.WithTransportCredentials(insecure.NewCredentials()))
clientConn, err := grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("grpc.NewClient(%v) failed: %v", stub.Address, err)
t.Fatalf("grpc.Dial(%v) failed: %v", lis.Addr().String(), err)
}
defer cc.Close()
client := testgrpc.NewTestServiceClient(cc)
defer clientConn.Close()
client := testgrpc.NewTestServiceClient(clientConn)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
@ -527,7 +536,7 @@ func (s) TestFileWatcherEnd2End(t *testing.T) {
// Verifying authorization decision for Streaming RPC.
stream, err := client.StreamingInputCall(ctx)
if err != nil {
t.Fatalf("failed StreamingInputCall : %v", err)
t.Fatalf("failed StreamingInputCall err: %v", err)
}
req := &testpb.StreamingInputCallRequest{
Payload: &testpb.Payload{
@ -535,7 +544,7 @@ func (s) TestFileWatcherEnd2End(t *testing.T) {
},
}
if err := stream.Send(req); err != nil && err != io.EOF {
t.Fatalf("failed stream.Send : %v", err)
t.Fatalf("failed stream.Send err: %v", err)
}
_, err = stream.CloseAndRecv()
if got := status.Convert(err); got.Code() != test.wantStatus.Code() || got.Message() != test.wantStatus.Message() {
@ -562,23 +571,26 @@ func (s) TestFileWatcher_ValidPolicyRefresh(t *testing.T) {
i, _ := authz.NewFileWatcher(file, 100*time.Millisecond)
defer i.Close()
stub := &stubserver.StubServer{
UnaryCallF: func(context.Context, *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
return &testpb.SimpleResponse{}, nil
},
// Start a gRPC server with gRPC authz unary server interceptor.
S: grpc.NewServer(grpc.ChainUnaryInterceptor(i.UnaryInterceptor)),
// Start a gRPC server with gRPC authz unary server interceptor.
s := grpc.NewServer(
grpc.ChainUnaryInterceptor(i.UnaryInterceptor))
defer s.Stop()
testgrpc.RegisterTestServiceServer(s, &testServer{})
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("error listening: %v", err)
}
stubserver.StartTestService(t, stub)
defer stub.Stop()
defer lis.Close()
go s.Serve(lis)
// Establish a connection to the server.
cc, err := grpc.NewClient(stub.Address, grpc.WithTransportCredentials(insecure.NewCredentials()))
clientConn, err := grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("grpc.NewClient(%v) failed: %v", stub.Address, err)
t.Fatalf("grpc.Dial(%v) failed: %v", lis.Addr().String(), err)
}
defer cc.Close()
client := testgrpc.NewTestServiceClient(cc)
defer clientConn.Close()
client := testgrpc.NewTestServiceClient(clientConn)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
@ -607,23 +619,26 @@ func (s) TestFileWatcher_InvalidPolicySkipReload(t *testing.T) {
i, _ := authz.NewFileWatcher(file, 20*time.Millisecond)
defer i.Close()
stub := &stubserver.StubServer{
UnaryCallF: func(context.Context, *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
return &testpb.SimpleResponse{}, nil
},
// Start a gRPC server with gRPC authz unary server interceptors.
S: grpc.NewServer(grpc.ChainUnaryInterceptor(i.UnaryInterceptor)),
// Start a gRPC server with gRPC authz unary server interceptors.
s := grpc.NewServer(
grpc.ChainUnaryInterceptor(i.UnaryInterceptor))
defer s.Stop()
testgrpc.RegisterTestServiceServer(s, &testServer{})
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("error listening: %v", err)
}
stubserver.StartTestService(t, stub)
defer stub.Stop()
defer lis.Close()
go s.Serve(lis)
// Establish a connection to the server.
cc, err := grpc.NewClient(stub.Address, grpc.WithTransportCredentials(insecure.NewCredentials()))
clientConn, err := grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("grpc.NewClient(%v) failed: %v", stub.Address, err)
t.Fatalf("grpc.Dial(%v) failed: %v", lis.Addr().String(), err)
}
defer cc.Close()
client := testgrpc.NewTestServiceClient(cc)
defer clientConn.Close()
client := testgrpc.NewTestServiceClient(clientConn)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
@ -655,22 +670,26 @@ func (s) TestFileWatcher_RecoversFromReloadFailure(t *testing.T) {
i, _ := authz.NewFileWatcher(file, 100*time.Millisecond)
defer i.Close()
stub := &stubserver.StubServer{
UnaryCallF: func(context.Context, *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
return &testpb.SimpleResponse{}, nil
},
S: grpc.NewServer(grpc.ChainUnaryInterceptor(i.UnaryInterceptor)),
// Start a gRPC server with gRPC authz unary server interceptors.
s := grpc.NewServer(
grpc.ChainUnaryInterceptor(i.UnaryInterceptor))
defer s.Stop()
testgrpc.RegisterTestServiceServer(s, &testServer{})
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("error listening: %v", err)
}
stubserver.StartTestService(t, stub)
defer stub.Stop()
defer lis.Close()
go s.Serve(lis)
// Establish a connection to the server.
cc, err := grpc.NewClient(stub.Address, grpc.WithTransportCredentials(insecure.NewCredentials()))
clientConn, err := grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("grpc.NewClient(%v) failed: %v", stub.Address, err)
t.Fatalf("grpc.Dial(%v) failed: %v", lis.Addr().String(), err)
}
defer cc.Close()
client := testgrpc.NewTestServiceClient(cc)
defer clientConn.Close()
client := testgrpc.NewTestServiceClient(clientConn)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

View File

@ -58,7 +58,7 @@ func NewStatic(authzPolicy string) (*StaticInterceptor, error) {
// UnaryInterceptor intercepts incoming Unary RPC requests.
// Only authorized requests are allowed to pass. Otherwise, an unauthorized
// error is returned to the client.
func (i *StaticInterceptor) UnaryInterceptor(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
func (i *StaticInterceptor) UnaryInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
err := i.engines.IsAuthorized(ctx)
if err != nil {
if status.Code(err) == codes.PermissionDenied {
@ -75,7 +75,7 @@ func (i *StaticInterceptor) UnaryInterceptor(ctx context.Context, req any, _ *gr
// StreamInterceptor intercepts incoming Stream RPC requests.
// Only authorized requests are allowed to pass. Otherwise, an unauthorized
// error is returned to the client.
func (i *StaticInterceptor) StreamInterceptor(srv any, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
func (i *StaticInterceptor) StreamInterceptor(srv interface{}, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
err := i.engines.IsAuthorized(ss.Context())
if err != nil {
if status.Code(err) == codes.PermissionDenied {
@ -166,13 +166,13 @@ func (i *FileWatcherInterceptor) Close() {
// UnaryInterceptor intercepts incoming Unary RPC requests.
// Only authorized requests are allowed to pass. Otherwise, an unauthorized
// error is returned to the client.
func (i *FileWatcherInterceptor) UnaryInterceptor(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
func (i *FileWatcherInterceptor) UnaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
return ((*StaticInterceptor)(atomic.LoadPointer(&i.internalInterceptor))).UnaryInterceptor(ctx, req, info, handler)
}
// StreamInterceptor intercepts incoming Stream RPC requests.
// Only authorized requests are allowed to pass. Otherwise, an unauthorized
// error is returned to the client.
func (i *FileWatcherInterceptor) StreamInterceptor(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
func (i *FileWatcherInterceptor) StreamInterceptor(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
return ((*StaticInterceptor)(atomic.LoadPointer(&i.internalInterceptor))).StreamInterceptor(srv, ss, info, handler)
}

View File

@ -57,9 +57,9 @@ func (s) TestNewStatic(t *testing.T) {
wantErr: fmt.Errorf(`"name" is not present`),
},
"ValidPolicyCreatesInterceptor": {
authzPolicy: `{
authzPolicy: `{
"name": "authz",
"allow_rules":
"allow_rules":
[
{
"name": "allow_all"

View File

@ -308,7 +308,7 @@ func TestTranslatePolicy(t *testing.T) {
AuditLoggingOptions: &v3rbacpb.RBAC_AuditLoggingOptions{
AuditCondition: v3rbacpb.RBAC_AuditLoggingOptions_NONE,
LoggerConfigs: []*v3rbacpb.RBAC_AuditLoggingOptions_AuditLoggerConfig{
{AuditLogger: &v3corepb.TypedExtensionConfig{Name: "stdout_logger", TypedConfig: anyPbHelper(t, map[string]any{}, "stdout_logger")},
{AuditLogger: &v3corepb.TypedExtensionConfig{Name: "stdout_logger", TypedConfig: anyPbHelper(t, map[string]interface{}{}, "stdout_logger")},
IsOptional: false,
},
},
@ -342,7 +342,7 @@ func TestTranslatePolicy(t *testing.T) {
AuditLoggingOptions: &v3rbacpb.RBAC_AuditLoggingOptions{
AuditCondition: v3rbacpb.RBAC_AuditLoggingOptions_ON_ALLOW,
LoggerConfigs: []*v3rbacpb.RBAC_AuditLoggingOptions_AuditLoggerConfig{
{AuditLogger: &v3corepb.TypedExtensionConfig{Name: "stdout_logger", TypedConfig: anyPbHelper(t, map[string]any{}, "stdout_logger")},
{AuditLogger: &v3corepb.TypedExtensionConfig{Name: "stdout_logger", TypedConfig: anyPbHelper(t, map[string]interface{}{}, "stdout_logger")},
IsOptional: false,
},
},
@ -404,7 +404,7 @@ func TestTranslatePolicy(t *testing.T) {
AuditLoggingOptions: &v3rbacpb.RBAC_AuditLoggingOptions{
AuditCondition: v3rbacpb.RBAC_AuditLoggingOptions_ON_DENY,
LoggerConfigs: []*v3rbacpb.RBAC_AuditLoggingOptions_AuditLoggerConfig{
{AuditLogger: &v3corepb.TypedExtensionConfig{Name: "stdout_logger", TypedConfig: anyPbHelper(t, map[string]any{}, "stdout_logger")},
{AuditLogger: &v3corepb.TypedExtensionConfig{Name: "stdout_logger", TypedConfig: anyPbHelper(t, map[string]interface{}{}, "stdout_logger")},
IsOptional: false,
},
},
@ -438,7 +438,7 @@ func TestTranslatePolicy(t *testing.T) {
AuditLoggingOptions: &v3rbacpb.RBAC_AuditLoggingOptions{
AuditCondition: v3rbacpb.RBAC_AuditLoggingOptions_ON_DENY_AND_ALLOW,
LoggerConfigs: []*v3rbacpb.RBAC_AuditLoggingOptions_AuditLoggerConfig{
{AuditLogger: &v3corepb.TypedExtensionConfig{Name: "stdout_logger", TypedConfig: anyPbHelper(t, map[string]any{}, "stdout_logger")},
{AuditLogger: &v3corepb.TypedExtensionConfig{Name: "stdout_logger", TypedConfig: anyPbHelper(t, map[string]interface{}{}, "stdout_logger")},
IsOptional: false,
},
},
@ -500,7 +500,7 @@ func TestTranslatePolicy(t *testing.T) {
AuditLoggingOptions: &v3rbacpb.RBAC_AuditLoggingOptions{
AuditCondition: v3rbacpb.RBAC_AuditLoggingOptions_NONE,
LoggerConfigs: []*v3rbacpb.RBAC_AuditLoggingOptions_AuditLoggerConfig{
{AuditLogger: &v3corepb.TypedExtensionConfig{Name: "stdout_logger", TypedConfig: anyPbHelper(t, map[string]any{}, "stdout_logger")},
{AuditLogger: &v3corepb.TypedExtensionConfig{Name: "stdout_logger", TypedConfig: anyPbHelper(t, map[string]interface{}{}, "stdout_logger")},
IsOptional: false,
},
},
@ -534,7 +534,7 @@ func TestTranslatePolicy(t *testing.T) {
AuditLoggingOptions: &v3rbacpb.RBAC_AuditLoggingOptions{
AuditCondition: v3rbacpb.RBAC_AuditLoggingOptions_NONE,
LoggerConfigs: []*v3rbacpb.RBAC_AuditLoggingOptions_AuditLoggerConfig{
{AuditLogger: &v3corepb.TypedExtensionConfig{Name: "stdout_logger", TypedConfig: anyPbHelper(t, map[string]any{}, "stdout_logger")},
{AuditLogger: &v3corepb.TypedExtensionConfig{Name: "stdout_logger", TypedConfig: anyPbHelper(t, map[string]interface{}{}, "stdout_logger")},
IsOptional: false,
},
},
@ -596,7 +596,7 @@ func TestTranslatePolicy(t *testing.T) {
AuditLoggingOptions: &v3rbacpb.RBAC_AuditLoggingOptions{
AuditCondition: v3rbacpb.RBAC_AuditLoggingOptions_NONE,
LoggerConfigs: []*v3rbacpb.RBAC_AuditLoggingOptions_AuditLoggerConfig{
{AuditLogger: &v3corepb.TypedExtensionConfig{Name: "stdout_logger", TypedConfig: anyPbHelper(t, map[string]any{"abc": 123, "xyz": "123"}, "stdout_logger")},
{AuditLogger: &v3corepb.TypedExtensionConfig{Name: "stdout_logger", TypedConfig: anyPbHelper(t, map[string]interface{}{"abc": 123, "xyz": "123"}, "stdout_logger")},
IsOptional: false,
},
},
@ -630,7 +630,7 @@ func TestTranslatePolicy(t *testing.T) {
AuditLoggingOptions: &v3rbacpb.RBAC_AuditLoggingOptions{
AuditCondition: v3rbacpb.RBAC_AuditLoggingOptions_NONE,
LoggerConfigs: []*v3rbacpb.RBAC_AuditLoggingOptions_AuditLoggerConfig{
{AuditLogger: &v3corepb.TypedExtensionConfig{Name: "stdout_logger", TypedConfig: anyPbHelper(t, map[string]any{"abc": 123, "xyz": "123"}, "stdout_logger")},
{AuditLogger: &v3corepb.TypedExtensionConfig{Name: "stdout_logger", TypedConfig: anyPbHelper(t, map[string]interface{}{"abc": 123, "xyz": "123"}, "stdout_logger")},
IsOptional: false,
},
},
@ -688,7 +688,7 @@ func TestTranslatePolicy(t *testing.T) {
AuditLoggingOptions: &v3rbacpb.RBAC_AuditLoggingOptions{
AuditCondition: v3rbacpb.RBAC_AuditLoggingOptions_NONE,
LoggerConfigs: []*v3rbacpb.RBAC_AuditLoggingOptions_AuditLoggerConfig{
{AuditLogger: &v3corepb.TypedExtensionConfig{Name: "stdout_logger", TypedConfig: anyPbHelper(t, map[string]any{"abc": 123, "xyz": map[string]any{"abc": 123}}, "stdout_logger")},
{AuditLogger: &v3corepb.TypedExtensionConfig{Name: "stdout_logger", TypedConfig: anyPbHelper(t, map[string]interface{}{"abc": 123, "xyz": map[string]interface{}{"abc": 123}}, "stdout_logger")},
IsOptional: false,
},
},
@ -792,7 +792,7 @@ func TestTranslatePolicy(t *testing.T) {
AuditLoggingOptions: &v3rbacpb.RBAC_AuditLoggingOptions{
AuditCondition: v3rbacpb.RBAC_AuditLoggingOptions_NONE,
LoggerConfigs: []*v3rbacpb.RBAC_AuditLoggingOptions_AuditLoggerConfig{
{AuditLogger: &v3corepb.TypedExtensionConfig{Name: "stdout_logger", TypedConfig: anyPbHelper(t, map[string]any{}, "stdout_logger")},
{AuditLogger: &v3corepb.TypedExtensionConfig{Name: "stdout_logger", TypedConfig: anyPbHelper(t, map[string]interface{}{}, "stdout_logger")},
IsOptional: false,
},
},
@ -853,7 +853,7 @@ func TestTranslatePolicy(t *testing.T) {
AuditLoggingOptions: &v3rbacpb.RBAC_AuditLoggingOptions{
AuditCondition: v3rbacpb.RBAC_AuditLoggingOptions_ON_DENY,
LoggerConfigs: []*v3rbacpb.RBAC_AuditLoggingOptions_AuditLoggerConfig{
{AuditLogger: &v3corepb.TypedExtensionConfig{Name: "stdout_logger", TypedConfig: anyPbHelper(t, map[string]any{}, "stdout_logger")},
{AuditLogger: &v3corepb.TypedExtensionConfig{Name: "stdout_logger", TypedConfig: anyPbHelper(t, map[string]interface{}{}, "stdout_logger")},
IsOptional: false,
},
},
@ -887,7 +887,7 @@ func TestTranslatePolicy(t *testing.T) {
AuditLoggingOptions: &v3rbacpb.RBAC_AuditLoggingOptions{
AuditCondition: v3rbacpb.RBAC_AuditLoggingOptions_ON_DENY,
LoggerConfigs: []*v3rbacpb.RBAC_AuditLoggingOptions_AuditLoggerConfig{
{AuditLogger: &v3corepb.TypedExtensionConfig{Name: "stdout_logger", TypedConfig: anyPbHelper(t, map[string]any{}, "stdout_logger")},
{AuditLogger: &v3corepb.TypedExtensionConfig{Name: "stdout_logger", TypedConfig: anyPbHelper(t, map[string]interface{}{}, "stdout_logger")},
IsOptional: false,
},
},
@ -1034,7 +1034,7 @@ func TestTranslatePolicy(t *testing.T) {
}
}
func anyPbHelper(t *testing.T, in map[string]any, name string) *anypb.Any {
func anyPbHelper(t *testing.T, in map[string]interface{}, name string) *anypb.Any {
t.Helper()
pb, err := structpb.NewStruct(in)
typedStruct := &v1xdsudpatypepb.TypedStruct{

View File

@ -39,7 +39,7 @@ type Config struct {
MaxDelay time.Duration
}
// DefaultConfig is a backoff configuration with the default values specified
// DefaultConfig is a backoff configuration with the default values specfied
// at https://github.com/grpc/grpc/blob/master/doc/connection-backoff.md.
//
// This should be useful for callers who want to configure backoff with

View File

@ -30,8 +30,6 @@ import (
"google.golang.org/grpc/channelz"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
estats "google.golang.org/grpc/experimental/stats"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/resolver"
@ -41,8 +39,6 @@ import (
var (
// m is a map from name to balancer builder.
m = make(map[string]Builder)
logger = grpclog.Component("balancer")
)
// Register registers the balancer builder to the balancer map. b.Name
@ -55,14 +51,7 @@ var (
// an init() function), and is not thread-safe. If multiple Balancers are
// registered with the same name, the one registered last will take effect.
func Register(b Builder) {
name := strings.ToLower(b.Name())
if name != b.Name() {
// TODO: Skip the use of strings.ToLower() to index the map after v1.59
// is released to switch to case sensitive balancer registry. Also,
// remove this warning and update the docstrings for Register and Get.
logger.Warningf("Balancer registered with name %q. grpc-go will be switching to case sensitive balancer registries soon", b.Name())
}
m[name] = b
m[strings.ToLower(b.Name())] = b
}
// unregisterForTesting deletes the balancer with the given name from the
@ -75,26 +64,59 @@ func unregisterForTesting(name string) {
func init() {
internal.BalancerUnregister = unregisterForTesting
internal.ConnectedAddress = connectedAddress
internal.SetConnectedAddress = setConnectedAddress
}
// Get returns the resolver builder registered with the given name.
// Note that the compare is done in a case-insensitive fashion.
// If no builder is register with the name, nil will be returned.
func Get(name string) Builder {
if strings.ToLower(name) != name {
// TODO: Skip the use of strings.ToLower() to index the map after v1.59
// is released to switch to case sensitive balancer registry. Also,
// remove this warning and update the docstrings for Register and Get.
logger.Warningf("Balancer retrieved for name %q. grpc-go will be switching to case sensitive balancer registries soon", name)
}
if b, ok := m[strings.ToLower(name)]; ok {
return b
}
return nil
}
// A SubConn represents a single connection to a gRPC backend service.
//
// Each SubConn contains a list of addresses.
//
// All SubConns start in IDLE, and will not try to connect. To trigger the
// connecting, Balancers must call Connect. If a connection re-enters IDLE,
// Balancers must call Connect again to trigger a new connection attempt.
//
// gRPC will try to connect to the addresses in sequence, and stop trying the
// remainder once the first connection is successful. If an attempt to connect
// to all addresses encounters an error, the SubConn will enter
// TRANSIENT_FAILURE for a backoff period, and then transition to IDLE.
//
// Once established, if a connection is lost, the SubConn will transition
// directly to IDLE.
//
// This interface is to be implemented by gRPC. Users should not need their own
// implementation of this interface. For situations like testing, any
// implementations should embed this interface. This allows gRPC to add new
// methods to this interface.
type SubConn interface {
// UpdateAddresses updates the addresses used in this SubConn.
// gRPC checks if currently-connected address is still in the new list.
// If it's in the list, the connection will be kept.
// If it's not in the list, the connection will gracefully closed, and
// a new connection will be created.
//
// This will trigger a state transition for the SubConn.
//
// Deprecated: This method is now part of the ClientConn interface and will
// eventually be removed from here.
UpdateAddresses([]resolver.Address)
// Connect starts the connecting for this SubConn.
Connect()
// GetOrBuildProducer returns a reference to the existing Producer for this
// ProducerBuilder in this SubConn, or, if one does not currently exist,
// creates a new one and returns it. Returns a close function which must
// be called when the Producer is no longer needed.
GetOrBuildProducer(ProducerBuilder) (p Producer, close func())
}
// NewSubConnOptions contains options to create new SubConn.
type NewSubConnOptions struct {
// CredsBundle is the credentials bundle that will be used in the created
@ -107,11 +129,6 @@ type NewSubConnOptions struct {
// HealthCheckEnabled indicates whether health check service should be
// enabled on this SubConn
HealthCheckEnabled bool
// StateListener is called when the state of the subconn changes. If nil,
// Balancer.UpdateSubConnState will be called instead. Will never be
// invoked until after Connect() is called on the SubConn created with
// these options.
StateListener func(SubConnState)
}
// State contains the balancer's state relevant to the gRPC ClientConn.
@ -129,35 +146,20 @@ type State struct {
// brand new implementation of this interface. For the situations like
// testing, the new implementation should embed this interface. This allows
// gRPC to add new methods to this interface.
//
// NOTICE: This interface is intended to be implemented by gRPC, or intercepted
// by custom load balancing polices. Users should not need their own complete
// implementation of this interface -- they should always delegate to a
// ClientConn passed to Builder.Build() by embedding it in their
// implementations. An embedded ClientConn must never be nil, or runtime panics
// will occur.
type ClientConn interface {
// NewSubConn is called by balancer to create a new SubConn.
// It doesn't block and wait for the connections to be established.
// Behaviors of the SubConn can be controlled by options.
//
// Deprecated: please be aware that in a future version, SubConns will only
// support one address per SubConn.
NewSubConn([]resolver.Address, NewSubConnOptions) (SubConn, error)
// RemoveSubConn removes the SubConn from ClientConn.
// The SubConn will be shutdown.
//
// Deprecated: use SubConn.Shutdown instead.
RemoveSubConn(SubConn)
// UpdateAddresses updates the addresses used in the passed in SubConn.
// gRPC checks if the currently connected address is still in the new list.
// If so, the connection will be kept. Else, the connection will be
// gracefully closed, and a new connection will be created.
//
// This may trigger a state transition for the SubConn.
//
// Deprecated: this method will be removed. Create new SubConns for new
// addresses instead.
// This will trigger a state transition for the SubConn.
UpdateAddresses(SubConn, []resolver.Address)
// UpdateState notifies gRPC that the balancer's internal state has
@ -174,17 +176,6 @@ type ClientConn interface {
//
// Deprecated: Use the Target field in the BuildOptions instead.
Target() string
// MetricsRecorder provides the metrics recorder that balancers can use to
// record metrics. Balancer implementations which do not register metrics on
// metrics registry and record on them can ignore this method. The returned
// MetricsRecorder is guaranteed to never be nil.
MetricsRecorder() estats.MetricsRecorder
// EnforceClientConnEmbedding is included to force implementers to embed
// another implementation of this interface, allowing gRPC to add methods
// without breaking users.
internal.EnforceClientConnEmbedding
}
// BuildOptions contains additional information for Build.
@ -206,8 +197,8 @@ type BuildOptions struct {
// implementations which do not communicate with a remote load balancer
// server can ignore this field.
Authority string
// ChannelzParent is the parent ClientConn's channelz channel.
ChannelzParent channelz.Identifier
// ChannelzParentID is the parent ClientConn's channelz ID.
ChannelzParentID *channelz.Identifier
// CustomUserAgent is the custom user agent set on the parent ClientConn.
// The balancer should set the same custom user agent if it creates a
// ClientConn.
@ -259,7 +250,7 @@ type DoneInfo struct {
// trailing metadata.
//
// The only supported type now is *orca_v3.LoadReport.
ServerLoad any
ServerLoad interface{}
}
var (
@ -352,18 +343,10 @@ type Balancer interface {
ResolverError(error)
// UpdateSubConnState is called by gRPC when the state of a SubConn
// changes.
//
// Deprecated: Use NewSubConnOptions.StateListener when creating the
// SubConn instead.
UpdateSubConnState(SubConn, SubConnState)
// Close closes the balancer. The balancer is not currently required to
// call SubConn.Shutdown for its existing SubConns; however, this will be
// required in a future release, so it is recommended.
// Close closes the balancer. The balancer is not required to call
// ClientConn.RemoveSubConn for its existing SubConns.
Close()
// ExitIdle instructs the LB policy to reconnect to backends / exit the
// IDLE state, if appropriate and possible. Note that SubConns that enter
// the IDLE state will not reconnect until SubConn.Connect is called.
ExitIdle()
}
// ExitIdler is an optional interface for balancers to implement. If
@ -371,8 +354,8 @@ type Balancer interface {
// the ClientConn is idle. If unimplemented, ClientConn.Connect will cause
// all SubConns to connect.
//
// Deprecated: All balancers must implement this interface. This interface will
// be removed in a future release.
// Notice: it will be required for all balancers to implement this in a future
// release.
type ExitIdler interface {
// ExitIdle instructs the LB policy to reconnect to backends / exit the
// IDLE state, if appropriate and possible. Note that SubConns that enter
@ -380,6 +363,15 @@ type ExitIdler interface {
ExitIdle()
}
// SubConnState describes the state of a SubConn.
type SubConnState struct {
// ConnectivityState is the connectivity state of the SubConn.
ConnectivityState connectivity.State
// ConnectionError is set if the ConnectivityState is TransientFailure,
// describing the reason the SubConn failed. Otherwise, it is nil.
ConnectionError error
}
// ClientConnState describes the state of a ClientConn relevant to the
// balancer.
type ClientConnState struct {
@ -392,3 +384,21 @@ type ClientConnState struct {
// ErrBadResolverState may be returned by UpdateClientConnState to indicate a
// problem with the provided name resolver data.
var ErrBadResolverState = errors.New("bad resolver state")
// A ProducerBuilder is a simple constructor for a Producer. It is used by the
// SubConn to create producers when needed.
type ProducerBuilder interface {
// Build creates a Producer. The first parameter is always a
// grpc.ClientConnInterface (a type to allow creating RPCs/streams on the
// associated SubConn), but is declared as interface{} to avoid a
// dependency cycle. Should also return a close function that will be
// called when all references to the Producer have been given up.
Build(grpcClientConnInterface interface{}) (p Producer, close func())
}
// A Producer is a type shared among potentially many consumers. It is
// associated with a SubConn, and an implementation will typically contain
// other methods to provide additional functionality, e.g. configuration or
// subscription registration.
type Producer interface {
}

View File

@ -36,12 +36,12 @@ type baseBuilder struct {
config Config
}
func (bb *baseBuilder) Build(cc balancer.ClientConn, _ balancer.BuildOptions) balancer.Balancer {
func (bb *baseBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) balancer.Balancer {
bal := &baseBalancer{
cc: cc,
pickerBuilder: bb.pickerBuilder,
subConns: resolver.NewAddressMapV2[balancer.SubConn](),
subConns: resolver.NewAddressMap(),
scStates: make(map[balancer.SubConn]connectivity.State),
csEvltr: &balancer.ConnectivityStateEvaluator{},
config: bb.config,
@ -65,7 +65,7 @@ type baseBalancer struct {
csEvltr *balancer.ConnectivityStateEvaluator
state connectivity.State
subConns *resolver.AddressMapV2[balancer.SubConn]
subConns *resolver.AddressMap
scStates map[balancer.SubConn]connectivity.State
picker balancer.Picker
config Config
@ -100,17 +100,12 @@ func (b *baseBalancer) UpdateClientConnState(s balancer.ClientConnState) error {
// Successful resolution; clear resolver error and ensure we return nil.
b.resolverErr = nil
// addrsSet is the set converted from addrs, it's used for quick lookup of an address.
addrsSet := resolver.NewAddressMapV2[any]()
addrsSet := resolver.NewAddressMap()
for _, a := range s.ResolverState.Addresses {
addrsSet.Set(a, nil)
if _, ok := b.subConns.Get(a); !ok {
// a is a new address (not existing in b.subConns).
var sc balancer.SubConn
opts := balancer.NewSubConnOptions{
HealthCheckEnabled: b.config.HealthCheck,
StateListener: func(scs balancer.SubConnState) { b.updateSubConnState(sc, scs) },
}
sc, err := b.cc.NewSubConn([]resolver.Address{a}, opts)
sc, err := b.cc.NewSubConn([]resolver.Address{a}, balancer.NewSubConnOptions{HealthCheckEnabled: b.config.HealthCheck})
if err != nil {
logger.Warningf("base.baseBalancer: failed to create new SubConn: %v", err)
continue
@ -122,17 +117,18 @@ func (b *baseBalancer) UpdateClientConnState(s balancer.ClientConnState) error {
}
}
for _, a := range b.subConns.Keys() {
sc, _ := b.subConns.Get(a)
sci, _ := b.subConns.Get(a)
sc := sci.(balancer.SubConn)
// a was removed by resolver.
if _, ok := addrsSet.Get(a); !ok {
sc.Shutdown()
b.cc.RemoveSubConn(sc)
b.subConns.Delete(a)
// Keep the state of this sc in b.scStates until sc's state becomes Shutdown.
// The entry will be deleted in updateSubConnState.
// The entry will be deleted in UpdateSubConnState.
}
}
// If resolver state contains no addresses, return an error so ClientConn
// will trigger re-resolve. Also records this as a resolver error, so when
// will trigger re-resolve. Also records this as an resolver error, so when
// the overall state turns transient failure, the error message will have
// the zero address information.
if len(s.ResolverState.Addresses) == 0 {
@ -172,7 +168,8 @@ func (b *baseBalancer) regeneratePicker() {
// Filter out all ready SCs from full subConn map.
for _, addr := range b.subConns.Keys() {
sc, _ := b.subConns.Get(addr)
sci, _ := b.subConns.Get(addr)
sc := sci.(balancer.SubConn)
if st, ok := b.scStates[sc]; ok && st == connectivity.Ready {
readySCs[sc] = SubConnInfo{Address: addr}
}
@ -180,12 +177,7 @@ func (b *baseBalancer) regeneratePicker() {
b.picker = b.pickerBuilder.Build(PickerBuildInfo{ReadySCs: readySCs})
}
// UpdateSubConnState is a nop because a StateListener is always set in NewSubConn.
func (b *baseBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) {
logger.Errorf("base.baseBalancer: UpdateSubConnState(%v, %+v) called unexpectedly", sc, state)
}
func (b *baseBalancer) updateSubConnState(sc balancer.SubConn, state balancer.SubConnState) {
s := state.ConnectivityState
if logger.V(2) {
logger.Infof("base.baseBalancer: handle SubConn state change: %p, %v", sc, s)
@ -212,8 +204,8 @@ func (b *baseBalancer) updateSubConnState(sc balancer.SubConn, state balancer.Su
case connectivity.Idle:
sc.Connect()
case connectivity.Shutdown:
// When an address was removed by resolver, b called Shutdown but kept
// the sc's state in scStates. Remove state for this sc here.
// When an address was removed by resolver, b called RemoveSubConn but
// kept the sc's state in scStates. Remove state for this sc here.
delete(b.scStates, sc)
case connectivity.TransientFailure:
// Save error to be reported via picker.
@ -234,7 +226,7 @@ func (b *baseBalancer) updateSubConnState(sc balancer.SubConn, state balancer.Su
}
// Close is a nop because base balancer doesn't have internal state to clean up,
// and it doesn't need to call Shutdown for the SubConns.
// and it doesn't need to call RemoveSubConn for the SubConns.
func (b *baseBalancer) Close() {
}
@ -257,6 +249,6 @@ type errPicker struct {
err error // Pick() always returns this err.
}
func (p *errPicker) Pick(balancer.PickInfo) (balancer.PickResult, error) {
func (p *errPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
return balancer.PickResult{}, p.err
}

View File

@ -19,9 +19,7 @@
package base
import (
"context"
"testing"
"time"
"google.golang.org/grpc/attributes"
"google.golang.org/grpc/balancer"
@ -40,24 +38,16 @@ func (c *testClientConn) NewSubConn(addrs []resolver.Address, opts balancer.NewS
func (c *testClientConn) UpdateState(balancer.State) {}
type testSubConn struct {
balancer.SubConn
updateState func(balancer.SubConnState)
}
type testSubConn struct{}
func (sc *testSubConn) UpdateAddresses([]resolver.Address) {}
func (sc *testSubConn) UpdateAddresses(addresses []resolver.Address) {}
func (sc *testSubConn) Connect() {}
func (sc *testSubConn) Shutdown() {}
func (sc *testSubConn) GetOrBuildProducer(balancer.ProducerBuilder) (balancer.Producer, func()) {
return nil, nil
}
// RegisterHealthListener is a no-op.
func (*testSubConn) RegisterHealthListener(func(balancer.SubConnState)) {}
// testPickBuilder creates balancer.Picker for test.
type testPickBuilder struct {
validate func(info PickerBuildInfo)
@ -69,11 +59,7 @@ func (p *testPickBuilder) Build(info PickerBuildInfo) balancer.Picker {
}
func TestBaseBalancerReserveAttributes(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
validated := make(chan struct{}, 1)
v := func(info PickerBuildInfo) {
defer func() { validated <- struct{}{} }()
var v = func(info PickerBuildInfo) {
for _, sc := range info.ReadySCs {
if sc.Address.Addr == "1.1.1.1" {
if sc.Address.Attributes == nil {
@ -92,8 +78,8 @@ func TestBaseBalancerReserveAttributes(t *testing.T) {
}
pickBuilder := &testPickBuilder{validate: v}
b := (&baseBuilder{pickerBuilder: pickBuilder}).Build(&testClientConn{
newSubConn: func(_ []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) {
return &testSubConn{updateState: opts.StateListener}, nil
newSubConn: func(addrs []resolver.Address, _ balancer.NewSubConnOptions) (balancer.SubConn, error) {
return &testSubConn{}, nil
},
}, balancer.BuildOptions{}).(*baseBalancer)
@ -105,18 +91,8 @@ func TestBaseBalancerReserveAttributes(t *testing.T) {
},
},
})
select {
case <-validated:
case <-ctx.Done():
t.Fatalf("timed out waiting for UpdateClientConnState to call picker.Build")
}
for sc := range b.scStates {
sc.(*testSubConn).updateState(balancer.SubConnState{ConnectivityState: connectivity.Ready, ConnectionError: nil})
select {
case <-validated:
case <-ctx.Done():
t.Fatalf("timed out waiting for UpdateClientConnState to call picker.Build")
}
b.UpdateSubConnState(sc, balancer.SubConnState{ConnectivityState: connectivity.Ready, ConnectionError: nil})
}
}

View File

@ -1,389 +0,0 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package endpointsharding implements a load balancing policy that manages
// homogeneous child policies each owning a single endpoint.
//
// # Experimental
//
// Notice: This package is EXPERIMENTAL and may be changed or removed in a
// later release.
package endpointsharding
import (
"errors"
rand "math/rand/v2"
"sync"
"sync/atomic"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/base"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/resolver"
)
var randIntN = rand.IntN
// ChildState is the balancer state of a child along with the endpoint which
// identifies the child balancer.
type ChildState struct {
Endpoint resolver.Endpoint
State balancer.State
// Balancer exposes only the ExitIdler interface of the child LB policy.
// Other methods of the child policy are called only by endpointsharding.
Balancer ExitIdler
}
// ExitIdler provides access to only the ExitIdle method of the child balancer.
type ExitIdler interface {
// ExitIdle instructs the LB policy to reconnect to backends / exit the
// IDLE state, if appropriate and possible. Note that SubConns that enter
// the IDLE state will not reconnect until SubConn.Connect is called.
ExitIdle()
}
// Options are the options to configure the behaviour of the
// endpointsharding balancer.
type Options struct {
// DisableAutoReconnect allows the balancer to keep child balancer in the
// IDLE state until they are explicitly triggered to exit using the
// ChildState obtained from the endpointsharding picker. When set to false,
// the endpointsharding balancer will automatically call ExitIdle on child
// connections that report IDLE.
DisableAutoReconnect bool
}
// ChildBuilderFunc creates a new balancer with the ClientConn. It has the same
// type as the balancer.Builder.Build method.
type ChildBuilderFunc func(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer
// NewBalancer returns a load balancing policy that manages homogeneous child
// policies each owning a single endpoint. The endpointsharding balancer
// forwards the LoadBalancingConfig in ClientConn state updates to its children.
func NewBalancer(cc balancer.ClientConn, opts balancer.BuildOptions, childBuilder ChildBuilderFunc, esOpts Options) balancer.Balancer {
es := &endpointSharding{
cc: cc,
bOpts: opts,
esOpts: esOpts,
childBuilder: childBuilder,
}
es.children.Store(resolver.NewEndpointMap[*balancerWrapper]())
return es
}
// endpointSharding is a balancer that wraps child balancers. It creates a child
// balancer with child config for every unique Endpoint received. It updates the
// child states on any update from parent or child.
type endpointSharding struct {
cc balancer.ClientConn
bOpts balancer.BuildOptions
esOpts Options
childBuilder ChildBuilderFunc
// childMu synchronizes calls to any single child. It must be held for all
// calls into a child. To avoid deadlocks, do not acquire childMu while
// holding mu.
childMu sync.Mutex
children atomic.Pointer[resolver.EndpointMap[*balancerWrapper]]
// inhibitChildUpdates is set during UpdateClientConnState/ResolverError
// calls (calls to children will each produce an update, only want one
// update).
inhibitChildUpdates atomic.Bool
// mu synchronizes access to the state stored in balancerWrappers in the
// children field. mu must not be held during calls into a child since
// synchronous calls back from the child may require taking mu, causing a
// deadlock. To avoid deadlocks, do not acquire childMu while holding mu.
mu sync.Mutex
}
// rotateEndpoints returns a slice of all the input endpoints rotated a random
// amount.
func rotateEndpoints(es []resolver.Endpoint) []resolver.Endpoint {
les := len(es)
if les == 0 {
return es
}
r := randIntN(les)
// Make a copy to avoid mutating data beyond the end of es.
ret := make([]resolver.Endpoint, les)
copy(ret, es[r:])
copy(ret[les-r:], es[:r])
return ret
}
// UpdateClientConnState creates a child for new endpoints and deletes children
// for endpoints that are no longer present. It also updates all the children,
// and sends a single synchronous update of the childrens' aggregated state at
// the end of the UpdateClientConnState operation. If any endpoint has no
// addresses it will ignore that endpoint. Otherwise, returns first error found
// from a child, but fully processes the new update.
func (es *endpointSharding) UpdateClientConnState(state balancer.ClientConnState) error {
es.childMu.Lock()
defer es.childMu.Unlock()
es.inhibitChildUpdates.Store(true)
defer func() {
es.inhibitChildUpdates.Store(false)
es.updateState()
}()
var ret error
children := es.children.Load()
newChildren := resolver.NewEndpointMap[*balancerWrapper]()
// Update/Create new children.
for _, endpoint := range rotateEndpoints(state.ResolverState.Endpoints) {
if _, ok := newChildren.Get(endpoint); ok {
// Endpoint child was already created, continue to avoid duplicate
// update.
continue
}
childBalancer, ok := children.Get(endpoint)
if ok {
// Endpoint attributes may have changed, update the stored endpoint.
es.mu.Lock()
childBalancer.childState.Endpoint = endpoint
es.mu.Unlock()
} else {
childBalancer = &balancerWrapper{
childState: ChildState{Endpoint: endpoint},
ClientConn: es.cc,
es: es,
}
childBalancer.childState.Balancer = childBalancer
childBalancer.child = es.childBuilder(childBalancer, es.bOpts)
}
newChildren.Set(endpoint, childBalancer)
if err := childBalancer.updateClientConnStateLocked(balancer.ClientConnState{
BalancerConfig: state.BalancerConfig,
ResolverState: resolver.State{
Endpoints: []resolver.Endpoint{endpoint},
Attributes: state.ResolverState.Attributes,
},
}); err != nil && ret == nil {
// Return first error found, and always commit full processing of
// updating children. If desired to process more specific errors
// across all endpoints, caller should make these specific
// validations, this is a current limitation for simplicity sake.
ret = err
}
}
// Delete old children that are no longer present.
for _, e := range children.Keys() {
child, _ := children.Get(e)
if _, ok := newChildren.Get(e); !ok {
child.closeLocked()
}
}
es.children.Store(newChildren)
if newChildren.Len() == 0 {
return balancer.ErrBadResolverState
}
return ret
}
// ResolverError forwards the resolver error to all of the endpointSharding's
// children and sends a single synchronous update of the childStates at the end
// of the ResolverError operation.
func (es *endpointSharding) ResolverError(err error) {
es.childMu.Lock()
defer es.childMu.Unlock()
es.inhibitChildUpdates.Store(true)
defer func() {
es.inhibitChildUpdates.Store(false)
es.updateState()
}()
children := es.children.Load()
for _, child := range children.Values() {
child.resolverErrorLocked(err)
}
}
func (es *endpointSharding) UpdateSubConnState(balancer.SubConn, balancer.SubConnState) {
// UpdateSubConnState is deprecated.
}
func (es *endpointSharding) Close() {
es.childMu.Lock()
defer es.childMu.Unlock()
children := es.children.Load()
for _, child := range children.Values() {
child.closeLocked()
}
}
func (es *endpointSharding) ExitIdle() {
es.childMu.Lock()
defer es.childMu.Unlock()
for _, bw := range es.children.Load().Values() {
if !bw.isClosed {
bw.child.ExitIdle()
}
}
}
// updateState updates this component's state. It sends the aggregated state,
// and a picker with round robin behavior with all the child states present if
// needed.
func (es *endpointSharding) updateState() {
if es.inhibitChildUpdates.Load() {
return
}
var readyPickers, connectingPickers, idlePickers, transientFailurePickers []balancer.Picker
es.mu.Lock()
defer es.mu.Unlock()
children := es.children.Load()
childStates := make([]ChildState, 0, children.Len())
for _, child := range children.Values() {
childState := child.childState
childStates = append(childStates, childState)
childPicker := childState.State.Picker
switch childState.State.ConnectivityState {
case connectivity.Ready:
readyPickers = append(readyPickers, childPicker)
case connectivity.Connecting:
connectingPickers = append(connectingPickers, childPicker)
case connectivity.Idle:
idlePickers = append(idlePickers, childPicker)
case connectivity.TransientFailure:
transientFailurePickers = append(transientFailurePickers, childPicker)
// connectivity.Shutdown shouldn't appear.
}
}
// Construct the round robin picker based off the aggregated state. Whatever
// the aggregated state, use the pickers present that are currently in that
// state only.
var aggState connectivity.State
var pickers []balancer.Picker
if len(readyPickers) >= 1 {
aggState = connectivity.Ready
pickers = readyPickers
} else if len(connectingPickers) >= 1 {
aggState = connectivity.Connecting
pickers = connectingPickers
} else if len(idlePickers) >= 1 {
aggState = connectivity.Idle
pickers = idlePickers
} else if len(transientFailurePickers) >= 1 {
aggState = connectivity.TransientFailure
pickers = transientFailurePickers
} else {
aggState = connectivity.TransientFailure
pickers = []balancer.Picker{base.NewErrPicker(errors.New("no children to pick from"))}
} // No children (resolver error before valid update).
p := &pickerWithChildStates{
pickers: pickers,
childStates: childStates,
next: uint32(randIntN(len(pickers))),
}
es.cc.UpdateState(balancer.State{
ConnectivityState: aggState,
Picker: p,
})
}
// pickerWithChildStates delegates to the pickers it holds in a round robin
// fashion. It also contains the childStates of all the endpointSharding's
// children.
type pickerWithChildStates struct {
pickers []balancer.Picker
childStates []ChildState
next uint32
}
func (p *pickerWithChildStates) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
nextIndex := atomic.AddUint32(&p.next, 1)
picker := p.pickers[nextIndex%uint32(len(p.pickers))]
return picker.Pick(info)
}
// ChildStatesFromPicker returns the state of all the children managed by the
// endpoint sharding balancer that created this picker.
func ChildStatesFromPicker(picker balancer.Picker) []ChildState {
p, ok := picker.(*pickerWithChildStates)
if !ok {
return nil
}
return p.childStates
}
// balancerWrapper is a wrapper of a balancer. It ID's a child balancer by
// endpoint, and persists recent child balancer state.
type balancerWrapper struct {
// The following fields are initialized at build time and read-only after
// that and therefore do not need to be guarded by a mutex.
// child contains the wrapped balancer. Access its methods only through
// methods on balancerWrapper to ensure proper synchronization
child balancer.Balancer
balancer.ClientConn // embed to intercept UpdateState, doesn't deal with SubConns
es *endpointSharding
// Access to the following fields is guarded by es.mu.
childState ChildState
isClosed bool
}
func (bw *balancerWrapper) UpdateState(state balancer.State) {
bw.es.mu.Lock()
bw.childState.State = state
bw.es.mu.Unlock()
if state.ConnectivityState == connectivity.Idle && !bw.es.esOpts.DisableAutoReconnect {
bw.ExitIdle()
}
bw.es.updateState()
}
// ExitIdle pings an IDLE child balancer to exit idle in a new goroutine to
// avoid deadlocks due to synchronous balancer state updates.
func (bw *balancerWrapper) ExitIdle() {
go func() {
bw.es.childMu.Lock()
if !bw.isClosed {
bw.child.ExitIdle()
}
bw.es.childMu.Unlock()
}()
}
// updateClientConnStateLocked delivers the ClientConnState to the child
// balancer. Callers must hold the child mutex of the parent endpointsharding
// balancer.
func (bw *balancerWrapper) updateClientConnStateLocked(ccs balancer.ClientConnState) error {
return bw.child.UpdateClientConnState(ccs)
}
// closeLocked closes the child balancer. Callers must hold the child mutext of
// the parent endpointsharding balancer.
func (bw *balancerWrapper) closeLocked() {
bw.child.Close()
bw.isClosed = true
}
func (bw *balancerWrapper) resolverErrorLocked(err error) {
bw.child.ResolverError(err)
}

View File

@ -1,353 +0,0 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package endpointsharding_test
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"testing"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/backoff"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/endpointsharding"
"google.golang.org/grpc/balancer/pickfirst/pickfirstleaf"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/balancer/stub"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/internal/testutils/roundrobin"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
"google.golang.org/grpc/serviceconfig"
"google.golang.org/grpc/status"
testgrpc "google.golang.org/grpc/interop/grpc_testing"
testpb "google.golang.org/grpc/interop/grpc_testing"
)
var (
defaultTestTimeout = time.Second * 10
defaultTestShortTimeout = time.Millisecond * 10
)
type s struct {
grpctest.Tester
}
func Test(t *testing.T) {
grpctest.RunSubTests(t, s{})
}
var logger = grpclog.Component("endpoint-sharding-test")
func init() {
balancer.Register(fakePetioleBuilder{})
}
const fakePetioleName = "fake_petiole"
type fakePetioleBuilder struct{}
func (fakePetioleBuilder) Name() string {
return fakePetioleName
}
func (fakePetioleBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
fp := &fakePetiole{
ClientConn: cc,
bOpts: opts,
}
fp.Balancer = endpointsharding.NewBalancer(fp, opts, balancer.Get(pickfirstleaf.Name).Build, endpointsharding.Options{})
return fp
}
func (fakePetioleBuilder) ParseConfig(json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
return nil, nil
}
// fakePetiole is a load balancer that wraps the endpointShardingBalancer, and
// forwards ClientConnUpdates with a child config of graceful switch that wraps
// pick first. It also intercepts UpdateState to make sure it can access the
// child state maintained by EndpointSharding.
type fakePetiole struct {
balancer.Balancer
balancer.ClientConn
bOpts balancer.BuildOptions
}
func (fp *fakePetiole) UpdateClientConnState(state balancer.ClientConnState) error {
if el := state.ResolverState.Endpoints; len(el) != 2 {
return fmt.Errorf("UpdateClientConnState wants two endpoints, got: %v", el)
}
return fp.Balancer.UpdateClientConnState(state)
}
func (fp *fakePetiole) UpdateState(state balancer.State) {
childStates := endpointsharding.ChildStatesFromPicker(state.Picker)
// Both child states should be present in the child picker. States and
// picker change over the lifecycle of test, but there should always be two.
if len(childStates) != 2 {
logger.Fatal(fmt.Errorf("length of child states received: %v, want 2", len(childStates)))
}
fp.ClientConn.UpdateState(state)
}
// TestEndpointShardingBasic tests the basic functionality of the endpoint
// sharding balancer. It specifies a petiole policy that is essentially a
// wrapper around the endpoint sharder. Two backends are started, with each
// backend's address specified in an endpoint. The petiole does not have a
// special picker, so it should fallback to the default behavior, which is to
// round_robin amongst the endpoint children that are in the aggregated state.
// It also verifies the petiole has access to the raw child state in case it
// wants to implement a custom picker. The test sends a resolver error to the
// endpointsharding balancer and verifies an error picker from the children
// is used while making an RPC.
func (s) TestEndpointShardingBasic(t *testing.T) {
backend1 := stubserver.StartTestService(t, nil)
defer backend1.Stop()
backend2 := stubserver.StartTestService(t, nil)
defer backend2.Stop()
mr := manual.NewBuilderWithScheme("e2e-test")
defer mr.Close()
json := fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, fakePetioleName)
sc := internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(json)
mr.InitialState(resolver.State{
Endpoints: []resolver.Endpoint{
{Addresses: []resolver.Address{{Addr: backend1.Address}}},
{Addresses: []resolver.Address{{Addr: backend2.Address}}},
},
ServiceConfig: sc,
})
dOpts := []grpc.DialOption{
grpc.WithResolvers(mr), grpc.WithTransportCredentials(insecure.NewCredentials()),
// Use a large backoff delay to avoid the error picker being updated
// too quickly.
grpc.WithConnectParams(grpc.ConnectParams{
Backoff: backoff.Config{
BaseDelay: 2 * defaultTestTimeout,
Multiplier: float64(0),
Jitter: float64(0),
MaxDelay: 2 * defaultTestTimeout,
},
}),
}
cc, err := grpc.NewClient(mr.Scheme()+":///", dOpts...)
if err != nil {
t.Fatalf("Failed to create new client: %v", err)
}
defer cc.Close()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
client := testgrpc.NewTestServiceClient(cc)
// Assert a round robin distribution between the two spun up backends. This
// requires a poll and eventual consistency as both endpoint children do not
// start in state READY.
if err = roundrobin.CheckRoundRobinRPCs(ctx, client, []resolver.Address{{Addr: backend1.Address}, {Addr: backend2.Address}}); err != nil {
t.Fatalf("error in expected round robin: %v", err)
}
// Stopping both the backends should make the channel enter
// TransientFailure.
backend1.Stop()
backend2.Stop()
testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure)
// When the resolver reports an error, the picker should get updated to
// return the resolver error.
mr.CC().ReportError(errors.New("test error"))
testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure)
for ; ctx.Err() == nil; <-time.After(time.Millisecond) {
_, err := client.EmptyCall(ctx, &testpb.Empty{})
if err == nil {
t.Fatalf("EmptyCall succeeded when expected to fail with %q", "test error")
}
if strings.Contains(err.Error(), "test error") {
break
}
}
if ctx.Err() != nil {
t.Fatalf("Context timed out waiting for picker with resolver error.")
}
}
// Tests that endpointsharding doesn't automatically re-connect IDLE children.
// The test creates an endpoint with two servers and another with a single
// server. The active service in endpoint 1 is closed to make the child
// pickfirst enter IDLE state. The test verifies that the child pickfirst
// doesn't connect to the second address in the endpoint.
func (s) TestEndpointShardingReconnectDisabled(t *testing.T) {
backend1 := stubserver.StartTestService(t, nil)
defer backend1.Stop()
backend2 := stubserver.StartTestService(t, nil)
defer backend2.Stop()
backend3 := stubserver.StartTestService(t, nil)
defer backend3.Stop()
mr := manual.NewBuilderWithScheme("e2e-test")
defer mr.Close()
name := strings.ReplaceAll(strings.ToLower(t.Name()), "/", "")
bf := stub.BalancerFuncs{
Init: func(bd *stub.BalancerData) {
epOpts := endpointsharding.Options{DisableAutoReconnect: true}
bd.ChildBalancer = endpointsharding.NewBalancer(bd.ClientConn, bd.BuildOptions, balancer.Get(pickfirstleaf.Name).Build, epOpts)
},
UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
return bd.ChildBalancer.UpdateClientConnState(ccs)
},
Close: func(bd *stub.BalancerData) {
bd.ChildBalancer.Close()
},
}
stub.Register(name, bf)
json := fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, name)
sc := internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(json)
mr.InitialState(resolver.State{
Endpoints: []resolver.Endpoint{
{Addresses: []resolver.Address{{Addr: backend1.Address}, {Addr: backend2.Address}}},
{Addresses: []resolver.Address{{Addr: backend3.Address}}},
},
ServiceConfig: sc,
})
cc, err := grpc.NewClient(mr.Scheme()+":///", grpc.WithResolvers(mr), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("Failed to create new client: %v", err)
}
defer cc.Close()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
client := testgrpc.NewTestServiceClient(cc)
// Assert a round robin distribution between the two spun up backends. This
// requires a poll and eventual consistency as both endpoint children do not
// start in state READY.
if err = roundrobin.CheckRoundRobinRPCs(ctx, client, []resolver.Address{{Addr: backend1.Address}, {Addr: backend3.Address}}); err != nil {
t.Fatalf("error in expected round robin: %v", err)
}
// On closing the first server, the first child balancer should enter
// IDLE. Since endpointsharding is configured not to auto-reconnect, it will
// remain IDLE and will not try to connect to the second backend in the same
// endpoint.
backend1.Stop()
// CheckRoundRobinRPCs waits for all the backends to become reachable, we
// call it to ensure the picker no longer sends RPCs to closed backend.
if err = roundrobin.CheckRoundRobinRPCs(ctx, client, []resolver.Address{{Addr: backend3.Address}}); err != nil {
t.Fatalf("error in expected round robin: %v", err)
}
// Verify requests go only to backend3 for a short time.
shortCtx, cancel := context.WithTimeout(ctx, defaultTestShortTimeout)
defer cancel()
for ; shortCtx.Err() == nil; <-time.After(time.Millisecond) {
var peer peer.Peer
if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.Peer(&peer)); err != nil {
if status.Code(err) != codes.DeadlineExceeded {
t.Fatalf("EmptyCall() returned unexpected error %v", err)
}
break
}
if got, want := peer.Addr.String(), backend3.Address; got != want {
t.Fatalf("EmptyCall() went to unexpected backend: got %q, want %q", got, want)
}
}
}
// Tests that endpointsharding doesn't automatically re-connect IDLE children
// until cc.Connect() is called. The test creates an endpoint with a single
// address. The client is connected and the active server is closed to make the
// child pickfirst enter IDLE state. The test verifies that the child pickfirst
// doesn't re-connect automatically. The test calls cc.Connect() and verified
// that the balancer connects causing the channel to enter TransientFailure.
func (s) TestEndpointShardingExitIdle(t *testing.T) {
backend := stubserver.StartTestService(t, nil)
defer backend.Stop()
mr := manual.NewBuilderWithScheme("e2e-test")
defer mr.Close()
name := strings.ReplaceAll(strings.ToLower(t.Name()), "/", "")
bf := stub.BalancerFuncs{
Init: func(bd *stub.BalancerData) {
epOpts := endpointsharding.Options{DisableAutoReconnect: true}
bd.ChildBalancer = endpointsharding.NewBalancer(bd.ClientConn, bd.BuildOptions, balancer.Get(pickfirstleaf.Name).Build, epOpts)
},
UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
return bd.ChildBalancer.UpdateClientConnState(ccs)
},
Close: func(bd *stub.BalancerData) {
bd.ChildBalancer.Close()
},
ExitIdle: func(bd *stub.BalancerData) {
bd.ChildBalancer.ExitIdle()
},
}
stub.Register(name, bf)
json := fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, name)
sc := internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(json)
mr.InitialState(resolver.State{
Endpoints: []resolver.Endpoint{
{Addresses: []resolver.Address{{Addr: backend.Address}}},
},
ServiceConfig: sc,
})
cc, err := grpc.NewClient(mr.Scheme()+":///", grpc.WithResolvers(mr), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("Failed to create new client: %v", err)
}
defer cc.Close()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
client := testgrpc.NewTestServiceClient(cc)
if _, err := client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
t.Errorf("client.EmptyCall() returned unexpected error: %v", err)
}
// On closing the first server, the first child balancer should enter
// IDLE. Since endpointsharding is configured not to auto-reconnect, it will
// remain IDLE and will not try to re-connect
backend.Stop()
testutils.AwaitState(ctx, t, cc, connectivity.Idle)
shortCtx, shortCancel := context.WithTimeout(ctx, defaultTestShortTimeout)
defer shortCancel()
testutils.AwaitNoStateChange(shortCtx, t, cc, connectivity.Idle)
// The balancer should try to re-connect and fail.
cc.Connect()
testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure)
}

View File

@ -1,83 +0,0 @@
/*
*
* Copyright 2025 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package endpointsharding
import (
"fmt"
"testing"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/resolver"
)
type s struct {
grpctest.Tester
}
func Test(t *testing.T) {
grpctest.RunSubTests(t, s{})
}
func (s) TestRotateEndpoints(t *testing.T) {
ep := func(addr string) resolver.Endpoint {
return resolver.Endpoint{Addresses: []resolver.Address{{Addr: addr}}}
}
endpoints := []resolver.Endpoint{ep("1"), ep("2"), ep("3"), ep("4"), ep("5")}
testCases := []struct {
rval int
want []resolver.Endpoint
}{
{
rval: 0,
want: []resolver.Endpoint{ep("1"), ep("2"), ep("3"), ep("4"), ep("5")},
},
{
rval: 1,
want: []resolver.Endpoint{ep("2"), ep("3"), ep("4"), ep("5"), ep("1")},
},
{
rval: 2,
want: []resolver.Endpoint{ep("3"), ep("4"), ep("5"), ep("1"), ep("2")},
},
{
rval: 3,
want: []resolver.Endpoint{ep("4"), ep("5"), ep("1"), ep("2"), ep("3")},
},
{
rval: 4,
want: []resolver.Endpoint{ep("5"), ep("1"), ep("2"), ep("3"), ep("4")},
},
}
defer func(r func(int) int) {
randIntN = r
}(randIntN)
for _, tc := range testCases {
t.Run(fmt.Sprint(tc.rval), func(t *testing.T) {
randIntN = func(int) int {
return tc.rval
}
got := rotateEndpoints(endpoints)
if fmt.Sprint(got) != fmt.Sprint(tc.want) {
t.Fatalf("rand=%v; rotateEndpoints(%v) = %v; want %v", tc.rval, endpoints, got, tc.want)
}
})
}
}

View File

@ -19,8 +19,8 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.6
// protoc v5.27.1
// protoc-gen-go v1.30.0
// protoc v4.22.0
// source: grpc/lb/v1/load_balancer.proto
package grpc_lb_v1
@ -32,7 +32,6 @@ import (
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
reflect "reflect"
sync "sync"
unsafe "unsafe"
)
const (
@ -43,21 +42,24 @@ const (
)
type LoadBalanceRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
// Types that are valid to be assigned to LoadBalanceRequestType:
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
// Types that are assignable to LoadBalanceRequestType:
//
// *LoadBalanceRequest_InitialRequest
// *LoadBalanceRequest_ClientStats
LoadBalanceRequestType isLoadBalanceRequest_LoadBalanceRequestType `protobuf_oneof:"load_balance_request_type"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *LoadBalanceRequest) Reset() {
*x = LoadBalanceRequest{}
mi := &file_grpc_lb_v1_load_balancer_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
if protoimpl.UnsafeEnabled {
mi := &file_grpc_lb_v1_load_balancer_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *LoadBalanceRequest) String() string {
@ -68,7 +70,7 @@ func (*LoadBalanceRequest) ProtoMessage() {}
func (x *LoadBalanceRequest) ProtoReflect() protoreflect.Message {
mi := &file_grpc_lb_v1_load_balancer_proto_msgTypes[0]
if x != nil {
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@ -83,27 +85,23 @@ func (*LoadBalanceRequest) Descriptor() ([]byte, []int) {
return file_grpc_lb_v1_load_balancer_proto_rawDescGZIP(), []int{0}
}
func (x *LoadBalanceRequest) GetLoadBalanceRequestType() isLoadBalanceRequest_LoadBalanceRequestType {
if x != nil {
return x.LoadBalanceRequestType
func (m *LoadBalanceRequest) GetLoadBalanceRequestType() isLoadBalanceRequest_LoadBalanceRequestType {
if m != nil {
return m.LoadBalanceRequestType
}
return nil
}
func (x *LoadBalanceRequest) GetInitialRequest() *InitialLoadBalanceRequest {
if x != nil {
if x, ok := x.LoadBalanceRequestType.(*LoadBalanceRequest_InitialRequest); ok {
return x.InitialRequest
}
if x, ok := x.GetLoadBalanceRequestType().(*LoadBalanceRequest_InitialRequest); ok {
return x.InitialRequest
}
return nil
}
func (x *LoadBalanceRequest) GetClientStats() *ClientStats {
if x != nil {
if x, ok := x.LoadBalanceRequestType.(*LoadBalanceRequest_ClientStats); ok {
return x.ClientStats
}
if x, ok := x.GetLoadBalanceRequestType().(*LoadBalanceRequest_ClientStats); ok {
return x.ClientStats
}
return nil
}
@ -128,21 +126,24 @@ func (*LoadBalanceRequest_InitialRequest) isLoadBalanceRequest_LoadBalanceReques
func (*LoadBalanceRequest_ClientStats) isLoadBalanceRequest_LoadBalanceRequestType() {}
type InitialLoadBalanceRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
// The name of the load balanced service (e.g., service.googleapis.com). Its
// length should be less than 256 bytes.
// The name might include a port number. How to handle the port number is up
// to the balancer.
Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
}
func (x *InitialLoadBalanceRequest) Reset() {
*x = InitialLoadBalanceRequest{}
mi := &file_grpc_lb_v1_load_balancer_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
if protoimpl.UnsafeEnabled {
mi := &file_grpc_lb_v1_load_balancer_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *InitialLoadBalanceRequest) String() string {
@ -153,7 +154,7 @@ func (*InitialLoadBalanceRequest) ProtoMessage() {}
func (x *InitialLoadBalanceRequest) ProtoReflect() protoreflect.Message {
mi := &file_grpc_lb_v1_load_balancer_proto_msgTypes[1]
if x != nil {
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@ -177,20 +178,23 @@ func (x *InitialLoadBalanceRequest) GetName() string {
// Contains the number of calls finished for a particular load balance token.
type ClientStatsPerToken struct {
state protoimpl.MessageState `protogen:"open.v1"`
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
// See Server.load_balance_token.
LoadBalanceToken string `protobuf:"bytes,1,opt,name=load_balance_token,json=loadBalanceToken,proto3" json:"load_balance_token,omitempty"`
// The total number of RPCs that finished associated with the token.
NumCalls int64 `protobuf:"varint,2,opt,name=num_calls,json=numCalls,proto3" json:"num_calls,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
NumCalls int64 `protobuf:"varint,2,opt,name=num_calls,json=numCalls,proto3" json:"num_calls,omitempty"`
}
func (x *ClientStatsPerToken) Reset() {
*x = ClientStatsPerToken{}
mi := &file_grpc_lb_v1_load_balancer_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
if protoimpl.UnsafeEnabled {
mi := &file_grpc_lb_v1_load_balancer_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *ClientStatsPerToken) String() string {
@ -201,7 +205,7 @@ func (*ClientStatsPerToken) ProtoMessage() {}
func (x *ClientStatsPerToken) ProtoReflect() protoreflect.Message {
mi := &file_grpc_lb_v1_load_balancer_proto_msgTypes[2]
if x != nil {
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@ -233,7 +237,10 @@ func (x *ClientStatsPerToken) GetNumCalls() int64 {
// Contains client level statistics that are useful to load balancing. Each
// count except the timestamp should be reset to zero after reporting the stats.
type ClientStats struct {
state protoimpl.MessageState `protogen:"open.v1"`
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
// The timestamp of generating the report.
Timestamp *timestamppb.Timestamp `protobuf:"bytes,1,opt,name=timestamp,proto3" json:"timestamp,omitempty"`
// The total number of RPCs that started.
@ -247,15 +254,15 @@ type ClientStats struct {
NumCallsFinishedKnownReceived int64 `protobuf:"varint,7,opt,name=num_calls_finished_known_received,json=numCallsFinishedKnownReceived,proto3" json:"num_calls_finished_known_received,omitempty"`
// The list of dropped calls.
CallsFinishedWithDrop []*ClientStatsPerToken `protobuf:"bytes,8,rep,name=calls_finished_with_drop,json=callsFinishedWithDrop,proto3" json:"calls_finished_with_drop,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *ClientStats) Reset() {
*x = ClientStats{}
mi := &file_grpc_lb_v1_load_balancer_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
if protoimpl.UnsafeEnabled {
mi := &file_grpc_lb_v1_load_balancer_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *ClientStats) String() string {
@ -266,7 +273,7 @@ func (*ClientStats) ProtoMessage() {}
func (x *ClientStats) ProtoReflect() protoreflect.Message {
mi := &file_grpc_lb_v1_load_balancer_proto_msgTypes[3]
if x != nil {
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@ -324,22 +331,25 @@ func (x *ClientStats) GetCallsFinishedWithDrop() []*ClientStatsPerToken {
}
type LoadBalanceResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
// Types that are valid to be assigned to LoadBalanceResponseType:
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
// Types that are assignable to LoadBalanceResponseType:
//
// *LoadBalanceResponse_InitialResponse
// *LoadBalanceResponse_ServerList
// *LoadBalanceResponse_FallbackResponse
LoadBalanceResponseType isLoadBalanceResponse_LoadBalanceResponseType `protobuf_oneof:"load_balance_response_type"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *LoadBalanceResponse) Reset() {
*x = LoadBalanceResponse{}
mi := &file_grpc_lb_v1_load_balancer_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
if protoimpl.UnsafeEnabled {
mi := &file_grpc_lb_v1_load_balancer_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *LoadBalanceResponse) String() string {
@ -350,7 +360,7 @@ func (*LoadBalanceResponse) ProtoMessage() {}
func (x *LoadBalanceResponse) ProtoReflect() protoreflect.Message {
mi := &file_grpc_lb_v1_load_balancer_proto_msgTypes[4]
if x != nil {
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@ -365,36 +375,30 @@ func (*LoadBalanceResponse) Descriptor() ([]byte, []int) {
return file_grpc_lb_v1_load_balancer_proto_rawDescGZIP(), []int{4}
}
func (x *LoadBalanceResponse) GetLoadBalanceResponseType() isLoadBalanceResponse_LoadBalanceResponseType {
if x != nil {
return x.LoadBalanceResponseType
func (m *LoadBalanceResponse) GetLoadBalanceResponseType() isLoadBalanceResponse_LoadBalanceResponseType {
if m != nil {
return m.LoadBalanceResponseType
}
return nil
}
func (x *LoadBalanceResponse) GetInitialResponse() *InitialLoadBalanceResponse {
if x != nil {
if x, ok := x.LoadBalanceResponseType.(*LoadBalanceResponse_InitialResponse); ok {
return x.InitialResponse
}
if x, ok := x.GetLoadBalanceResponseType().(*LoadBalanceResponse_InitialResponse); ok {
return x.InitialResponse
}
return nil
}
func (x *LoadBalanceResponse) GetServerList() *ServerList {
if x != nil {
if x, ok := x.LoadBalanceResponseType.(*LoadBalanceResponse_ServerList); ok {
return x.ServerList
}
if x, ok := x.GetLoadBalanceResponseType().(*LoadBalanceResponse_ServerList); ok {
return x.ServerList
}
return nil
}
func (x *LoadBalanceResponse) GetFallbackResponse() *FallbackResponse {
if x != nil {
if x, ok := x.LoadBalanceResponseType.(*LoadBalanceResponse_FallbackResponse); ok {
return x.FallbackResponse
}
if x, ok := x.GetLoadBalanceResponseType().(*LoadBalanceResponse_FallbackResponse); ok {
return x.FallbackResponse
}
return nil
}
@ -427,16 +431,18 @@ func (*LoadBalanceResponse_ServerList) isLoadBalanceResponse_LoadBalanceResponse
func (*LoadBalanceResponse_FallbackResponse) isLoadBalanceResponse_LoadBalanceResponseType() {}
type FallbackResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
}
func (x *FallbackResponse) Reset() {
*x = FallbackResponse{}
mi := &file_grpc_lb_v1_load_balancer_proto_msgTypes[5]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
if protoimpl.UnsafeEnabled {
mi := &file_grpc_lb_v1_load_balancer_proto_msgTypes[5]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *FallbackResponse) String() string {
@ -447,7 +453,7 @@ func (*FallbackResponse) ProtoMessage() {}
func (x *FallbackResponse) ProtoReflect() protoreflect.Message {
mi := &file_grpc_lb_v1_load_balancer_proto_msgTypes[5]
if x != nil {
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@ -463,20 +469,23 @@ func (*FallbackResponse) Descriptor() ([]byte, []int) {
}
type InitialLoadBalanceResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
// This interval defines how often the client should send the client stats
// to the load balancer. Stats should only be reported when the duration is
// positive.
ClientStatsReportInterval *durationpb.Duration `protobuf:"bytes,2,opt,name=client_stats_report_interval,json=clientStatsReportInterval,proto3" json:"client_stats_report_interval,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *InitialLoadBalanceResponse) Reset() {
*x = InitialLoadBalanceResponse{}
mi := &file_grpc_lb_v1_load_balancer_proto_msgTypes[6]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
if protoimpl.UnsafeEnabled {
mi := &file_grpc_lb_v1_load_balancer_proto_msgTypes[6]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *InitialLoadBalanceResponse) String() string {
@ -487,7 +496,7 @@ func (*InitialLoadBalanceResponse) ProtoMessage() {}
func (x *InitialLoadBalanceResponse) ProtoReflect() protoreflect.Message {
mi := &file_grpc_lb_v1_load_balancer_proto_msgTypes[6]
if x != nil {
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@ -510,21 +519,24 @@ func (x *InitialLoadBalanceResponse) GetClientStatsReportInterval() *durationpb.
}
type ServerList struct {
state protoimpl.MessageState `protogen:"open.v1"`
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
// Contains a list of servers selected by the load balancer. The list will
// be updated when server resolutions change or as needed to balance load
// across more servers. The client should consume the server list in order
// unless instructed otherwise via the client_config.
Servers []*Server `protobuf:"bytes,1,rep,name=servers,proto3" json:"servers,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
Servers []*Server `protobuf:"bytes,1,rep,name=servers,proto3" json:"servers,omitempty"`
}
func (x *ServerList) Reset() {
*x = ServerList{}
mi := &file_grpc_lb_v1_load_balancer_proto_msgTypes[7]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
if protoimpl.UnsafeEnabled {
mi := &file_grpc_lb_v1_load_balancer_proto_msgTypes[7]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *ServerList) String() string {
@ -535,7 +547,7 @@ func (*ServerList) ProtoMessage() {}
func (x *ServerList) ProtoReflect() protoreflect.Message {
mi := &file_grpc_lb_v1_load_balancer_proto_msgTypes[7]
if x != nil {
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@ -560,7 +572,10 @@ func (x *ServerList) GetServers() []*Server {
// Contains server information. When the drop field is not true, use the other
// fields.
type Server struct {
state protoimpl.MessageState `protogen:"open.v1"`
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
// A resolved address for the server, serialized in network-byte-order. It may
// either be an IPv4 or IPv6 address.
IpAddress []byte `protobuf:"bytes,1,opt,name=ip_address,json=ipAddress,proto3" json:"ip_address,omitempty"`
@ -577,16 +592,16 @@ type Server struct {
// Indicates whether this particular request should be dropped by the client.
// If the request is dropped, there will be a corresponding entry in
// ClientStats.calls_finished_with_drop.
Drop bool `protobuf:"varint,4,opt,name=drop,proto3" json:"drop,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
Drop bool `protobuf:"varint,4,opt,name=drop,proto3" json:"drop,omitempty"`
}
func (x *Server) Reset() {
*x = Server{}
mi := &file_grpc_lb_v1_load_balancer_proto_msgTypes[8]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
if protoimpl.UnsafeEnabled {
mi := &file_grpc_lb_v1_load_balancer_proto_msgTypes[8]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *Server) String() string {
@ -597,7 +612,7 @@ func (*Server) ProtoMessage() {}
func (x *Server) ProtoReflect() protoreflect.Message {
mi := &file_grpc_lb_v1_load_balancer_proto_msgTypes[8]
if x != nil {
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@ -642,62 +657,130 @@ func (x *Server) GetDrop() bool {
var File_grpc_lb_v1_load_balancer_proto protoreflect.FileDescriptor
const file_grpc_lb_v1_load_balancer_proto_rawDesc = "" +
"\n" +
"\x1egrpc/lb/v1/load_balancer.proto\x12\n" +
"grpc.lb.v1\x1a\x1egoogle/protobuf/duration.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"\xc1\x01\n" +
"\x12LoadBalanceRequest\x12P\n" +
"\x0finitial_request\x18\x01 \x01(\v2%.grpc.lb.v1.InitialLoadBalanceRequestH\x00R\x0einitialRequest\x12<\n" +
"\fclient_stats\x18\x02 \x01(\v2\x17.grpc.lb.v1.ClientStatsH\x00R\vclientStatsB\x1b\n" +
"\x19load_balance_request_type\"/\n" +
"\x19InitialLoadBalanceRequest\x12\x12\n" +
"\x04name\x18\x01 \x01(\tR\x04name\"`\n" +
"\x13ClientStatsPerToken\x12,\n" +
"\x12load_balance_token\x18\x01 \x01(\tR\x10loadBalanceToken\x12\x1b\n" +
"\tnum_calls\x18\x02 \x01(\x03R\bnumCalls\"\xb0\x03\n" +
"\vClientStats\x128\n" +
"\ttimestamp\x18\x01 \x01(\v2\x1a.google.protobuf.TimestampR\ttimestamp\x12*\n" +
"\x11num_calls_started\x18\x02 \x01(\x03R\x0fnumCallsStarted\x12,\n" +
"\x12num_calls_finished\x18\x03 \x01(\x03R\x10numCallsFinished\x12]\n" +
"-num_calls_finished_with_client_failed_to_send\x18\x06 \x01(\x03R&numCallsFinishedWithClientFailedToSend\x12H\n" +
"!num_calls_finished_known_received\x18\a \x01(\x03R\x1dnumCallsFinishedKnownReceived\x12X\n" +
"\x18calls_finished_with_drop\x18\b \x03(\v2\x1f.grpc.lb.v1.ClientStatsPerTokenR\x15callsFinishedWithDropJ\x04\b\x04\x10\x05J\x04\b\x05\x10\x06\"\x90\x02\n" +
"\x13LoadBalanceResponse\x12S\n" +
"\x10initial_response\x18\x01 \x01(\v2&.grpc.lb.v1.InitialLoadBalanceResponseH\x00R\x0finitialResponse\x129\n" +
"\vserver_list\x18\x02 \x01(\v2\x16.grpc.lb.v1.ServerListH\x00R\n" +
"serverList\x12K\n" +
"\x11fallback_response\x18\x03 \x01(\v2\x1c.grpc.lb.v1.FallbackResponseH\x00R\x10fallbackResponseB\x1c\n" +
"\x1aload_balance_response_type\"\x12\n" +
"\x10FallbackResponse\"~\n" +
"\x1aInitialLoadBalanceResponse\x12Z\n" +
"\x1cclient_stats_report_interval\x18\x02 \x01(\v2\x19.google.protobuf.DurationR\x19clientStatsReportIntervalJ\x04\b\x01\x10\x02\"@\n" +
"\n" +
"ServerList\x12,\n" +
"\aservers\x18\x01 \x03(\v2\x12.grpc.lb.v1.ServerR\aserversJ\x04\b\x03\x10\x04\"\x83\x01\n" +
"\x06Server\x12\x1d\n" +
"\n" +
"ip_address\x18\x01 \x01(\fR\tipAddress\x12\x12\n" +
"\x04port\x18\x02 \x01(\x05R\x04port\x12,\n" +
"\x12load_balance_token\x18\x03 \x01(\tR\x10loadBalanceToken\x12\x12\n" +
"\x04drop\x18\x04 \x01(\bR\x04dropJ\x04\b\x05\x10\x062b\n" +
"\fLoadBalancer\x12R\n" +
"\vBalanceLoad\x12\x1e.grpc.lb.v1.LoadBalanceRequest\x1a\x1f.grpc.lb.v1.LoadBalanceResponse(\x010\x01BW\n" +
"\rio.grpc.lb.v1B\x11LoadBalancerProtoP\x01Z1google.golang.org/grpc/balancer/grpclb/grpc_lb_v1b\x06proto3"
var file_grpc_lb_v1_load_balancer_proto_rawDesc = []byte{
0x0a, 0x1e, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x6c, 0x62, 0x2f, 0x76, 0x31, 0x2f, 0x6c, 0x6f, 0x61,
0x64, 0x5f, 0x62, 0x61, 0x6c, 0x61, 0x6e, 0x63, 0x65, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x12, 0x0a, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x62, 0x2e, 0x76, 0x31, 0x1a, 0x1e, 0x67, 0x6f,
0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x64, 0x75,
0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1f, 0x67, 0x6f,
0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69,
0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xc1, 0x01,
0x0a, 0x12, 0x4c, 0x6f, 0x61, 0x64, 0x42, 0x61, 0x6c, 0x61, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x12, 0x50, 0x0a, 0x0f, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x5f,
0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x25, 0x2e,
0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x62, 0x2e, 0x76, 0x31, 0x2e, 0x49, 0x6e, 0x69, 0x74, 0x69,
0x61, 0x6c, 0x4c, 0x6f, 0x61, 0x64, 0x42, 0x61, 0x6c, 0x61, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x48, 0x00, 0x52, 0x0e, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x52,
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x3c, 0x0a, 0x0c, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74,
0x5f, 0x73, 0x74, 0x61, 0x74, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x67,
0x72, 0x70, 0x63, 0x2e, 0x6c, 0x62, 0x2e, 0x76, 0x31, 0x2e, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74,
0x53, 0x74, 0x61, 0x74, 0x73, 0x48, 0x00, 0x52, 0x0b, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53,
0x74, 0x61, 0x74, 0x73, 0x42, 0x1b, 0x0a, 0x19, 0x6c, 0x6f, 0x61, 0x64, 0x5f, 0x62, 0x61, 0x6c,
0x61, 0x6e, 0x63, 0x65, 0x5f, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x74, 0x79, 0x70,
0x65, 0x22, 0x2f, 0x0a, 0x19, 0x49, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x4c, 0x6f, 0x61, 0x64,
0x42, 0x61, 0x6c, 0x61, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12,
0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61,
0x6d, 0x65, 0x22, 0x60, 0x0a, 0x13, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74,
0x73, 0x50, 0x65, 0x72, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x2c, 0x0a, 0x12, 0x6c, 0x6f, 0x61,
0x64, 0x5f, 0x62, 0x61, 0x6c, 0x61, 0x6e, 0x63, 0x65, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18,
0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x10, 0x6c, 0x6f, 0x61, 0x64, 0x42, 0x61, 0x6c, 0x61, 0x6e,
0x63, 0x65, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x1b, 0x0a, 0x09, 0x6e, 0x75, 0x6d, 0x5f, 0x63,
0x61, 0x6c, 0x6c, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x08, 0x6e, 0x75, 0x6d, 0x43,
0x61, 0x6c, 0x6c, 0x73, 0x22, 0xb0, 0x03, 0x0a, 0x0b, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53,
0x74, 0x61, 0x74, 0x73, 0x12, 0x38, 0x0a, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d,
0x70, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65,
0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74,
0x61, 0x6d, 0x70, 0x52, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x12, 0x2a,
0x0a, 0x11, 0x6e, 0x75, 0x6d, 0x5f, 0x63, 0x61, 0x6c, 0x6c, 0x73, 0x5f, 0x73, 0x74, 0x61, 0x72,
0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0f, 0x6e, 0x75, 0x6d, 0x43, 0x61,
0x6c, 0x6c, 0x73, 0x53, 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x12, 0x2c, 0x0a, 0x12, 0x6e, 0x75,
0x6d, 0x5f, 0x63, 0x61, 0x6c, 0x6c, 0x73, 0x5f, 0x66, 0x69, 0x6e, 0x69, 0x73, 0x68, 0x65, 0x64,
0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x10, 0x6e, 0x75, 0x6d, 0x43, 0x61, 0x6c, 0x6c, 0x73,
0x46, 0x69, 0x6e, 0x69, 0x73, 0x68, 0x65, 0x64, 0x12, 0x5d, 0x0a, 0x2d, 0x6e, 0x75, 0x6d, 0x5f,
0x63, 0x61, 0x6c, 0x6c, 0x73, 0x5f, 0x66, 0x69, 0x6e, 0x69, 0x73, 0x68, 0x65, 0x64, 0x5f, 0x77,
0x69, 0x74, 0x68, 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x66, 0x61, 0x69, 0x6c, 0x65,
0x64, 0x5f, 0x74, 0x6f, 0x5f, 0x73, 0x65, 0x6e, 0x64, 0x18, 0x06, 0x20, 0x01, 0x28, 0x03, 0x52,
0x26, 0x6e, 0x75, 0x6d, 0x43, 0x61, 0x6c, 0x6c, 0x73, 0x46, 0x69, 0x6e, 0x69, 0x73, 0x68, 0x65,
0x64, 0x57, 0x69, 0x74, 0x68, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x46, 0x61, 0x69, 0x6c, 0x65,
0x64, 0x54, 0x6f, 0x53, 0x65, 0x6e, 0x64, 0x12, 0x48, 0x0a, 0x21, 0x6e, 0x75, 0x6d, 0x5f, 0x63,
0x61, 0x6c, 0x6c, 0x73, 0x5f, 0x66, 0x69, 0x6e, 0x69, 0x73, 0x68, 0x65, 0x64, 0x5f, 0x6b, 0x6e,
0x6f, 0x77, 0x6e, 0x5f, 0x72, 0x65, 0x63, 0x65, 0x69, 0x76, 0x65, 0x64, 0x18, 0x07, 0x20, 0x01,
0x28, 0x03, 0x52, 0x1d, 0x6e, 0x75, 0x6d, 0x43, 0x61, 0x6c, 0x6c, 0x73, 0x46, 0x69, 0x6e, 0x69,
0x73, 0x68, 0x65, 0x64, 0x4b, 0x6e, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x63, 0x65, 0x69, 0x76, 0x65,
0x64, 0x12, 0x58, 0x0a, 0x18, 0x63, 0x61, 0x6c, 0x6c, 0x73, 0x5f, 0x66, 0x69, 0x6e, 0x69, 0x73,
0x68, 0x65, 0x64, 0x5f, 0x77, 0x69, 0x74, 0x68, 0x5f, 0x64, 0x72, 0x6f, 0x70, 0x18, 0x08, 0x20,
0x03, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x62, 0x2e, 0x76, 0x31,
0x2e, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x73, 0x50, 0x65, 0x72, 0x54,
0x6f, 0x6b, 0x65, 0x6e, 0x52, 0x15, 0x63, 0x61, 0x6c, 0x6c, 0x73, 0x46, 0x69, 0x6e, 0x69, 0x73,
0x68, 0x65, 0x64, 0x57, 0x69, 0x74, 0x68, 0x44, 0x72, 0x6f, 0x70, 0x4a, 0x04, 0x08, 0x04, 0x10,
0x05, 0x4a, 0x04, 0x08, 0x05, 0x10, 0x06, 0x22, 0x90, 0x02, 0x0a, 0x13, 0x4c, 0x6f, 0x61, 0x64,
0x42, 0x61, 0x6c, 0x61, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12,
0x53, 0x0a, 0x10, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x5f, 0x72, 0x65, 0x73, 0x70, 0x6f,
0x6e, 0x73, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x26, 0x2e, 0x67, 0x72, 0x70, 0x63,
0x2e, 0x6c, 0x62, 0x2e, 0x76, 0x31, 0x2e, 0x49, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x4c, 0x6f,
0x61, 0x64, 0x42, 0x61, 0x6c, 0x61, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73,
0x65, 0x48, 0x00, 0x52, 0x0f, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x73, 0x70,
0x6f, 0x6e, 0x73, 0x65, 0x12, 0x39, 0x0a, 0x0b, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x5f, 0x6c,
0x69, 0x73, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x67, 0x72, 0x70, 0x63,
0x2e, 0x6c, 0x62, 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4c, 0x69, 0x73,
0x74, 0x48, 0x00, 0x52, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4c, 0x69, 0x73, 0x74, 0x12,
0x4b, 0x0a, 0x11, 0x66, 0x61, 0x6c, 0x6c, 0x62, 0x61, 0x63, 0x6b, 0x5f, 0x72, 0x65, 0x73, 0x70,
0x6f, 0x6e, 0x73, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x67, 0x72, 0x70,
0x63, 0x2e, 0x6c, 0x62, 0x2e, 0x76, 0x31, 0x2e, 0x46, 0x61, 0x6c, 0x6c, 0x62, 0x61, 0x63, 0x6b,
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x48, 0x00, 0x52, 0x10, 0x66, 0x61, 0x6c, 0x6c,
0x62, 0x61, 0x63, 0x6b, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x1c, 0x0a, 0x1a,
0x6c, 0x6f, 0x61, 0x64, 0x5f, 0x62, 0x61, 0x6c, 0x61, 0x6e, 0x63, 0x65, 0x5f, 0x72, 0x65, 0x73,
0x70, 0x6f, 0x6e, 0x73, 0x65, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x22, 0x12, 0x0a, 0x10, 0x46, 0x61,
0x6c, 0x6c, 0x62, 0x61, 0x63, 0x6b, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x7e,
0x0a, 0x1a, 0x49, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x4c, 0x6f, 0x61, 0x64, 0x42, 0x61, 0x6c,
0x61, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x5a, 0x0a, 0x1c,
0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x73, 0x5f, 0x72, 0x65, 0x70,
0x6f, 0x72, 0x74, 0x5f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x18, 0x02, 0x20, 0x01,
0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x19, 0x63,
0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x73, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74,
0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x4a, 0x04, 0x08, 0x01, 0x10, 0x02, 0x22, 0x40,
0x0a, 0x0a, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4c, 0x69, 0x73, 0x74, 0x12, 0x2c, 0x0a, 0x07,
0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e,
0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x62, 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65,
0x72, 0x52, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x4a, 0x04, 0x08, 0x03, 0x10, 0x04,
0x22, 0x83, 0x01, 0x0a, 0x06, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x1d, 0x0a, 0x0a, 0x69,
0x70, 0x5f, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52,
0x09, 0x69, 0x70, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x6f,
0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x2c,
0x0a, 0x12, 0x6c, 0x6f, 0x61, 0x64, 0x5f, 0x62, 0x61, 0x6c, 0x61, 0x6e, 0x63, 0x65, 0x5f, 0x74,
0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x10, 0x6c, 0x6f, 0x61, 0x64,
0x42, 0x61, 0x6c, 0x61, 0x6e, 0x63, 0x65, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x12, 0x0a, 0x04,
0x64, 0x72, 0x6f, 0x70, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, 0x64, 0x72, 0x6f, 0x70,
0x4a, 0x04, 0x08, 0x05, 0x10, 0x06, 0x32, 0x62, 0x0a, 0x0c, 0x4c, 0x6f, 0x61, 0x64, 0x42, 0x61,
0x6c, 0x61, 0x6e, 0x63, 0x65, 0x72, 0x12, 0x52, 0x0a, 0x0b, 0x42, 0x61, 0x6c, 0x61, 0x6e, 0x63,
0x65, 0x4c, 0x6f, 0x61, 0x64, 0x12, 0x1e, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x62, 0x2e,
0x76, 0x31, 0x2e, 0x4c, 0x6f, 0x61, 0x64, 0x42, 0x61, 0x6c, 0x61, 0x6e, 0x63, 0x65, 0x52, 0x65,
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1f, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x62, 0x2e,
0x76, 0x31, 0x2e, 0x4c, 0x6f, 0x61, 0x64, 0x42, 0x61, 0x6c, 0x61, 0x6e, 0x63, 0x65, 0x52, 0x65,
0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x28, 0x01, 0x30, 0x01, 0x42, 0x57, 0x0a, 0x0d, 0x69, 0x6f,
0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x62, 0x2e, 0x76, 0x31, 0x42, 0x11, 0x4c, 0x6f, 0x61,
0x64, 0x42, 0x61, 0x6c, 0x61, 0x6e, 0x63, 0x65, 0x72, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01,
0x5a, 0x31, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x67, 0x6f, 0x6c, 0x61, 0x6e, 0x67, 0x2e,
0x6f, 0x72, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x62, 0x61, 0x6c, 0x61, 0x6e, 0x63, 0x65,
0x72, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x6c, 0x62, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x5f, 0x6c, 0x62,
0x5f, 0x76, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_grpc_lb_v1_load_balancer_proto_rawDescOnce sync.Once
file_grpc_lb_v1_load_balancer_proto_rawDescData []byte
file_grpc_lb_v1_load_balancer_proto_rawDescData = file_grpc_lb_v1_load_balancer_proto_rawDesc
)
func file_grpc_lb_v1_load_balancer_proto_rawDescGZIP() []byte {
file_grpc_lb_v1_load_balancer_proto_rawDescOnce.Do(func() {
file_grpc_lb_v1_load_balancer_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_grpc_lb_v1_load_balancer_proto_rawDesc), len(file_grpc_lb_v1_load_balancer_proto_rawDesc)))
file_grpc_lb_v1_load_balancer_proto_rawDescData = protoimpl.X.CompressGZIP(file_grpc_lb_v1_load_balancer_proto_rawDescData)
})
return file_grpc_lb_v1_load_balancer_proto_rawDescData
}
var file_grpc_lb_v1_load_balancer_proto_msgTypes = make([]protoimpl.MessageInfo, 9)
var file_grpc_lb_v1_load_balancer_proto_goTypes = []any{
var file_grpc_lb_v1_load_balancer_proto_goTypes = []interface{}{
(*LoadBalanceRequest)(nil), // 0: grpc.lb.v1.LoadBalanceRequest
(*InitialLoadBalanceRequest)(nil), // 1: grpc.lb.v1.InitialLoadBalanceRequest
(*ClientStatsPerToken)(nil), // 2: grpc.lb.v1.ClientStatsPerToken
@ -734,11 +817,121 @@ func file_grpc_lb_v1_load_balancer_proto_init() {
if File_grpc_lb_v1_load_balancer_proto != nil {
return
}
file_grpc_lb_v1_load_balancer_proto_msgTypes[0].OneofWrappers = []any{
if !protoimpl.UnsafeEnabled {
file_grpc_lb_v1_load_balancer_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*LoadBalanceRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_grpc_lb_v1_load_balancer_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*InitialLoadBalanceRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_grpc_lb_v1_load_balancer_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*ClientStatsPerToken); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_grpc_lb_v1_load_balancer_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*ClientStats); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_grpc_lb_v1_load_balancer_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*LoadBalanceResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_grpc_lb_v1_load_balancer_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*FallbackResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_grpc_lb_v1_load_balancer_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*InitialLoadBalanceResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_grpc_lb_v1_load_balancer_proto_msgTypes[7].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*ServerList); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_grpc_lb_v1_load_balancer_proto_msgTypes[8].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*Server); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
file_grpc_lb_v1_load_balancer_proto_msgTypes[0].OneofWrappers = []interface{}{
(*LoadBalanceRequest_InitialRequest)(nil),
(*LoadBalanceRequest_ClientStats)(nil),
}
file_grpc_lb_v1_load_balancer_proto_msgTypes[4].OneofWrappers = []any{
file_grpc_lb_v1_load_balancer_proto_msgTypes[4].OneofWrappers = []interface{}{
(*LoadBalanceResponse_InitialResponse)(nil),
(*LoadBalanceResponse_ServerList)(nil),
(*LoadBalanceResponse_FallbackResponse)(nil),
@ -747,7 +940,7 @@ func file_grpc_lb_v1_load_balancer_proto_init() {
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_grpc_lb_v1_load_balancer_proto_rawDesc), len(file_grpc_lb_v1_load_balancer_proto_rawDesc)),
RawDescriptor: file_grpc_lb_v1_load_balancer_proto_rawDesc,
NumEnums: 0,
NumMessages: 9,
NumExtensions: 0,
@ -758,6 +951,7 @@ func file_grpc_lb_v1_load_balancer_proto_init() {
MessageInfos: file_grpc_lb_v1_load_balancer_proto_msgTypes,
}.Build()
File_grpc_lb_v1_load_balancer_proto = out.File
file_grpc_lb_v1_load_balancer_proto_rawDesc = nil
file_grpc_lb_v1_load_balancer_proto_goTypes = nil
file_grpc_lb_v1_load_balancer_proto_depIdxs = nil
}

View File

@ -19,8 +19,8 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.5.1
// - protoc v5.27.1
// - protoc-gen-go-grpc v1.3.0
// - protoc v4.22.0
// source: grpc/lb/v1/load_balancer.proto
package grpc_lb_v1
@ -34,8 +34,8 @@ import (
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.64.0 or later.
const _ = grpc.SupportPackageIsVersion9
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
const (
LoadBalancer_BalanceLoad_FullMethodName = "/grpc.lb.v1.LoadBalancer/BalanceLoad"
@ -46,7 +46,7 @@ const (
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type LoadBalancerClient interface {
// Bidirectional rpc to get a list of servers.
BalanceLoad(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[LoadBalanceRequest, LoadBalanceResponse], error)
BalanceLoad(ctx context.Context, opts ...grpc.CallOption) (LoadBalancer_BalanceLoadClient, error)
}
type loadBalancerClient struct {
@ -57,38 +57,52 @@ func NewLoadBalancerClient(cc grpc.ClientConnInterface) LoadBalancerClient {
return &loadBalancerClient{cc}
}
func (c *loadBalancerClient) BalanceLoad(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[LoadBalanceRequest, LoadBalanceResponse], error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
stream, err := c.cc.NewStream(ctx, &LoadBalancer_ServiceDesc.Streams[0], LoadBalancer_BalanceLoad_FullMethodName, cOpts...)
func (c *loadBalancerClient) BalanceLoad(ctx context.Context, opts ...grpc.CallOption) (LoadBalancer_BalanceLoadClient, error) {
stream, err := c.cc.NewStream(ctx, &LoadBalancer_ServiceDesc.Streams[0], LoadBalancer_BalanceLoad_FullMethodName, opts...)
if err != nil {
return nil, err
}
x := &grpc.GenericClientStream[LoadBalanceRequest, LoadBalanceResponse]{ClientStream: stream}
x := &loadBalancerBalanceLoadClient{stream}
return x, nil
}
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type LoadBalancer_BalanceLoadClient = grpc.BidiStreamingClient[LoadBalanceRequest, LoadBalanceResponse]
type LoadBalancer_BalanceLoadClient interface {
Send(*LoadBalanceRequest) error
Recv() (*LoadBalanceResponse, error)
grpc.ClientStream
}
type loadBalancerBalanceLoadClient struct {
grpc.ClientStream
}
func (x *loadBalancerBalanceLoadClient) Send(m *LoadBalanceRequest) error {
return x.ClientStream.SendMsg(m)
}
func (x *loadBalancerBalanceLoadClient) Recv() (*LoadBalanceResponse, error) {
m := new(LoadBalanceResponse)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// LoadBalancerServer is the server API for LoadBalancer service.
// All implementations should embed UnimplementedLoadBalancerServer
// for forward compatibility.
// for forward compatibility
type LoadBalancerServer interface {
// Bidirectional rpc to get a list of servers.
BalanceLoad(grpc.BidiStreamingServer[LoadBalanceRequest, LoadBalanceResponse]) error
BalanceLoad(LoadBalancer_BalanceLoadServer) error
}
// UnimplementedLoadBalancerServer should be embedded to have
// forward compatible implementations.
//
// NOTE: this should be embedded by value instead of pointer to avoid a nil
// pointer dereference when methods are called.
type UnimplementedLoadBalancerServer struct{}
func (UnimplementedLoadBalancerServer) BalanceLoad(grpc.BidiStreamingServer[LoadBalanceRequest, LoadBalanceResponse]) error {
return status.Error(codes.Unimplemented, "method BalanceLoad not implemented")
// UnimplementedLoadBalancerServer should be embedded to have forward compatible implementations.
type UnimplementedLoadBalancerServer struct {
}
func (UnimplementedLoadBalancerServer) BalanceLoad(LoadBalancer_BalanceLoadServer) error {
return status.Errorf(codes.Unimplemented, "method BalanceLoad not implemented")
}
func (UnimplementedLoadBalancerServer) testEmbeddedByValue() {}
// UnsafeLoadBalancerServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to LoadBalancerServer will
@ -98,22 +112,34 @@ type UnsafeLoadBalancerServer interface {
}
func RegisterLoadBalancerServer(s grpc.ServiceRegistrar, srv LoadBalancerServer) {
// If the following call panics, it indicates UnimplementedLoadBalancerServer was
// embedded by pointer and is nil. This will cause panics if an
// unimplemented method is ever invoked, so we test this at initialization
// time to prevent it from happening at runtime later due to I/O.
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
t.testEmbeddedByValue()
}
s.RegisterService(&LoadBalancer_ServiceDesc, srv)
}
func _LoadBalancer_BalanceLoad_Handler(srv interface{}, stream grpc.ServerStream) error {
return srv.(LoadBalancerServer).BalanceLoad(&grpc.GenericServerStream[LoadBalanceRequest, LoadBalanceResponse]{ServerStream: stream})
return srv.(LoadBalancerServer).BalanceLoad(&loadBalancerBalanceLoadServer{stream})
}
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type LoadBalancer_BalanceLoadServer = grpc.BidiStreamingServer[LoadBalanceRequest, LoadBalanceResponse]
type LoadBalancer_BalanceLoadServer interface {
Send(*LoadBalanceResponse) error
Recv() (*LoadBalanceRequest, error)
grpc.ServerStream
}
type loadBalancerBalanceLoadServer struct {
grpc.ServerStream
}
func (x *loadBalancerBalanceLoadServer) Send(m *LoadBalanceResponse) error {
return x.ServerStream.SendMsg(m)
}
func (x *loadBalancerBalanceLoadServer) Recv() (*LoadBalanceRequest, error) {
m := new(LoadBalanceRequest)
if err := x.ServerStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// LoadBalancer_ServiceDesc is the grpc.ServiceDesc for LoadBalancer service.
// It's only intended for direct use with grpc.RegisterService,

View File

@ -32,20 +32,16 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/base"
grpclbstate "google.golang.org/grpc/balancer/grpclb/state"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/backoff"
internalgrpclog "google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/pretty"
"google.golang.org/grpc/internal/resolver/dns"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
"google.golang.org/protobuf/types/known/durationpb"
durationpb "github.com/golang/protobuf/ptypes/duration"
lbpb "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1"
)
@ -136,11 +132,7 @@ func (b *lbBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) bal
// This generates a manual resolver builder with a fixed scheme. This
// scheme will be used to dial to remote LB, so we can send filtered
// address updates to remote LB ClientConn using this manual resolver.
mr := manual.NewBuilderWithScheme("grpclb-internal")
// ResolveNow() on this manual resolver is forwarded to the parent
// ClientConn, so when grpclb client loses contact with the remote balancer,
// the parent ClientConn's resolver will re-resolve.
mr.ResolveNowCallback = cc.ResolveNow
r := &lbManualResolver{scheme: "grpclb-internal", ccb: cc}
lb := &lbBalancer{
cc: newLBCacheClientConn(cc),
@ -150,24 +142,23 @@ func (b *lbBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) bal
fallbackTimeout: b.fallbackTimeout,
doneCh: make(chan struct{}),
manualResolver: mr,
manualResolver: r,
subConns: make(map[resolver.Address]balancer.SubConn),
scStates: make(map[balancer.SubConn]connectivity.State),
picker: base.NewErrPicker(balancer.ErrNoSubConnAvailable),
picker: &errPicker{err: balancer.ErrNoSubConnAvailable},
clientStats: newRPCStats(),
backoff: backoff.DefaultExponential, // TODO: make backoff configurable.
}
lb.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf("[grpclb %p] ", lb))
var err error
if opt.CredsBundle != nil {
lb.grpclbClientConnCreds, err = opt.CredsBundle.NewWithMode(internal.CredsBundleModeBalancer)
if err != nil {
lb.logger.Warningf("Failed to create credentials used for connecting to grpclb: %v", err)
logger.Warningf("lbBalancer: client connection creds NewWithMode failed: %v", err)
}
lb.grpclbBackendCreds, err = opt.CredsBundle.NewWithMode(internal.CredsBundleModeBackendFromBalancer)
if err != nil {
lb.logger.Warningf("Failed to create credentials used for connecting to backends returned by grpclb: %v", err)
logger.Warningf("lbBalancer: backend creds NewWithMode failed: %v", err)
}
}
@ -179,7 +170,6 @@ type lbBalancer struct {
dialTarget string // user's dial target
target string // same as dialTarget unless overridden in service config
opt balancer.BuildOptions
logger *internalgrpclog.PrefixLogger
usePickFirst bool
@ -197,8 +187,8 @@ type lbBalancer struct {
// manualResolver is used in the remote LB ClientConn inside grpclb. When
// resolved address updates are received by grpclb, filtered updates will be
// sent to remote LB ClientConn through this resolver.
manualResolver *manual.Resolver
// send to remote LB ClientConn through this resolver.
manualResolver *lbManualResolver
// The ClientConn to talk to the remote balancer.
ccRemoteLB *remoteBalancerCCWrapper
// backoff for calling remote balancer.
@ -219,11 +209,11 @@ type lbBalancer struct {
// All backends addresses, with metadata set to nil. This list contains all
// backend addresses in the same order and with the same duplicates as in
// serverlist. When generating picker, a SubConn slice with the same order
// but with only READY SCs will be generated.
// but with only READY SCs will be gerenated.
backendAddrsWithoutMetadata []resolver.Address
// Roundrobin functionalities.
state connectivity.State
subConns map[resolver.Address]balancer.SubConn // Used to new/shutdown SubConn.
subConns map[resolver.Address]balancer.SubConn // Used to new/remove SubConn.
scStates map[balancer.SubConn]connectivity.State // Used to filter READY SubConns.
picker balancer.Picker
// Support fallback to resolved backend addresses if there's no response
@ -246,12 +236,12 @@ type lbBalancer struct {
// Caller must hold lb.mu.
func (lb *lbBalancer) regeneratePicker(resetDrop bool) {
if lb.state == connectivity.TransientFailure {
lb.picker = base.NewErrPicker(fmt.Errorf("all SubConns are in TransientFailure, last connection error: %v", lb.connErr))
lb.picker = &errPicker{err: fmt.Errorf("all SubConns are in TransientFailure, last connection error: %v", lb.connErr)}
return
}
if lb.state == connectivity.Connecting {
lb.picker = base.NewErrPicker(balancer.ErrNoSubConnAvailable)
lb.picker = &errPicker{err: balancer.ErrNoSubConnAvailable}
return
}
@ -278,7 +268,7 @@ func (lb *lbBalancer) regeneratePicker(resetDrop bool) {
//
// This doesn't seem to be necessary after the connecting check above.
// Kept for safety.
lb.picker = base.NewErrPicker(balancer.ErrNoSubConnAvailable)
lb.picker = &errPicker{err: balancer.ErrNoSubConnAvailable}
return
}
if lb.inFallback {
@ -300,7 +290,7 @@ func (lb *lbBalancer) regeneratePicker(resetDrop bool) {
// aggregateSubConnStats calculate the aggregated state of SubConns in
// lb.SubConns. These SubConns are subconns in use (when switching between
// fallback and grpclb). lb.scState contains states for all SubConns, including
// those in cache (SubConns are cached for 10 seconds after shutdown).
// those in cache (SubConns are cached for 10 seconds after remove).
//
// The aggregated state is:
// - If at least one SubConn in Ready, the aggregated state is Ready;
@ -329,24 +319,18 @@ func (lb *lbBalancer) aggregateSubConnStates() connectivity.State {
return connectivity.TransientFailure
}
// UpdateSubConnState is unused; NewSubConn's options always specifies
// updateSubConnState as the listener.
func (lb *lbBalancer) UpdateSubConnState(sc balancer.SubConn, scs balancer.SubConnState) {
lb.logger.Errorf("UpdateSubConnState(%v, %+v) called unexpectedly", sc, scs)
}
func (lb *lbBalancer) updateSubConnState(sc balancer.SubConn, scs balancer.SubConnState) {
s := scs.ConnectivityState
if lb.logger.V(2) {
lb.logger.Infof("SubConn state change: %p, %v", sc, s)
if logger.V(2) {
logger.Infof("lbBalancer: handle SubConn state change: %p, %v", sc, s)
}
lb.mu.Lock()
defer lb.mu.Unlock()
oldS, ok := lb.scStates[sc]
if !ok {
if lb.logger.V(2) {
lb.logger.Infof("Received state change for an unknown SubConn: %p, %v", sc, s)
if logger.V(2) {
logger.Infof("lbBalancer: got state changes for an unknown SubConn: %p, %v", sc, s)
}
return
}
@ -355,8 +339,8 @@ func (lb *lbBalancer) updateSubConnState(sc balancer.SubConn, scs balancer.SubCo
case connectivity.Idle:
sc.Connect()
case connectivity.Shutdown:
// When an address was removed by resolver, b called Shutdown but kept
// the sc's state in scStates. Remove state for this sc here.
// When an address was removed by resolver, b called RemoveSubConn but
// kept the sc's state in scStates. Remove state for this sc here.
delete(lb.scStates, sc)
case connectivity.TransientFailure:
lb.connErr = scs.ConnectionError
@ -389,13 +373,8 @@ func (lb *lbBalancer) updateStateAndPicker(forceRegeneratePicker bool, resetDrop
if forceRegeneratePicker || (lb.state != oldAggrState) {
lb.regeneratePicker(resetDrop)
}
var cc balancer.ClientConn = lb.cc
if lb.usePickFirst {
// Bypass the caching layer that would wrap the picker.
cc = lb.cc.ClientConn
}
cc.UpdateState(balancer.State{ConnectivityState: lb.state, Picker: lb.picker})
lb.cc.UpdateState(balancer.State{ConnectivityState: lb.state, Picker: lb.picker})
}
// fallbackToBackendsAfter blocks for fallbackTimeout and falls back to use
@ -451,8 +430,8 @@ func (lb *lbBalancer) handleServiceConfig(gc *grpclbServiceConfig) {
if lb.usePickFirst == newUsePickFirst {
return
}
if lb.logger.V(2) {
lb.logger.Infof("Switching mode. Is pick_first used for backends? %v", newUsePickFirst)
if logger.V(2) {
logger.Infof("lbBalancer: switching mode, new usePickFirst: %+v", newUsePickFirst)
}
lb.refreshSubConns(lb.backendAddrs, lb.inFallback, newUsePickFirst)
}
@ -463,15 +442,23 @@ func (lb *lbBalancer) ResolverError(error) {
}
func (lb *lbBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error {
if lb.logger.V(2) {
lb.logger.Infof("UpdateClientConnState: %s", pretty.ToJSON(ccs))
if logger.V(2) {
logger.Infof("lbBalancer: UpdateClientConnState: %+v", ccs)
}
gc, _ := ccs.BalancerConfig.(*grpclbServiceConfig)
lb.handleServiceConfig(gc)
backendAddrs := ccs.ResolverState.Addresses
addrs := ccs.ResolverState.Addresses
var remoteBalancerAddrs []resolver.Address
var remoteBalancerAddrs, backendAddrs []resolver.Address
for _, a := range addrs {
if a.Type == resolver.GRPCLB {
a.Type = resolver.Backend
remoteBalancerAddrs = append(remoteBalancerAddrs, a)
} else {
backendAddrs = append(backendAddrs, a)
}
}
if sd := grpclbstate.Get(ccs.ResolverState); sd != nil {
// Override any balancer addresses provided via
// ccs.ResolverState.Addresses.
@ -492,9 +479,7 @@ func (lb *lbBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error
} else if lb.ccRemoteLB == nil {
// First time receiving resolved addresses, create a cc to remote
// balancers.
if err := lb.newRemoteBalancerCCWrapper(); err != nil {
return err
}
lb.newRemoteBalancerCCWrapper()
// Start the fallback goroutine.
go lb.fallbackToBackendsAfter(lb.fallbackTimeout)
}

View File

@ -21,14 +21,14 @@ package grpclb
import (
"encoding/json"
"google.golang.org/grpc/balancer/pickfirst"
"google.golang.org/grpc"
"google.golang.org/grpc/balancer/roundrobin"
"google.golang.org/grpc/serviceconfig"
)
const (
roundRobinName = roundrobin.Name
pickFirstName = pickfirst.Name
pickFirstName = grpc.PickFirstBalancerName
)
type grpclbServiceConfig struct {

View File

@ -19,13 +19,13 @@
package grpclb
import (
rand "math/rand/v2"
"sync"
"sync/atomic"
"google.golang.org/grpc/balancer"
lbpb "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/internal/grpcrand"
"google.golang.org/grpc/status"
)
@ -98,6 +98,15 @@ func (s *rpcStats) knownReceived() {
atomic.AddInt64(&s.numCallsFinished, 1)
}
type errPicker struct {
// Pick always returns this err.
err error
}
func (p *errPicker) Pick(balancer.PickInfo) (balancer.PickResult, error) {
return balancer.PickResult{}, p.err
}
// rrPicker does roundrobin on subConns. It's typically used when there's no
// response from remote balancer, and grpclb falls back to the resolved
// backends.
@ -112,7 +121,7 @@ type rrPicker struct {
func newRRPicker(readySCs []balancer.SubConn) *rrPicker {
return &rrPicker{
subConns: readySCs,
subConnsNext: rand.IntN(len(readySCs)),
subConnsNext: grpcrand.Intn(len(readySCs)),
}
}
@ -147,7 +156,7 @@ func newLBPicker(serverList []*lbpb.Server, readySCs []balancer.SubConn, stats *
return &lbPicker{
serverList: serverList,
subConns: readySCs,
subConnsNext: rand.IntN(len(readySCs)),
subConnsNext: grpcrand.Intn(len(readySCs)),
stats: stats,
}
}

View File

@ -26,8 +26,12 @@ import (
"sync"
"time"
"github.com/golang/protobuf/proto"
timestamppb "github.com/golang/protobuf/ptypes/timestamp"
"github.com/google/go-cmp/cmp"
"google.golang.org/grpc"
"google.golang.org/grpc/balancer"
lbpb "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/internal/backoff"
@ -35,29 +39,13 @@ import (
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/resolver"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
lbpb "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1"
)
func serverListEqual(a, b []*lbpb.Server) bool {
if len(a) != len(b) {
return false
}
for i := 0; i < len(a); i++ {
if !proto.Equal(a[i], b[i]) {
return false
}
}
return true
}
// processServerList updates balancer's internal state, create/remove SubConns
// and regenerates picker using the received serverList.
func (lb *lbBalancer) processServerList(l *lbpb.ServerList) {
if lb.logger.V(2) {
lb.logger.Infof("Processing server list: %#v", l)
if logger.V(2) {
logger.Infof("lbBalancer: processing server list: %+v", l)
}
lb.mu.Lock()
defer lb.mu.Unlock()
@ -67,9 +55,9 @@ func (lb *lbBalancer) processServerList(l *lbpb.ServerList) {
lb.serverListReceived = true
// If the new server list == old server list, do nothing.
if serverListEqual(lb.fullServerList, l.Servers) {
if lb.logger.V(2) {
lb.logger.Infof("Ignoring new server list as it is the same as the previous one")
if cmp.Equal(lb.fullServerList, l.Servers, cmp.Comparer(proto.Equal)) {
if logger.V(2) {
logger.Infof("lbBalancer: new serverlist same as the previous one, ignoring")
}
return
}
@ -82,10 +70,17 @@ func (lb *lbBalancer) processServerList(l *lbpb.ServerList) {
}
md := metadata.Pairs(lbTokenKey, s.LoadBalanceToken)
ipStr := net.IP(s.IpAddress).String()
addr := imetadata.Set(resolver.Address{Addr: net.JoinHostPort(ipStr, fmt.Sprintf("%d", s.Port))}, md)
if lb.logger.V(2) {
lb.logger.Infof("Server list entry:|%d|, ipStr:|%s|, port:|%d|, load balancer token:|%v|", i, ipStr, s.Port, s.LoadBalanceToken)
ip := net.IP(s.IpAddress)
ipStr := ip.String()
if ip.To4() == nil {
// Add square brackets to ipv6 addresses, otherwise net.Dial() and
// net.SplitHostPort() will return too many colons error.
ipStr = fmt.Sprintf("[%s]", ipStr)
}
addr := imetadata.Set(resolver.Address{Addr: fmt.Sprintf("%s:%d", ipStr, s.Port)}, md)
if logger.V(2) {
logger.Infof("lbBalancer: server list entry[%d]: ipStr:|%s|, port:|%d|, load balancer token:|%v|",
i, ipStr, s.Port, s.LoadBalanceToken)
}
backendAddrs = append(backendAddrs, addr)
}
@ -118,6 +113,7 @@ func (lb *lbBalancer) refreshSubConns(backendAddrs []resolver.Address, fallback
}
balancingPolicyChanged := lb.usePickFirst != pickFirst
oldUsePickFirst := lb.usePickFirst
lb.usePickFirst = pickFirst
if fallbackModeChanged || balancingPolicyChanged {
@ -127,7 +123,13 @@ func (lb *lbBalancer) refreshSubConns(backendAddrs []resolver.Address, fallback
// For fallback mode switching with pickfirst, we want to recreate the
// SubConn because the creds could be different.
for a, sc := range lb.subConns {
sc.Shutdown()
if oldUsePickFirst {
// If old SubConn were created for pickfirst, bypass cache and
// remove directly.
lb.cc.cc.RemoveSubConn(sc)
} else {
lb.cc.RemoveSubConn(sc)
}
delete(lb.subConns, a)
}
}
@ -142,19 +144,18 @@ func (lb *lbBalancer) refreshSubConns(backendAddrs []resolver.Address, fallback
}
if sc != nil {
if len(backendAddrs) == 0 {
sc.Shutdown()
lb.cc.cc.RemoveSubConn(sc)
delete(lb.subConns, scKey)
return
}
lb.cc.ClientConn.UpdateAddresses(sc, backendAddrs)
lb.cc.cc.UpdateAddresses(sc, backendAddrs)
sc.Connect()
return
}
opts.StateListener = func(scs balancer.SubConnState) { lb.updateSubConnState(sc, scs) }
// This bypasses the cc wrapper with SubConn cache.
sc, err := lb.cc.ClientConn.NewSubConn(backendAddrs, opts)
sc, err := lb.cc.cc.NewSubConn(backendAddrs, opts)
if err != nil {
lb.logger.Warningf("Failed to create new SubConn: %v", err)
logger.Warningf("grpclb: failed to create new SubConn: %v", err)
return
}
sc.Connect()
@ -175,11 +176,9 @@ func (lb *lbBalancer) refreshSubConns(backendAddrs []resolver.Address, fallback
if _, ok := lb.subConns[addrWithoutAttrs]; !ok {
// Use addrWithMD to create the SubConn.
var sc balancer.SubConn
opts.StateListener = func(scs balancer.SubConnState) { lb.updateSubConnState(sc, scs) }
sc, err := lb.cc.NewSubConn([]resolver.Address{addr}, opts)
if err != nil {
lb.logger.Warningf("Failed to create new SubConn: %v", err)
logger.Warningf("grpclb: failed to create new SubConn: %v", err)
continue
}
lb.subConns[addrWithoutAttrs] = sc // Use the addr without MD as key for the map.
@ -195,7 +194,7 @@ func (lb *lbBalancer) refreshSubConns(backendAddrs []resolver.Address, fallback
for a, sc := range lb.subConns {
// a was removed by resolver.
if _, ok := addrsSet[a]; !ok {
sc.Shutdown()
lb.cc.RemoveSubConn(sc)
delete(lb.subConns, a)
// Keep the state of this sc in b.scStates until sc's state becomes Shutdown.
// The entry will be deleted in UpdateSubConnState.
@ -222,7 +221,7 @@ type remoteBalancerCCWrapper struct {
wg sync.WaitGroup
}
func (lb *lbBalancer) newRemoteBalancerCCWrapper() error {
func (lb *lbBalancer) newRemoteBalancerCCWrapper() {
var dopts []grpc.DialOption
if creds := lb.opt.DialCreds; creds != nil {
dopts = append(dopts, grpc.WithTransportCredentials(creds))
@ -240,7 +239,7 @@ func (lb *lbBalancer) newRemoteBalancerCCWrapper() error {
// Explicitly set pickfirst as the balancer.
dopts = append(dopts, grpc.WithDefaultServiceConfig(`{"loadBalancingPolicy":"pick_first"}`))
dopts = append(dopts, grpc.WithResolvers(lb.manualResolver))
dopts = append(dopts, grpc.WithChannelzParentID(lb.opt.ChannelzParent))
dopts = append(dopts, grpc.WithChannelzParentID(lb.opt.ChannelzParentID))
// Enable Keepalive for grpclb client.
dopts = append(dopts, grpc.WithKeepaliveParams(keepalive.ClientParameters{
@ -253,12 +252,10 @@ func (lb *lbBalancer) newRemoteBalancerCCWrapper() error {
//
// The grpclb server addresses will set field ServerName, and creds will
// receive ServerName as authority.
target := lb.manualResolver.Scheme() + ":///grpclb.subClientConn"
cc, err := grpc.NewClient(target, dopts...)
cc, err := grpc.DialContext(context.Background(), lb.manualResolver.Scheme()+":///grpclb.subClientConn", dopts...)
if err != nil {
return fmt.Errorf("grpc.NewClient(%s): %v", target, err)
logger.Fatalf("failed to dial: %v", err)
}
cc.Connect()
ccw := &remoteBalancerCCWrapper{
cc: cc,
lb: lb,
@ -268,7 +265,6 @@ func (lb *lbBalancer) newRemoteBalancerCCWrapper() error {
lb.ccRemoteLB = ccw
ccw.wg.Add(1)
go ccw.watchRemoteBalancer()
return nil
}
// close closed the ClientConn to remote balancer, and waits until all
@ -416,14 +412,14 @@ func (ccw *remoteBalancerCCWrapper) watchRemoteBalancer() {
default:
if err != nil {
if err == errServerTerminatedConnection {
ccw.lb.logger.Infof("Call to remote balancer failed: %v", err)
logger.Info(err)
} else {
ccw.lb.logger.Warningf("Call to remote balancer failed: %v", err)
logger.Warning(err)
}
}
}
// Trigger a re-resolve when the stream errors.
ccw.lb.cc.ClientConn.ResolveNow(resolver.ResolveNowOptions{})
ccw.lb.cc.cc.ResolveNow(resolver.ResolveNowOptions{})
ccw.lb.mu.Lock()
ccw.lb.remoteBalancerConnected = false

View File

@ -50,8 +50,8 @@ import (
"google.golang.org/grpc/resolver/manual"
"google.golang.org/grpc/serviceconfig"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb"
durationpb "github.com/golang/protobuf/ptypes/duration"
lbgrpc "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1"
lbpb "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1"
testgrpc "google.golang.org/grpc/interop/grpc_testing"
@ -126,7 +126,7 @@ func (c *serverNameCheckCreds) Info() credentials.ProtocolInfo {
func (c *serverNameCheckCreds) Clone() credentials.TransportCredentials {
return &serverNameCheckCreds{}
}
func (c *serverNameCheckCreds) OverrideServerName(string) error {
func (c *serverNameCheckCreds) OverrideServerName(s string) error {
return nil
}
@ -307,7 +307,7 @@ type testServer struct {
const testmdkey = "testmd"
func (s *testServer) EmptyCall(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, status.Error(codes.Internal, "failed to receive metadata")
@ -319,7 +319,7 @@ func (s *testServer) EmptyCall(ctx context.Context, _ *testpb.Empty) (*testpb.Em
return &testpb.Empty{}, nil
}
func (s *testServer) FullDuplexCall(testgrpc.TestService_FullDuplexCallServer) error {
func (s *testServer) FullDuplexCall(stream testgrpc.TestService_FullDuplexCallServer) error {
return nil
}
@ -458,9 +458,9 @@ func (s) TestGRPCLB_Basic(t *testing.T) {
grpc.WithContextDialer(fakeNameDialer),
grpc.WithUserAgent(testUserAgent),
}
cc, err := grpc.NewClient(r.Scheme()+":///"+beServerName, dopts...)
cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, dopts...)
if err != nil {
t.Fatalf("Failed to create a client for the backend %v", err)
t.Fatalf("Failed to dial to the backend %v", err)
}
defer cc.Close()
@ -515,9 +515,9 @@ func (s) TestGRPCLB_Weighted(t *testing.T) {
grpc.WithTransportCredentials(&serverNameCheckCreds{}),
grpc.WithContextDialer(fakeNameDialer),
}
cc, err := grpc.NewClient(r.Scheme()+":///"+beServerName, dopts...)
cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, dopts...)
if err != nil {
t.Fatalf("Failed to create a client for the backend %v", err)
t.Fatalf("Failed to dial to the backend %v", err)
}
defer cc.Close()
@ -541,7 +541,7 @@ func (s) TestGRPCLB_Weighted(t *testing.T) {
tss.ls.sls <- &lbpb.ServerList{Servers: backends}
testC := testgrpc.NewTestServiceClient(cc)
if err := roundrobin.CheckWeightedRoundRobinRPCs(ctx, t, testC, wantAddrs); err != nil {
if err := roundrobin.CheckWeightedRoundRobinRPCs(ctx, testC, wantAddrs); err != nil {
t.Fatal(err)
}
}
@ -595,9 +595,9 @@ func (s) TestGRPCLB_DropRequest(t *testing.T) {
grpc.WithTransportCredentials(&serverNameCheckCreds{}),
grpc.WithContextDialer(fakeNameDialer),
}
cc, err := grpc.NewClient(r.Scheme()+":///"+beServerName, dopts...)
cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, dopts...)
if err != nil {
t.Fatalf("Failed to create a client for the backend %v", err)
t.Fatalf("Failed to dial to the backend %v", err)
}
defer cc.Close()
testC := testgrpc.NewTestServiceClient(cc)
@ -767,9 +767,9 @@ func (s) TestGRPCLB_BalancerDisconnects(t *testing.T) {
grpc.WithTransportCredentials(&serverNameCheckCreds{}),
grpc.WithContextDialer(fakeNameDialer),
}
cc, err := grpc.NewClient(r.Scheme()+":///"+beServerName, dopts...)
cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, dopts...)
if err != nil {
t.Fatalf("Failed to create a client for the backend %v", err)
t.Fatalf("Failed to dial to the backend %v", err)
}
defer cc.Close()
testC := testgrpc.NewTestServiceClient(cc)
@ -827,28 +827,28 @@ func (s) TestGRPCLB_Fallback(t *testing.T) {
defer stopBackends(standaloneBEs)
r := manual.NewBuilderWithScheme("whatever")
// Set the initial resolver state with fallback backend address stored in
// the `Addresses` field and an invalid remote balancer address stored in
// attributes, which will cause fallback behavior to be invoked.
rs := resolver.State{
Addresses: []resolver.Address{{Addr: beLis.Addr().String()}},
ServiceConfig: internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(grpclbConfig),
}
rs = grpclbstate.Set(rs, &grpclbstate.State{BalancerAddresses: []resolver.Address{{Addr: "invalid.address", ServerName: lbServerName}}})
r.InitialState(rs)
dopts := []grpc.DialOption{
grpc.WithResolvers(r),
grpc.WithTransportCredentials(&serverNameCheckCreds{}),
grpc.WithContextDialer(fakeNameDialer),
}
cc, err := grpc.NewClient(r.Scheme()+":///"+beServerName, dopts...)
cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, dopts...)
if err != nil {
t.Fatalf("Failed to create new client to the backend %v", err)
t.Fatalf("Failed to dial to the backend %v", err)
}
defer cc.Close()
testC := testgrpc.NewTestServiceClient(cc)
// Push an update to the resolver with fallback backend address stored in
// the `Addresses` field and an invalid remote balancer address stored in
// attributes, which will cause fallback behavior to be invoked.
rs := resolver.State{
Addresses: []resolver.Address{{Addr: beLis.Addr().String()}},
ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig),
}
rs = grpclbstate.Set(rs, &grpclbstate.State{BalancerAddresses: []resolver.Address{{Addr: "invalid.address", ServerName: lbServerName}}})
r.UpdateState(rs)
// Make an RPC and verify that it got routed to the fallback backend.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
@ -859,7 +859,7 @@ func (s) TestGRPCLB_Fallback(t *testing.T) {
// Push another update to the resolver, this time with a valid balancer
// address in the attributes field.
rs = resolver.State{
ServiceConfig: r.CC().ParseServiceConfig(grpclbConfig),
ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig),
Addresses: []resolver.Address{{Addr: beLis.Addr().String()}},
}
rs = grpclbstate.Set(rs, &grpclbstate.State{BalancerAddresses: []resolver.Address{{Addr: tss.lbAddr, ServerName: lbServerName}}})
@ -938,9 +938,9 @@ func (s) TestGRPCLB_ExplicitFallback(t *testing.T) {
grpc.WithTransportCredentials(&serverNameCheckCreds{}),
grpc.WithContextDialer(fakeNameDialer),
}
cc, err := grpc.NewClient(r.Scheme()+":///"+beServerName, dopts...)
cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, dopts...)
if err != nil {
t.Fatalf("Failed to create a client for the backend %v", err)
t.Fatalf("Failed to dial to the backend %v", err)
}
defer cc.Close()
testC := testgrpc.NewTestServiceClient(cc)
@ -973,7 +973,7 @@ func (s) TestGRPCLB_FallBackWithNoServerAddress(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
defer cancel()
if err := resolveNowCh.SendContext(ctx, nil); err != nil {
t.Error("timeout when attempting to send on resolverNowCh")
t.Error("timeout when attemping to send on resolverNowCh")
}
}
@ -1008,12 +1008,11 @@ func (s) TestGRPCLB_FallBackWithNoServerAddress(t *testing.T) {
grpc.WithTransportCredentials(&serverNameCheckCreds{}),
grpc.WithContextDialer(fakeNameDialer),
}
cc, err := grpc.NewClient(r.Scheme()+":///"+beServerName, dopts...)
cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, dopts...)
if err != nil {
t.Fatalf("Failed to create a client for the backend %v", err)
t.Fatalf("Failed to dial to the backend %v", err)
}
defer cc.Close()
cc.Connect()
testC := testgrpc.NewTestServiceClient(cc)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
@ -1023,7 +1022,7 @@ func (s) TestGRPCLB_FallBackWithNoServerAddress(t *testing.T) {
// fallback and use the fallback backend.
r.UpdateState(resolver.State{
Addresses: []resolver.Address{{Addr: beLis.Addr().String()}},
ServiceConfig: r.CC().ParseServiceConfig(grpclbConfig),
ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig),
})
sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
@ -1051,7 +1050,7 @@ func (s) TestGRPCLB_FallBackWithNoServerAddress(t *testing.T) {
// be used.
rs := resolver.State{
Addresses: []resolver.Address{{Addr: beLis.Addr().String()}},
ServiceConfig: r.CC().ParseServiceConfig(grpclbConfig),
ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig),
}
rs = grpclbstate.Set(rs, &grpclbstate.State{BalancerAddresses: []resolver.Address{{Addr: tss.lbAddr, ServerName: lbServerName}}})
r.UpdateState(rs)
@ -1103,16 +1102,15 @@ func (s) TestGRPCLB_PickFirst(t *testing.T) {
grpc.WithTransportCredentials(&serverNameCheckCreds{}),
grpc.WithContextDialer(fakeNameDialer),
}
cc, err := grpc.NewClient(r.Scheme()+":///"+beServerName, dopts...)
cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, dopts...)
if err != nil {
t.Fatalf("Failed to create a client for the backend: %v", err)
t.Fatalf("Failed to dial to the backend %v", err)
}
cc.Connect()
defer cc.Close()
// Push a service config with grpclb as the load balancing policy and
// configure pick_first as its child policy.
rs := resolver.State{ServiceConfig: r.CC().ParseServiceConfig(`{"loadBalancingConfig":[{"grpclb":{"childPolicy":[{"pick_first":{}}]}}]}`)}
rs := resolver.State{ServiceConfig: r.CC.ParseServiceConfig(`{"loadBalancingConfig":[{"grpclb":{"childPolicy":[{"pick_first":{}}]}}]}`)}
// Push a resolver update with the remote balancer address specified via
// attributes.
@ -1152,7 +1150,7 @@ func (s) TestGRPCLB_PickFirst(t *testing.T) {
},
},
}
rs = grpclbstate.Set(resolver.State{ServiceConfig: r.CC().ParseServiceConfig(grpclbConfig)}, s)
rs = grpclbstate.Set(resolver.State{ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig)}, s)
r.UpdateState(rs)
testC := testgrpc.NewTestServiceClient(cc)
if err := roundrobin.CheckRoundRobinRPCs(ctx, testC, beServerAddrs[1:]); err != nil {
@ -1189,22 +1187,23 @@ func (s) TestGRPCLB_BackendConnectionErrorPropagation(t *testing.T) {
standaloneBEs := startBackends(t, "arbitrary.invalid.name", true, beLis)
defer stopBackends(standaloneBEs)
rs := resolver.State{
Addresses: []resolver.Address{{Addr: beLis.Addr().String()}},
ServiceConfig: internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(grpclbConfig),
}
rs = grpclbstate.Set(rs, &grpclbstate.State{BalancerAddresses: []resolver.Address{{Addr: tss.lbAddr, ServerName: lbServerName}}})
r.InitialState(rs)
cc, err := grpc.NewClient(r.Scheme()+":///"+beServerName,
cc, err := grpc.Dial(r.Scheme()+":///"+beServerName,
grpc.WithResolvers(r),
grpc.WithTransportCredentials(&serverNameCheckCreds{}),
grpc.WithContextDialer(fakeNameDialer))
if err != nil {
t.Fatalf("Failed to create a client for the backend: %v", err)
t.Fatalf("Failed to dial to the backend %v", err)
}
defer cc.Close()
testC := testgrpc.NewTestServiceClient(cc)
rs := resolver.State{
Addresses: []resolver.Address{{Addr: beLis.Addr().String()}},
ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig),
}
rs = grpclbstate.Set(rs, &grpclbstate.State{BalancerAddresses: []resolver.Address{{Addr: tss.lbAddr, ServerName: lbServerName}}})
r.UpdateState(rs)
// If https://github.com/grpc/grpc-go/blob/65cabd74d8e18d7347fecd414fa8d83a00035f5f/balancer/grpclb/grpclb_test.go#L103
// changes, then expectedErrMsg may need to be updated.
const expectedErrMsg = "received unexpected server name"
@ -1243,11 +1242,10 @@ func testGRPCLBEmptyServerList(t *testing.T, svcfg string) {
grpc.WithTransportCredentials(&serverNameCheckCreds{}),
grpc.WithContextDialer(fakeNameDialer),
}
cc, err := grpc.NewClient(r.Scheme()+":///"+beServerName, dopts...)
cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, dopts...)
if err != nil {
t.Fatalf("Failed to create a client for the backend %v", err)
t.Fatalf("Failed to dial to the backend %v", err)
}
cc.Connect()
defer cc.Close()
testC := testgrpc.NewTestServiceClient(cc)
@ -1261,7 +1259,7 @@ func testGRPCLBEmptyServerList(t *testing.T, svcfg string) {
},
},
}
rs := grpclbstate.Set(resolver.State{ServiceConfig: r.CC().ParseServiceConfig(svcfg)}, s)
rs := grpclbstate.Set(resolver.State{ServiceConfig: r.CC.ParseServiceConfig(svcfg)}, s)
r.UpdateState(rs)
t.Log("Perform an initial RPC and expect it to succeed...")
if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil {
@ -1314,22 +1312,21 @@ func (s) TestGRPCLBWithTargetNameFieldInConfig(t *testing.T) {
// Push the backend address to the remote balancer.
tss.ls.sls <- sl
cc, err := grpc.NewClient(r.Scheme()+":///"+beServerName,
cc, err := grpc.Dial(r.Scheme()+":///"+beServerName,
grpc.WithResolvers(r),
grpc.WithTransportCredentials(&serverNameCheckCreds{}),
grpc.WithContextDialer(fakeNameDialer),
grpc.WithUserAgent(testUserAgent))
if err != nil {
t.Fatalf("Failed to create a client for the backend %v", err)
t.Fatalf("Failed to dial to the backend %v", err)
}
defer cc.Close()
cc.Connect()
testC := testgrpc.NewTestServiceClient(cc)
// Push a resolver update with grpclb configuration which does not contain the
// target_name field. Our fake remote balancer is configured to always
// expect `beServerName` as the server name in the initial request.
rs := grpclbstate.Set(resolver.State{ServiceConfig: r.CC().ParseServiceConfig(grpclbConfig)},
rs := grpclbstate.Set(resolver.State{ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig)},
&grpclbstate.State{BalancerAddresses: []resolver.Address{{
Addr: tss.lbAddr,
ServerName: lbServerName,
@ -1366,7 +1363,7 @@ func (s) TestGRPCLBWithTargetNameFieldInConfig(t *testing.T) {
},
},
}
rs = grpclbstate.Set(resolver.State{ServiceConfig: r.CC().ParseServiceConfig(lbCfg)}, s)
rs = grpclbstate.Set(resolver.State{ServiceConfig: r.CC.ParseServiceConfig(lbCfg)}, s)
r.UpdateState(rs)
select {
case <-ctx.Done():
@ -1381,7 +1378,7 @@ func (s) TestGRPCLBWithTargetNameFieldInConfig(t *testing.T) {
type failPreRPCCred struct{}
func (failPreRPCCred) GetRequestMetadata(_ context.Context, uri ...string) (map[string]string, error) {
func (failPreRPCCred) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
if strings.Contains(uri[0], failtosendURI) {
return nil, fmt.Errorf("rpc should fail to send")
}
@ -1422,21 +1419,22 @@ func runAndCheckStats(t *testing.T, drop bool, statsChan chan *lbpb.ClientStats,
tss.ls.statsDura = 100 * time.Millisecond
creds := serverNameCheckCreds{}
cc, err := grpc.NewClient(r.Scheme()+":///"+beServerName, grpc.WithResolvers(r),
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r),
grpc.WithTransportCredentials(&creds),
grpc.WithPerRPCCredentials(failPreRPCCred{}),
grpc.WithContextDialer(fakeNameDialer))
if err != nil {
t.Fatalf("Failed to create a client for the backend %v", err)
t.Fatalf("Failed to dial to the backend %v", err)
}
cc.Connect()
defer cc.Close()
rstate := resolver.State{ServiceConfig: r.CC().ParseServiceConfig(grpclbConfig)}
r.UpdateState(grpclbstate.Set(rstate, &grpclbstate.State{BalancerAddresses: []resolver.Address{{
r.UpdateState(resolver.State{Addresses: []resolver.Address{{
Addr: tss.lbAddr,
Type: resolver.GRPCLB,
ServerName: lbServerName,
}}}))
}}})
runRPCs(cc)
end := time.Now().Add(time.Second)
@ -1621,7 +1619,7 @@ func (s) TestGRPCLBStatsStreamingFailedToSend(t *testing.T) {
func (s) TestGRPCLBStatsQuashEmpty(t *testing.T) {
ch := make(chan *lbpb.ClientStats)
defer close(ch)
if err := runAndCheckStats(t, false, ch, func(*grpc.ClientConn) {
if err := runAndCheckStats(t, false, ch, func(cc *grpc.ClientConn) {
// Perform no RPCs; wait for load reports to start, which should be
// zero, then expect no other load report within 5x the update
// interval.

View File

@ -27,15 +27,75 @@ import (
"google.golang.org/grpc/resolver"
)
// The parent ClientConn should re-resolve when grpclb loses connection to the
// remote balancer. When the ClientConn inside grpclb gets a TransientFailure,
// it calls lbManualResolver.ResolveNow(), which calls parent ClientConn's
// ResolveNow, and eventually results in re-resolve happening in parent
// ClientConn's resolver (DNS for example).
//
// parent
// ClientConn
// +-----------------------------------------------------------------+
// | parent +---------------------------------+ |
// | DNS ClientConn | grpclb | |
// | resolver balancerWrapper | | |
// | + + | grpclb grpclb | |
// | | | | ManualResolver ClientConn | |
// | | | | + + | |
// | | | | | | Transient | |
// | | | | | | Failure | |
// | | | | | <--------- | | |
// | | | <--------------- | ResolveNow | | |
// | | <--------- | ResolveNow | | | | |
// | | ResolveNow | | | | | |
// | | | | | | | |
// | + + | + + | |
// | +---------------------------------+ |
// +-----------------------------------------------------------------+
// lbManualResolver is used by the ClientConn inside grpclb. It's a manual
// resolver with a special ResolveNow() function.
//
// When ResolveNow() is called, it calls ResolveNow() on the parent ClientConn,
// so when grpclb client lose contact with remote balancers, the parent
// ClientConn's resolver will re-resolve.
type lbManualResolver struct {
scheme string
ccr resolver.ClientConn
ccb balancer.ClientConn
}
func (r *lbManualResolver) Build(_ resolver.Target, cc resolver.ClientConn, _ resolver.BuildOptions) (resolver.Resolver, error) {
r.ccr = cc
return r, nil
}
func (r *lbManualResolver) Scheme() string {
return r.scheme
}
// ResolveNow calls resolveNow on the parent ClientConn.
func (r *lbManualResolver) ResolveNow(o resolver.ResolveNowOptions) {
r.ccb.ResolveNow(o)
}
// Close is a noop for Resolver.
func (*lbManualResolver) Close() {}
// UpdateState calls cc.UpdateState.
func (r *lbManualResolver) UpdateState(s resolver.State) {
r.ccr.UpdateState(s)
}
const subConnCacheTime = time.Second * 10
// lbCacheClientConn is a wrapper balancer.ClientConn with a SubConn cache.
// SubConns will be kept in cache for subConnCacheTime before being shut down.
// SubConns will be kept in cache for subConnCacheTime before being removed.
//
// Its NewSubconn and SubConn.Shutdown methods are updated to do cache first.
// Its new and remove methods are updated to do cache first.
type lbCacheClientConn struct {
balancer.ClientConn
cc balancer.ClientConn
timeout time.Duration
mu sync.Mutex
@ -53,7 +113,7 @@ type subConnCacheEntry struct {
func newLBCacheClientConn(cc balancer.ClientConn) *lbCacheClientConn {
return &lbCacheClientConn{
ClientConn: cc,
cc: cc,
timeout: subConnCacheTime,
subConnCache: make(map[resolver.Address]*subConnCacheEntry),
subConnToAddr: make(map[balancer.SubConn]resolver.Address),
@ -77,27 +137,16 @@ func (ccc *lbCacheClientConn) NewSubConn(addrs []resolver.Address, opts balancer
return entry.sc, nil
}
scNew, err := ccc.ClientConn.NewSubConn(addrs, opts)
scNew, err := ccc.cc.NewSubConn(addrs, opts)
if err != nil {
return nil, err
}
scNew = &lbCacheSubConn{SubConn: scNew, ccc: ccc}
ccc.subConnToAddr[scNew] = addrWithoutAttrs
return scNew, nil
}
func (ccc *lbCacheClientConn) RemoveSubConn(sc balancer.SubConn) {
logger.Errorf("RemoveSubConn(%v) called unexpectedly", sc)
}
type lbCacheSubConn struct {
balancer.SubConn
ccc *lbCacheClientConn
}
func (sc *lbCacheSubConn) Shutdown() {
ccc := sc.ccc
ccc.mu.Lock()
defer ccc.mu.Unlock()
addr, ok := ccc.subConnToAddr[sc]
@ -107,11 +156,11 @@ func (sc *lbCacheSubConn) Shutdown() {
if entry, ok := ccc.subConnCache[addr]; ok {
if entry.sc != sc {
// This could happen if NewSubConn was called multiple times for
// the same address, and those SubConns are all shut down. We
// remove sc immediately here.
// This could happen if NewSubConn was called multiple times for the
// same address, and those SubConns are all removed. We remove sc
// immediately here.
delete(ccc.subConnToAddr, sc)
sc.SubConn.Shutdown()
ccc.cc.RemoveSubConn(sc)
}
return
}
@ -127,7 +176,7 @@ func (sc *lbCacheSubConn) Shutdown() {
if entry.abortDeleting {
return
}
sc.SubConn.Shutdown()
ccc.cc.RemoveSubConn(sc)
delete(ccc.subConnToAddr, sc)
delete(ccc.subConnCache, addr)
})
@ -146,28 +195,14 @@ func (sc *lbCacheSubConn) Shutdown() {
}
func (ccc *lbCacheClientConn) UpdateState(s balancer.State) {
s.Picker = &lbCachePicker{Picker: s.Picker}
ccc.ClientConn.UpdateState(s)
ccc.cc.UpdateState(s)
}
func (ccc *lbCacheClientConn) close() {
ccc.mu.Lock()
defer ccc.mu.Unlock()
// Only cancel all existing timers. There's no need to shut down SubConns.
// Only cancel all existing timers. There's no need to remove SubConns.
for _, entry := range ccc.subConnCache {
entry.cancel()
}
}
type lbCachePicker struct {
balancer.Picker
}
func (cp *lbCachePicker) Pick(i balancer.PickInfo) (balancer.PickResult, error) {
res, err := cp.Picker.Pick(i)
if err != nil {
return res, err
}
res.SubConn = res.SubConn.(*lbCacheSubConn).SubConn
return res, nil
ccc.mu.Unlock()
}

View File

@ -30,13 +30,6 @@ import (
type mockSubConn struct {
balancer.SubConn
mcc *mockClientConn
}
func (msc *mockSubConn) Shutdown() {
msc.mcc.mu.Lock()
defer msc.mcc.mu.Unlock()
delete(msc.mcc.subConns, msc)
}
type mockClientConn struct {
@ -52,8 +45,8 @@ func newMockClientConn() *mockClientConn {
}
}
func (mcc *mockClientConn) NewSubConn(addrs []resolver.Address, _ balancer.NewSubConnOptions) (balancer.SubConn, error) {
sc := &mockSubConn{mcc: mcc}
func (mcc *mockClientConn) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) {
sc := &mockSubConn{}
mcc.mu.Lock()
defer mcc.mu.Unlock()
mcc.subConns[sc] = addrs[0]
@ -61,7 +54,9 @@ func (mcc *mockClientConn) NewSubConn(addrs []resolver.Address, _ balancer.NewSu
}
func (mcc *mockClientConn) RemoveSubConn(sc balancer.SubConn) {
panic(fmt.Sprintf("RemoveSubConn(%v) called unexpectedly", sc))
mcc.mu.Lock()
defer mcc.mu.Unlock()
delete(mcc.subConns, sc)
}
const testCacheTimeout = 100 * time.Millisecond
@ -87,7 +82,7 @@ func checkCacheCC(ccc *lbCacheClientConn, sccLen, sctaLen int) error {
return nil
}
// Test that SubConn won't be immediately shut down.
// Test that SubConn won't be immediately removed.
func (s) TestLBCacheClientConnExpire(t *testing.T) {
mcc := newMockClientConn()
if err := checkMockCC(mcc, 0); err != nil {
@ -110,7 +105,7 @@ func (s) TestLBCacheClientConnExpire(t *testing.T) {
t.Fatal(err)
}
sc.Shutdown()
ccc.RemoveSubConn(sc)
// One subconn in MockCC before timeout.
if err := checkMockCC(mcc, 1); err != nil {
t.Fatal(err)
@ -138,7 +133,7 @@ func (s) TestLBCacheClientConnExpire(t *testing.T) {
}
}
// Test that NewSubConn with the same address of a SubConn being shut down will
// Test that NewSubConn with the same address of a SubConn being removed will
// reuse the SubConn and cancel the removing.
func (s) TestLBCacheClientConnReuse(t *testing.T) {
mcc := newMockClientConn()
@ -162,7 +157,7 @@ func (s) TestLBCacheClientConnReuse(t *testing.T) {
t.Fatal(err)
}
sc.Shutdown()
ccc.RemoveSubConn(sc)
// One subconn in MockCC before timeout.
if err := checkMockCC(mcc, 1); err != nil {
t.Fatal(err)
@ -195,8 +190,8 @@ func (s) TestLBCacheClientConnReuse(t *testing.T) {
t.Fatal(err)
}
// Call Shutdown again, will delete after timeout.
sc.Shutdown()
// Call remove again, will delete after timeout.
ccc.RemoveSubConn(sc)
// One subconn in MockCC before timeout.
if err := checkMockCC(mcc, 1); err != nil {
t.Fatal(err)
@ -223,9 +218,9 @@ func (s) TestLBCacheClientConnReuse(t *testing.T) {
}
}
// Test that if the timer to shut down a SubConn fires at the same time
// NewSubConn cancels the timer, it doesn't cause deadlock.
func (s) TestLBCache_ShutdownTimer_New_Race(t *testing.T) {
// Test that if the timer to remove a SubConn fires at the same time NewSubConn
// cancels the timer, it doesn't cause deadlock.
func (s) TestLBCache_RemoveTimer_New_Race(t *testing.T) {
mcc := newMockClientConn()
if err := checkMockCC(mcc, 0); err != nil {
t.Fatal(err)
@ -251,9 +246,9 @@ func (s) TestLBCache_ShutdownTimer_New_Race(t *testing.T) {
go func() {
for i := 0; i < 1000; i++ {
// Shutdown starts a timer with 1 ns timeout, the NewSubConn will
// race with the timer.
sc.Shutdown()
// Remove starts a timer with 1 ns timeout, the NewSubConn will race
// with with the timer.
ccc.RemoveSubConn(sc)
sc, _ = ccc.NewSubConn([]resolver.Address{{Addr: "address1"}}, balancer.NewSubConnOptions{})
}
close(done)

View File

@ -1,157 +0,0 @@
/*
*
* Copyright 2025 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package lazy contains a load balancer that starts in IDLE instead of
// CONNECTING. Once it starts connecting, it instantiates its delegate.
//
// # Experimental
//
// Notice: This package is EXPERIMENTAL and may be changed or removed in a
// later release.
package lazy
import (
"fmt"
"sync"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/resolver"
internalgrpclog "google.golang.org/grpc/internal/grpclog"
)
var (
logger = grpclog.Component("lazy-lb")
)
const (
logPrefix = "[lazy-lb %p] "
)
// ChildBuilderFunc creates a new balancer with the ClientConn. It has the same
// type as the balancer.Builder.Build method.
type ChildBuilderFunc func(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer
// NewBalancer is the constructor for the lazy balancer.
func NewBalancer(cc balancer.ClientConn, bOpts balancer.BuildOptions, childBuilder ChildBuilderFunc) balancer.Balancer {
b := &lazyBalancer{
cc: cc,
buildOptions: bOpts,
childBuilder: childBuilder,
}
b.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf(logPrefix, b))
cc.UpdateState(balancer.State{
ConnectivityState: connectivity.Idle,
Picker: &idlePicker{exitIdle: sync.OnceFunc(func() {
// Call ExitIdle in a new goroutine to avoid deadlocks while calling
// back into the channel synchronously.
go b.ExitIdle()
})},
})
return b
}
type lazyBalancer struct {
// The following fields are initialized at build time and read-only after
// that and therefore do not need to be guarded by a mutex.
cc balancer.ClientConn
buildOptions balancer.BuildOptions
logger *internalgrpclog.PrefixLogger
childBuilder ChildBuilderFunc
// The following fields are accessed while handling calls to the idlePicker
// and when handling ClientConn state updates. They are guarded by a mutex.
mu sync.Mutex
delegate balancer.Balancer
latestClientConnState *balancer.ClientConnState
latestResolverError error
}
func (lb *lazyBalancer) Close() {
lb.mu.Lock()
defer lb.mu.Unlock()
if lb.delegate != nil {
lb.delegate.Close()
lb.delegate = nil
}
}
func (lb *lazyBalancer) ResolverError(err error) {
lb.mu.Lock()
defer lb.mu.Unlock()
if lb.delegate != nil {
lb.delegate.ResolverError(err)
return
}
lb.latestResolverError = err
}
func (lb *lazyBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error {
lb.mu.Lock()
defer lb.mu.Unlock()
if lb.delegate != nil {
return lb.delegate.UpdateClientConnState(ccs)
}
lb.latestClientConnState = &ccs
lb.latestResolverError = nil
return nil
}
// UpdateSubConnState implements balancer.Balancer.
func (lb *lazyBalancer) UpdateSubConnState(balancer.SubConn, balancer.SubConnState) {
// UpdateSubConnState is deprecated.
}
func (lb *lazyBalancer) ExitIdle() {
lb.mu.Lock()
defer lb.mu.Unlock()
if lb.delegate != nil {
lb.delegate.ExitIdle()
return
}
lb.delegate = lb.childBuilder(lb.cc, lb.buildOptions)
if lb.latestClientConnState != nil {
if err := lb.delegate.UpdateClientConnState(*lb.latestClientConnState); err != nil {
if err == balancer.ErrBadResolverState {
lb.cc.ResolveNow(resolver.ResolveNowOptions{})
} else {
lb.logger.Warningf("Error from child policy on receiving initial state: %v", err)
}
}
lb.latestClientConnState = nil
}
if lb.latestResolverError != nil {
lb.delegate.ResolverError(lb.latestResolverError)
lb.latestResolverError = nil
}
}
// idlePicker is used when the SubConn is IDLE and kicks the SubConn into
// CONNECTING when Pick is called.
type idlePicker struct {
exitIdle func()
}
func (i *idlePicker) Pick(balancer.PickInfo) (balancer.PickResult, error) {
i.exitIdle()
return balancer.PickResult{}, balancer.ErrNoSubConnAvailable
}

View File

@ -1,466 +0,0 @@
/*
*
* Copyright 2025 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package lazy_test
import (
"context"
"errors"
"fmt"
"strings"
"testing"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/lazy"
"google.golang.org/grpc/balancer/pickfirst/pickfirstleaf"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/internal/balancer/stub"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
testgrpc "google.golang.org/grpc/interop/grpc_testing"
testpb "google.golang.org/grpc/interop/grpc_testing"
)
const (
// Default timeout for tests in this package.
defaultTestTimeout = 10 * time.Second
// Default short timeout, to be used when waiting for events which are not
// expected to happen.
defaultTestShortTimeout = 100 * time.Millisecond
)
type s struct {
grpctest.Tester
}
func Test(t *testing.T) {
grpctest.RunSubTests(t, s{})
}
// TestExitIdle creates a lazy balancer than manages a pickfirst child. The test
// calls Connect() on the channel which in turn calls ExitIdle on the lazy
// balancer. The test verifies that the channel enters READY.
func (s) TestExitIdle(t *testing.T) {
backend1 := stubserver.StartTestService(t, nil)
defer backend1.Stop()
mr := manual.NewBuilderWithScheme("e2e-test")
defer mr.Close()
mr.InitialState(resolver.State{
Endpoints: []resolver.Endpoint{
{Addresses: []resolver.Address{{Addr: backend1.Address}}},
},
})
bf := stub.BalancerFuncs{
Init: func(bd *stub.BalancerData) {
bd.ChildBalancer = lazy.NewBalancer(bd.ClientConn, bd.BuildOptions, balancer.Get(pickfirstleaf.Name).Build)
},
ExitIdle: func(bd *stub.BalancerData) {
bd.ChildBalancer.ExitIdle()
},
ResolverError: func(bd *stub.BalancerData, err error) {
bd.ChildBalancer.ResolverError(err)
},
UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
return bd.ChildBalancer.UpdateClientConnState(ccs)
},
Close: func(bd *stub.BalancerData) {
bd.ChildBalancer.Close()
},
}
stub.Register(t.Name(), bf)
json := fmt.Sprintf(`{"loadBalancingConfig": [{"%s": {}}]}`, t.Name())
opts := []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultServiceConfig(json),
grpc.WithResolvers(mr),
}
cc, err := grpc.NewClient(mr.Scheme()+":///", opts...)
if err != nil {
t.Fatalf("grpc.NewClient(_) failed: %v", err)
}
defer cc.Close()
cc.Connect()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
testutils.AwaitState(ctx, t, cc, connectivity.Ready)
// Send a resolver update to verify that the resolver state is correctly
// passed through to the leaf pickfirst balancer.
backend2 := stubserver.StartTestService(t, nil)
defer backend2.Stop()
mr.UpdateState(resolver.State{
Endpoints: []resolver.Endpoint{
{Addresses: []resolver.Address{{Addr: backend2.Address}}},
},
})
var peer peer.Peer
client := testgrpc.NewTestServiceClient(cc)
if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.Peer(&peer)); err != nil {
t.Errorf("client.EmptyCall() returned unexpected error: %v", err)
}
if got, want := peer.Addr.String(), backend2.Address; got != want {
t.Errorf("EmptyCall() went to unexpected backend: got %q, want %q", got, want)
}
}
// TestPicker creates a lazy balancer under a stub balancer which block all
// calls to ExitIdle. This ensures the only way to trigger lazy to exit idle is
// through the picker. The test makes an RPC and ensures it succeeds.
func (s) TestPicker(t *testing.T) {
backend := stubserver.StartTestService(t, nil)
defer backend.Stop()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
bf := stub.BalancerFuncs{
Init: func(bd *stub.BalancerData) {
bd.ChildBalancer = lazy.NewBalancer(bd.ClientConn, bd.BuildOptions, balancer.Get(pickfirstleaf.Name).Build)
},
ExitIdle: func(*stub.BalancerData) {
t.Log("Ignoring call to ExitIdle, calling the picker should make the lazy balancer exit IDLE state.")
},
UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
return bd.ChildBalancer.UpdateClientConnState(ccs)
},
Close: func(bd *stub.BalancerData) {
bd.ChildBalancer.Close()
},
}
name := strings.ReplaceAll(strings.ToLower(t.Name()), "/", "")
stub.Register(name, bf)
json := fmt.Sprintf(`{"loadBalancingConfig": [{%q: {}}]}`, name)
opts := []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultServiceConfig(json),
}
cc, err := grpc.NewClient(backend.Address, opts...)
if err != nil {
t.Fatalf("grpc.NewClient(_) failed: %v", err)
}
defer cc.Close()
// The channel should remain in IDLE as the ExitIdle calls are not
// propagated to the lazy balancer from the stub balancer.
cc.Connect()
shortCtx, shortCancel := context.WithTimeout(ctx, defaultTestShortTimeout)
defer shortCancel()
testutils.AwaitNoStateChange(shortCtx, t, cc, connectivity.Idle)
// The picker from the lazy balancer should be send to the channel when the
// first resolver update is received by lazy. Making an RPC should trigger
// child creation.
client := testgrpc.NewTestServiceClient(cc)
if _, err := client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
t.Errorf("client.EmptyCall() returned unexpected error: %v", err)
}
}
// Tests the scenario when a resolver produces a good state followed by a
// resolver error. The test verifies that the child balancer receives the good
// update followed by the error.
func (s) TestGoodUpdateThenResolverError(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
backend := stubserver.StartTestService(t, nil)
defer backend.Stop()
resolverStateReceived := false
resolverErrorReceived := grpcsync.NewEvent()
childBF := stub.BalancerFuncs{
Init: func(bd *stub.BalancerData) {
bd.ChildBalancer = balancer.Get(pickfirstleaf.Name).Build(bd.ClientConn, bd.BuildOptions)
},
UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
if resolverErrorReceived.HasFired() {
t.Error("Received resolver error before resolver state.")
}
resolverStateReceived = true
return bd.ChildBalancer.UpdateClientConnState(ccs)
},
ResolverError: func(bd *stub.BalancerData, err error) {
if !resolverStateReceived {
t.Error("Received resolver error before resolver state.")
}
resolverErrorReceived.Fire()
bd.ChildBalancer.ResolverError(err)
},
Close: func(bd *stub.BalancerData) {
bd.ChildBalancer.Close()
},
}
childBalName := strings.ReplaceAll(strings.ToLower(t.Name())+"_child", "/", "")
stub.Register(childBalName, childBF)
topLevelBF := stub.BalancerFuncs{
Init: func(bd *stub.BalancerData) {
bd.ChildBalancer = lazy.NewBalancer(bd.ClientConn, bd.BuildOptions, balancer.Get(childBalName).Build)
},
ExitIdle: func(*stub.BalancerData) {
t.Log("Ignoring call to ExitIdle to delay lazy child creation until RPC time.")
},
ResolverError: func(bd *stub.BalancerData, err error) {
bd.ChildBalancer.ResolverError(err)
},
UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
return bd.ChildBalancer.UpdateClientConnState(ccs)
},
Close: func(bd *stub.BalancerData) {
bd.ChildBalancer.Close()
},
}
topLevelBalName := strings.ReplaceAll(strings.ToLower(t.Name())+"_top_level", "/", "")
stub.Register(topLevelBalName, topLevelBF)
json := fmt.Sprintf(`{"loadBalancingConfig": [{%q: {}}]}`, topLevelBalName)
mr := manual.NewBuilderWithScheme("e2e-test")
defer mr.Close()
mr.InitialState(resolver.State{
Endpoints: []resolver.Endpoint{
{Addresses: []resolver.Address{{Addr: backend.Address}}},
},
})
opts := []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithResolvers(mr),
grpc.WithDefaultServiceConfig(json),
}
cc, err := grpc.NewClient(mr.Scheme()+":///whatever", opts...)
if err != nil {
t.Fatalf("grpc.NewClient(_) failed: %v", err)
}
defer cc.Close()
cc.Connect()
mr.CC().ReportError(errors.New("test error"))
// The channel should remain in IDLE as the ExitIdle calls are not
// propagated to the lazy balancer from the stub balancer.
shortCtx, shortCancel := context.WithTimeout(ctx, defaultTestShortTimeout)
defer shortCancel()
testutils.AwaitNoStateChange(shortCtx, t, cc, connectivity.Idle)
client := testgrpc.NewTestServiceClient(cc)
if _, err := client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
t.Errorf("client.EmptyCall() returned unexpected error: %v", err)
}
if !resolverStateReceived {
t.Fatalf("Child balancer did not receive resolver state.")
}
select {
case <-resolverErrorReceived.Done():
case <-ctx.Done():
t.Fatal("Context timed out waiting for resolver error to be delivered to child balancer.")
}
}
// Tests the scenario when a resolver produces a list of endpoints followed by
// a resolver error. The test verifies that the child balancer receives only the
// good update.
func (s) TestResolverErrorThenGoodUpdate(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
backend := stubserver.StartTestService(t, nil)
defer backend.Stop()
childBF := stub.BalancerFuncs{
Init: func(bd *stub.BalancerData) {
bd.ChildBalancer = balancer.Get(pickfirstleaf.Name).Build(bd.ClientConn, bd.BuildOptions)
},
UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
return bd.ChildBalancer.UpdateClientConnState(ccs)
},
ResolverError: func(bd *stub.BalancerData, err error) {
t.Error("Received unexpected resolver error.")
bd.ChildBalancer.ResolverError(err)
},
Close: func(bd *stub.BalancerData) {
bd.ChildBalancer.Close()
},
}
childBalName := strings.ReplaceAll(strings.ToLower(t.Name())+"_child", "/", "")
stub.Register(childBalName, childBF)
topLevelBF := stub.BalancerFuncs{
Init: func(bd *stub.BalancerData) {
bd.ChildBalancer = lazy.NewBalancer(bd.ClientConn, bd.BuildOptions, balancer.Get(childBalName).Build)
},
ExitIdle: func(*stub.BalancerData) {
t.Log("Ignoring call to ExitIdle to delay lazy child creation until RPC time.")
},
UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
return bd.ChildBalancer.UpdateClientConnState(ccs)
},
Close: func(bd *stub.BalancerData) {
bd.ChildBalancer.Close()
},
}
topLevelBalName := strings.ReplaceAll(strings.ToLower(t.Name())+"_top_level", "/", "")
stub.Register(topLevelBalName, topLevelBF)
json := fmt.Sprintf(`{"loadBalancingConfig": [{%q: {}}]}`, topLevelBalName)
mr := manual.NewBuilderWithScheme("e2e-test")
defer mr.Close()
mr.InitialState(resolver.State{
Endpoints: []resolver.Endpoint{
{Addresses: []resolver.Address{{Addr: backend.Address}}},
},
})
opts := []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithResolvers(mr),
grpc.WithDefaultServiceConfig(json),
}
cc, err := grpc.NewClient(mr.Scheme()+":///whatever", opts...)
if err != nil {
t.Fatalf("grpc.NewClient(_) failed: %v", err)
}
defer cc.Close()
cc.Connect()
// Send an error followed by a good update.
mr.CC().ReportError(errors.New("test error"))
mr.UpdateState(resolver.State{
Endpoints: []resolver.Endpoint{
{Addresses: []resolver.Address{{Addr: backend.Address}}},
},
})
// The channel should remain in IDLE as the ExitIdle calls are not
// propagated to the lazy balancer from the stub balancer.
shortCtx, shortCancel := context.WithTimeout(ctx, defaultTestShortTimeout)
defer shortCancel()
testutils.AwaitNoStateChange(shortCtx, t, cc, connectivity.Idle)
// An RPC would succeed only if the leaf pickfirst receives the endpoint
// list.
client := testgrpc.NewTestServiceClient(cc)
if _, err := client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
t.Errorf("client.EmptyCall() returned unexpected error: %v", err)
}
}
// Tests that ExitIdle calls are correctly passed through to the child balancer.
// It starts a backend and ensures the channel connects to it. The test then
// stops the backend, making the channel enter IDLE. The test calls Connect on
// the channel and verifies that the child balancer exits idle.
func (s) TestExitIdlePassthrough(t *testing.T) {
backend1 := stubserver.StartTestService(t, nil)
defer backend1.Stop()
mr := manual.NewBuilderWithScheme("e2e-test")
defer mr.Close()
mr.InitialState(resolver.State{
Endpoints: []resolver.Endpoint{
{Addresses: []resolver.Address{{Addr: backend1.Address}}},
},
})
bf := stub.BalancerFuncs{
Init: func(bd *stub.BalancerData) {
bd.ChildBalancer = lazy.NewBalancer(bd.ClientConn, bd.BuildOptions, balancer.Get(pickfirstleaf.Name).Build)
},
ExitIdle: func(bd *stub.BalancerData) {
bd.ChildBalancer.ExitIdle()
},
ResolverError: func(bd *stub.BalancerData, err error) {
bd.ChildBalancer.ResolverError(err)
},
UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
return bd.ChildBalancer.UpdateClientConnState(ccs)
},
Close: func(bd *stub.BalancerData) {
bd.ChildBalancer.Close()
},
}
stub.Register(t.Name(), bf)
json := fmt.Sprintf(`{"loadBalancingConfig": [{"%s": {}}]}`, t.Name())
opts := []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultServiceConfig(json),
grpc.WithResolvers(mr),
}
cc, err := grpc.NewClient(mr.Scheme()+":///", opts...)
if err != nil {
t.Fatalf("grpc.NewClient(_) failed: %v", err)
}
defer cc.Close()
cc.Connect()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
testutils.AwaitState(ctx, t, cc, connectivity.Ready)
// Stopping the active backend should put the channel in IDLE.
backend1.Stop()
testutils.AwaitState(ctx, t, cc, connectivity.Idle)
// Sending a new backend address should not kick the channel out of IDLE.
// On calling cc.Connect(), the channel should call ExitIdle on the lazy
// balancer which passes through the call to the leaf pickfirst.
backend2 := stubserver.StartTestService(t, nil)
defer backend2.Stop()
mr.UpdateState(resolver.State{
Endpoints: []resolver.Endpoint{
{Addresses: []resolver.Address{{Addr: backend2.Address}}},
},
})
shortCtx, shortCancel := context.WithTimeout(ctx, defaultTestShortTimeout)
defer shortCancel()
testutils.AwaitNoStateChange(shortCtx, t, cc, connectivity.Idle)
cc.Connect()
testutils.AwaitState(ctx, t, cc, connectivity.Ready)
}

View File

@ -1,250 +0,0 @@
/*
*
* Copyright 2023 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package leastrequest implements a least request load balancer.
package leastrequest
import (
"encoding/json"
"fmt"
rand "math/rand/v2"
"sync"
"sync/atomic"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/endpointsharding"
"google.golang.org/grpc/balancer/pickfirst/pickfirstleaf"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/grpclog"
internalgrpclog "google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig"
)
// Name is the name of the least request balancer.
const Name = "least_request_experimental"
var (
// randuint32 is a global to stub out in tests.
randuint32 = rand.Uint32
logger = grpclog.Component("least-request")
)
func init() {
balancer.Register(bb{})
}
// LBConfig is the balancer config for least_request_experimental balancer.
type LBConfig struct {
serviceconfig.LoadBalancingConfig `json:"-"`
// ChoiceCount is the number of random SubConns to sample to find the one
// with the fewest outstanding requests. If unset, defaults to 2. If set to
// < 2, the config will be rejected, and if set to > 10, will become 10.
ChoiceCount uint32 `json:"choiceCount,omitempty"`
}
type bb struct{}
func (bb) ParseConfig(s json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
lbConfig := &LBConfig{
ChoiceCount: 2,
}
if err := json.Unmarshal(s, lbConfig); err != nil {
return nil, fmt.Errorf("least-request: unable to unmarshal LBConfig: %v", err)
}
// "If `choice_count < 2`, the config will be rejected." - A48
if lbConfig.ChoiceCount < 2 { // sweet
return nil, fmt.Errorf("least-request: lbConfig.choiceCount: %v, must be >= 2", lbConfig.ChoiceCount)
}
// "If a LeastRequestLoadBalancingConfig with a choice_count > 10 is
// received, the least_request_experimental policy will set choice_count =
// 10." - A48
if lbConfig.ChoiceCount > 10 {
lbConfig.ChoiceCount = 10
}
return lbConfig, nil
}
func (bb) Name() string {
return Name
}
func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer {
b := &leastRequestBalancer{
ClientConn: cc,
endpointRPCCounts: resolver.NewEndpointMap[*atomic.Int32](),
}
b.child = endpointsharding.NewBalancer(b, bOpts, balancer.Get(pickfirstleaf.Name).Build, endpointsharding.Options{})
b.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf("[%p] ", b))
b.logger.Infof("Created")
return b
}
type leastRequestBalancer struct {
// Embeds balancer.ClientConn because we need to intercept UpdateState
// calls from the child balancer.
balancer.ClientConn
child balancer.Balancer
logger *internalgrpclog.PrefixLogger
mu sync.Mutex
choiceCount uint32
// endpointRPCCounts holds RPC counts to keep track for subsequent picker
// updates.
endpointRPCCounts *resolver.EndpointMap[*atomic.Int32]
}
func (lrb *leastRequestBalancer) Close() {
lrb.child.Close()
lrb.endpointRPCCounts = nil
}
func (lrb *leastRequestBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) {
lrb.logger.Errorf("UpdateSubConnState(%v, %+v) called unexpectedly", sc, state)
}
func (lrb *leastRequestBalancer) ResolverError(err error) {
// Will cause inline picker update from endpoint sharding.
lrb.child.ResolverError(err)
}
func (lrb *leastRequestBalancer) ExitIdle() {
lrb.child.ExitIdle()
}
func (lrb *leastRequestBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error {
lrCfg, ok := ccs.BalancerConfig.(*LBConfig)
if !ok {
logger.Errorf("least-request: received config with unexpected type %T: %v", ccs.BalancerConfig, ccs.BalancerConfig)
return balancer.ErrBadResolverState
}
lrb.mu.Lock()
lrb.choiceCount = lrCfg.ChoiceCount
lrb.mu.Unlock()
return lrb.child.UpdateClientConnState(balancer.ClientConnState{
// Enable the health listener in pickfirst children for client side health
// checks and outlier detection, if configured.
ResolverState: pickfirstleaf.EnableHealthListener(ccs.ResolverState),
})
}
type endpointState struct {
picker balancer.Picker
numRPCs *atomic.Int32
}
func (lrb *leastRequestBalancer) UpdateState(state balancer.State) {
var readyEndpoints []endpointsharding.ChildState
for _, child := range endpointsharding.ChildStatesFromPicker(state.Picker) {
if child.State.ConnectivityState == connectivity.Ready {
readyEndpoints = append(readyEndpoints, child)
}
}
// If no ready pickers are present, simply defer to the round robin picker
// from endpoint sharding, which will round robin across the most relevant
// pick first children in the highest precedence connectivity state.
if len(readyEndpoints) == 0 {
lrb.ClientConn.UpdateState(state)
return
}
lrb.mu.Lock()
defer lrb.mu.Unlock()
if logger.V(2) {
lrb.logger.Infof("UpdateState called with ready endpoints: %v", readyEndpoints)
}
// Reconcile endpoints.
newEndpoints := resolver.NewEndpointMap[any]()
for _, child := range readyEndpoints {
newEndpoints.Set(child.Endpoint, nil)
}
// If endpoints are no longer ready, no need to count their active RPCs.
for _, endpoint := range lrb.endpointRPCCounts.Keys() {
if _, ok := newEndpoints.Get(endpoint); !ok {
lrb.endpointRPCCounts.Delete(endpoint)
}
}
// Copy refs to counters into picker.
endpointStates := make([]endpointState, 0, len(readyEndpoints))
for _, child := range readyEndpoints {
counter, ok := lrb.endpointRPCCounts.Get(child.Endpoint)
if !ok {
// Create new counts if needed.
counter = new(atomic.Int32)
lrb.endpointRPCCounts.Set(child.Endpoint, counter)
}
endpointStates = append(endpointStates, endpointState{
picker: child.State.Picker,
numRPCs: counter,
})
}
lrb.ClientConn.UpdateState(balancer.State{
Picker: &picker{
choiceCount: lrb.choiceCount,
endpointStates: endpointStates,
},
ConnectivityState: connectivity.Ready,
})
}
type picker struct {
// choiceCount is the number of random endpoints to sample for choosing the
// one with the least requests.
choiceCount uint32
endpointStates []endpointState
}
func (p *picker) Pick(pInfo balancer.PickInfo) (balancer.PickResult, error) {
var pickedEndpointState *endpointState
var pickedEndpointNumRPCs int32
for i := 0; i < int(p.choiceCount); i++ {
index := randuint32() % uint32(len(p.endpointStates))
endpointState := p.endpointStates[index]
n := endpointState.numRPCs.Load()
if pickedEndpointState == nil || n < pickedEndpointNumRPCs {
pickedEndpointState = &endpointState
pickedEndpointNumRPCs = n
}
}
result, err := pickedEndpointState.picker.Pick(pInfo)
if err != nil {
return result, err
}
// "The counter for a subchannel should be atomically incremented by one
// after it has been successfully picked by the picker." - A48
pickedEndpointState.numRPCs.Add(1)
// "the picker should add a callback for atomically decrementing the
// subchannel counter once the RPC finishes (regardless of Status code)." -
// A48.
originalDone := result.Done
result.Done = func(info balancer.DoneInfo) {
pickedEndpointState.numRPCs.Add(-1)
if originalDone != nil {
originalDone(info)
}
}
return result, nil
}

View File

@ -1,770 +0,0 @@
/*
*
* Copyright 2023 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package leastrequest
import (
"context"
"encoding/json"
"fmt"
"strings"
"sync"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/internal/testutils"
testgrpc "google.golang.org/grpc/interop/grpc_testing"
testpb "google.golang.org/grpc/interop/grpc_testing"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
"google.golang.org/grpc/serviceconfig"
)
const (
defaultTestTimeout = 5 * time.Second
defaultTestShortTimeout = 10 * time.Millisecond
)
type s struct {
grpctest.Tester
}
func Test(t *testing.T) {
grpctest.RunSubTests(t, s{})
}
func (s) TestParseConfig(t *testing.T) {
parser := bb{}
tests := []struct {
name string
input string
wantCfg serviceconfig.LoadBalancingConfig
wantErr string
}{
{
name: "happy-case-default",
input: `{}`,
wantCfg: &LBConfig{
ChoiceCount: 2,
},
},
{
name: "happy-case-choice-count-set",
input: `{"choiceCount": 3}`,
wantCfg: &LBConfig{
ChoiceCount: 3,
},
},
{
name: "happy-case-choice-count-greater-than-ten",
input: `{"choiceCount": 11}`,
wantCfg: &LBConfig{
ChoiceCount: 10,
},
},
{
name: "choice-count-less-than-2",
input: `{"choiceCount": 1}`,
wantErr: "must be >= 2",
},
{
name: "invalid-json",
input: "{{invalidjson{{",
wantErr: "invalid character",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
gotCfg, gotErr := parser.ParseConfig(json.RawMessage(test.input))
// Substring match makes this very tightly coupled to the
// internalserviceconfig.BalancerConfig error strings. However, it
// is important to distinguish the different types of error messages
// possible as the parser has a few defined buckets of ways it can
// error out.
if (gotErr != nil) != (test.wantErr != "") {
t.Fatalf("ParseConfig(%v) = %v, wantErr %v", test.input, gotErr, test.wantErr)
}
if gotErr != nil && !strings.Contains(gotErr.Error(), test.wantErr) {
t.Fatalf("ParseConfig(%v) = %v, wantErr %v", test.input, gotErr, test.wantErr)
}
if test.wantErr != "" {
return
}
if diff := cmp.Diff(gotCfg, test.wantCfg); diff != "" {
t.Fatalf("ParseConfig(%v) got unexpected output, diff (-got +want): %v", test.input, diff)
}
})
}
}
func startBackends(t *testing.T, numBackends int) []*stubserver.StubServer {
backends := make([]*stubserver.StubServer, 0, numBackends)
// Construct and start working backends.
for i := 0; i < numBackends; i++ {
backend := &stubserver.StubServer{
EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
return &testpb.Empty{}, nil
},
FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
<-stream.Context().Done()
return nil
},
}
if err := backend.StartServer(); err != nil {
t.Fatalf("Failed to start backend: %v", err)
}
t.Logf("Started good TestService backend at: %q", backend.Address)
t.Cleanup(func() { backend.Stop() })
backends = append(backends, backend)
}
return backends
}
// setupBackends spins up three test backends, each listening on a port on
// localhost. The three backends always reply with an empty response with no
// error, and for streaming receive until hitting an EOF error.
func setupBackends(t *testing.T, numBackends int) []string {
t.Helper()
addresses := make([]string, numBackends)
backends := startBackends(t, numBackends)
// Construct and start working backends.
for i := 0; i < numBackends; i++ {
addresses[i] = backends[i].Address
}
return addresses
}
// checkRoundRobinRPCs verifies that EmptyCall RPCs on the given ClientConn,
// connected to a server exposing the test.grpc_testing.TestService, are
// roundrobined across the given backend addresses.
//
// Returns a non-nil error if context deadline expires before RPCs start to get
// roundrobined across the given backends.
func checkRoundRobinRPCs(ctx context.Context, client testgrpc.TestServiceClient, addrs []resolver.Address) error {
wantAddrCount := make(map[string]int)
for _, addr := range addrs {
wantAddrCount[addr.Addr]++
}
gotAddrCount := make(map[string]int)
for ; ctx.Err() == nil; <-time.After(time.Millisecond) {
gotAddrCount = make(map[string]int)
// Perform 3 iterations.
var iterations [][]string
for i := 0; i < 3; i++ {
iteration := make([]string, len(addrs))
for c := 0; c < len(addrs); c++ {
var peer peer.Peer
client.EmptyCall(ctx, &testpb.Empty{}, grpc.Peer(&peer))
iteration[c] = peer.Addr.String()
}
iterations = append(iterations, iteration)
}
// Ensure the first iteration contains all addresses in addrs.
for _, addr := range iterations[0] {
gotAddrCount[addr]++
}
if !cmp.Equal(gotAddrCount, wantAddrCount) {
continue
}
// Ensure all three iterations contain the same addresses.
if !cmp.Equal(iterations[0], iterations[1]) || !cmp.Equal(iterations[0], iterations[2]) {
continue
}
return nil
}
return fmt.Errorf("timeout when waiting for roundrobin distribution of RPCs across addresses: %v; got: %v", addrs, gotAddrCount)
}
// TestLeastRequestE2E tests the Least Request LB policy in an e2e style. The
// Least Request balancer is configured as the top level balancer of the
// channel, and is passed three addresses. Eventually, the test creates three
// streams, which should be on certain backends according to the least request
// algorithm. The randomness in the picker is injected in the test to be
// deterministic, allowing the test to make assertions on the distribution.
func (s) TestLeastRequestE2E(t *testing.T) {
defer func(u func() uint32) {
randuint32 = u
}(randuint32)
var index int
indexes := []uint32{
0, 0, 1, 1, 2, 2, // Triggers a round robin distribution.
}
randuint32 = func() uint32 {
ret := indexes[index%len(indexes)]
index++
return ret
}
addresses := setupBackends(t, 3)
mr := manual.NewBuilderWithScheme("lr-e2e")
defer mr.Close()
// Configure least request as top level balancer of channel.
lrscJSON := `
{
"loadBalancingConfig": [
{
"least_request_experimental": {
"choiceCount": 2
}
}
]
}`
sc := internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(lrscJSON)
firstThreeAddresses := []resolver.Address{
{Addr: addresses[0]},
{Addr: addresses[1]},
{Addr: addresses[2]},
}
mr.InitialState(resolver.State{
Addresses: firstThreeAddresses,
ServiceConfig: sc,
})
cc, err := grpc.NewClient(mr.Scheme()+":///", grpc.WithResolvers(mr), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("grpc.NewClient() failed: %v", err)
}
defer cc.Close()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
testServiceClient := testgrpc.NewTestServiceClient(cc)
// Wait for all 3 backends to round robin across. The happens because a
// SubConn transitioning into READY causes a new picker update. Once the
// picker update with all 3 backends is present, this test can start to make
// assertions based on those backends.
if err := checkRoundRobinRPCs(ctx, testServiceClient, firstThreeAddresses); err != nil {
t.Fatalf("error in expected round robin: %v", err)
}
// Map ordering of READY SubConns is non deterministic. Thus, perform 3 RPCs
// mocked from the random to each index to learn the addresses of SubConns
// at each index.
index = 0
peerAtIndex := make([]string, 3)
var peer0 peer.Peer
if _, err := testServiceClient.EmptyCall(ctx, &testpb.Empty{}, grpc.Peer(&peer0)); err != nil {
t.Fatalf("testServiceClient.EmptyCall failed: %v", err)
}
peerAtIndex[0] = peer0.Addr.String()
if _, err := testServiceClient.EmptyCall(ctx, &testpb.Empty{}, grpc.Peer(&peer0)); err != nil {
t.Fatalf("testServiceClient.EmptyCall failed: %v", err)
}
peerAtIndex[1] = peer0.Addr.String()
if _, err := testServiceClient.EmptyCall(ctx, &testpb.Empty{}, grpc.Peer(&peer0)); err != nil {
t.Fatalf("testServiceClient.EmptyCall failed: %v", err)
}
peerAtIndex[2] = peer0.Addr.String()
// Start streaming RPCs, but do not finish them. Each subsequent stream
// should be started according to the least request algorithm, and chosen
// between the indexes provided.
index = 0
indexes = []uint32{
0, 0, // Causes first stream to be on first address.
0, 1, // Compares first address (one RPC) to second (no RPCs), so choose second.
1, 2, // Compares second address (one RPC) to third (no RPCs), so choose third.
0, 3, // Causes another stream on first address.
1, 0, // Compares second address (one RPC) to first (two RPCs), so choose second.
2, 0, // Compares third address (one RPC) to first (two RPCs), so choose third.
0, 0, // Causes another stream on first address.
2, 2, // Causes a stream on third address.
2, 1, // Compares third address (three RPCs) to second (two RPCs), so choose third.
}
wantIndex := []uint32{0, 1, 2, 0, 1, 2, 0, 2, 1}
// Start streaming RPC's, but do not finish them. Each created stream should
// be started based on the least request algorithm and injected randomness
// (see indexes slice above for exact expectations).
for _, wantIndex := range wantIndex {
stream, err := testServiceClient.FullDuplexCall(ctx)
if err != nil {
t.Fatalf("testServiceClient.FullDuplexCall failed: %v", err)
}
p, ok := peer.FromContext(stream.Context())
if !ok {
t.Fatalf("testServiceClient.FullDuplexCall has no Peer")
}
if p.Addr.String() != peerAtIndex[wantIndex] {
t.Fatalf("testServiceClient.FullDuplexCall's Peer got: %v, want: %v", p.Addr.String(), peerAtIndex[wantIndex])
}
}
}
// TestLeastRequestPersistsCounts tests that the Least Request Balancer persists
// counts once it gets a new picker update. It first updates the Least Request
// Balancer with two backends, and creates a bunch of streams on them. Then, it
// updates the Least Request Balancer with three backends, including the two
// previous. Any created streams should then be started on the new backend.
func (s) TestLeastRequestPersistsCounts(t *testing.T) {
defer func(u func() uint32) {
randuint32 = u
}(randuint32)
var index int
indexes := []uint32{
0, 0, 1, 1,
}
randuint32 = func() uint32 {
ret := indexes[index%len(indexes)]
index++
return ret
}
addresses := setupBackends(t, 3)
mr := manual.NewBuilderWithScheme("lr-e2e")
defer mr.Close()
// Configure least request as top level balancer of channel.
lrscJSON := `
{
"loadBalancingConfig": [
{
"least_request_experimental": {
"choiceCount": 2
}
}
]
}`
sc := internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(lrscJSON)
firstTwoAddresses := []resolver.Address{
{Addr: addresses[0]},
{Addr: addresses[1]},
}
mr.InitialState(resolver.State{
Addresses: firstTwoAddresses,
ServiceConfig: sc,
})
cc, err := grpc.NewClient(mr.Scheme()+":///", grpc.WithResolvers(mr), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("grpc.NewClient() failed: %v", err)
}
defer cc.Close()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
testServiceClient := testgrpc.NewTestServiceClient(cc)
// Wait for the two backends to round robin across. The happens because a
// SubConn transitioning into READY causes a new picker update. Once the
// picker update with the two backends is present, this test can start to
// populate those backends with streams.
if err := checkRoundRobinRPCs(ctx, testServiceClient, firstTwoAddresses); err != nil {
t.Fatalf("error in expected round robin: %v", err)
}
// Start 50 streaming RPCs, and leave them unfinished for the duration of
// the test. This will populate the first two addresses with many active
// RPCs.
for i := 0; i < 50; i++ {
_, err := testServiceClient.FullDuplexCall(ctx)
if err != nil {
t.Fatalf("testServiceClient.FullDuplexCall failed: %v", err)
}
}
// Update the least request balancer to choice count 3. Also update the
// address list adding a third address. Alongside the injected randomness,
// this should trigger the least request balancer to search all created
// SubConns. Thus, since address 3 is the new address and the first two
// addresses are populated with RPCs, once the picker update of all 3 READY
// SubConns takes effect, all new streams should be started on address 3.
index = 0
indexes = []uint32{
0, 1, 2, 3, 4, 5,
}
lrscJSON = `
{
"loadBalancingConfig": [
{
"least_request_experimental": {
"choiceCount": 3
}
}
]
}`
sc = internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(lrscJSON)
fullAddresses := []resolver.Address{
{Addr: addresses[0]},
{Addr: addresses[1]},
{Addr: addresses[2]},
}
mr.UpdateState(resolver.State{
Addresses: fullAddresses,
ServiceConfig: sc,
})
newAddress := fullAddresses[2]
// Poll for only address 3 to show up. This requires a polling loop because
// picker update with all three SubConns doesn't take into effect
// immediately, needs the third SubConn to become READY.
if err := checkRoundRobinRPCs(ctx, testServiceClient, []resolver.Address{newAddress}); err != nil {
t.Fatalf("error in expected round robin: %v", err)
}
// Start 25 rpcs, but don't finish them. They should all start on address 3,
// since the first two addresses both have 25 RPCs (and randomness
// injection/choiceCount causes all 3 to be compared every iteration).
for i := 0; i < 25; i++ {
stream, err := testServiceClient.FullDuplexCall(ctx)
if err != nil {
t.Fatalf("testServiceClient.FullDuplexCall failed: %v", err)
}
p, ok := peer.FromContext(stream.Context())
if !ok {
t.Fatalf("testServiceClient.FullDuplexCall has no Peer")
}
if p.Addr.String() != addresses[2] {
t.Fatalf("testServiceClient.FullDuplexCall's Peer got: %v, want: %v", p.Addr.String(), addresses[2])
}
}
// Now 25 RPC's are active on each address, the next three RPC's should
// round robin, since choiceCount is three and the injected random indexes
// cause it to search all three addresses for fewest outstanding requests on
// each iteration.
wantAddrCount := map[string]int{
addresses[0]: 1,
addresses[1]: 1,
addresses[2]: 1,
}
gotAddrCount := make(map[string]int)
for i := 0; i < len(addresses); i++ {
stream, err := testServiceClient.FullDuplexCall(ctx)
if err != nil {
t.Fatalf("testServiceClient.FullDuplexCall failed: %v", err)
}
p, ok := peer.FromContext(stream.Context())
if !ok {
t.Fatalf("testServiceClient.FullDuplexCall has no Peer")
}
if p.Addr != nil {
gotAddrCount[p.Addr.String()]++
}
}
if diff := cmp.Diff(gotAddrCount, wantAddrCount); diff != "" {
t.Fatalf("addr count (-got:, +want): %v", diff)
}
}
// TestConcurrentRPCs tests concurrent RPCs on the least request balancer. It
// configures a channel with a least request balancer as the top level balancer,
// and makes 100 RPCs asynchronously. This makes sure no race conditions happen
// in this scenario.
func (s) TestConcurrentRPCs(t *testing.T) {
addresses := setupBackends(t, 3)
mr := manual.NewBuilderWithScheme("lr-e2e")
defer mr.Close()
// Configure least request as top level balancer of channel.
lrscJSON := `
{
"loadBalancingConfig": [
{
"least_request_experimental": {
"choiceCount": 2
}
}
]
}`
sc := internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(lrscJSON)
firstTwoAddresses := []resolver.Address{
{Addr: addresses[0]},
{Addr: addresses[1]},
}
mr.InitialState(resolver.State{
Addresses: firstTwoAddresses,
ServiceConfig: sc,
})
cc, err := grpc.NewClient(mr.Scheme()+":///", grpc.WithResolvers(mr), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("grpc.NewClient() failed: %v", err)
}
defer cc.Close()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
testServiceClient := testgrpc.NewTestServiceClient(cc)
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < 5; j++ {
testServiceClient.EmptyCall(ctx, &testpb.Empty{})
}
}()
}
wg.Wait()
}
// Test tests that the least request balancer persists RPC counts once it gets
// new picker updates and backends within an endpoint go down. It first updates
// the balancer with two endpoints having two addresses each. It verifies the
// requests are round robined across the first address of each endpoint. It then
// stops the active backend in endpoint[0]. It verified that the balancer starts
// using the second address in endpoint[0]. The test then creates a bunch of
// streams on two endpoints. Then, it updates the balancer with three endpoints,
// including the two previous. Any created streams should then be started on the
// new endpoint. The test shuts down the active backed in endpoint[1] and
// endpoint[2]. The test verifies that new RPCs are round robined across the
// active backends in endpoint[1] and endpoint[2].
func (s) TestLeastRequestEndpoints_MultipleAddresses(t *testing.T) {
defer func(u func() uint32) {
randuint32 = u
}(randuint32)
var index int
indexes := []uint32{
0, 0, 1, 1,
}
randuint32 = func() uint32 {
ret := indexes[index%len(indexes)]
index++
return ret
}
backends := startBackends(t, 6)
mr := manual.NewBuilderWithScheme("lr-e2e")
defer mr.Close()
// Configure least request as top level balancer of channel.
lrscJSON := `
{
"loadBalancingConfig": [
{
"least_request_experimental": {
"choiceCount": 2
}
}
]
}`
endpoints := []resolver.Endpoint{
{Addresses: []resolver.Address{{Addr: backends[0].Address}, {Addr: backends[1].Address}}},
{Addresses: []resolver.Address{{Addr: backends[2].Address}, {Addr: backends[3].Address}}},
{Addresses: []resolver.Address{{Addr: backends[4].Address}, {Addr: backends[5].Address}}},
}
sc := internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(lrscJSON)
firstTwoEndpoints := []resolver.Endpoint{endpoints[0], endpoints[1]}
mr.InitialState(resolver.State{
Endpoints: firstTwoEndpoints,
ServiceConfig: sc,
})
cc, err := grpc.NewClient(mr.Scheme()+":///", grpc.WithResolvers(mr), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("grpc.NewClient() failed: %v", err)
}
defer cc.Close()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
testServiceClient := testgrpc.NewTestServiceClient(cc)
// Wait for the two backends to round robin across. The happens because a
// child pickfirst transitioning into READY causes a new picker update. Once
// the picker update with the two backends is present, this test can start
// to populate those backends with streams.
wantAddrs := []resolver.Address{
endpoints[0].Addresses[0],
endpoints[1].Addresses[0],
}
if err := checkRoundRobinRPCs(ctx, testServiceClient, wantAddrs); err != nil {
t.Fatalf("error in expected round robin: %v", err)
}
// Shut down one of the addresses in endpoints[0], the child pickfirst
// should fallback to the next address in endpoints[0].
backends[0].Stop()
wantAddrs = []resolver.Address{
endpoints[0].Addresses[1],
endpoints[1].Addresses[0],
}
if err := checkRoundRobinRPCs(ctx, testServiceClient, wantAddrs); err != nil {
t.Fatalf("error in expected round robin: %v", err)
}
// Start 50 streaming RPCs, and leave them unfinished for the duration of
// the test. This will populate the first two endpoints with many active
// RPCs.
for i := 0; i < 50; i++ {
_, err := testServiceClient.FullDuplexCall(ctx)
if err != nil {
t.Fatalf("testServiceClient.FullDuplexCall failed: %v", err)
}
}
// Update the least request balancer to choice count 3. Also update the
// address list adding a third endpoint. Alongside the injected randomness,
// this should trigger the least request balancer to search all created
// endpoints. Thus, since endpoint 3 is the new endpoint and the first two
// endpoint are populated with RPCs, once the picker update of all 3 READY
// pickfirsts takes effect, all new streams should be started on endpoint 3.
index = 0
indexes = []uint32{
0, 1, 2, 3, 4, 5,
}
lrscJSON = `
{
"loadBalancingConfig": [
{
"least_request_experimental": {
"choiceCount": 3
}
}
]
}`
sc = internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(lrscJSON)
mr.UpdateState(resolver.State{
Endpoints: endpoints,
ServiceConfig: sc,
})
newAddress := endpoints[2].Addresses[0]
// Poll for only endpoint 3 to show up. This requires a polling loop because
// picker update with all three endpoints doesn't take into effect
// immediately, needs the third pickfirst to become READY.
if err := checkRoundRobinRPCs(ctx, testServiceClient, []resolver.Address{newAddress}); err != nil {
t.Fatalf("error in expected round robin: %v", err)
}
// Start 25 rpcs, but don't finish them. They should all start on endpoint 3,
// since the first two endpoints both have 25 RPCs (and randomness
// injection/choiceCount causes all 3 to be compared every iteration).
for i := 0; i < 25; i++ {
stream, err := testServiceClient.FullDuplexCall(ctx)
if err != nil {
t.Fatalf("testServiceClient.FullDuplexCall failed: %v", err)
}
p, ok := peer.FromContext(stream.Context())
if !ok {
t.Fatalf("testServiceClient.FullDuplexCall has no Peer")
}
if p.Addr.String() != newAddress.Addr {
t.Fatalf("testServiceClient.FullDuplexCall's Peer got: %v, want: %v", p.Addr.String(), newAddress)
}
}
// Now 25 RPC's are active on each endpoint, the next three RPC's should
// round robin, since choiceCount is three and the injected random indexes
// cause it to search all three endpoints for fewest outstanding requests on
// each iteration.
wantAddrCount := map[string]int{
endpoints[0].Addresses[1].Addr: 1,
endpoints[1].Addresses[0].Addr: 1,
endpoints[2].Addresses[0].Addr: 1,
}
gotAddrCount := make(map[string]int)
for i := 0; i < len(endpoints); i++ {
stream, err := testServiceClient.FullDuplexCall(ctx)
if err != nil {
t.Fatalf("testServiceClient.FullDuplexCall failed: %v", err)
}
p, ok := peer.FromContext(stream.Context())
if !ok {
t.Fatalf("testServiceClient.FullDuplexCall has no Peer")
}
if p.Addr != nil {
gotAddrCount[p.Addr.String()]++
}
}
if diff := cmp.Diff(gotAddrCount, wantAddrCount); diff != "" {
t.Fatalf("addr count (-got:, +want): %v", diff)
}
// Shutdown the active address for endpoint[1] and endpoint[2]. This should
// result in their streams failing. Now the requests should roundrobin b/w
// endpoint[1] and endpoint[2].
backends[2].Stop()
backends[4].Stop()
index = 0
indexes = []uint32{
0, 1, 2, 2, 1, 0,
}
wantAddrs = []resolver.Address{
endpoints[1].Addresses[1],
endpoints[2].Addresses[1],
}
if err := checkRoundRobinRPCs(ctx, testServiceClient, wantAddrs); err != nil {
t.Fatalf("error in expected round robin: %v", err)
}
}
// Test tests that the least request balancer properly surfaces resolver
// errors.
func (s) TestLeastRequestEndpoints_ResolverError(t *testing.T) {
const sc = `{"loadBalancingConfig": [{"least_request_experimental": {}}]}`
mr := manual.NewBuilderWithScheme("lr-e2e")
defer mr.Close()
cc, err := grpc.NewClient(
mr.Scheme()+":///",
grpc.WithResolvers(mr),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultServiceConfig(sc),
)
if err != nil {
t.Fatalf("grpc.NewClient() failed: %v", err)
}
defer cc.Close()
// We need to pass an endpoint with a valid address to the resolver before
// reporting an error - otherwise endpointsharding does not report the
// error through.
lis, err := testutils.LocalTCPListener()
if err != nil {
t.Fatalf("net.Listen() failed: %v", err)
}
// Act like a server that closes the connection without sending a server
// preface.
go func() {
conn, err := lis.Accept()
if err != nil {
t.Errorf("Unexpected error when accepting a connection: %v", err)
}
conn.Close()
}()
mr.UpdateState(resolver.State{
Endpoints: []resolver.Endpoint{{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}}},
})
cc.Connect()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure)
// Report an error through the resolver
resolverErr := fmt.Errorf("simulated resolver error")
mr.CC().ReportError(resolverErr)
// Ensure the client returns the expected resolver error.
testServiceClient := testgrpc.NewTestServiceClient(cc)
for ; ctx.Err() == nil; <-time.After(defaultTestShortTimeout) {
_, err = testServiceClient.EmptyCall(ctx, &testpb.Empty{})
if strings.Contains(err.Error(), resolverErr.Error()) {
break
}
}
if ctx.Err() != nil {
t.Fatalf("Timeout when waiting for RPCs to fail with error containing %s. Last error: %v", resolverErr, err)
}
}

View File

@ -1,35 +0,0 @@
/*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package internal contains code internal to the pickfirst package.
package internal
import (
rand "math/rand/v2"
"time"
)
var (
// RandShuffle pseudo-randomizes the order of addresses.
RandShuffle = rand.Shuffle
// TimeAfterFunc allows mocking the timer for testing connection delay
// related functionality.
TimeAfterFunc = func(d time.Duration, f func()) func() {
timer := time.AfterFunc(d, f)
return func() { timer.Stop() }
}
)

View File

@ -1,965 +0,0 @@
/*
*
* Copyright 2022 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package pickfirst_test
import (
"context"
"errors"
"fmt"
"strings"
"testing"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/backoff"
pfinternal "google.golang.org/grpc/balancer/pickfirst/internal"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/internal/testutils/pickfirst"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
"google.golang.org/grpc/serviceconfig"
"google.golang.org/grpc/status"
testgrpc "google.golang.org/grpc/interop/grpc_testing"
testpb "google.golang.org/grpc/interop/grpc_testing"
)
const (
pickFirstServiceConfig = `{"loadBalancingConfig": [{"pick_first":{}}]}`
// Default timeout for tests in this package.
defaultTestTimeout = 10 * time.Second
// Default short timeout, to be used when waiting for events which are not
// expected to happen.
defaultTestShortTimeout = 100 * time.Millisecond
)
func init() {
channelz.TurnOn()
}
type s struct {
grpctest.Tester
}
func Test(t *testing.T) {
grpctest.RunSubTests(t, s{})
}
// parseServiceConfig is a test helper which uses the manual resolver to parse
// the given service config. It calls t.Fatal() if service config parsing fails.
func parseServiceConfig(t *testing.T, r *manual.Resolver, sc string) *serviceconfig.ParseResult {
t.Helper()
scpr := r.CC().ParseServiceConfig(sc)
if scpr.Err != nil {
t.Fatalf("Failed to parse service config %q: %v", sc, scpr.Err)
}
return scpr
}
// setupPickFirst performs steps required for pick_first tests. It starts a
// bunch of backends exporting the TestService, creates a ClientConn to them
// with service config specifying the use of the pick_first LB policy.
func setupPickFirst(t *testing.T, backendCount int, opts ...grpc.DialOption) (*grpc.ClientConn, *manual.Resolver, []*stubserver.StubServer) {
t.Helper()
r := manual.NewBuilderWithScheme("whatever")
backends := make([]*stubserver.StubServer, backendCount)
addrs := make([]resolver.Address, backendCount)
for i := 0; i < backendCount; i++ {
backend := &stubserver.StubServer{
EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
return &testpb.Empty{}, nil
},
}
if err := backend.StartServer(); err != nil {
t.Fatalf("Failed to start backend: %v", err)
}
t.Logf("Started TestService backend at: %q", backend.Address)
t.Cleanup(func() { backend.Stop() })
backends[i] = backend
addrs[i] = resolver.Address{Addr: backend.Address}
}
dopts := []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithResolvers(r),
grpc.WithDefaultServiceConfig(pickFirstServiceConfig),
}
dopts = append(dopts, opts...)
cc, err := grpc.NewClient(r.Scheme()+":///test.server", dopts...)
if err != nil {
t.Fatalf("grpc.NewClient() failed: %v", err)
}
t.Cleanup(func() { cc.Close() })
// At this point, the resolver has not returned any addresses to the channel.
// This RPC must block until the context expires.
sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
defer sCancel()
client := testgrpc.NewTestServiceClient(cc)
if _, err := client.EmptyCall(sCtx, &testpb.Empty{}); status.Code(err) != codes.DeadlineExceeded {
t.Fatalf("EmptyCall() = %s, want %s", status.Code(err), codes.DeadlineExceeded)
}
return cc, r, backends
}
// stubBackendsToResolverAddrs converts from a set of stub server backends to
// resolver addresses. Useful when pushing addresses to the manual resolver.
func stubBackendsToResolverAddrs(backends []*stubserver.StubServer) []resolver.Address {
addrs := make([]resolver.Address, len(backends))
for i, backend := range backends {
addrs[i] = resolver.Address{Addr: backend.Address}
}
return addrs
}
// TestPickFirst_OneBackend tests the most basic scenario for pick_first. It
// brings up a single backend and verifies that all RPCs get routed to it.
func (s) TestPickFirst_OneBackend(t *testing.T) {
cc, r, backends := setupPickFirst(t, 1)
addrs := stubBackendsToResolverAddrs(backends)
r.UpdateState(resolver.State{Addresses: addrs})
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil {
t.Fatal(err)
}
}
// TestPickFirst_MultipleBackends tests the scenario with multiple backends and
// verifies that all RPCs get routed to the first one.
func (s) TestPickFirst_MultipleBackends(t *testing.T) {
cc, r, backends := setupPickFirst(t, 2)
addrs := stubBackendsToResolverAddrs(backends)
r.UpdateState(resolver.State{Addresses: addrs})
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil {
t.Fatal(err)
}
}
// TestPickFirst_OneServerDown tests the scenario where we have multiple
// backends and pick_first is working as expected. Verifies that RPCs get routed
// to the next backend in the list when the first one goes down.
func (s) TestPickFirst_OneServerDown(t *testing.T) {
cc, r, backends := setupPickFirst(t, 2)
addrs := stubBackendsToResolverAddrs(backends)
r.UpdateState(resolver.State{Addresses: addrs})
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil {
t.Fatal(err)
}
// Stop the backend which is currently being used. RPCs should get routed to
// the next backend in the list.
backends[0].Stop()
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[1]); err != nil {
t.Fatal(err)
}
}
// TestPickFirst_AllServersDown tests the scenario where we have multiple
// backends and pick_first is working as expected. When all backends go down,
// the test verifies that RPCs fail with appropriate status code.
func (s) TestPickFirst_AllServersDown(t *testing.T) {
cc, r, backends := setupPickFirst(t, 2)
addrs := stubBackendsToResolverAddrs(backends)
r.UpdateState(resolver.State{Addresses: addrs})
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil {
t.Fatal(err)
}
for _, b := range backends {
b.Stop()
}
client := testgrpc.NewTestServiceClient(cc)
for {
if ctx.Err() != nil {
t.Fatalf("channel failed to move to Unavailable after all backends were stopped: %v", ctx.Err())
}
if _, err := client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) == codes.Unavailable {
return
}
time.Sleep(defaultTestShortTimeout)
}
}
// TestPickFirst_AddressesRemoved tests the scenario where we have multiple
// backends and pick_first is working as expected. It then verifies that when
// addresses are removed by the name resolver, RPCs get routed appropriately.
func (s) TestPickFirst_AddressesRemoved(t *testing.T) {
cc, r, backends := setupPickFirst(t, 3)
addrs := stubBackendsToResolverAddrs(backends)
r.UpdateState(resolver.State{Addresses: addrs})
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil {
t.Fatal(err)
}
// Remove the first backend from the list of addresses originally pushed.
// RPCs should get routed to the first backend in the new list.
r.UpdateState(resolver.State{Addresses: []resolver.Address{addrs[1], addrs[2]}})
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[1]); err != nil {
t.Fatal(err)
}
// Append the backend that we just removed to the end of the list.
// Nothing should change.
r.UpdateState(resolver.State{Addresses: []resolver.Address{addrs[1], addrs[2], addrs[0]}})
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[1]); err != nil {
t.Fatal(err)
}
// Remove the first backend from the existing list of addresses.
// RPCs should get routed to the first backend in the new list.
r.UpdateState(resolver.State{Addresses: []resolver.Address{addrs[2], addrs[0]}})
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[2]); err != nil {
t.Fatal(err)
}
// Remove the first backend from the existing list of addresses.
// RPCs should get routed to the first backend in the new list.
r.UpdateState(resolver.State{Addresses: []resolver.Address{addrs[0]}})
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil {
t.Fatal(err)
}
}
// TestPickFirst_NewAddressWhileBlocking tests the case where pick_first is
// configured on a channel, things are working as expected and then a resolver
// updates removes all addresses. An RPC attempted at this point in time will be
// blocked because there are no valid backends. This test verifies that when new
// backends are added, the RPC is able to complete.
func (s) TestPickFirst_NewAddressWhileBlocking(t *testing.T) {
cc, r, backends := setupPickFirst(t, 2)
addrs := stubBackendsToResolverAddrs(backends)
r.UpdateState(resolver.State{Addresses: addrs})
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil {
t.Fatal(err)
}
// Send a resolver update with no addresses. This should push the channel into
// TransientFailure.
r.UpdateState(resolver.State{})
testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure)
doneCh := make(chan struct{})
client := testgrpc.NewTestServiceClient(cc)
go func() {
// The channel is currently in TransientFailure and this RPC will block
// until the channel becomes Ready, which will only happen when we push a
// resolver update with a valid backend address.
if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil {
t.Errorf("EmptyCall() = %v, want <nil>", err)
}
close(doneCh)
}()
// Make sure that there is one pending RPC on the ClientConn before attempting
// to push new addresses through the name resolver. If we don't do this, the
// resolver update can happen before the above goroutine gets to make the RPC.
for {
if err := ctx.Err(); err != nil {
t.Fatal(err)
}
tcs, _ := channelz.GetTopChannels(0, 0)
if len(tcs) != 1 {
t.Fatalf("there should only be one top channel, not %d", len(tcs))
}
started := tcs[0].ChannelMetrics.CallsStarted.Load()
completed := tcs[0].ChannelMetrics.CallsSucceeded.Load() + tcs[0].ChannelMetrics.CallsFailed.Load()
if (started - completed) == 1 {
break
}
time.Sleep(defaultTestShortTimeout)
}
// Send a resolver update with a valid backend to push the channel to Ready
// and unblock the above RPC.
r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: backends[0].Address}}})
select {
case <-ctx.Done():
t.Fatal("Timeout when waiting for blocked RPC to complete")
case <-doneCh:
}
}
// TestPickFirst_StickyTransientFailure tests the case where pick_first is
// configured on a channel, and the backend is configured to close incoming
// connections as soon as they are accepted. The test verifies that the channel
// enters TransientFailure and stays there. The test also verifies that the
// pick_first LB policy is constantly trying to reconnect to the backend.
func (s) TestPickFirst_StickyTransientFailure(t *testing.T) {
// Spin up a local server which closes the connection as soon as it receives
// one. It also sends a signal on a channel whenever it received a connection.
lis, err := testutils.LocalTCPListener()
if err != nil {
t.Fatalf("Failed to create listener: %v", err)
}
t.Cleanup(func() { lis.Close() })
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
connCh := make(chan struct{}, 1)
go func() {
for {
conn, err := lis.Accept()
if err != nil {
return
}
select {
case connCh <- struct{}{}:
conn.Close()
case <-ctx.Done():
return
}
}
}()
// Dial the above server with a ConnectParams that does a constant backoff
// of defaultTestShortTimeout duration.
dopts := []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultServiceConfig(pickFirstServiceConfig),
grpc.WithConnectParams(grpc.ConnectParams{
Backoff: backoff.Config{
BaseDelay: defaultTestShortTimeout,
Multiplier: float64(0),
Jitter: float64(0),
MaxDelay: defaultTestShortTimeout,
},
}),
}
cc, err := grpc.NewClient(lis.Addr().String(), dopts...)
if err != nil {
t.Fatalf("Failed to create new client: %v", err)
}
t.Cleanup(func() { cc.Close() })
cc.Connect()
testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure)
// Spawn a goroutine to ensure that the channel stays in TransientFailure.
// The call to cc.WaitForStateChange will return false when the main
// goroutine exits and the context is cancelled.
go func() {
if cc.WaitForStateChange(ctx, connectivity.TransientFailure) {
if state := cc.GetState(); state != connectivity.Shutdown {
t.Errorf("Unexpected state change from TransientFailure to %s", cc.GetState())
}
}
}()
// Ensures that the pick_first LB policy is constantly trying to reconnect.
for i := 0; i < 10; i++ {
select {
case <-connCh:
case <-time.After(2 * defaultTestShortTimeout):
t.Error("Timeout when waiting for pick_first to reconnect")
}
}
}
// Tests the PF LB policy with shuffling enabled.
func (s) TestPickFirst_ShuffleAddressList(t *testing.T) {
const serviceConfig = `{"loadBalancingConfig": [{"pick_first":{ "shuffleAddressList": true }}]}`
// Install a shuffler that always reverses two entries.
origShuf := pfinternal.RandShuffle
defer func() { pfinternal.RandShuffle = origShuf }()
pfinternal.RandShuffle = func(n int, f func(int, int)) {
if n != 2 {
t.Errorf("Shuffle called with n=%v; want 2", n)
return
}
f(0, 1) // reverse the two addresses
}
// Set up our backends.
cc, r, backends := setupPickFirst(t, 2)
addrs := stubBackendsToResolverAddrs(backends)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
// Push an update with both addresses and shuffling disabled. We should
// connect to backend 0.
r.UpdateState(resolver.State{Endpoints: []resolver.Endpoint{
{Addresses: []resolver.Address{addrs[0]}},
{Addresses: []resolver.Address{addrs[1]}},
}})
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil {
t.Fatal(err)
}
// Send a config with shuffling enabled. This will reverse the addresses,
// but the channel should still be connected to backend 0.
shufState := resolver.State{
ServiceConfig: parseServiceConfig(t, r, serviceConfig),
Endpoints: []resolver.Endpoint{
{Addresses: []resolver.Address{addrs[0]}},
{Addresses: []resolver.Address{addrs[1]}},
},
}
r.UpdateState(shufState)
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil {
t.Fatal(err)
}
// Send a resolver update with no addresses. This should push the channel
// into TransientFailure.
r.UpdateState(resolver.State{})
testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure)
// Send the same config as last time with shuffling enabled. Since we are
// not connected to backend 0, we should connect to backend 1.
r.UpdateState(shufState)
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[1]); err != nil {
t.Fatal(err)
}
}
// Test config parsing with the env var turned on and off for various scenarios.
func (s) TestPickFirst_ParseConfig_Success(t *testing.T) {
// Install a shuffler that always reverses two entries.
origShuf := pfinternal.RandShuffle
defer func() { pfinternal.RandShuffle = origShuf }()
pfinternal.RandShuffle = func(n int, f func(int, int)) {
if n != 2 {
t.Errorf("Shuffle called with n=%v; want 2", n)
return
}
f(0, 1) // reverse the two addresses
}
tests := []struct {
name string
serviceConfig string
wantFirstAddr bool
}{
{
name: "empty pickfirst config",
serviceConfig: `{"loadBalancingConfig": [{"pick_first":{}}]}`,
wantFirstAddr: true,
},
{
name: "empty good pickfirst config",
serviceConfig: `{"loadBalancingConfig": [{"pick_first":{ "shuffleAddressList": true }}]}`,
wantFirstAddr: false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Set up our backends.
cc, r, backends := setupPickFirst(t, 2)
addrs := stubBackendsToResolverAddrs(backends)
r.UpdateState(resolver.State{
ServiceConfig: parseServiceConfig(t, r, test.serviceConfig),
Addresses: addrs,
})
// Some tests expect address shuffling to happen, and indicate that
// by setting wantFirstAddr to false (since our shuffling function
// defined at the top of this test, simply reverses the list of
// addresses provided to it).
wantAddr := addrs[0]
if !test.wantFirstAddr {
wantAddr = addrs[1]
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if err := pickfirst.CheckRPCsToBackend(ctx, cc, wantAddr); err != nil {
t.Fatal(err)
}
})
}
}
// Test config parsing for a bad service config.
func (s) TestPickFirst_ParseConfig_Failure(t *testing.T) {
// Service config should fail with the below config. Name resolvers are
// expected to perform this parsing before they push the parsed service
// config to the channel.
const sc = `{"loadBalancingConfig": [{"pick_first":{ "shuffleAddressList": 666 }}]}`
scpr := internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(sc)
if scpr.Err == nil {
t.Fatalf("ParseConfig() succeeded and returned %+v, when expected to fail", scpr)
}
}
// setupPickFirstWithListenerWrapper is very similar to setupPickFirst, but uses
// a wrapped listener that the test can use to track accepted connections.
func setupPickFirstWithListenerWrapper(t *testing.T, backendCount int, opts ...grpc.DialOption) (*grpc.ClientConn, *manual.Resolver, []*stubserver.StubServer, []*testutils.ListenerWrapper) {
t.Helper()
backends := make([]*stubserver.StubServer, backendCount)
addrs := make([]resolver.Address, backendCount)
listeners := make([]*testutils.ListenerWrapper, backendCount)
for i := 0; i < backendCount; i++ {
lis := testutils.NewListenerWrapper(t, nil)
backend := &stubserver.StubServer{
Listener: lis,
EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
return &testpb.Empty{}, nil
},
}
if err := backend.StartServer(); err != nil {
t.Fatalf("Failed to start backend: %v", err)
}
t.Logf("Started TestService backend at: %q", backend.Address)
t.Cleanup(func() { backend.Stop() })
backends[i] = backend
addrs[i] = resolver.Address{Addr: backend.Address}
listeners[i] = lis
}
r := manual.NewBuilderWithScheme("whatever")
dopts := []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithResolvers(r),
grpc.WithDefaultServiceConfig(pickFirstServiceConfig),
}
dopts = append(dopts, opts...)
cc, err := grpc.NewClient(r.Scheme()+":///test.server", dopts...)
if err != nil {
t.Fatalf("grpc.NewClient() failed: %v", err)
}
t.Cleanup(func() { cc.Close() })
// At this point, the resolver has not returned any addresses to the channel.
// This RPC must block until the context expires.
sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
defer sCancel()
client := testgrpc.NewTestServiceClient(cc)
if _, err := client.EmptyCall(sCtx, &testpb.Empty{}); status.Code(err) != codes.DeadlineExceeded {
t.Fatalf("EmptyCall() = %s, want %s", status.Code(err), codes.DeadlineExceeded)
}
return cc, r, backends, listeners
}
// TestPickFirst_AddressUpdateWithAttributes tests the case where an address
// update received by the pick_first LB policy differs in attributes. Addresses
// which differ in attributes are considered different from the perspective of
// subconn creation and connection establishment and the test verifies that new
// connections are created when attributes change.
func (s) TestPickFirst_AddressUpdateWithAttributes(t *testing.T) {
cc, r, backends, listeners := setupPickFirstWithListenerWrapper(t, 2)
// Add a set of attributes to the addresses before pushing them to the
// pick_first LB policy through the manual resolver.
addrs := stubBackendsToResolverAddrs(backends)
for i := range addrs {
addrs[i].Attributes = addrs[i].Attributes.WithValue("test-attribute-1", fmt.Sprintf("%d", i))
}
r.UpdateState(resolver.State{Addresses: addrs})
// Ensure that RPCs succeed to the first backend in the list.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil {
t.Fatal(err)
}
// Grab the wrapped connection from the listener wrapper. This will be used
// to verify the connection is closed.
val, err := listeners[0].NewConnCh.Receive(ctx)
if err != nil {
t.Fatalf("Failed to receive new connection from wrapped listener: %v", err)
}
conn := val.(*testutils.ConnWrapper)
// Add another set of attributes to the addresses, and push them to the
// pick_first LB policy through the manual resolver. Leave the order of the
// addresses unchanged.
for i := range addrs {
addrs[i].Attributes = addrs[i].Attributes.WithValue("test-attribute-2", fmt.Sprintf("%d", i))
}
r.UpdateState(resolver.State{Addresses: addrs})
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil {
t.Fatal(err)
}
// A change in the address attributes results in the new address being
// considered different to the current address. This will result in the old
// connection being closed and a new connection to the same backend (since
// address order is not modified).
if _, err := conn.CloseCh.Receive(ctx); err != nil {
t.Fatalf("Timeout when expecting existing connection to be closed: %v", err)
}
val, err = listeners[0].NewConnCh.Receive(ctx)
if err != nil {
t.Fatalf("Failed to receive new connection from wrapped listener: %v", err)
}
conn = val.(*testutils.ConnWrapper)
// Add another set of attributes to the addresses, and push them to the
// pick_first LB policy through the manual resolver. Reverse of the order
// of addresses.
for i := range addrs {
addrs[i].Attributes = addrs[i].Attributes.WithValue("test-attribute-3", fmt.Sprintf("%d", i))
}
addrs[0], addrs[1] = addrs[1], addrs[0]
r.UpdateState(resolver.State{Addresses: addrs})
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil {
t.Fatal(err)
}
// Ensure that the old connection is closed and a new connection is
// established to the first address in the new list.
if _, err := conn.CloseCh.Receive(ctx); err != nil {
t.Fatalf("Timeout when expecting existing connection to be closed: %v", err)
}
_, err = listeners[1].NewConnCh.Receive(ctx)
if err != nil {
t.Fatalf("Failed to receive new connection from wrapped listener: %v", err)
}
}
// TestPickFirst_AddressUpdateWithBalancerAttributes tests the case where an
// address update received by the pick_first LB policy differs in balancer
// attributes, which are meant only for consumption by LB policies. In this
// case, the test verifies that new connections are not created when the address
// update only changes the balancer attributes.
func (s) TestPickFirst_AddressUpdateWithBalancerAttributes(t *testing.T) {
cc, r, backends, listeners := setupPickFirstWithListenerWrapper(t, 2)
// Add a set of balancer attributes to the addresses before pushing them to
// the pick_first LB policy through the manual resolver.
addrs := stubBackendsToResolverAddrs(backends)
for i := range addrs {
addrs[i].BalancerAttributes = addrs[i].BalancerAttributes.WithValue("test-attribute-1", fmt.Sprintf("%d", i))
}
r.UpdateState(resolver.State{Addresses: addrs})
// Ensure that RPCs succeed to the expected backend.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil {
t.Fatal(err)
}
// Grab the wrapped connection from the listener wrapper. This will be used
// to verify the connection is not closed.
val, err := listeners[0].NewConnCh.Receive(ctx)
if err != nil {
t.Fatalf("Failed to receive new connection from wrapped listener: %v", err)
}
conn := val.(*testutils.ConnWrapper)
// Add a set of balancer attributes to the addresses before pushing them to
// the pick_first LB policy through the manual resolver. Leave the order of
// the addresses unchanged.
for i := range addrs {
addrs[i].BalancerAttributes = addrs[i].BalancerAttributes.WithValue("test-attribute-2", fmt.Sprintf("%d", i))
}
r.UpdateState(resolver.State{Addresses: addrs})
// Ensure that no new connection is established, and ensure that the old
// connection is not closed.
for i := range listeners {
sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout)
defer sCancel()
if _, err := listeners[i].NewConnCh.Receive(sCtx); err != context.DeadlineExceeded {
t.Fatalf("Unexpected error when expecting no new connection: %v", err)
}
}
sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout)
defer sCancel()
if _, err := conn.CloseCh.Receive(sCtx); err != context.DeadlineExceeded {
t.Fatalf("Unexpected error when expecting existing connection to stay active: %v", err)
}
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil {
t.Fatal(err)
}
// Add a set of balancer attributes to the addresses before pushing them to
// the pick_first LB policy through the manual resolver. Reverse of the
// order of addresses.
for i := range addrs {
addrs[i].BalancerAttributes = addrs[i].BalancerAttributes.WithValue("test-attribute-3", fmt.Sprintf("%d", i))
}
addrs[0], addrs[1] = addrs[1], addrs[0]
r.UpdateState(resolver.State{Addresses: addrs})
// Ensure that no new connection is established, and ensure that the old
// connection is not closed.
for i := range listeners {
sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout)
defer sCancel()
if _, err := listeners[i].NewConnCh.Receive(sCtx); err != context.DeadlineExceeded {
t.Fatalf("Unexpected error when expecting no new connection: %v", err)
}
}
sCtx, sCancel = context.WithTimeout(ctx, defaultTestShortTimeout)
defer sCancel()
if _, err := conn.CloseCh.Receive(sCtx); err != context.DeadlineExceeded {
t.Fatalf("Unexpected error when expecting existing connection to stay active: %v", err)
}
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[1]); err != nil {
t.Fatal(err)
}
}
// Tests the case where the pick_first LB policy receives an error from the name
// resolver without previously receiving a good update. Verifies that the
// channel moves to TRANSIENT_FAILURE and that error received from the name
// resolver is propagated to the caller of an RPC.
func (s) TestPickFirst_ResolverError_NoPreviousUpdate(t *testing.T) {
cc, r, _ := setupPickFirst(t, 0)
nrErr := errors.New("error from name resolver")
r.CC().ReportError(nrErr)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure)
client := testgrpc.NewTestServiceClient(cc)
_, err := client.EmptyCall(ctx, &testpb.Empty{})
if err == nil {
t.Fatalf("EmptyCall() succeeded when expected to fail with error: %v", nrErr)
}
if !strings.Contains(err.Error(), nrErr.Error()) {
t.Fatalf("EmptyCall() failed with error: %v, want error: %v", err, nrErr)
}
}
// Tests the case where the pick_first LB policy receives an error from the name
// resolver after receiving a good update (and the channel is currently READY).
// The test verifies that the channel continues to use the previously received
// good update.
func (s) TestPickFirst_ResolverError_WithPreviousUpdate_Ready(t *testing.T) {
cc, r, backends := setupPickFirst(t, 1)
addrs := stubBackendsToResolverAddrs(backends)
r.UpdateState(resolver.State{Addresses: addrs})
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil {
t.Fatal(err)
}
nrErr := errors.New("error from name resolver")
r.CC().ReportError(nrErr)
// Ensure that RPCs continue to succeed for the next second.
client := testgrpc.NewTestServiceClient(cc)
for end := time.Now().Add(time.Second); time.Now().Before(end); <-time.After(defaultTestShortTimeout) {
if _, err := client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
t.Fatalf("EmptyCall() failed: %v", err)
}
}
}
// Tests the case where the pick_first LB policy receives an error from the name
// resolver after receiving a good update (and the channel is currently in
// CONNECTING state). The test verifies that the channel continues to use the
// previously received good update, and that RPCs don't fail with the error
// received from the name resolver.
func (s) TestPickFirst_ResolverError_WithPreviousUpdate_Connecting(t *testing.T) {
lis, err := testutils.LocalTCPListener()
if err != nil {
t.Fatalf("net.Listen() failed: %v", err)
}
// Listen on a local port and act like a server that blocks until the
// channel reaches CONNECTING and closes the connection without sending a
// server preface.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
waitForConnecting := make(chan struct{})
go func() {
conn, err := lis.Accept()
if err != nil {
t.Errorf("Unexpected error when accepting a connection: %v", err)
}
defer conn.Close()
select {
case <-waitForConnecting:
case <-ctx.Done():
t.Error("Timeout when waiting for channel to move to CONNECTING state")
}
}()
r := manual.NewBuilderWithScheme("whatever")
dopts := []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithResolvers(r),
grpc.WithDefaultServiceConfig(pickFirstServiceConfig),
}
cc, err := grpc.NewClient(r.Scheme()+":///test.server", dopts...)
if err != nil {
t.Fatalf("grpc.NewClient() failed: %v", err)
}
t.Cleanup(func() { cc.Close() })
cc.Connect()
addrs := []resolver.Address{{Addr: lis.Addr().String()}}
r.UpdateState(resolver.State{Addresses: addrs})
testutils.AwaitState(ctx, t, cc, connectivity.Connecting)
nrErr := errors.New("error from name resolver")
r.CC().ReportError(nrErr)
// RPCs should fail with deadline exceed error as long as they are in
// CONNECTING and not the error returned by the name resolver.
client := testgrpc.NewTestServiceClient(cc)
sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout)
defer sCancel()
if _, err := client.EmptyCall(sCtx, &testpb.Empty{}); !strings.Contains(err.Error(), context.DeadlineExceeded.Error()) {
t.Fatalf("EmptyCall() failed with error: %v, want error: %v", err, context.DeadlineExceeded)
}
// Closing this channel leads to closing of the connection by our listener.
// gRPC should see this as a connection error.
close(waitForConnecting)
testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure)
checkForConnectionError(ctx, t, cc)
}
// Tests the case where the pick_first LB policy receives an error from the name
// resolver after receiving a good update. The previous good update though has
// seen the channel move to TRANSIENT_FAILURE. The test verifies that the
// channel fails RPCs with the new error from the resolver.
func (s) TestPickFirst_ResolverError_WithPreviousUpdate_TransientFailure(t *testing.T) {
lis, err := testutils.LocalTCPListener()
if err != nil {
t.Fatalf("net.Listen() failed: %v", err)
}
// Listen on a local port and act like a server that closes the connection
// without sending a server preface.
go func() {
conn, err := lis.Accept()
if err != nil {
t.Errorf("Unexpected error when accepting a connection: %v", err)
}
conn.Close()
}()
r := manual.NewBuilderWithScheme("whatever")
dopts := []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithResolvers(r),
grpc.WithDefaultServiceConfig(pickFirstServiceConfig),
}
cc, err := grpc.NewClient(r.Scheme()+":///test.server", dopts...)
if err != nil {
t.Fatalf("grpc.NewClient() failed: %v", err)
}
t.Cleanup(func() { cc.Close() })
cc.Connect()
addrs := []resolver.Address{{Addr: lis.Addr().String()}}
r.UpdateState(resolver.State{Addresses: addrs})
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure)
checkForConnectionError(ctx, t, cc)
// An error from the name resolver should result in RPCs failing with that
// error instead of the old error that caused the channel to move to
// TRANSIENT_FAILURE in the first place.
nrErr := errors.New("error from name resolver")
r.CC().ReportError(nrErr)
client := testgrpc.NewTestServiceClient(cc)
for ; ctx.Err() == nil; <-time.After(defaultTestShortTimeout) {
if _, err := client.EmptyCall(ctx, &testpb.Empty{}); strings.Contains(err.Error(), nrErr.Error()) {
break
}
}
if ctx.Err() != nil {
t.Fatal("Timeout when waiting for RPCs to fail with error returned by the name resolver")
}
}
func checkForConnectionError(ctx context.Context, t *testing.T, cc *grpc.ClientConn) {
t.Helper()
// RPCs may fail on the client side in two ways, once the fake server closes
// the accepted connection:
// - writing the client preface succeeds, but not reading the server preface
// - writing the client preface fails
// In either case, we should see it fail with UNAVAILABLE.
client := testgrpc.NewTestServiceClient(cc)
if _, err := client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Unavailable {
t.Fatalf("EmptyCall() failed with error: %v, want code %v", err, codes.Unavailable)
}
}
// Tests the case where the pick_first LB policy receives an update from the
// name resolver with no addresses after receiving a good update. The test
// verifies that the channel fails RPCs with an error indicating the fact that
// the name resolver returned no addresses.
func (s) TestPickFirst_ResolverError_ZeroAddresses_WithPreviousUpdate(t *testing.T) {
cc, r, backends := setupPickFirst(t, 1)
addrs := stubBackendsToResolverAddrs(backends)
r.UpdateState(resolver.State{Addresses: addrs})
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil {
t.Fatal(err)
}
r.UpdateState(resolver.State{})
wantErr := "produced zero addresses"
client := testgrpc.NewTestServiceClient(cc)
for ; ctx.Err() == nil; <-time.After(defaultTestShortTimeout) {
if _, err := client.EmptyCall(ctx, &testpb.Empty{}); strings.Contains(err.Error(), wantErr) {
break
}
}
if ctx.Err() != nil {
t.Fatal("Timeout when waiting for RPCs to fail with error returned by the name resolver")
}
}

View File

@ -1,132 +0,0 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package pickfirst
import (
"context"
"errors"
"fmt"
"testing"
"time"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/resolver"
)
const (
// Default timeout for tests in this package.
defaultTestTimeout = 10 * time.Second
// Default short timeout, to be used when waiting for events which are not
// expected to happen.
defaultTestShortTimeout = 100 * time.Millisecond
)
type s struct {
grpctest.Tester
}
func Test(t *testing.T) {
grpctest.RunSubTests(t, s{})
}
// TestPickFirst_InitialResolverError sends a resolver error to the balancer
// before a valid resolver update. It verifies that the clientconn state is
// updated to TRANSIENT_FAILURE.
func (s) TestPickFirst_InitialResolverError(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
cc := testutils.NewBalancerClientConn(t)
bal := balancer.Get(Name).Build(cc, balancer.BuildOptions{})
defer bal.Close()
bal.ResolverError(errors.New("resolution failed: test error"))
if err := cc.WaitForConnectivityState(ctx, connectivity.TransientFailure); err != nil {
t.Fatalf("cc.WaitForConnectivityState(%v) returned error: %v", connectivity.TransientFailure, err)
}
// After sending a valid update, the LB policy should report CONNECTING.
ccState := balancer.ClientConnState{
ResolverState: resolver.State{
Endpoints: []resolver.Endpoint{
{Addresses: []resolver.Address{{Addr: "1.1.1.1:1"}}},
{Addresses: []resolver.Address{{Addr: "2.2.2.2:2"}}},
},
},
}
if err := bal.UpdateClientConnState(ccState); err != nil {
t.Fatalf("UpdateClientConnState(%v) returned error: %v", ccState, err)
}
if err := cc.WaitForConnectivityState(ctx, connectivity.Connecting); err != nil {
t.Fatalf("cc.WaitForConnectivityState(%v) returned error: %v", connectivity.Connecting, err)
}
}
// TestPickFirst_ResolverErrorinTF sends a resolver error to the balancer
// before when it's attempting to connect to a SubConn TRANSIENT_FAILURE. It
// verifies that the picker is updated and the SubConn is not closed.
func (s) TestPickFirst_ResolverErrorinTF(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
cc := testutils.NewBalancerClientConn(t)
bal := balancer.Get(Name).Build(cc, balancer.BuildOptions{})
defer bal.Close()
// After sending a valid update, the LB policy should report CONNECTING.
ccState := balancer.ClientConnState{
ResolverState: resolver.State{
Endpoints: []resolver.Endpoint{
{Addresses: []resolver.Address{{Addr: "1.1.1.1:1"}}},
},
},
}
if err := bal.UpdateClientConnState(ccState); err != nil {
t.Fatalf("UpdateClientConnState(%v) returned error: %v", ccState, err)
}
sc1 := <-cc.NewSubConnCh
if err := cc.WaitForConnectivityState(ctx, connectivity.Connecting); err != nil {
t.Fatalf("cc.WaitForConnectivityState(%v) returned error: %v", connectivity.Connecting, err)
}
scErr := fmt.Errorf("test error: connection refused")
sc1.UpdateState(balancer.SubConnState{
ConnectivityState: connectivity.TransientFailure,
ConnectionError: scErr,
})
if err := cc.WaitForPickerWithErr(ctx, scErr); err != nil {
t.Fatalf("cc.WaitForPickerWithErr(%v) returned error: %v", scErr, err)
}
bal.ResolverError(errors.New("resolution failed: test error"))
if err := cc.WaitForErrPicker(ctx); err != nil {
t.Fatalf("cc.WaitForPickerWithErr() returned error: %v", err)
}
select {
case <-time.After(defaultTestShortTimeout):
case sc := <-cc.ShutdownSubConnCh:
t.Fatalf("Unexpected SubConn shutdown: %v", sc)
}
}

View File

@ -1,273 +0,0 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package pickfirstleaf_test
import (
"context"
"fmt"
"testing"
"google.golang.org/grpc"
"google.golang.org/grpc/balancer/pickfirst/pickfirstleaf"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/internal/testutils/stats"
testgrpc "google.golang.org/grpc/interop/grpc_testing"
testpb "google.golang.org/grpc/interop/grpc_testing"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
"google.golang.org/grpc/serviceconfig"
"google.golang.org/grpc/stats/opentelemetry"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/sdk/metric"
"go.opentelemetry.io/otel/sdk/metric/metricdata"
"go.opentelemetry.io/otel/sdk/metric/metricdata/metricdatatest"
)
var pfConfig string
func init() {
pfConfig = fmt.Sprintf(`{
"loadBalancingConfig": [
{
%q: {
}
}
]
}`, pickfirstleaf.Name)
}
// TestPickFirstMetrics tests pick first metrics. It configures a pick first
// balancer, causes it to connect and then disconnect, and expects the
// subsequent metrics to emit from that.
func (s) TestPickFirstMetrics(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
ss := &stubserver.StubServer{
EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
return &testpb.Empty{}, nil
},
}
ss.StartServer()
defer ss.Stop()
sc := internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(pfConfig)
r := manual.NewBuilderWithScheme("whatever")
r.InitialState(resolver.State{
ServiceConfig: sc,
Addresses: []resolver.Address{{Addr: ss.Address}}},
)
tmr := stats.NewTestMetricsRecorder()
cc, err := grpc.NewClient(r.Scheme()+":///", grpc.WithStatsHandler(tmr), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithResolvers(r))
if err != nil {
t.Fatalf("NewClient() failed with error: %v", err)
}
defer cc.Close()
tsc := testgrpc.NewTestServiceClient(cc)
if _, err := tsc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
t.Fatalf("EmptyCall() failed: %v", err)
}
if got, _ := tmr.Metric("grpc.lb.pick_first.connection_attempts_succeeded"); got != 1 {
t.Errorf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.pick_first.connection_attempts_succeeded", got, 1)
}
if got, _ := tmr.Metric("grpc.lb.pick_first.connection_attempts_failed"); got != 0 {
t.Errorf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.pick_first.connection_attempts_failed", got, 0)
}
if got, _ := tmr.Metric("grpc.lb.pick_first.disconnections"); got != 0 {
t.Errorf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.pick_first.disconnections", got, 0)
}
ss.Stop()
testutils.AwaitState(ctx, t, cc, connectivity.Idle)
if got, _ := tmr.Metric("grpc.lb.pick_first.disconnections"); got != 1 {
t.Errorf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.pick_first.disconnections", got, 1)
}
}
// TestPickFirstMetricsFailure tests the connection attempts failed metric. It
// configures a channel and scenario that causes a pick first connection attempt
// to fail, and then expects that metric to emit.
func (s) TestPickFirstMetricsFailure(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
sc := internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(pfConfig)
r := manual.NewBuilderWithScheme("whatever")
r.InitialState(resolver.State{
ServiceConfig: sc,
Addresses: []resolver.Address{{Addr: "bad address"}}},
)
grpcTarget := r.Scheme() + ":///"
tmr := stats.NewTestMetricsRecorder()
cc, err := grpc.NewClient(grpcTarget, grpc.WithStatsHandler(tmr), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithResolvers(r))
if err != nil {
t.Fatalf("NewClient() failed with error: %v", err)
}
defer cc.Close()
tsc := testgrpc.NewTestServiceClient(cc)
if _, err := tsc.EmptyCall(ctx, &testpb.Empty{}); err == nil {
t.Fatalf("EmptyCall() passed when expected to fail")
}
if got, _ := tmr.Metric("grpc.lb.pick_first.connection_attempts_succeeded"); got != 0 {
t.Errorf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.pick_first.connection_attempts_succeeded", got, 0)
}
if got, _ := tmr.Metric("grpc.lb.pick_first.connection_attempts_failed"); got != 1 {
t.Errorf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.pick_first.connection_attempts_failed", got, 1)
}
if got, _ := tmr.Metric("grpc.lb.pick_first.disconnections"); got != 0 {
t.Errorf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.pick_first.disconnections", got, 0)
}
}
// TestPickFirstMetricsE2E tests the pick first metrics end to end. It
// configures a channel with an OpenTelemetry plugin, induces all 3 pick first
// metrics to emit, and makes sure the correct OpenTelemetry metrics atoms emit.
func (s) TestPickFirstMetricsE2E(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
ss := &stubserver.StubServer{
EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
return &testpb.Empty{}, nil
},
}
ss.StartServer()
defer ss.Stop()
sc := internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(pfConfig)
r := manual.NewBuilderWithScheme("whatever")
r.InitialState(resolver.State{
ServiceConfig: sc,
Addresses: []resolver.Address{{Addr: "bad address"}}},
) // Will trigger connection failed.
grpcTarget := r.Scheme() + ":///"
reader := metric.NewManualReader()
provider := metric.NewMeterProvider(metric.WithReader(reader))
mo := opentelemetry.MetricsOptions{
MeterProvider: provider,
Metrics: opentelemetry.DefaultMetrics().Add("grpc.lb.pick_first.disconnections", "grpc.lb.pick_first.connection_attempts_succeeded", "grpc.lb.pick_first.connection_attempts_failed"),
}
cc, err := grpc.NewClient(grpcTarget, opentelemetry.DialOption(opentelemetry.Options{MetricsOptions: mo}), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithResolvers(r))
if err != nil {
t.Fatalf("NewClient() failed with error: %v", err)
}
defer cc.Close()
tsc := testgrpc.NewTestServiceClient(cc)
if _, err := tsc.EmptyCall(ctx, &testpb.Empty{}); err == nil {
t.Fatalf("EmptyCall() passed when expected to fail")
}
r.UpdateState(resolver.State{
ServiceConfig: sc,
Addresses: []resolver.Address{{Addr: ss.Address}},
}) // Will trigger successful connection metric.
if _, err := tsc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil {
t.Fatalf("EmptyCall() failed: %v", err)
}
// Stop the server, that should send signal to disconnect, which will
// eventually emit disconnection metric before ClientConn goes IDLE.
ss.Stop()
testutils.AwaitState(ctx, t, cc, connectivity.Idle)
wantMetrics := []metricdata.Metrics{
{
Name: "grpc.lb.pick_first.connection_attempts_succeeded",
Description: "EXPERIMENTAL. Number of successful connection attempts.",
Unit: "{attempt}",
Data: metricdata.Sum[int64]{
DataPoints: []metricdata.DataPoint[int64]{
{
Attributes: attribute.NewSet(attribute.String("grpc.target", grpcTarget)),
Value: 1,
},
},
Temporality: metricdata.CumulativeTemporality,
IsMonotonic: true,
},
},
{
Name: "grpc.lb.pick_first.connection_attempts_failed",
Description: "EXPERIMENTAL. Number of failed connection attempts.",
Unit: "{attempt}",
Data: metricdata.Sum[int64]{
DataPoints: []metricdata.DataPoint[int64]{
{
Attributes: attribute.NewSet(attribute.String("grpc.target", grpcTarget)),
Value: 1,
},
},
Temporality: metricdata.CumulativeTemporality,
IsMonotonic: true,
},
},
{
Name: "grpc.lb.pick_first.disconnections",
Description: "EXPERIMENTAL. Number of times the selected subchannel becomes disconnected.",
Unit: "{disconnection}",
Data: metricdata.Sum[int64]{
DataPoints: []metricdata.DataPoint[int64]{
{
Attributes: attribute.NewSet(attribute.String("grpc.target", grpcTarget)),
Value: 1,
},
},
Temporality: metricdata.CumulativeTemporality,
IsMonotonic: true,
},
},
}
gotMetrics := metricsDataFromReader(ctx, reader)
for _, metric := range wantMetrics {
val, ok := gotMetrics[metric.Name]
if !ok {
t.Fatalf("Metric %v not present in recorded metrics", metric.Name)
}
if !metricdatatest.AssertEqual(t, metric, val, metricdatatest.IgnoreTimestamp(), metricdatatest.IgnoreExemplars()) {
t.Fatalf("Metrics data type not equal for metric: %v", metric.Name)
}
}
}
func metricsDataFromReader(ctx context.Context, reader *metric.ManualReader) map[string]metricdata.Metrics {
rm := &metricdata.ResourceMetrics{}
reader.Collect(ctx, rm)
gotMetrics := map[string]metricdata.Metrics{}
for _, sm := range rm.ScopeMetrics {
for _, m := range sm.Metrics {
gotMetrics[m.Name] = m
}
}
return gotMetrics
}

View File

@ -1,906 +0,0 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package pickfirstleaf contains the pick_first load balancing policy which
// will be the universal leaf policy after dualstack changes are implemented.
//
// # Experimental
//
// Notice: This package is EXPERIMENTAL and may be changed or removed in a
// later release.
package pickfirstleaf
import (
"encoding/json"
"errors"
"fmt"
"net"
"net/netip"
"sync"
"time"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/pickfirst/internal"
"google.golang.org/grpc/connectivity"
expstats "google.golang.org/grpc/experimental/stats"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/envconfig"
internalgrpclog "google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/pretty"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig"
)
func init() {
if envconfig.NewPickFirstEnabled {
// Register as the default pick_first balancer.
Name = "pick_first"
}
balancer.Register(pickfirstBuilder{})
}
// enableHealthListenerKeyType is a unique key type used in resolver
// attributes to indicate whether the health listener usage is enabled.
type enableHealthListenerKeyType struct{}
var (
logger = grpclog.Component("pick-first-leaf-lb")
// Name is the name of the pick_first_leaf balancer.
// It is changed to "pick_first" in init() if this balancer is to be
// registered as the default pickfirst.
Name = "pick_first_leaf"
disconnectionsMetric = expstats.RegisterInt64Count(expstats.MetricDescriptor{
Name: "grpc.lb.pick_first.disconnections",
Description: "EXPERIMENTAL. Number of times the selected subchannel becomes disconnected.",
Unit: "{disconnection}",
Labels: []string{"grpc.target"},
Default: false,
})
connectionAttemptsSucceededMetric = expstats.RegisterInt64Count(expstats.MetricDescriptor{
Name: "grpc.lb.pick_first.connection_attempts_succeeded",
Description: "EXPERIMENTAL. Number of successful connection attempts.",
Unit: "{attempt}",
Labels: []string{"grpc.target"},
Default: false,
})
connectionAttemptsFailedMetric = expstats.RegisterInt64Count(expstats.MetricDescriptor{
Name: "grpc.lb.pick_first.connection_attempts_failed",
Description: "EXPERIMENTAL. Number of failed connection attempts.",
Unit: "{attempt}",
Labels: []string{"grpc.target"},
Default: false,
})
)
const (
// TODO: change to pick-first when this becomes the default pick_first policy.
logPrefix = "[pick-first-leaf-lb %p] "
// connectionDelayInterval is the time to wait for during the happy eyeballs
// pass before starting the next connection attempt.
connectionDelayInterval = 250 * time.Millisecond
)
type ipAddrFamily int
const (
// ipAddrFamilyUnknown represents strings that can't be parsed as an IP
// address.
ipAddrFamilyUnknown ipAddrFamily = iota
ipAddrFamilyV4
ipAddrFamilyV6
)
type pickfirstBuilder struct{}
func (pickfirstBuilder) Build(cc balancer.ClientConn, bo balancer.BuildOptions) balancer.Balancer {
b := &pickfirstBalancer{
cc: cc,
target: bo.Target.String(),
metricsRecorder: cc.MetricsRecorder(),
subConns: resolver.NewAddressMapV2[*scData](),
state: connectivity.Connecting,
cancelConnectionTimer: func() {},
}
b.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf(logPrefix, b))
return b
}
func (b pickfirstBuilder) Name() string {
return Name
}
func (pickfirstBuilder) ParseConfig(js json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
var cfg pfConfig
if err := json.Unmarshal(js, &cfg); err != nil {
return nil, fmt.Errorf("pickfirst: unable to unmarshal LB policy config: %s, error: %v", string(js), err)
}
return cfg, nil
}
// EnableHealthListener updates the state to configure pickfirst for using a
// generic health listener.
func EnableHealthListener(state resolver.State) resolver.State {
state.Attributes = state.Attributes.WithValue(enableHealthListenerKeyType{}, true)
return state
}
type pfConfig struct {
serviceconfig.LoadBalancingConfig `json:"-"`
// If set to true, instructs the LB policy to shuffle the order of the list
// of endpoints received from the name resolver before attempting to
// connect to them.
ShuffleAddressList bool `json:"shuffleAddressList"`
}
// scData keeps track of the current state of the subConn.
// It is not safe for concurrent access.
type scData struct {
// The following fields are initialized at build time and read-only after
// that.
subConn balancer.SubConn
addr resolver.Address
rawConnectivityState connectivity.State
// The effective connectivity state based on raw connectivity, health state
// and after following sticky TransientFailure behaviour defined in A62.
effectiveState connectivity.State
lastErr error
connectionFailedInFirstPass bool
}
func (b *pickfirstBalancer) newSCData(addr resolver.Address) (*scData, error) {
sd := &scData{
rawConnectivityState: connectivity.Idle,
effectiveState: connectivity.Idle,
addr: addr,
}
sc, err := b.cc.NewSubConn([]resolver.Address{addr}, balancer.NewSubConnOptions{
StateListener: func(state balancer.SubConnState) {
b.updateSubConnState(sd, state)
},
})
if err != nil {
return nil, err
}
sd.subConn = sc
return sd, nil
}
type pickfirstBalancer struct {
// The following fields are initialized at build time and read-only after
// that and therefore do not need to be guarded by a mutex.
logger *internalgrpclog.PrefixLogger
cc balancer.ClientConn
target string
metricsRecorder expstats.MetricsRecorder // guaranteed to be non nil
// The mutex is used to ensure synchronization of updates triggered
// from the idle picker and the already serialized resolver,
// SubConn state updates.
mu sync.Mutex
// State reported to the channel based on SubConn states and resolver
// updates.
state connectivity.State
// scData for active subonns mapped by address.
subConns *resolver.AddressMapV2[*scData]
addressList addressList
firstPass bool
numTF int
cancelConnectionTimer func()
healthCheckingEnabled bool
}
// ResolverError is called by the ClientConn when the name resolver produces
// an error or when pickfirst determined the resolver update to be invalid.
func (b *pickfirstBalancer) ResolverError(err error) {
b.mu.Lock()
defer b.mu.Unlock()
b.resolverErrorLocked(err)
}
func (b *pickfirstBalancer) resolverErrorLocked(err error) {
if b.logger.V(2) {
b.logger.Infof("Received error from the name resolver: %v", err)
}
// The picker will not change since the balancer does not currently
// report an error. If the balancer hasn't received a single good resolver
// update yet, transition to TRANSIENT_FAILURE.
if b.state != connectivity.TransientFailure && b.addressList.size() > 0 {
if b.logger.V(2) {
b.logger.Infof("Ignoring resolver error because balancer is using a previous good update.")
}
return
}
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.TransientFailure,
Picker: &picker{err: fmt.Errorf("name resolver error: %v", err)},
})
}
func (b *pickfirstBalancer) UpdateClientConnState(state balancer.ClientConnState) error {
b.mu.Lock()
defer b.mu.Unlock()
b.cancelConnectionTimer()
if len(state.ResolverState.Addresses) == 0 && len(state.ResolverState.Endpoints) == 0 {
// Cleanup state pertaining to the previous resolver state.
// Treat an empty address list like an error by calling b.ResolverError.
b.closeSubConnsLocked()
b.addressList.updateAddrs(nil)
b.resolverErrorLocked(errors.New("produced zero addresses"))
return balancer.ErrBadResolverState
}
b.healthCheckingEnabled = state.ResolverState.Attributes.Value(enableHealthListenerKeyType{}) != nil
cfg, ok := state.BalancerConfig.(pfConfig)
if state.BalancerConfig != nil && !ok {
return fmt.Errorf("pickfirst: received illegal BalancerConfig (type %T): %v: %w", state.BalancerConfig, state.BalancerConfig, balancer.ErrBadResolverState)
}
if b.logger.V(2) {
b.logger.Infof("Received new config %s, resolver state %s", pretty.ToJSON(cfg), pretty.ToJSON(state.ResolverState))
}
var newAddrs []resolver.Address
if endpoints := state.ResolverState.Endpoints; len(endpoints) != 0 {
// Perform the optional shuffling described in gRFC A62. The shuffling
// will change the order of endpoints but not touch the order of the
// addresses within each endpoint. - A61
if cfg.ShuffleAddressList {
endpoints = append([]resolver.Endpoint{}, endpoints...)
internal.RandShuffle(len(endpoints), func(i, j int) { endpoints[i], endpoints[j] = endpoints[j], endpoints[i] })
}
// "Flatten the list by concatenating the ordered list of addresses for
// each of the endpoints, in order." - A61
for _, endpoint := range endpoints {
newAddrs = append(newAddrs, endpoint.Addresses...)
}
} else {
// Endpoints not set, process addresses until we migrate resolver
// emissions fully to Endpoints. The top channel does wrap emitted
// addresses with endpoints, however some balancers such as weighted
// target do not forward the corresponding correct endpoints down/split
// endpoints properly. Once all balancers correctly forward endpoints
// down, can delete this else conditional.
newAddrs = state.ResolverState.Addresses
if cfg.ShuffleAddressList {
newAddrs = append([]resolver.Address{}, newAddrs...)
internal.RandShuffle(len(endpoints), func(i, j int) { endpoints[i], endpoints[j] = endpoints[j], endpoints[i] })
}
}
// If an address appears in multiple endpoints or in the same endpoint
// multiple times, we keep it only once. We will create only one SubConn
// for the address because an AddressMap is used to store SubConns.
// Not de-duplicating would result in attempting to connect to the same
// SubConn multiple times in the same pass. We don't want this.
newAddrs = deDupAddresses(newAddrs)
newAddrs = interleaveAddresses(newAddrs)
prevAddr := b.addressList.currentAddress()
prevSCData, found := b.subConns.Get(prevAddr)
prevAddrsCount := b.addressList.size()
isPrevRawConnectivityStateReady := found && prevSCData.rawConnectivityState == connectivity.Ready
b.addressList.updateAddrs(newAddrs)
// If the previous ready SubConn exists in new address list,
// keep this connection and don't create new SubConns.
if isPrevRawConnectivityStateReady && b.addressList.seekTo(prevAddr) {
return nil
}
b.reconcileSubConnsLocked(newAddrs)
// If it's the first resolver update or the balancer was already READY
// (but the new address list does not contain the ready SubConn) or
// CONNECTING, enter CONNECTING.
// We may be in TRANSIENT_FAILURE due to a previous empty address list,
// we should still enter CONNECTING because the sticky TF behaviour
// mentioned in A62 applies only when the TRANSIENT_FAILURE is reported
// due to connectivity failures.
if isPrevRawConnectivityStateReady || b.state == connectivity.Connecting || prevAddrsCount == 0 {
// Start connection attempt at first address.
b.forceUpdateConcludedStateLocked(balancer.State{
ConnectivityState: connectivity.Connecting,
Picker: &picker{err: balancer.ErrNoSubConnAvailable},
})
b.startFirstPassLocked()
} else if b.state == connectivity.TransientFailure {
// If we're in TRANSIENT_FAILURE, we stay in TRANSIENT_FAILURE until
// we're READY. See A62.
b.startFirstPassLocked()
}
return nil
}
// UpdateSubConnState is unused as a StateListener is always registered when
// creating SubConns.
func (b *pickfirstBalancer) UpdateSubConnState(subConn balancer.SubConn, state balancer.SubConnState) {
b.logger.Errorf("UpdateSubConnState(%v, %+v) called unexpectedly", subConn, state)
}
func (b *pickfirstBalancer) Close() {
b.mu.Lock()
defer b.mu.Unlock()
b.closeSubConnsLocked()
b.cancelConnectionTimer()
b.state = connectivity.Shutdown
}
// ExitIdle moves the balancer out of idle state. It can be called concurrently
// by the idlePicker and clientConn so access to variables should be
// synchronized.
func (b *pickfirstBalancer) ExitIdle() {
b.mu.Lock()
defer b.mu.Unlock()
if b.state == connectivity.Idle {
b.startFirstPassLocked()
}
}
func (b *pickfirstBalancer) startFirstPassLocked() {
b.firstPass = true
b.numTF = 0
// Reset the connection attempt record for existing SubConns.
for _, sd := range b.subConns.Values() {
sd.connectionFailedInFirstPass = false
}
b.requestConnectionLocked()
}
func (b *pickfirstBalancer) closeSubConnsLocked() {
for _, sd := range b.subConns.Values() {
sd.subConn.Shutdown()
}
b.subConns = resolver.NewAddressMapV2[*scData]()
}
// deDupAddresses ensures that each address appears only once in the slice.
func deDupAddresses(addrs []resolver.Address) []resolver.Address {
seenAddrs := resolver.NewAddressMapV2[*scData]()
retAddrs := []resolver.Address{}
for _, addr := range addrs {
if _, ok := seenAddrs.Get(addr); ok {
continue
}
retAddrs = append(retAddrs, addr)
}
return retAddrs
}
// interleaveAddresses interleaves addresses of both families (IPv4 and IPv6)
// as per RFC-8305 section 4.
// Whichever address family is first in the list is followed by an address of
// the other address family; that is, if the first address in the list is IPv6,
// then the first IPv4 address should be moved up in the list to be second in
// the list. It doesn't support configuring "First Address Family Count", i.e.
// there will always be a single member of the first address family at the
// beginning of the interleaved list.
// Addresses that are neither IPv4 nor IPv6 are treated as part of a third
// "unknown" family for interleaving.
// See: https://datatracker.ietf.org/doc/html/rfc8305#autoid-6
func interleaveAddresses(addrs []resolver.Address) []resolver.Address {
familyAddrsMap := map[ipAddrFamily][]resolver.Address{}
interleavingOrder := []ipAddrFamily{}
for _, addr := range addrs {
family := addressFamily(addr.Addr)
if _, found := familyAddrsMap[family]; !found {
interleavingOrder = append(interleavingOrder, family)
}
familyAddrsMap[family] = append(familyAddrsMap[family], addr)
}
interleavedAddrs := make([]resolver.Address, 0, len(addrs))
for curFamilyIdx := 0; len(interleavedAddrs) < len(addrs); curFamilyIdx = (curFamilyIdx + 1) % len(interleavingOrder) {
// Some IP types may have fewer addresses than others, so we look for
// the next type that has a remaining member to add to the interleaved
// list.
family := interleavingOrder[curFamilyIdx]
remainingMembers := familyAddrsMap[family]
if len(remainingMembers) > 0 {
interleavedAddrs = append(interleavedAddrs, remainingMembers[0])
familyAddrsMap[family] = remainingMembers[1:]
}
}
return interleavedAddrs
}
// addressFamily returns the ipAddrFamily after parsing the address string.
// If the address isn't of the format "ip-address:port", it returns
// ipAddrFamilyUnknown. The address may be valid even if it's not an IP when
// using a resolver like passthrough where the address may be a hostname in
// some format that the dialer can resolve.
func addressFamily(address string) ipAddrFamily {
// Parse the IP after removing the port.
host, _, err := net.SplitHostPort(address)
if err != nil {
return ipAddrFamilyUnknown
}
ip, err := netip.ParseAddr(host)
if err != nil {
return ipAddrFamilyUnknown
}
switch {
case ip.Is4() || ip.Is4In6():
return ipAddrFamilyV4
case ip.Is6():
return ipAddrFamilyV6
default:
return ipAddrFamilyUnknown
}
}
// reconcileSubConnsLocked updates the active subchannels based on a new address
// list from the resolver. It does this by:
// - closing subchannels: any existing subchannels associated with addresses
// that are no longer in the updated list are shut down.
// - removing subchannels: entries for these closed subchannels are removed
// from the subchannel map.
//
// This ensures that the subchannel map accurately reflects the current set of
// addresses received from the name resolver.
func (b *pickfirstBalancer) reconcileSubConnsLocked(newAddrs []resolver.Address) {
newAddrsMap := resolver.NewAddressMapV2[bool]()
for _, addr := range newAddrs {
newAddrsMap.Set(addr, true)
}
for _, oldAddr := range b.subConns.Keys() {
if _, ok := newAddrsMap.Get(oldAddr); ok {
continue
}
val, _ := b.subConns.Get(oldAddr)
val.subConn.Shutdown()
b.subConns.Delete(oldAddr)
}
}
// shutdownRemainingLocked shuts down remaining subConns. Called when a subConn
// becomes ready, which means that all other subConn must be shutdown.
func (b *pickfirstBalancer) shutdownRemainingLocked(selected *scData) {
b.cancelConnectionTimer()
for _, sd := range b.subConns.Values() {
if sd.subConn != selected.subConn {
sd.subConn.Shutdown()
}
}
b.subConns = resolver.NewAddressMapV2[*scData]()
b.subConns.Set(selected.addr, selected)
}
// requestConnectionLocked starts connecting on the subchannel corresponding to
// the current address. If no subchannel exists, one is created. If the current
// subchannel is in TransientFailure, a connection to the next address is
// attempted until a subchannel is found.
func (b *pickfirstBalancer) requestConnectionLocked() {
if !b.addressList.isValid() {
return
}
var lastErr error
for valid := true; valid; valid = b.addressList.increment() {
curAddr := b.addressList.currentAddress()
sd, ok := b.subConns.Get(curAddr)
if !ok {
var err error
// We want to assign the new scData to sd from the outer scope,
// hence we can't use := below.
sd, err = b.newSCData(curAddr)
if err != nil {
// This should never happen, unless the clientConn is being shut
// down.
if b.logger.V(2) {
b.logger.Infof("Failed to create a subConn for address %v: %v", curAddr.String(), err)
}
// Do nothing, the LB policy will be closed soon.
return
}
b.subConns.Set(curAddr, sd)
}
switch sd.rawConnectivityState {
case connectivity.Idle:
sd.subConn.Connect()
b.scheduleNextConnectionLocked()
return
case connectivity.TransientFailure:
// The SubConn is being re-used and failed during a previous pass
// over the addressList. It has not completed backoff yet.
// Mark it as having failed and try the next address.
sd.connectionFailedInFirstPass = true
lastErr = sd.lastErr
continue
case connectivity.Connecting:
// Wait for the connection attempt to complete or the timer to fire
// before attempting the next address.
b.scheduleNextConnectionLocked()
return
default:
b.logger.Errorf("SubConn with unexpected state %v present in SubConns map.", sd.rawConnectivityState)
return
}
}
// All the remaining addresses in the list are in TRANSIENT_FAILURE, end the
// first pass if possible.
b.endFirstPassIfPossibleLocked(lastErr)
}
func (b *pickfirstBalancer) scheduleNextConnectionLocked() {
b.cancelConnectionTimer()
if !b.addressList.hasNext() {
return
}
curAddr := b.addressList.currentAddress()
cancelled := false // Access to this is protected by the balancer's mutex.
closeFn := internal.TimeAfterFunc(connectionDelayInterval, func() {
b.mu.Lock()
defer b.mu.Unlock()
// If the scheduled task is cancelled while acquiring the mutex, return.
if cancelled {
return
}
if b.logger.V(2) {
b.logger.Infof("Happy Eyeballs timer expired while waiting for connection to %q.", curAddr.Addr)
}
if b.addressList.increment() {
b.requestConnectionLocked()
}
})
// Access to the cancellation callback held by the balancer is guarded by
// the balancer's mutex, so it's safe to set the boolean from the callback.
b.cancelConnectionTimer = sync.OnceFunc(func() {
cancelled = true
closeFn()
})
}
func (b *pickfirstBalancer) updateSubConnState(sd *scData, newState balancer.SubConnState) {
b.mu.Lock()
defer b.mu.Unlock()
oldState := sd.rawConnectivityState
sd.rawConnectivityState = newState.ConnectivityState
// Previously relevant SubConns can still callback with state updates.
// To prevent pickers from returning these obsolete SubConns, this logic
// is included to check if the current list of active SubConns includes this
// SubConn.
if !b.isActiveSCData(sd) {
return
}
if newState.ConnectivityState == connectivity.Shutdown {
sd.effectiveState = connectivity.Shutdown
return
}
// Record a connection attempt when exiting CONNECTING.
if newState.ConnectivityState == connectivity.TransientFailure {
sd.connectionFailedInFirstPass = true
connectionAttemptsFailedMetric.Record(b.metricsRecorder, 1, b.target)
}
if newState.ConnectivityState == connectivity.Ready {
connectionAttemptsSucceededMetric.Record(b.metricsRecorder, 1, b.target)
b.shutdownRemainingLocked(sd)
if !b.addressList.seekTo(sd.addr) {
// This should not fail as we should have only one SubConn after
// entering READY. The SubConn should be present in the addressList.
b.logger.Errorf("Address %q not found address list in %v", sd.addr, b.addressList.addresses)
return
}
if !b.healthCheckingEnabled {
if b.logger.V(2) {
b.logger.Infof("SubConn %p reported connectivity state READY and the health listener is disabled. Transitioning SubConn to READY.", sd.subConn)
}
sd.effectiveState = connectivity.Ready
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.Ready,
Picker: &picker{result: balancer.PickResult{SubConn: sd.subConn}},
})
return
}
if b.logger.V(2) {
b.logger.Infof("SubConn %p reported connectivity state READY. Registering health listener.", sd.subConn)
}
// Send a CONNECTING update to take the SubConn out of sticky-TF if
// required.
sd.effectiveState = connectivity.Connecting
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.Connecting,
Picker: &picker{err: balancer.ErrNoSubConnAvailable},
})
sd.subConn.RegisterHealthListener(func(scs balancer.SubConnState) {
b.updateSubConnHealthState(sd, scs)
})
return
}
// If the LB policy is READY, and it receives a subchannel state change,
// it means that the READY subchannel has failed.
// A SubConn can also transition from CONNECTING directly to IDLE when
// a transport is successfully created, but the connection fails
// before the SubConn can send the notification for READY. We treat
// this as a successful connection and transition to IDLE.
// TODO: https://github.com/grpc/grpc-go/issues/7862 - Remove the second
// part of the if condition below once the issue is fixed.
if oldState == connectivity.Ready || (oldState == connectivity.Connecting && newState.ConnectivityState == connectivity.Idle) {
// Once a transport fails, the balancer enters IDLE and starts from
// the first address when the picker is used.
b.shutdownRemainingLocked(sd)
sd.effectiveState = newState.ConnectivityState
// READY SubConn interspliced in between CONNECTING and IDLE, need to
// account for that.
if oldState == connectivity.Connecting {
// A known issue (https://github.com/grpc/grpc-go/issues/7862)
// causes a race that prevents the READY state change notification.
// This works around it.
connectionAttemptsSucceededMetric.Record(b.metricsRecorder, 1, b.target)
}
disconnectionsMetric.Record(b.metricsRecorder, 1, b.target)
b.addressList.reset()
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.Idle,
Picker: &idlePicker{exitIdle: sync.OnceFunc(b.ExitIdle)},
})
return
}
if b.firstPass {
switch newState.ConnectivityState {
case connectivity.Connecting:
// The effective state can be in either IDLE, CONNECTING or
// TRANSIENT_FAILURE. If it's TRANSIENT_FAILURE, stay in
// TRANSIENT_FAILURE until it's READY. See A62.
if sd.effectiveState != connectivity.TransientFailure {
sd.effectiveState = connectivity.Connecting
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.Connecting,
Picker: &picker{err: balancer.ErrNoSubConnAvailable},
})
}
case connectivity.TransientFailure:
sd.lastErr = newState.ConnectionError
sd.effectiveState = connectivity.TransientFailure
// Since we're re-using common SubConns while handling resolver
// updates, we could receive an out of turn TRANSIENT_FAILURE from
// a pass over the previous address list. Happy Eyeballs will also
// cause out of order updates to arrive.
if curAddr := b.addressList.currentAddress(); equalAddressIgnoringBalAttributes(&curAddr, &sd.addr) {
b.cancelConnectionTimer()
if b.addressList.increment() {
b.requestConnectionLocked()
return
}
}
// End the first pass if we've seen a TRANSIENT_FAILURE from all
// SubConns once.
b.endFirstPassIfPossibleLocked(newState.ConnectionError)
}
return
}
// We have finished the first pass, keep re-connecting failing SubConns.
switch newState.ConnectivityState {
case connectivity.TransientFailure:
b.numTF = (b.numTF + 1) % b.subConns.Len()
sd.lastErr = newState.ConnectionError
if b.numTF%b.subConns.Len() == 0 {
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.TransientFailure,
Picker: &picker{err: newState.ConnectionError},
})
}
// We don't need to request re-resolution since the SubConn already
// does that before reporting TRANSIENT_FAILURE.
// TODO: #7534 - Move re-resolution requests from SubConn into
// pick_first.
case connectivity.Idle:
sd.subConn.Connect()
}
}
// endFirstPassIfPossibleLocked ends the first happy-eyeballs pass if all the
// addresses are tried and their SubConns have reported a failure.
func (b *pickfirstBalancer) endFirstPassIfPossibleLocked(lastErr error) {
// An optimization to avoid iterating over the entire SubConn map.
if b.addressList.isValid() {
return
}
// Connect() has been called on all the SubConns. The first pass can be
// ended if all the SubConns have reported a failure.
for _, sd := range b.subConns.Values() {
if !sd.connectionFailedInFirstPass {
return
}
}
b.firstPass = false
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.TransientFailure,
Picker: &picker{err: lastErr},
})
// Start re-connecting all the SubConns that are already in IDLE.
for _, sd := range b.subConns.Values() {
if sd.rawConnectivityState == connectivity.Idle {
sd.subConn.Connect()
}
}
}
func (b *pickfirstBalancer) isActiveSCData(sd *scData) bool {
activeSD, found := b.subConns.Get(sd.addr)
return found && activeSD == sd
}
func (b *pickfirstBalancer) updateSubConnHealthState(sd *scData, state balancer.SubConnState) {
b.mu.Lock()
defer b.mu.Unlock()
// Previously relevant SubConns can still callback with state updates.
// To prevent pickers from returning these obsolete SubConns, this logic
// is included to check if the current list of active SubConns includes
// this SubConn.
if !b.isActiveSCData(sd) {
return
}
sd.effectiveState = state.ConnectivityState
switch state.ConnectivityState {
case connectivity.Ready:
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.Ready,
Picker: &picker{result: balancer.PickResult{SubConn: sd.subConn}},
})
case connectivity.TransientFailure:
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.TransientFailure,
Picker: &picker{err: fmt.Errorf("pickfirst: health check failure: %v", state.ConnectionError)},
})
case connectivity.Connecting:
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.Connecting,
Picker: &picker{err: balancer.ErrNoSubConnAvailable},
})
default:
b.logger.Errorf("Got unexpected health update for SubConn %p: %v", state)
}
}
// updateBalancerState stores the state reported to the channel and calls
// ClientConn.UpdateState(). As an optimization, it avoids sending duplicate
// updates to the channel.
func (b *pickfirstBalancer) updateBalancerState(newState balancer.State) {
// In case of TransientFailures allow the picker to be updated to update
// the connectivity error, in all other cases don't send duplicate state
// updates.
if newState.ConnectivityState == b.state && b.state != connectivity.TransientFailure {
return
}
b.forceUpdateConcludedStateLocked(newState)
}
// forceUpdateConcludedStateLocked stores the state reported to the channel and
// calls ClientConn.UpdateState().
// A separate function is defined to force update the ClientConn state since the
// channel doesn't correctly assume that LB policies start in CONNECTING and
// relies on LB policy to send an initial CONNECTING update.
func (b *pickfirstBalancer) forceUpdateConcludedStateLocked(newState balancer.State) {
b.state = newState.ConnectivityState
b.cc.UpdateState(newState)
}
type picker struct {
result balancer.PickResult
err error
}
func (p *picker) Pick(balancer.PickInfo) (balancer.PickResult, error) {
return p.result, p.err
}
// idlePicker is used when the SubConn is IDLE and kicks the SubConn into
// CONNECTING when Pick is called.
type idlePicker struct {
exitIdle func()
}
func (i *idlePicker) Pick(balancer.PickInfo) (balancer.PickResult, error) {
i.exitIdle()
return balancer.PickResult{}, balancer.ErrNoSubConnAvailable
}
// addressList manages sequentially iterating over addresses present in a list
// of endpoints. It provides a 1 dimensional view of the addresses present in
// the endpoints.
// This type is not safe for concurrent access.
type addressList struct {
addresses []resolver.Address
idx int
}
func (al *addressList) isValid() bool {
return al.idx < len(al.addresses)
}
func (al *addressList) size() int {
return len(al.addresses)
}
// increment moves to the next index in the address list.
// This method returns false if it went off the list, true otherwise.
func (al *addressList) increment() bool {
if !al.isValid() {
return false
}
al.idx++
return al.idx < len(al.addresses)
}
// currentAddress returns the current address pointed to in the addressList.
// If the list is in an invalid state, it returns an empty address instead.
func (al *addressList) currentAddress() resolver.Address {
if !al.isValid() {
return resolver.Address{}
}
return al.addresses[al.idx]
}
func (al *addressList) reset() {
al.idx = 0
}
func (al *addressList) updateAddrs(addrs []resolver.Address) {
al.addresses = addrs
al.reset()
}
// seekTo returns false if the needle was not found and the current index was
// left unchanged.
func (al *addressList) seekTo(needle resolver.Address) bool {
for ai, addr := range al.addresses {
if !equalAddressIgnoringBalAttributes(&addr, &needle) {
continue
}
al.idx = ai
return true
}
return false
}
// hasNext returns whether incrementing the addressList will result in moving
// past the end of the list. If the list has already moved past the end, it
// returns false.
func (al *addressList) hasNext() bool {
if !al.isValid() {
return false
}
return al.idx+1 < len(al.addresses)
}
// equalAddressIgnoringBalAttributes returns true is a and b are considered
// equal. This is different from the Equal method on the resolver.Address type
// which considers all fields to determine equality. Here, we only consider
// fields that are meaningful to the SubConn.
func equalAddressIgnoringBalAttributes(a, b *resolver.Address) bool {
return a.Addr == b.Addr && a.ServerName == b.ServerName &&
a.Attributes.Equal(b.Attributes)
}

File diff suppressed because it is too large Load Diff

View File

@ -1,246 +0,0 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package pickfirstleaf
import (
"context"
"fmt"
"testing"
"time"
"google.golang.org/grpc/attributes"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/resolver"
)
const (
// Default timeout for tests in this package.
defaultTestTimeout = 10 * time.Second
// Default short timeout, to be used when waiting for events which are not
// expected to happen.
defaultTestShortTimeout = 100 * time.Millisecond
)
type s struct {
grpctest.Tester
}
func Test(t *testing.T) {
grpctest.RunSubTests(t, s{})
}
// TestAddressList_Iteration verifies the behaviour of the addressList while
// iterating through the entries.
func (s) TestAddressList_Iteration(t *testing.T) {
addrs := []resolver.Address{
{
Addr: "192.168.1.1",
ServerName: "test-host-1",
Attributes: attributes.New("key-1", "val-1"),
BalancerAttributes: attributes.New("bal-key-1", "bal-val-1"),
},
{
Addr: "192.168.1.2",
ServerName: "test-host-2",
Attributes: attributes.New("key-2", "val-2"),
BalancerAttributes: attributes.New("bal-key-2", "bal-val-2"),
},
{
Addr: "192.168.1.3",
ServerName: "test-host-3",
Attributes: attributes.New("key-3", "val-3"),
BalancerAttributes: attributes.New("bal-key-3", "bal-val-3"),
},
}
addressList := addressList{}
addressList.updateAddrs(addrs)
for i := 0; i < len(addrs); i++ {
if got, want := addressList.isValid(), true; got != want {
t.Fatalf("addressList.isValid() = %t, want %t", got, want)
}
if got, want := addressList.currentAddress(), addrs[i]; !want.Equal(got) {
t.Errorf("addressList.currentAddress() = %v, want %v", got, want)
}
if got, want := addressList.increment(), i+1 < len(addrs); got != want {
t.Fatalf("addressList.increment() = %t, want %t", got, want)
}
}
if got, want := addressList.isValid(), false; got != want {
t.Fatalf("addressList.isValid() = %t, want %t", got, want)
}
// increment an invalid address list.
if got, want := addressList.increment(), false; got != want {
t.Errorf("addressList.increment() = %t, want %t", got, want)
}
if got, want := addressList.isValid(), false; got != want {
t.Errorf("addressList.isValid() = %t, want %t", got, want)
}
addressList.reset()
for i := 0; i < len(addrs); i++ {
if got, want := addressList.isValid(), true; got != want {
t.Fatalf("addressList.isValid() = %t, want %t", got, want)
}
if got, want := addressList.currentAddress(), addrs[i]; !want.Equal(got) {
t.Errorf("addressList.currentAddress() = %v, want %v", got, want)
}
if got, want := addressList.increment(), i+1 < len(addrs); got != want {
t.Fatalf("addressList.increment() = %t, want %t", got, want)
}
}
}
// TestAddressList_SeekTo verifies the behaviour of addressList.seekTo.
func (s) TestAddressList_SeekTo(t *testing.T) {
addrs := []resolver.Address{
{
Addr: "192.168.1.1",
ServerName: "test-host-1",
Attributes: attributes.New("key-1", "val-1"),
BalancerAttributes: attributes.New("bal-key-1", "bal-val-1"),
},
{
Addr: "192.168.1.2",
ServerName: "test-host-2",
Attributes: attributes.New("key-2", "val-2"),
BalancerAttributes: attributes.New("bal-key-2", "bal-val-2"),
},
{
Addr: "192.168.1.3",
ServerName: "test-host-3",
Attributes: attributes.New("key-3", "val-3"),
BalancerAttributes: attributes.New("bal-key-3", "bal-val-3"),
},
}
addressList := addressList{}
addressList.updateAddrs(addrs)
// Try finding an address in the list.
key := resolver.Address{
Addr: "192.168.1.2",
ServerName: "test-host-2",
Attributes: attributes.New("key-2", "val-2"),
BalancerAttributes: attributes.New("ignored", "bal-val-2"),
}
if got, want := addressList.seekTo(key), true; got != want {
t.Errorf("addressList.seekTo(%v) = %t, want %t", key, got, want)
}
// It should be possible to increment once more now that the pointer has advanced.
if got, want := addressList.increment(), true; got != want {
t.Errorf("addressList.increment() = %t, want %t", got, want)
}
if got, want := addressList.increment(), false; got != want {
t.Errorf("addressList.increment() = %t, want %t", got, want)
}
// Seek to the key again, it is behind the pointer now.
if got, want := addressList.seekTo(key), true; got != want {
t.Errorf("addressList.seekTo(%v) = %t, want %t", key, got, want)
}
// Seek to a key not in the list.
key = resolver.Address{
Addr: "192.168.1.5",
ServerName: "test-host-5",
Attributes: attributes.New("key-5", "val-5"),
BalancerAttributes: attributes.New("ignored", "bal-val-5"),
}
if got, want := addressList.seekTo(key), false; got != want {
t.Errorf("addressList.seekTo(%v) = %t, want %t", key, got, want)
}
// It should be possible to increment once more since the pointer has not advanced.
if got, want := addressList.increment(), true; got != want {
t.Errorf("addressList.increment() = %t, want %t", got, want)
}
if got, want := addressList.increment(), false; got != want {
t.Errorf("addressList.increment() = %t, want %t", got, want)
}
}
// TestPickFirstLeaf_TFPickerUpdate sends TRANSIENT_FAILURE SubConn state updates
// for each SubConn managed by a pickfirst balancer. It verifies that the picker
// is updated with the expected frequency.
func (s) TestPickFirstLeaf_TFPickerUpdate(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
cc := testutils.NewBalancerClientConn(t)
bal := pickfirstBuilder{}.Build(cc, balancer.BuildOptions{})
defer bal.Close()
ccState := balancer.ClientConnState{
ResolverState: resolver.State{
Endpoints: []resolver.Endpoint{
{Addresses: []resolver.Address{{Addr: "1.1.1.1:1"}}},
{Addresses: []resolver.Address{{Addr: "2.2.2.2:2"}}},
},
},
}
if err := bal.UpdateClientConnState(ccState); err != nil {
t.Fatalf("UpdateClientConnState(%v) returned error: %v", ccState, err)
}
// PF should report TRANSIENT_FAILURE only once all the sunbconns have failed
// once.
tfErr := fmt.Errorf("test err: connection refused")
sc1 := <-cc.NewSubConnCh
sc1.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Connecting})
sc1.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.TransientFailure, ConnectionError: tfErr})
if err := cc.WaitForPickerWithErr(ctx, balancer.ErrNoSubConnAvailable); err != nil {
t.Fatalf("cc.WaitForPickerWithErr(%v) returned error: %v", balancer.ErrNoSubConnAvailable, err)
}
sc2 := <-cc.NewSubConnCh
sc2.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Connecting})
sc2.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.TransientFailure, ConnectionError: tfErr})
if err := cc.WaitForPickerWithErr(ctx, tfErr); err != nil {
t.Fatalf("cc.WaitForPickerWithErr(%v) returned error: %v", tfErr, err)
}
// Subsequent TRANSIENT_FAILUREs should be reported only after seeing "# of SubConns"
// TRANSIENT_FAILUREs.
newTfErr := fmt.Errorf("test err: unreachable")
sc2.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.TransientFailure, ConnectionError: newTfErr})
select {
case <-time.After(defaultTestShortTimeout):
case p := <-cc.NewPickerCh:
sc, err := p.Pick(balancer.PickInfo{})
t.Fatalf("Unexpected picker update: %v, %v", sc, err)
}
sc2.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.TransientFailure, ConnectionError: newTfErr})
if err := cc.WaitForPickerWithErr(ctx, newTfErr); err != nil {
t.Fatalf("cc.WaitForPickerWithErr(%v) returned error: %v", newTfErr, err)
}
}

View File

@ -1,173 +0,0 @@
/*
*
* Copyright 2021 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package ringhash
import (
"encoding/json"
"testing"
"github.com/google/go-cmp/cmp"
"google.golang.org/grpc/internal/envconfig"
iringhash "google.golang.org/grpc/internal/ringhash"
"google.golang.org/grpc/internal/testutils"
)
func (s) TestParseConfig(t *testing.T) {
tests := []struct {
name string
js string
envConfigCap uint64
requestHeaderEnvVar bool
want *iringhash.LBConfig
wantErr bool
}{
{
name: "OK",
js: `{"minRingSize": 1, "maxRingSize": 2}`,
requestHeaderEnvVar: true,
want: &iringhash.LBConfig{MinRingSize: 1, MaxRingSize: 2},
},
{
name: "OK with default min",
js: `{"maxRingSize": 2000}`,
requestHeaderEnvVar: true,
want: &iringhash.LBConfig{MinRingSize: defaultMinSize, MaxRingSize: 2000},
},
{
name: "OK with default max",
js: `{"minRingSize": 2000}`,
requestHeaderEnvVar: true,
want: &iringhash.LBConfig{MinRingSize: 2000, MaxRingSize: defaultMaxSize},
},
{
name: "min greater than max",
js: `{"minRingSize": 10, "maxRingSize": 2}`,
requestHeaderEnvVar: true,
want: nil,
wantErr: true,
},
{
name: "min greater than max greater than global limit",
js: `{"minRingSize": 6000, "maxRingSize": 5000}`,
requestHeaderEnvVar: true,
want: nil,
wantErr: true,
},
{
name: "max greater than global limit",
js: `{"minRingSize": 1, "maxRingSize": 6000}`,
requestHeaderEnvVar: true,
want: &iringhash.LBConfig{MinRingSize: 1, MaxRingSize: 4096},
},
{
name: "min and max greater than global limit",
js: `{"minRingSize": 5000, "maxRingSize": 6000}`,
requestHeaderEnvVar: true,
want: &iringhash.LBConfig{MinRingSize: 4096, MaxRingSize: 4096},
},
{
name: "min and max less than raised global limit",
js: `{"minRingSize": 5000, "maxRingSize": 6000}`,
envConfigCap: 8000,
requestHeaderEnvVar: true,
want: &iringhash.LBConfig{MinRingSize: 5000, MaxRingSize: 6000},
},
{
name: "min and max greater than raised global limit",
js: `{"minRingSize": 10000, "maxRingSize": 10000}`,
envConfigCap: 8000,
requestHeaderEnvVar: true,
want: &iringhash.LBConfig{MinRingSize: 8000, MaxRingSize: 8000},
},
{
name: "min greater than upper bound",
js: `{"minRingSize": 8388610, "maxRingSize": 10}`,
requestHeaderEnvVar: true,
want: nil,
wantErr: true,
},
{
name: "max greater than upper bound",
js: `{"minRingSize": 10, "maxRingSize": 8388610}`,
requestHeaderEnvVar: true,
want: nil,
wantErr: true,
},
{
name: "request metadata key set",
js: `{"requestHashHeader": "x-foo"}`,
requestHeaderEnvVar: true,
want: &iringhash.LBConfig{
MinRingSize: defaultMinSize,
MaxRingSize: defaultMaxSize,
RequestHashHeader: "x-foo",
},
},
{
name: "request metadata key set with uppercase letters",
js: `{"requestHashHeader": "x-FOO"}`,
requestHeaderEnvVar: true,
want: &iringhash.LBConfig{
MinRingSize: defaultMinSize,
MaxRingSize: defaultMaxSize,
RequestHashHeader: "x-foo",
},
},
{
name: "invalid request hash header",
js: `{"requestHashHeader": "!invalid"}`,
requestHeaderEnvVar: true,
want: nil,
wantErr: true,
},
{
name: "binary request hash header",
js: `{"requestHashHeader": "header-with-bin"}`,
requestHeaderEnvVar: true,
want: nil,
wantErr: true,
},
{
name: "request hash header cleared when RingHashSetRequestHashKey env var is false",
js: `{"requestHashHeader": "x-foo"}`,
requestHeaderEnvVar: false,
want: &iringhash.LBConfig{
MinRingSize: defaultMinSize,
MaxRingSize: defaultMaxSize,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.envConfigCap != 0 {
testutils.SetEnvConfig(t, &envconfig.RingHashCap, tt.envConfigCap)
}
testutils.SetEnvConfig(t, &envconfig.RingHashSetRequestHashKey, tt.requestHeaderEnvVar)
got, err := parseConfig(json.RawMessage(tt.js))
if (err != nil) != tt.wantErr {
t.Errorf("parseConfig() error = %v, wantErr %v", err, tt.wantErr)
return
}
if diff := cmp.Diff(got, tt.want); diff != "" {
t.Errorf("parseConfig() got unexpected output, diff (-got +want): %v", diff)
}
})
}
}

View File

@ -1,124 +0,0 @@
/*
*
* Copyright 2021 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package ringhash
import (
"fmt"
"strings"
xxhash "github.com/cespare/xxhash/v2"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity"
iringhash "google.golang.org/grpc/internal/ringhash"
"google.golang.org/grpc/metadata"
)
type picker struct {
ring *ring
// endpointStates is a cache of endpoint states.
// The ringhash balancer stores endpoint states in a `resolver.EndpointMap`,
// with access guarded by `ringhashBalancer.mu`. The `endpointStates` cache
// in the picker helps avoid locking the ringhash balancer's mutex when
// reading the latest state at RPC time.
endpointStates map[string]endpointState // endpointState.hashKey -> endpointState
// requestHashHeader is the header key to look for the request hash. If it's
// empty, the request hash is expected to be set in the context via xDS.
// See gRFC A76.
requestHashHeader string
// hasEndpointInConnectingState is true if any of the endpoints is in
// CONNECTING.
hasEndpointInConnectingState bool
randUint64 func() uint64
}
func (p *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
usingRandomHash := false
var requestHash uint64
if p.requestHashHeader == "" {
var ok bool
if requestHash, ok = iringhash.XDSRequestHash(info.Ctx); !ok {
return balancer.PickResult{}, fmt.Errorf("ringhash: expected xDS config selector to set the request hash")
}
} else {
md, ok := metadata.FromOutgoingContext(info.Ctx)
if !ok || len(md.Get(p.requestHashHeader)) == 0 {
requestHash = p.randUint64()
usingRandomHash = true
} else {
values := strings.Join(md.Get(p.requestHashHeader), ",")
requestHash = xxhash.Sum64String(values)
}
}
e := p.ring.pick(requestHash)
ringSize := len(p.ring.items)
if !usingRandomHash {
// Per gRFC A61, because of sticky-TF with PickFirst's auto reconnect on TF,
// we ignore all TF subchannels and find the first ring entry in READY,
// CONNECTING or IDLE. If that entry is in IDLE, we need to initiate a
// connection. The idlePicker returned by the LazyLB or the new Pickfirst
// should do this automatically.
for i := 0; i < ringSize; i++ {
index := (e.idx + i) % ringSize
es := p.endpointState(p.ring.items[index])
switch es.state.ConnectivityState {
case connectivity.Ready, connectivity.Connecting, connectivity.Idle:
return es.state.Picker.Pick(info)
case connectivity.TransientFailure:
default:
panic(fmt.Sprintf("Found child balancer in unknown state: %v", es.state.ConnectivityState))
}
}
} else {
// If the picker has generated a random hash, it will walk the ring from
// this hash, and pick the first READY endpoint. If no endpoint is
// currently in CONNECTING state, it will trigger a connection attempt
// on at most one endpoint that is in IDLE state along the way. - A76
requestedConnection := p.hasEndpointInConnectingState
for i := 0; i < ringSize; i++ {
index := (e.idx + i) % ringSize
es := p.endpointState(p.ring.items[index])
if es.state.ConnectivityState == connectivity.Ready {
return es.state.Picker.Pick(info)
}
if !requestedConnection && es.state.ConnectivityState == connectivity.Idle {
requestedConnection = true
// If the SubChannel is in idle state, initiate a connection but
// continue to check other pickers to see if there is one in
// ready state.
es.balancer.ExitIdle()
}
}
if requestedConnection {
return balancer.PickResult{}, balancer.ErrNoSubConnAvailable
}
}
// All children are in transient failure. Return the first failure.
return p.endpointState(e).state.Picker.Pick(info)
}
func (p *picker) endpointState(e *ringEntry) endpointState {
return p.endpointStates[e.hashKey]
}

View File

@ -1,311 +0,0 @@
/*
*
* Copyright 2021 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package ringhash
import (
"context"
"errors"
"fmt"
"math"
"testing"
"time"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity"
iringhash "google.golang.org/grpc/internal/ringhash"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/metadata"
)
var (
testSubConns []*testutils.TestSubConn
errPicker = errors.New("picker in TransientFailure")
)
func init() {
for i := 0; i < 8; i++ {
testSubConns = append(testSubConns, testutils.NewTestSubConn(fmt.Sprint(i)))
}
}
// fakeChildPicker is used to mock pickers from child pickfirst balancers.
type fakeChildPicker struct {
connectivityState connectivity.State
subConn *testutils.TestSubConn
tfError error
}
func (p *fakeChildPicker) Pick(balancer.PickInfo) (balancer.PickResult, error) {
switch p.connectivityState {
case connectivity.Idle:
p.subConn.Connect()
return balancer.PickResult{}, balancer.ErrNoSubConnAvailable
case connectivity.Connecting:
return balancer.PickResult{}, balancer.ErrNoSubConnAvailable
case connectivity.Ready:
return balancer.PickResult{SubConn: p.subConn}, nil
default:
return balancer.PickResult{}, p.tfError
}
}
type fakeExitIdler struct {
sc *testutils.TestSubConn
}
func (ei *fakeExitIdler) ExitIdle() {
ei.sc.Connect()
}
func testRingAndEndpointStates(states []connectivity.State) (*ring, map[string]endpointState) {
var items []*ringEntry
epStates := map[string]endpointState{}
for i, st := range states {
testSC := testSubConns[i]
items = append(items, &ringEntry{
idx: i,
hash: math.MaxUint64 / uint64(len(states)) * uint64(i),
hashKey: testSC.String(),
})
epState := endpointState{
state: balancer.State{
ConnectivityState: st,
Picker: &fakeChildPicker{
connectivityState: st,
tfError: fmt.Errorf("%d: %w", i, errPicker),
subConn: testSC,
},
},
balancer: &fakeExitIdler{
sc: testSC,
},
}
epStates[testSC.String()] = epState
}
return &ring{items: items}, epStates
}
func (s) TestPickerPickFirstTwo(t *testing.T) {
tests := []struct {
name string
connectivityStates []connectivity.State
wantSC balancer.SubConn
wantErr error
wantSCToConnect balancer.SubConn
}{
{
name: "picked is Ready",
connectivityStates: []connectivity.State{connectivity.Ready, connectivity.Idle},
wantSC: testSubConns[0],
},
{
name: "picked is connecting, queue",
connectivityStates: []connectivity.State{connectivity.Connecting, connectivity.Idle},
wantErr: balancer.ErrNoSubConnAvailable,
},
{
name: "picked is Idle, connect and queue",
connectivityStates: []connectivity.State{connectivity.Idle, connectivity.Idle},
wantErr: balancer.ErrNoSubConnAvailable,
wantSCToConnect: testSubConns[0],
},
{
name: "picked is TransientFailure, next is ready, return",
connectivityStates: []connectivity.State{connectivity.TransientFailure, connectivity.Ready},
wantSC: testSubConns[1],
},
{
name: "picked is TransientFailure, next is connecting, queue",
connectivityStates: []connectivity.State{connectivity.TransientFailure, connectivity.Connecting},
wantErr: balancer.ErrNoSubConnAvailable,
},
{
name: "picked is TransientFailure, next is Idle, connect and queue",
connectivityStates: []connectivity.State{connectivity.TransientFailure, connectivity.Idle},
wantErr: balancer.ErrNoSubConnAvailable,
wantSCToConnect: testSubConns[1],
},
{
name: "all are in TransientFailure, return picked failure",
connectivityStates: []connectivity.State{connectivity.TransientFailure, connectivity.TransientFailure},
wantErr: errPicker,
},
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ring, epStates := testRingAndEndpointStates(tt.connectivityStates)
p := &picker{
ring: ring,
endpointStates: epStates,
}
got, err := p.Pick(balancer.PickInfo{
Ctx: iringhash.SetXDSRequestHash(ctx, 0), // always pick the first endpoint on the ring.
})
if (err != nil || tt.wantErr != nil) && !errors.Is(err, tt.wantErr) {
t.Errorf("Pick() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got.SubConn != tt.wantSC {
t.Errorf("Pick() got = %v, want picked SubConn: %v", got, tt.wantSC)
}
if sc := tt.wantSCToConnect; sc != nil {
select {
case <-sc.(*testutils.TestSubConn).ConnectCh:
case <-time.After(defaultTestShortTimeout):
t.Errorf("timeout waiting for Connect() from SubConn %v", sc)
}
}
})
}
}
func (s) TestPickerNoRequestHash(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
ring, epStates := testRingAndEndpointStates([]connectivity.State{connectivity.Ready})
p := &picker{
ring: ring,
endpointStates: epStates,
}
if _, err := p.Pick(balancer.PickInfo{Ctx: ctx}); err == nil {
t.Errorf("Pick() should have failed with no request hash")
}
}
func (s) TestPickerRequestHashKey(t *testing.T) {
tests := []struct {
name string
headerValues []string
expectedPick int
}{
{
name: "header not set",
expectedPick: 0, // Random hash set to 0, which is within (MaxUint64 / 3 * 2, 0]
},
{
name: "header empty",
headerValues: []string{""},
expectedPick: 0, // xxhash.Sum64String("value1,value2") is within (MaxUint64 / 3 * 2, 0]
},
{
name: "header set to one value",
headerValues: []string{"some-value"},
expectedPick: 1, // xxhash.Sum64String("some-value") is within (0, MaxUint64 / 3]
},
{
name: "header set to multiple values",
headerValues: []string{"value1", "value2"},
expectedPick: 2, // xxhash.Sum64String("value1,value2") is within (MaxUint64 / 3, MaxUint64 / 3 * 2]
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
ring, epStates := testRingAndEndpointStates(
[]connectivity.State{
connectivity.Ready,
connectivity.Ready,
connectivity.Ready,
})
headerName := "some-header"
p := &picker{
ring: ring,
endpointStates: epStates,
requestHashHeader: headerName,
randUint64: func() uint64 { return 0 },
}
for _, v := range tt.headerValues {
ctx = metadata.AppendToOutgoingContext(ctx, headerName, v)
}
if res, err := p.Pick(balancer.PickInfo{Ctx: ctx}); err != nil {
t.Errorf("Pick() failed: %v", err)
} else if res.SubConn != testSubConns[tt.expectedPick] {
t.Errorf("Pick() got = %v, want SubConn: %v", res.SubConn, testSubConns[tt.expectedPick])
}
})
}
}
func (s) TestPickerRandomHash(t *testing.T) {
tests := []struct {
name string
hash uint64
connectivityStates []connectivity.State
wantSC balancer.SubConn
wantErr error
wantSCToConnect balancer.SubConn
hasEndpointInConnectingState bool
}{
{
name: "header not set, picked is Ready",
connectivityStates: []connectivity.State{connectivity.Ready, connectivity.Idle},
wantSC: testSubConns[0],
},
{
name: "header not set, picked is Idle, another is Ready. Connect and pick Ready",
connectivityStates: []connectivity.State{connectivity.Idle, connectivity.Ready},
wantSC: testSubConns[1],
wantSCToConnect: testSubConns[0],
},
{
name: "header not set, picked is Idle, there is at least one Connecting",
connectivityStates: []connectivity.State{connectivity.Connecting, connectivity.Idle},
wantErr: balancer.ErrNoSubConnAvailable,
hasEndpointInConnectingState: true,
},
{
name: "header not set, all Idle or TransientFailure, connect",
connectivityStates: []connectivity.State{connectivity.TransientFailure, connectivity.Idle},
wantErr: balancer.ErrNoSubConnAvailable,
wantSCToConnect: testSubConns[1],
},
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ring, epStates := testRingAndEndpointStates(tt.connectivityStates)
p := &picker{
ring: ring,
endpointStates: epStates,
requestHashHeader: "some-header",
hasEndpointInConnectingState: tt.hasEndpointInConnectingState,
randUint64: func() uint64 { return 0 }, // always return the first endpoint on the ring.
}
if got, err := p.Pick(balancer.PickInfo{Ctx: ctx}); err != tt.wantErr {
t.Errorf("Pick() error = %v, wantErr %v", err, tt.wantErr)
return
} else if got.SubConn != tt.wantSC {
t.Errorf("Pick() got = %v, want picked SubConn: %v", got, tt.wantSC)
}
if sc := tt.wantSCToConnect; sc != nil {
select {
case <-sc.(*testutils.TestSubConn).ConnectCh:
case <-time.After(defaultTestShortTimeout):
t.Errorf("timeout waiting for Connect() from SubConn %v", sc)
}
}
})
}
}

View File

@ -1,408 +0,0 @@
/*
*
* Copyright 2021 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package ringhash implements the ringhash balancer. See the following
// gRFCs for details:
// - https://github.com/grpc/proposal/blob/master/A42-xds-ring-hash-lb-policy.md
// - https://github.com/grpc/proposal/blob/master/A61-IPv4-IPv6-dualstack-backends.md#ring-hash
// - https://github.com/grpc/proposal/blob/master/A76-ring-hash-improvements.md
//
// # Experimental
//
// Notice: This package is EXPERIMENTAL and may be changed or removed in a
// later release.
package ringhash
import (
"encoding/json"
"errors"
"fmt"
"math/rand/v2"
"sort"
"sync"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/base"
"google.golang.org/grpc/balancer/endpointsharding"
"google.golang.org/grpc/balancer/lazy"
"google.golang.org/grpc/balancer/pickfirst/pickfirstleaf"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/internal/balancer/weight"
"google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/pretty"
iringhash "google.golang.org/grpc/internal/ringhash"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/ringhash"
"google.golang.org/grpc/serviceconfig"
)
// Name is the name of the ring_hash balancer.
const Name = "ring_hash_experimental"
func lazyPickFirstBuilder(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
return lazy.NewBalancer(cc, opts, balancer.Get(pickfirstleaf.Name).Build)
}
func init() {
balancer.Register(bb{})
}
type bb struct{}
func (bb) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
b := &ringhashBalancer{
ClientConn: cc,
endpointStates: resolver.NewEndpointMap[*endpointState](),
}
esOpts := endpointsharding.Options{DisableAutoReconnect: true}
b.child = endpointsharding.NewBalancer(b, opts, lazyPickFirstBuilder, esOpts)
b.logger = prefixLogger(b)
b.logger.Infof("Created")
return b
}
func (bb) Name() string {
return Name
}
func (bb) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
return parseConfig(c)
}
type ringhashBalancer struct {
// The following fields are initialized at build time and read-only after
// that and therefore do not need to be guarded by a mutex.
// ClientConn is embedded to intercept UpdateState calls from the child
// endpointsharding balancer.
balancer.ClientConn
logger *grpclog.PrefixLogger
child balancer.Balancer
mu sync.Mutex
config *iringhash.LBConfig
inhibitChildUpdates bool
shouldRegenerateRing bool
endpointStates *resolver.EndpointMap[*endpointState]
// ring is always in sync with endpoints. When endpoints change, a new ring
// is generated. Note that address weights updates also regenerates the
// ring.
ring *ring
}
// hashKey returns the hash key to use for an endpoint. Per gRFC A61, each entry
// in the ring is a hash of the endpoint's hash key concatenated with a
// per-entry unique suffix.
func hashKey(endpoint resolver.Endpoint) string {
if hk := ringhash.HashKey(endpoint); hk != "" {
return hk
}
// If no hash key is set, use the endpoint's first address as the hash key.
// This is the default behavior when no hash key is set.
return endpoint.Addresses[0].Addr
}
// UpdateState intercepts child balancer state updates. It updates the
// per-endpoint state stored in the ring, and also the aggregated state based on
// the child picker. It also reconciles the endpoint list. It sets
// `b.shouldRegenerateRing` to true if the new endpoint list is different from
// the previous, i.e. any of the following is true:
// - an endpoint was added
// - an endpoint was removed
// - an endpoint's weight was updated
// - the first addresses of the endpoint has changed
func (b *ringhashBalancer) UpdateState(state balancer.State) {
b.mu.Lock()
defer b.mu.Unlock()
childStates := endpointsharding.ChildStatesFromPicker(state.Picker)
// endpointsSet is the set converted from endpoints, used for quick lookup.
endpointsSet := resolver.NewEndpointMap[bool]()
for _, childState := range childStates {
endpoint := childState.Endpoint
endpointsSet.Set(endpoint, true)
newWeight := getWeightAttribute(endpoint)
hk := hashKey(endpoint)
es, ok := b.endpointStates.Get(endpoint)
if !ok {
es := &endpointState{
balancer: childState.Balancer,
hashKey: hk,
weight: newWeight,
state: childState.State,
}
b.endpointStates.Set(endpoint, es)
b.shouldRegenerateRing = true
} else {
// We have seen this endpoint before and created a `endpointState`
// object for it. If the weight or the hash key of the endpoint has
// changed, update the endpoint state map with the new weight or
// hash key. This will be used when a new ring is created.
if oldWeight := es.weight; oldWeight != newWeight {
b.shouldRegenerateRing = true
es.weight = newWeight
}
if es.hashKey != hk {
b.shouldRegenerateRing = true
es.hashKey = hk
}
es.state = childState.State
}
}
for _, endpoint := range b.endpointStates.Keys() {
if _, ok := endpointsSet.Get(endpoint); ok {
continue
}
// endpoint was removed by resolver.
b.endpointStates.Delete(endpoint)
b.shouldRegenerateRing = true
}
b.updatePickerLocked()
}
func (b *ringhashBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error {
if b.logger.V(2) {
b.logger.Infof("Received update from resolver, balancer config: %+v", pretty.ToJSON(ccs.BalancerConfig))
}
newConfig, ok := ccs.BalancerConfig.(*iringhash.LBConfig)
if !ok {
return fmt.Errorf("unexpected balancer config with type: %T", ccs.BalancerConfig)
}
b.mu.Lock()
b.inhibitChildUpdates = true
b.mu.Unlock()
defer func() {
b.mu.Lock()
b.inhibitChildUpdates = false
b.updatePickerLocked()
b.mu.Unlock()
}()
if err := b.child.UpdateClientConnState(balancer.ClientConnState{
// Make pickfirst children use health listeners for outlier detection
// and health checking to work.
ResolverState: pickfirstleaf.EnableHealthListener(ccs.ResolverState),
}); err != nil {
return err
}
b.mu.Lock()
// Ring updates can happen due to the following:
// 1. Addition or deletion of endpoints: The synchronous picker update from
// the child endpointsharding balancer would contain the list of updated
// endpoints. Updates triggered by the child after handling the
// `UpdateClientConnState` call will not change the endpoint list.
// 2. Change in the `LoadBalancerConfig`: Ring config such as max/min ring
// size.
// To avoid extra ring updates, a boolean is used to track the need for a
// ring update and the update is done only once at the end.
//
// If the ring configuration has changed, we need to regenerate the ring
// while sending a new picker.
if b.config == nil || b.config.MinRingSize != newConfig.MinRingSize || b.config.MaxRingSize != newConfig.MaxRingSize {
b.shouldRegenerateRing = true
}
b.config = newConfig
b.mu.Unlock()
return nil
}
func (b *ringhashBalancer) ResolverError(err error) {
b.child.ResolverError(err)
}
func (b *ringhashBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) {
b.logger.Errorf("UpdateSubConnState(%v, %+v) called unexpectedly", sc, state)
}
func (b *ringhashBalancer) updatePickerLocked() {
state := b.aggregatedStateLocked()
// Start connecting to new endpoints if necessary.
if state == connectivity.Connecting || state == connectivity.TransientFailure {
// When overall state is TransientFailure, we need to make sure at least
// one endpoint is attempting to connect, otherwise this balancer may
// never get picks if the parent is priority.
//
// Because we report Connecting as the overall state when only one
// endpoint is in TransientFailure, we do the same check for Connecting
// here.
//
// Note that this check also covers deleting endpoints. E.g. if the
// endpoint attempting to connect is deleted, and the overall state is
// TF. Since there must be at least one endpoint attempting to connect,
// we need to trigger one.
//
// After calling `ExitIdle` on a child balancer, the child will send a
// picker update asynchronously. A race condition may occur if another
// picker update from endpointsharding arrives before the child's
// picker update. The received picker may trigger a re-execution of the
// loop below to find an idle child. Since map iteration order is
// non-deterministic, the list of `endpointState`s must be sorted to
// ensure `ExitIdle` is called on the same child, preventing unnecessary
// connections.
var endpointStates = make([]*endpointState, b.endpointStates.Len())
for i, s := range b.endpointStates.Values() {
endpointStates[i] = s
}
sort.Slice(endpointStates, func(i, j int) bool {
return endpointStates[i].hashKey < endpointStates[j].hashKey
})
var idleBalancer endpointsharding.ExitIdler
for _, es := range endpointStates {
connState := es.state.ConnectivityState
if connState == connectivity.Connecting {
idleBalancer = nil
break
}
if idleBalancer == nil && connState == connectivity.Idle {
idleBalancer = es.balancer
}
}
if idleBalancer != nil {
idleBalancer.ExitIdle()
}
}
if b.inhibitChildUpdates {
return
}
// Update the channel.
if b.endpointStates.Len() > 0 && b.shouldRegenerateRing {
// with a non-empty list of endpoints.
b.ring = newRing(b.endpointStates, b.config.MinRingSize, b.config.MaxRingSize, b.logger)
}
b.shouldRegenerateRing = false
var newPicker balancer.Picker
if b.endpointStates.Len() == 0 {
newPicker = base.NewErrPicker(errors.New("produced zero addresses"))
} else {
newPicker = b.newPickerLocked()
}
b.ClientConn.UpdateState(balancer.State{
ConnectivityState: state,
Picker: newPicker,
})
}
func (b *ringhashBalancer) Close() {
b.logger.Infof("Shutdown")
b.child.Close()
}
func (b *ringhashBalancer) ExitIdle() {
// ExitIdle implementation is a no-op because connections are either
// triggers from picks or from child balancer state changes.
}
// newPickerLocked generates a picker. The picker copies the endpoint states
// over to avoid locking the mutex at RPC time. The picker should be
// re-generated every time an endpoint state is updated.
func (b *ringhashBalancer) newPickerLocked() *picker {
states := make(map[string]endpointState)
hasEndpointConnecting := false
for _, epState := range b.endpointStates.Values() {
// Copy the endpoint state to avoid races, since ring hash
// mutates the state, weight and hash key in place.
states[epState.hashKey] = *epState
if epState.state.ConnectivityState == connectivity.Connecting {
hasEndpointConnecting = true
}
}
return &picker{
ring: b.ring,
endpointStates: states,
requestHashHeader: b.config.RequestHashHeader,
hasEndpointInConnectingState: hasEndpointConnecting,
randUint64: rand.Uint64,
}
}
// aggregatedStateLocked returns the aggregated child balancers state
// based on the following rules.
// - If there is at least one endpoint in READY state, report READY.
// - If there are 2 or more endpoints in TRANSIENT_FAILURE state, report
// TRANSIENT_FAILURE.
// - If there is at least one endpoint in CONNECTING state, report CONNECTING.
// - If there is one endpoint in TRANSIENT_FAILURE and there is more than one
// endpoint, report state CONNECTING.
// - If there is at least one endpoint in Idle state, report Idle.
// - Otherwise, report TRANSIENT_FAILURE.
//
// Note that if there are 1 connecting, 2 transient failure, the overall state
// is transient failure. This is because the second transient failure is a
// fallback of the first failing endpoint, and we want to report transient
// failure to failover to the lower priority.
func (b *ringhashBalancer) aggregatedStateLocked() connectivity.State {
var nums [5]int
for _, es := range b.endpointStates.Values() {
nums[es.state.ConnectivityState]++
}
if nums[connectivity.Ready] > 0 {
return connectivity.Ready
}
if nums[connectivity.TransientFailure] > 1 {
return connectivity.TransientFailure
}
if nums[connectivity.Connecting] > 0 {
return connectivity.Connecting
}
if nums[connectivity.TransientFailure] == 1 && b.endpointStates.Len() > 1 {
return connectivity.Connecting
}
if nums[connectivity.Idle] > 0 {
return connectivity.Idle
}
return connectivity.TransientFailure
}
// getWeightAttribute is a convenience function which returns the value of the
// weight endpoint Attribute.
//
// When used in the xDS context, the weight attribute is guaranteed to be
// non-zero. But, when used in a non-xDS context, the weight attribute could be
// unset. A Default of 1 is used in the latter case.
func getWeightAttribute(e resolver.Endpoint) uint32 {
w := weight.FromEndpoint(e).Weight
if w == 0 {
return 1
}
return w
}
type endpointState struct {
// hashKey is the hash key of the endpoint. Per gRFC A61, each entry in the
// ring is an endpoint, positioned based on the hash of the endpoint's first
// address by default. Per gRFC A76, the hash key of an endpoint may be
// overridden, for example based on EDS endpoint metadata.
hashKey string
weight uint32
balancer endpointsharding.ExitIdler
// state is updated by the balancer while receiving resolver updates from
// the channel and picker updates from its children. Access to it is guarded
// by ringhashBalancer.mu.
state balancer.State
}

File diff suppressed because it is too large Load Diff

View File

@ -1,737 +0,0 @@
/*
*
* Copyright 2021 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package ringhash
import (
"context"
"fmt"
"testing"
"time"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/internal/balancer/weight"
"google.golang.org/grpc/internal/grpctest"
iringhash "google.golang.org/grpc/internal/ringhash"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/resolver"
)
const (
defaultTestTimeout = 10 * time.Second
defaultTestShortTimeout = 10 * time.Millisecond
testBackendAddrsCount = 12
)
var (
testBackendAddrStrs []string
testConfig = &iringhash.LBConfig{MinRingSize: 1, MaxRingSize: 10}
)
func init() {
for i := 0; i < testBackendAddrsCount; i++ {
testBackendAddrStrs = append(testBackendAddrStrs, fmt.Sprintf("%d.%d.%d.%d:%d", i, i, i, i, i))
}
}
// setupTest creates the balancer, and does an initial sanity check.
func setupTest(t *testing.T, endpoints []resolver.Endpoint) (*testutils.BalancerClientConn, balancer.Balancer, balancer.Picker) {
t.Helper()
cc := testutils.NewBalancerClientConn(t)
builder := balancer.Get(Name)
b := builder.Build(cc, balancer.BuildOptions{})
if b == nil {
t.Fatalf("builder.Build(%s) failed and returned nil", Name)
}
if err := b.UpdateClientConnState(balancer.ClientConnState{
ResolverState: resolver.State{Endpoints: endpoints},
BalancerConfig: testConfig,
}); err != nil {
t.Fatalf("UpdateClientConnState returned err: %v", err)
}
// The leaf pickfirst are created lazily, only when their endpoint is picked
// or other endpoints are in TF. No SubConns should be created immediately.
select {
case sc := <-cc.NewSubConnCh:
t.Errorf("unexpected SubConn creation: %v", sc)
case <-time.After(defaultTestShortTimeout):
}
// Should also have a picker, with all endpoints in Idle.
p1 := <-cc.NewPickerCh
ringHashPicker := p1.(*picker)
if got, want := len(ringHashPicker.endpointStates), len(endpoints); got != want {
t.Errorf("Number of child balancers = %d, want = %d", got, want)
}
for firstAddr, bs := range ringHashPicker.endpointStates {
if got, want := bs.state.ConnectivityState, connectivity.Idle; got != want {
t.Errorf("Child balancer connectivity state for address %q = %v, want = %v", firstAddr, got, want)
}
}
return cc, b, p1
}
type s struct {
grpctest.Tester
}
func Test(t *testing.T) {
grpctest.RunSubTests(t, s{})
}
// TestUpdateClientConnState_NewRingSize tests the scenario where the ringhash
// LB policy receives new configuration which specifies new values for the ring
// min and max sizes. The test verifies that a new ring is created and a new
// picker is sent to the ClientConn.
func (s) TestUpdateClientConnState_NewRingSize(t *testing.T) {
origMinRingSize, origMaxRingSize := 1, 10 // Configured from `testConfig` in `setupTest`
newMinRingSize, newMaxRingSize := 20, 100
endpoints := []resolver.Endpoint{{Addresses: []resolver.Address{{Addr: testBackendAddrStrs[0]}}}}
cc, b, p1 := setupTest(t, endpoints)
ring1 := p1.(*picker).ring
if ringSize := len(ring1.items); ringSize < origMinRingSize || ringSize > origMaxRingSize {
t.Fatalf("Ring created with size %d, want between [%d, %d]", ringSize, origMinRingSize, origMaxRingSize)
}
if err := b.UpdateClientConnState(balancer.ClientConnState{
ResolverState: resolver.State{Endpoints: endpoints},
BalancerConfig: &iringhash.LBConfig{
MinRingSize: uint64(newMinRingSize),
MaxRingSize: uint64(newMaxRingSize),
},
}); err != nil {
t.Fatalf("UpdateClientConnState returned err: %v", err)
}
var ring2 *ring
select {
case <-time.After(defaultTestTimeout):
t.Fatal("Timeout when waiting for a picker update after a configuration update")
case p2 := <-cc.NewPickerCh:
ring2 = p2.(*picker).ring
}
if ringSize := len(ring2.items); ringSize < newMinRingSize || ringSize > newMaxRingSize {
t.Fatalf("Ring created with size %d, want between [%d, %d]", ringSize, newMinRingSize, newMaxRingSize)
}
}
func (s) TestOneEndpoint(t *testing.T) {
wantAddr1 := resolver.Address{Addr: testBackendAddrStrs[0]}
cc, _, p0 := setupTest(t, []resolver.Endpoint{{Addresses: []resolver.Address{wantAddr1}}})
ring0 := p0.(*picker).ring
firstHash := ring0.items[0].hash
// firstHash-1 will pick the first (and only) SubConn from the ring.
testHash := firstHash - 1
// The first pick should be queued, and should trigger a connection to the
// only Endpoint which has a single address.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := p0.Pick(balancer.PickInfo{Ctx: iringhash.SetXDSRequestHash(ctx, testHash)}); err != balancer.ErrNoSubConnAvailable {
t.Fatalf("first pick returned err %v, want %v", err, balancer.ErrNoSubConnAvailable)
}
var sc0 *testutils.TestSubConn
select {
case <-ctx.Done():
t.Fatalf("Timed out waiting for SubConn creation.")
case sc0 = <-cc.NewSubConnCh:
}
if got, want := sc0.Addresses[0].Addr, wantAddr1.Addr; got != want {
t.Fatalf("SubConn.Addresses = %v, want = %v", got, want)
}
select {
case <-sc0.ConnectCh:
case <-time.After(defaultTestTimeout):
t.Errorf("timeout waiting for Connect() from SubConn %v", sc0)
}
// Send state updates to Ready.
sc0.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Connecting})
sc0.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Ready})
if err := cc.WaitForConnectivityState(ctx, connectivity.Ready); err != nil {
t.Fatal(err)
}
// Test pick with one backend.
p1 := <-cc.NewPickerCh
for i := 0; i < 5; i++ {
gotSCSt, _ := p1.Pick(balancer.PickInfo{Ctx: iringhash.SetXDSRequestHash(ctx, testHash)})
if gotSCSt.SubConn != sc0 {
t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc0)
}
}
}
// TestThreeBackendsAffinity covers that there are 3 SubConns, RPCs with the
// same hash always pick the same SubConn. When the one picked is down, another
// one will be picked.
func (s) TestThreeSubConnsAffinity(t *testing.T) {
endpoints := []resolver.Endpoint{
{Addresses: []resolver.Address{{Addr: testBackendAddrStrs[0]}}},
{Addresses: []resolver.Address{{Addr: testBackendAddrStrs[1]}}},
{Addresses: []resolver.Address{{Addr: testBackendAddrStrs[2]}}},
}
remainingAddrs := map[string]bool{
testBackendAddrStrs[0]: true,
testBackendAddrStrs[1]: true,
testBackendAddrStrs[2]: true,
}
cc, _, p0 := setupTest(t, endpoints)
// This test doesn't update addresses, so this ring will be used by all the
// pickers.
ring := p0.(*picker).ring
firstHash := ring.items[0].hash
// firstHash+1 will pick the second endpoint from the ring.
testHash := firstHash + 1
// The first pick should be queued, and should trigger Connect() on the only
// SubConn.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := p0.Pick(balancer.PickInfo{Ctx: iringhash.SetXDSRequestHash(ctx, testHash)}); err != balancer.ErrNoSubConnAvailable {
t.Fatalf("first pick returned err %v, want %v", err, balancer.ErrNoSubConnAvailable)
}
// The picked endpoint should be the second in the ring.
var subConns [3]*testutils.TestSubConn
select {
case <-ctx.Done():
t.Fatalf("Timed out waiting for SubConn creation.")
case subConns[1] = <-cc.NewSubConnCh:
}
if got, want := subConns[1].Addresses[0].Addr, ring.items[1].hashKey; got != want {
t.Fatalf("SubConn.Address = %v, want = %v", got, want)
}
select {
case <-subConns[1].ConnectCh:
case <-time.After(defaultTestTimeout):
t.Errorf("timeout waiting for Connect() from SubConn %v", subConns[1])
}
delete(remainingAddrs, ring.items[1].hashKey)
// Turn down the subConn in use.
subConns[1].UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Connecting})
subConns[1].UpdateState(balancer.SubConnState{ConnectivityState: connectivity.TransientFailure})
// This should trigger a connection to a new endpoint.
<-cc.NewPickerCh
var sc *testutils.TestSubConn
select {
case <-ctx.Done():
t.Fatalf("Timed out waiting for SubConn creation.")
case sc = <-cc.NewSubConnCh:
}
scAddr := sc.Addresses[0].Addr
if _, ok := remainingAddrs[scAddr]; !ok {
t.Fatalf("New SubConn created with previously used address: %q", scAddr)
}
delete(remainingAddrs, scAddr)
select {
case <-sc.ConnectCh:
case <-time.After(defaultTestTimeout):
t.Errorf("timeout waiting for Connect() from SubConn %v", subConns[1])
}
if scAddr == ring.items[0].hashKey {
subConns[0] = sc
} else if scAddr == ring.items[2].hashKey {
subConns[2] = sc
}
// Turning down the SubConn should cause creation of a connection to the
// final endpoint.
sc.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Connecting})
sc.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.TransientFailure})
select {
case <-ctx.Done():
t.Fatalf("Timed out waiting for SubConn creation.")
case sc = <-cc.NewSubConnCh:
}
scAddr = sc.Addresses[0].Addr
if _, ok := remainingAddrs[scAddr]; !ok {
t.Fatalf("New SubConn created with previously used address: %q", scAddr)
}
delete(remainingAddrs, scAddr)
select {
case <-sc.ConnectCh:
case <-time.After(defaultTestTimeout):
t.Errorf("timeout waiting for Connect() from SubConn %v", subConns[1])
}
if scAddr == ring.items[0].hashKey {
subConns[0] = sc
} else if scAddr == ring.items[2].hashKey {
subConns[2] = sc
}
sc.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Connecting})
sc.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.TransientFailure})
// All endpoints are in TransientFailure. Make the first endpoint in the
// ring report Ready. All picks should go to this endpoint which is two
// indexes away from the endpoint with the chosen hash.
subConns[0].UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Idle})
subConns[0].UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Connecting})
subConns[0].UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Ready})
if err := cc.WaitForConnectivityState(ctx, connectivity.Ready); err != nil {
t.Fatalf("Context timed out while waiting for channel to report Ready.")
}
p1 := <-cc.NewPickerCh
for i := 0; i < 5; i++ {
gotSCSt, _ := p1.Pick(balancer.PickInfo{Ctx: iringhash.SetXDSRequestHash(ctx, testHash)})
if gotSCSt.SubConn != subConns[0] {
t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, subConns[0])
}
}
// Make the last endpoint in the ring report Ready. All picks should go to
// this endpoint since it is one index away from the chosen hash.
subConns[2].UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Idle})
subConns[2].UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Connecting})
subConns[2].UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Ready})
p2 := <-cc.NewPickerCh
for i := 0; i < 5; i++ {
gotSCSt, _ := p2.Pick(balancer.PickInfo{Ctx: iringhash.SetXDSRequestHash(ctx, testHash)})
if gotSCSt.SubConn != subConns[2] {
t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, subConns[2])
}
}
// Make the second endpoint in the ring report Ready. All picks should go to
// this endpoint as it is the one with the chosen hash.
subConns[1].UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Idle})
subConns[1].UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Connecting})
subConns[1].UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Ready})
p3 := <-cc.NewPickerCh
for i := 0; i < 5; i++ {
gotSCSt, _ := p3.Pick(balancer.PickInfo{Ctx: iringhash.SetXDSRequestHash(ctx, testHash)})
if gotSCSt.SubConn != subConns[1] {
t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, subConns[1])
}
}
}
// TestThreeBackendsAffinity covers that there are 3 SubConns, RPCs with the
// same hash always pick the same SubConn. Then try different hash to pick
// another backend, and verify the first hash still picks the first backend.
func (s) TestThreeBackendsAffinityMultiple(t *testing.T) {
wantEndpoints := []resolver.Endpoint{
{Addresses: []resolver.Address{{Addr: testBackendAddrStrs[0]}}},
{Addresses: []resolver.Address{{Addr: testBackendAddrStrs[1]}}},
{Addresses: []resolver.Address{{Addr: testBackendAddrStrs[2]}}},
}
cc, _, p0 := setupTest(t, wantEndpoints)
// This test doesn't update addresses, so this ring will be used by all the
// pickers.
ring0 := p0.(*picker).ring
firstHash := ring0.items[0].hash
// firstHash+1 will pick the second SubConn from the ring.
testHash := firstHash + 1
// The first pick should be queued, and should trigger Connect() on the only
// SubConn.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := p0.Pick(balancer.PickInfo{Ctx: iringhash.SetXDSRequestHash(ctx, testHash)}); err != balancer.ErrNoSubConnAvailable {
t.Fatalf("first pick returned err %v, want %v", err, balancer.ErrNoSubConnAvailable)
}
// The picked SubConn should be the second in the ring.
var sc0 *testutils.TestSubConn
select {
case <-ctx.Done():
t.Fatalf("Timed out waiting for SubConn creation.")
case sc0 = <-cc.NewSubConnCh:
}
if got, want := sc0.Addresses[0].Addr, ring0.items[1].hashKey; got != want {
t.Fatalf("SubConn.Address = %v, want = %v", got, want)
}
select {
case <-sc0.ConnectCh:
case <-time.After(defaultTestTimeout):
t.Errorf("timeout waiting for Connect() from SubConn %v", sc0)
}
// Send state updates to Ready.
sc0.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Connecting})
sc0.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Ready})
if err := cc.WaitForConnectivityState(ctx, connectivity.Ready); err != nil {
t.Fatal(err)
}
// First hash should always pick sc0.
p1 := <-cc.NewPickerCh
for i := 0; i < 5; i++ {
gotSCSt, _ := p1.Pick(balancer.PickInfo{Ctx: iringhash.SetXDSRequestHash(ctx, testHash)})
if gotSCSt.SubConn != sc0 {
t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc0)
}
}
secondHash := ring0.items[1].hash
// secondHash+1 will pick the third SubConn from the ring.
testHash2 := secondHash + 1
if _, err := p0.Pick(balancer.PickInfo{Ctx: iringhash.SetXDSRequestHash(ctx, testHash2)}); err != balancer.ErrNoSubConnAvailable {
t.Fatalf("first pick returned err %v, want %v", err, balancer.ErrNoSubConnAvailable)
}
var sc1 *testutils.TestSubConn
select {
case <-ctx.Done():
t.Fatalf("Timed out waiting for SubConn creation.")
case sc1 = <-cc.NewSubConnCh:
}
if got, want := sc1.Addresses[0].Addr, ring0.items[2].hashKey; got != want {
t.Fatalf("SubConn.Address = %v, want = %v", got, want)
}
select {
case <-sc1.ConnectCh:
case <-time.After(defaultTestTimeout):
t.Errorf("timeout waiting for Connect() from SubConn %v", sc1)
}
sc1.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Connecting})
sc1.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Ready})
// With the new generated picker, hash2 always picks sc1.
p2 := <-cc.NewPickerCh
for i := 0; i < 5; i++ {
gotSCSt, _ := p2.Pick(balancer.PickInfo{Ctx: iringhash.SetXDSRequestHash(ctx, testHash2)})
if gotSCSt.SubConn != sc1 {
t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc1)
}
}
// But the first hash still picks sc0.
for i := 0; i < 5; i++ {
gotSCSt, _ := p2.Pick(balancer.PickInfo{Ctx: iringhash.SetXDSRequestHash(ctx, testHash)})
if gotSCSt.SubConn != sc0 {
t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc0)
}
}
}
// TestAddrWeightChange covers the following scenarios after setting up the
// balancer with 3 addresses [A, B, C]:
// - updates balancer with [A, B, C], a new Picker should not be sent.
// - updates balancer with [A, B] (C removed), a new Picker is sent and the
// ring is updated.
// - updates balancer with [A, B], but B has a weight of 2, a new Picker is
// sent. And the new ring should contain the correct number of entries
// and weights.
func (s) TestAddrWeightChange(t *testing.T) {
endpoints := []resolver.Endpoint{
{Addresses: []resolver.Address{{Addr: testBackendAddrStrs[0]}}},
{Addresses: []resolver.Address{{Addr: testBackendAddrStrs[1]}}},
{Addresses: []resolver.Address{{Addr: testBackendAddrStrs[2]}}},
}
cc, b, p0 := setupTest(t, endpoints)
ring0 := p0.(*picker).ring
// Update with the same addresses, it will result in a new picker, but with
// the same ring.
if err := b.UpdateClientConnState(balancer.ClientConnState{
ResolverState: resolver.State{Endpoints: endpoints},
BalancerConfig: testConfig,
}); err != nil {
t.Fatalf("UpdateClientConnState returned err: %v", err)
}
var p1 balancer.Picker
select {
case p1 = <-cc.NewPickerCh:
case <-time.After(defaultTestTimeout):
t.Fatalf("timeout waiting for picker after UpdateClientConn with same addresses")
}
ring1 := p1.(*picker).ring
if ring1 != ring0 {
t.Fatalf("new picker with same address has a different ring than before, want same")
}
// Delete an address, should send a new Picker.
if err := b.UpdateClientConnState(balancer.ClientConnState{
ResolverState: resolver.State{Endpoints: endpoints[:2]},
BalancerConfig: testConfig,
}); err != nil {
t.Fatalf("UpdateClientConnState returned err: %v", err)
}
var p2 balancer.Picker
select {
case p2 = <-cc.NewPickerCh:
case <-time.After(defaultTestTimeout):
t.Fatalf("timeout waiting for picker after UpdateClientConn with different addresses")
}
ring2 := p2.(*picker).ring
if ring2 == ring0 {
t.Fatalf("new picker after removing address has the same ring as before, want different")
}
// Another update with the same addresses, but different weight.
if err := b.UpdateClientConnState(balancer.ClientConnState{
ResolverState: resolver.State{Endpoints: []resolver.Endpoint{
endpoints[0],
weight.Set(endpoints[1], weight.EndpointInfo{Weight: 2}),
}},
BalancerConfig: testConfig,
}); err != nil {
t.Fatalf("UpdateClientConnState returned err: %v", err)
}
var p3 balancer.Picker
select {
case p3 = <-cc.NewPickerCh:
case <-time.After(defaultTestTimeout):
t.Fatalf("timeout waiting for picker after UpdateClientConn with different addresses")
}
if p3.(*picker).ring == ring2 {
t.Fatalf("new picker after changing address weight has the same ring as before, want different")
}
// With the new update, the ring must look like this:
// [
// {idx:0 endpoint: {addr: testBackendAddrStrs[0], weight: 1}},
// {idx:1 endpoint: {addr: testBackendAddrStrs[1], weight: 2}},
// {idx:2 endpoint: {addr: testBackendAddrStrs[2], weight: 1}},
// ].
if len(p3.(*picker).ring.items) != 3 {
t.Fatalf("new picker after changing address weight has %d entries, want 3", len(p3.(*picker).ring.items))
}
for _, i := range p3.(*picker).ring.items {
if i.hashKey == testBackendAddrStrs[0] {
if i.weight != 1 {
t.Fatalf("new picker after changing address weight has weight %d for %v, want 1", i.weight, i.hashKey)
}
}
if i.hashKey == testBackendAddrStrs[1] {
if i.weight != 2 {
t.Fatalf("new picker after changing address weight has weight %d for %v, want 2", i.weight, i.hashKey)
}
}
}
}
// TestAutoConnectEndpointOnTransientFailure covers the situation when an
// endpoint fails. It verifies that a new endpoint is automatically tried
// (without a pick) when there is no endpoint already in Connecting state.
func (s) TestAutoConnectEndpointOnTransientFailure(t *testing.T) {
wantEndpoints := []resolver.Endpoint{
{Addresses: []resolver.Address{{Addr: testBackendAddrStrs[0]}}},
{Addresses: []resolver.Address{{Addr: testBackendAddrStrs[1]}}},
{Addresses: []resolver.Address{{Addr: testBackendAddrStrs[2]}}},
{Addresses: []resolver.Address{{Addr: testBackendAddrStrs[3]}}},
}
cc, _, p0 := setupTest(t, wantEndpoints)
// ringhash won't tell SCs to connect until there is an RPC, so simulate
// one now.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
ctx = iringhash.SetXDSRequestHash(ctx, 0)
defer cancel()
p0.Pick(balancer.PickInfo{Ctx: ctx})
// The picked SubConn should be the second in the ring.
var sc0 *testutils.TestSubConn
select {
case <-ctx.Done():
t.Fatalf("Timed out waiting for SubConn creation.")
case sc0 = <-cc.NewSubConnCh:
}
select {
case <-sc0.ConnectCh:
case <-time.After(defaultTestTimeout):
t.Errorf("timeout waiting for Connect() from SubConn %v", sc0)
}
// Turn the first subconn to transient failure. This should set the overall
// connectivity state to CONNECTING.
sc0.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Connecting})
sc0.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.TransientFailure})
cc.WaitForConnectivityState(ctx, connectivity.Connecting)
// It will trigger the second subconn to connect since there is only one
// endpoint, which is in TF.
var sc1 *testutils.TestSubConn
select {
case <-ctx.Done():
t.Fatalf("Timed out waiting for SubConn creation.")
case sc1 = <-cc.NewSubConnCh:
}
select {
case <-sc1.ConnectCh:
case <-time.After(defaultTestShortTimeout):
t.Fatalf("timeout waiting for Connect() from SubConn %v", sc1)
}
// Turn the second subconn to TF. This will set the overall state to TF.
sc1.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Connecting})
sc1.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.TransientFailure})
cc.WaitForConnectivityState(ctx, connectivity.TransientFailure)
// It will trigger the third subconn to connect.
var sc2 *testutils.TestSubConn
select {
case <-ctx.Done():
t.Fatalf("Timed out waiting for SubConn creation.")
case sc2 = <-cc.NewSubConnCh:
}
select {
case <-sc2.ConnectCh:
case <-time.After(defaultTestShortTimeout):
t.Fatalf("timeout waiting for Connect() from SubConn %v", sc2)
}
sc2.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Connecting})
// Send the first SubConn into CONNECTING. To do this, first make it READY,
// then CONNECTING.
sc0.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Ready})
cc.WaitForConnectivityState(ctx, connectivity.Ready)
sc0.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Idle})
// Since one endpoint is in TF and one in CONNECTING, the aggregated state
// will be CONNECTING.
cc.WaitForConnectivityState(ctx, connectivity.Connecting)
p1 := <-cc.NewPickerCh
p1.Pick(balancer.PickInfo{Ctx: ctx})
select {
case <-sc0.ConnectCh:
case <-time.After(defaultTestTimeout):
t.Errorf("timeout waiting for Connect() from SubConn %v", sc0)
}
sc0.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Connecting})
// This will not trigger any new SubCOnns to be created, because sc0 is
// still attempting to connect, and we only need one SubConn to connect.
sc2.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.TransientFailure})
select {
case sc := <-cc.NewSubConnCh:
t.Fatalf("unexpected SubConn creation: %v", sc)
case <-sc0.ConnectCh:
t.Fatalf("unexpected Connect() from SubConn %v", sc0)
case <-sc1.ConnectCh:
t.Fatalf("unexpected Connect() from SubConn %v", sc1)
case <-sc2.ConnectCh:
t.Fatalf("unexpected Connect() from SubConn %v", sc2)
case <-time.After(defaultTestShortTimeout):
}
}
func (s) TestAggregatedConnectivityState(t *testing.T) {
tests := []struct {
name string
endpointStates []connectivity.State
want connectivity.State
}{
{
name: "one ready",
endpointStates: []connectivity.State{connectivity.Ready},
want: connectivity.Ready,
},
{
name: "one connecting",
endpointStates: []connectivity.State{connectivity.Connecting},
want: connectivity.Connecting,
},
{
name: "one ready one transient failure",
endpointStates: []connectivity.State{connectivity.Ready, connectivity.TransientFailure},
want: connectivity.Ready,
},
{
name: "one connecting one transient failure",
endpointStates: []connectivity.State{connectivity.Connecting, connectivity.TransientFailure},
want: connectivity.Connecting,
},
{
name: "one connecting two transient failure",
endpointStates: []connectivity.State{connectivity.Connecting, connectivity.TransientFailure, connectivity.TransientFailure},
want: connectivity.TransientFailure,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
bal := &ringhashBalancer{endpointStates: resolver.NewEndpointMap[*endpointState]()}
for i, cs := range tt.endpointStates {
es := &endpointState{
state: balancer.State{ConnectivityState: cs},
}
ep := resolver.Endpoint{Addresses: []resolver.Address{{Addr: fmt.Sprintf("%d.%d.%d.%d:%d", i, i, i, i, i)}}}
bal.endpointStates.Set(ep, es)
}
if got := bal.aggregatedStateLocked(); got != tt.want {
t.Errorf("recordTransition() = %v, want %v", got, tt.want)
}
})
}
}
type testKeyType string
const testKey testKeyType = "grpc.lb.ringhash.testKey"
type testAttribute struct {
content string
}
func setTestAttrAddr(addr resolver.Address, content string) resolver.Address {
addr.BalancerAttributes = addr.BalancerAttributes.WithValue(testKey, testAttribute{content})
return addr
}
func setTestAttrEndpoint(endpoint resolver.Endpoint, content string) resolver.Endpoint {
endpoint.Attributes = endpoint.Attributes.WithValue(testKey, testAttribute{content})
return endpoint
}
// TestAddrBalancerAttributesChange tests the case where the ringhash balancer
// receives a ClientConnUpdate with the same config and addresses as received in
// the previous update. Although the `BalancerAttributes` and endpoint
// attributes contents are the same, the pointers are different. This test
// verifies that subConns are not recreated in this scenario.
func (s) TestAddrBalancerAttributesChange(t *testing.T) {
content := "test"
addrs1 := []resolver.Address{setTestAttrAddr(resolver.Address{Addr: testBackendAddrStrs[0]}, content)}
wantEndpoints1 := []resolver.Endpoint{
setTestAttrEndpoint(resolver.Endpoint{Addresses: addrs1}, "content"),
}
cc, b, p0 := setupTest(t, wantEndpoints1)
ring0 := p0.(*picker).ring
firstHash := ring0.items[0].hash
// The first pick should be queued, and should trigger a connection to the
// only Endpoint which has a single address.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := p0.Pick(balancer.PickInfo{Ctx: iringhash.SetXDSRequestHash(ctx, firstHash)}); err != balancer.ErrNoSubConnAvailable {
t.Fatalf("first pick returned err %v, want %v", err, balancer.ErrNoSubConnAvailable)
}
select {
case <-ctx.Done():
t.Fatalf("Timed out waiting for SubConn creation.")
case <-cc.NewSubConnCh:
}
addrs2 := []resolver.Address{setTestAttrAddr(resolver.Address{Addr: testBackendAddrStrs[0]}, content)}
wantEndpoints2 := []resolver.Endpoint{setTestAttrEndpoint(resolver.Endpoint{Addresses: addrs2}, content)}
if err := b.UpdateClientConnState(balancer.ClientConnState{
ResolverState: resolver.State{Endpoints: wantEndpoints2},
BalancerConfig: testConfig,
}); err != nil {
t.Fatalf("UpdateClientConnState returned err: %v", err)
}
select {
case <-cc.NewSubConnCh:
t.Fatal("new subConn created for an update with the same addresses")
case <-time.After(defaultTestShortTimeout):
}
}

View File

@ -30,7 +30,6 @@ import (
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity"
estats "google.golang.org/grpc/experimental/stats"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/backoff"
@ -78,42 +77,6 @@ var (
clientConnUpdateHook = func() {}
dataCachePurgeHook = func() {}
resetBackoffHook = func() {}
cacheEntriesMetric = estats.RegisterInt64Gauge(estats.MetricDescriptor{
Name: "grpc.lb.rls.cache_entries",
Description: "EXPERIMENTAL. Number of entries in the RLS cache.",
Unit: "{entry}",
Labels: []string{"grpc.target", "grpc.lb.rls.server_target", "grpc.lb.rls.instance_uuid"},
Default: false,
})
cacheSizeMetric = estats.RegisterInt64Gauge(estats.MetricDescriptor{
Name: "grpc.lb.rls.cache_size",
Description: "EXPERIMENTAL. The current size of the RLS cache.",
Unit: "By",
Labels: []string{"grpc.target", "grpc.lb.rls.server_target", "grpc.lb.rls.instance_uuid"},
Default: false,
})
defaultTargetPicksMetric = estats.RegisterInt64Count(estats.MetricDescriptor{
Name: "grpc.lb.rls.default_target_picks",
Description: "EXPERIMENTAL. Number of LB picks sent to the default target.",
Unit: "{pick}",
Labels: []string{"grpc.target", "grpc.lb.rls.server_target", "grpc.lb.rls.data_plane_target", "grpc.lb.pick_result"},
Default: false,
})
targetPicksMetric = estats.RegisterInt64Count(estats.MetricDescriptor{
Name: "grpc.lb.rls.target_picks",
Description: "EXPERIMENTAL. Number of LB picks sent to each RLS target. Note that if the default target is also returned by the RLS server, RPCs sent to that target from the cache will be counted in this metric, not in grpc.rls.default_target_picks.",
Unit: "{pick}",
Labels: []string{"grpc.target", "grpc.lb.rls.server_target", "grpc.lb.rls.data_plane_target", "grpc.lb.pick_result"},
Default: false,
})
failedPicksMetric = estats.RegisterInt64Count(estats.MetricDescriptor{
Name: "grpc.lb.rls.failed_picks",
Description: "EXPERIMENTAL. Number of LB picks failed due to either a failed RLS request or the RLS channel being throttled.",
Unit: "{pick}",
Labels: []string{"grpc.target", "grpc.lb.rls.server_target"},
Default: false,
})
)
func init() {
@ -140,14 +103,9 @@ func (rlsBB) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.
updateCh: buffer.NewUnbounded(),
}
lb.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf("[rls-experimental-lb %p] ", lb))
lb.dataCache = newDataCache(maxCacheSize, lb.logger, cc.MetricsRecorder(), opts.Target.String())
lb.bg = balancergroup.New(balancergroup.Options{
CC: cc,
BuildOpts: opts,
StateAggregator: lb,
Logger: lb.logger,
SubBalancerCloseTimeout: time.Duration(0), // Disable caching of removed child policies
})
lb.dataCache = newDataCache(maxCacheSize, lb.logger)
lb.bg = balancergroup.New(cc, opts, lb, lb.logger)
lb.bg.Start()
go lb.run()
return lb
}
@ -321,27 +279,27 @@ func (b *rlsBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error
// Update the copy of the config in the LB policy before releasing the lock.
b.lbCfg = newCfg
b.stateMu.Unlock()
// We cannot do cache operations above because `cacheMu` needs to be grabbed
// before `stateMu` if we are to hold both locks at the same time.
b.cacheMu.Lock()
b.dataCache.updateRLSServerTarget(newCfg.lookupService)
if resizeCache {
// If the new config changes reduces the size of the data cache, we
// might have to evict entries to get the cache size down to the newly
// specified size. If we do evict an entry with valid backoff timer,
// the new picker needs to be sent to the channel to re-process any
// RPCs queued as a result of this backoff timer.
b.dataCache.resize(newCfg.cacheSizeBytes)
}
b.cacheMu.Unlock()
// Enqueue an event which will notify us when the above update has been
// propagated to all child policies, and the child policies have all
// processed their updates, and we have sent a picker update.
done := make(chan struct{})
b.updateCh.Put(resumePickerUpdates{done: done})
b.stateMu.Unlock()
<-done
if resizeCache {
// If the new config changes reduces the size of the data cache, we
// might have to evict entries to get the cache size down to the newly
// specified size.
//
// And we cannot do this operation above (where we compute the
// `resizeCache` boolean) because `cacheMu` needs to be grabbed before
// `stateMu` if we are to hold both locks at the same time.
b.cacheMu.Lock()
b.dataCache.resize(newCfg.cacheSizeBytes)
b.cacheMu.Unlock()
}
return nil
}
@ -478,7 +436,7 @@ func (b *rlsBalancer) ResolverError(err error) {
}
func (b *rlsBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) {
b.logger.Errorf("UpdateSubConnState(%v, %+v) called unexpectedly", sc, state)
b.bg.UpdateSubConnState(sc, state)
}
func (b *rlsBalancer) Close() {
@ -526,19 +484,15 @@ func (b *rlsBalancer) sendNewPickerLocked() {
if b.defaultPolicy != nil {
b.defaultPolicy.acquireRef()
}
picker := &rlsPicker{
kbm: b.lbCfg.kbMap,
origEndpoint: b.bopts.Target.Endpoint(),
lb: b,
defaultPolicy: b.defaultPolicy,
ctrlCh: b.ctrlCh,
maxAge: b.lbCfg.maxAge,
staleAge: b.lbCfg.staleAge,
bg: b.bg,
rlsServerTarget: b.lbCfg.lookupService,
grpcTarget: b.bopts.Target.String(),
metricsRecorder: b.cc.MetricsRecorder(),
kbm: b.lbCfg.kbMap,
origEndpoint: b.bopts.Target.Endpoint(),
lb: b,
defaultPolicy: b.defaultPolicy,
ctrlCh: b.ctrlCh,
maxAge: b.lbCfg.maxAge,
staleAge: b.lbCfg.staleAge,
bg: b.bg,
}
picker.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf("[rls-picker %p] ", picker))
state := balancer.State{

View File

@ -30,14 +30,13 @@ import (
"github.com/google/go-cmp/cmp"
"google.golang.org/grpc"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/pickfirst"
"google.golang.org/grpc/balancer/rls/internal/test/e2e"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/balancer/stub"
rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1"
internalserviceconfig "google.golang.org/grpc/internal/serviceconfig"
"google.golang.org/grpc/internal/testutils"
rlstest "google.golang.org/grpc/internal/testutils/rls"
@ -46,8 +45,6 @@ import (
"google.golang.org/grpc/resolver/manual"
"google.golang.org/grpc/serviceconfig"
"google.golang.org/grpc/testdata"
rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1"
"google.golang.org/protobuf/types/known/durationpb"
)
@ -69,20 +66,20 @@ func (s) TestConfigUpdate_ControlChannel(t *testing.T) {
// Start a couple of test backends, and set up the fake RLS servers to return
// these as a target in the RLS response.
backendCh1, backendAddress1 := startBackend(t)
rlsServer1.SetResponseCallback(func(_ context.Context, _ *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
rlsServer1.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
return &rlstest.RouteLookupResponse{Resp: &rlspb.RouteLookupResponse{Targets: []string{backendAddress1}}}
})
backendCh2, backendAddress2 := startBackend(t)
rlsServer2.SetResponseCallback(func(context.Context, *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
rlsServer2.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
return &rlstest.RouteLookupResponse{Resp: &rlspb.RouteLookupResponse{Targets: []string{backendAddress2}}}
})
// Register a manual resolver and push the RLS service config through it.
r := startManualResolverWithConfig(t, rlsConfig)
cc, err := grpc.NewClient(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
cc, err := grpc.Dial(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("Failed to create gRPC client: %v", err)
t.Fatalf("grpc.Dial() failed: %v", err)
}
defer cc.Close()
@ -155,7 +152,7 @@ func (s) TestConfigUpdate_ControlChannelWithCreds(t *testing.T) {
// and set up the fake RLS server to return this as the target in the RLS
// response.
backendCh, backendAddress := startBackend(t, grpc.Creds(serverCreds))
rlsServer.SetResponseCallback(func(_ context.Context, _ *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
rlsServer.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
return &rlstest.RouteLookupResponse{Resp: &rlspb.RouteLookupResponse{Targets: []string{backendAddress}}}
})
@ -166,9 +163,9 @@ func (s) TestConfigUpdate_ControlChannelWithCreds(t *testing.T) {
// server certificate used for the RLS server and the backend specifies a
// DNS SAN of "*.test.example.com". Hence we use a dial target which is a
// subdomain of the same here.
cc, err := grpc.NewClient(r.Scheme()+":///rls.test.example.com", grpc.WithResolvers(r), grpc.WithTransportCredentials(clientCreds))
cc, err := grpc.Dial(r.Scheme()+":///rls.test.example.com", grpc.WithResolvers(r), grpc.WithTransportCredentials(clientCreds))
if err != nil {
t.Fatalf("Failed to create gRPC client: %v", err)
t.Fatalf("grpc.Dial() failed: %v", err)
}
defer cc.Close()
@ -219,16 +216,16 @@ func (s) TestConfigUpdate_ControlChannelServiceConfig(t *testing.T) {
// Start a test backend, and set up the fake RLS server to return this as a
// target in the RLS response.
backendCh, backendAddress := startBackend(t)
rlsServer.SetResponseCallback(func(_ context.Context, _ *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
rlsServer.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
return &rlstest.RouteLookupResponse{Resp: &rlspb.RouteLookupResponse{Targets: []string{backendAddress}}}
})
// Register a manual resolver and push the RLS service config through it.
r := startManualResolverWithConfig(t, rlsConfig)
cc, err := grpc.NewClient(r.Scheme()+":///rls.test.example.com", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
cc, err := grpc.Dial(r.Scheme()+":///rls.test.example.com", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("Failed to create gRPC client: %v", err)
t.Fatalf("grpc.Dial() failed: %v", err)
}
defer cc.Close()
@ -263,9 +260,9 @@ func (s) TestConfigUpdate_DefaultTarget(t *testing.T) {
// Register a manual resolver and push the RLS service config through it.
r := startManualResolverWithConfig(t, rlsConfig)
cc, err := grpc.NewClient(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
cc, err := grpc.Dial(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("Failed to create gRPC client: %v", err)
t.Fatalf("grpc.Dial() failed: %v", err)
}
defer cc.Close()
@ -300,7 +297,7 @@ func (s) TestConfigUpdate_ChildPolicyConfigs(t *testing.T) {
testBackendCh, testBackendAddress := startBackend(t)
// Set up the RLS server to respond with the test backend.
rlsServer.SetResponseCallback(func(_ context.Context, _ *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
rlsServer.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
return &rlstest.RouteLookupResponse{Resp: &rlspb.RouteLookupResponse{Targets: []string{testBackendAddress}}}
})
@ -333,12 +330,11 @@ func (s) TestConfigUpdate_ChildPolicyConfigs(t *testing.T) {
// Register a manual resolver and push the RLS service config through it.
r := startManualResolverWithConfig(t, rlsConfig)
cc, err := grpc.NewClient(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
cc, err := grpc.Dial(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("grpc.NewClient() failed: %v", err)
t.Fatalf("grpc.Dial() failed: %v", err)
}
defer cc.Close()
cc.Connect()
// At this point, the RLS LB policy should have received its config, and
// should have created a child policy for the default target.
@ -449,12 +445,11 @@ func (s) TestConfigUpdate_ChildPolicyChange(t *testing.T) {
// Register a manual resolver and push the RLS service config through it.
r := startManualResolverWithConfig(t, rlsConfig)
cc, err := grpc.NewClient(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
cc, err := grpc.Dial(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("grpc.NewClient() failed: %v", err)
t.Fatalf("grpc.Dial() failed: %v", err)
}
defer cc.Close()
cc.Connect()
// At this point, the RLS LB policy should have received its config, and
// should have created a child policy for the default target.
@ -523,7 +518,7 @@ func (s) TestConfigUpdate_BadChildPolicyConfigs(t *testing.T) {
// Set up the RLS server to respond with a bad target field which is expected
// to cause the child policy's ParseTarget to fail and should result in the LB
// policy creating a lame child policy wrapper.
rlsServer.SetResponseCallback(func(_ context.Context, _ *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
rlsServer.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
return &rlstest.RouteLookupResponse{Resp: &rlspb.RouteLookupResponse{Targets: []string{e2e.RLSChildPolicyBadTarget}}}
})
@ -539,9 +534,9 @@ func (s) TestConfigUpdate_BadChildPolicyConfigs(t *testing.T) {
// Register a manual resolver and push the RLS service config through it.
r := startManualResolverWithConfig(t, rlsConfig)
cc, err := grpc.NewClient(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
cc, err := grpc.Dial(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("Failed to create gRPC client: %v", err)
t.Fatalf("grpc.Dial() failed: %v", err)
}
defer cc.Close()
@ -592,7 +587,7 @@ func (s) TestConfigUpdate_DataCacheSizeDecrease(t *testing.T) {
// these as targets in the RLS response, based on request keys.
backendCh1, backendAddress1 := startBackend(t)
backendCh2, backendAddress2 := startBackend(t)
rlsServer.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
rlsServer.SetResponseCallback(func(ctx context.Context, req *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
if req.KeyMap["k1"] == "v1" {
return &rlstest.RouteLookupResponse{Resp: &rlspb.RouteLookupResponse{Targets: []string{backendAddress1}}}
}
@ -605,12 +600,11 @@ func (s) TestConfigUpdate_DataCacheSizeDecrease(t *testing.T) {
// Register a manual resolver and push the RLS service config through it.
r := startManualResolverWithConfig(t, rlsConfig)
cc, err := grpc.NewClient(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
cc, err := grpc.Dial(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("grpc.NewClient() failed: %v", err)
t.Fatalf("grpc.Dial() failed: %v", err)
}
defer cc.Close()
cc.Connect()
<-clientConnUpdateDone
@ -655,178 +649,6 @@ func (s) TestConfigUpdate_DataCacheSizeDecrease(t *testing.T) {
verifyRLSRequest(t, rlsReqCh, true)
}
// Test that when a data cache entry is evicted due to config change
// in cache size, the picker is updated accordingly.
func (s) TestPickerUpdateOnDataCacheSizeDecrease(t *testing.T) {
// Override the clientConn update hook to get notified.
clientConnUpdateDone := make(chan struct{}, 1)
origClientConnUpdateHook := clientConnUpdateHook
clientConnUpdateHook = func() { clientConnUpdateDone <- struct{}{} }
defer func() { clientConnUpdateHook = origClientConnUpdateHook }()
// Override the cache entry size func, and always return 1.
origEntrySizeFunc := computeDataCacheEntrySize
computeDataCacheEntrySize = func(cacheKey, *cacheEntry) int64 { return 1 }
defer func() { computeDataCacheEntrySize = origEntrySizeFunc }()
// Override the backoff strategy to return a large backoff which
// will make sure the date cache entry remains in backoff for the
// duration of the test.
origBackoffStrategy := defaultBackoffStrategy
defaultBackoffStrategy = &fakeBackoffStrategy{backoff: defaultTestTimeout}
defer func() { defaultBackoffStrategy = origBackoffStrategy }()
// Override the minEvictionDuration to ensure that when the config update
// reduces the cache size, the resize operation is not stopped because
// we find an entry whose minExpiryDuration has not elapsed.
origMinEvictDuration := minEvictDuration
minEvictDuration = time.Duration(0)
defer func() { minEvictDuration = origMinEvictDuration }()
// Register the top-level wrapping balancer which forwards calls to RLS.
topLevelBalancerName := t.Name() + "top-level"
var ccWrapper *testCCWrapper
stub.Register(topLevelBalancerName, stub.BalancerFuncs{
Init: func(bd *stub.BalancerData) {
ccWrapper = &testCCWrapper{ClientConn: bd.ClientConn}
bd.ChildBalancer = balancer.Get(Name).Build(ccWrapper, bd.BuildOptions)
},
ParseConfig: func(sc json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
parser := balancer.Get(Name).(balancer.ConfigParser)
return parser.ParseConfig(sc)
},
UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
return bd.ChildBalancer.UpdateClientConnState(ccs)
},
Close: func(bd *stub.BalancerData) {
bd.ChildBalancer.Close()
},
})
// Start an RLS server and set the throttler to never throttle requests.
rlsServer, rlsReqCh := rlstest.SetupFakeRLSServer(t, nil)
overrideAdaptiveThrottler(t, neverThrottlingThrottler())
// Register an LB policy to act as the child policy for RLS LB policy.
childPolicyName := "test-child-policy" + t.Name()
e2e.RegisterRLSChildPolicy(childPolicyName, nil)
t.Logf("Registered child policy with name %q", childPolicyName)
// Start a couple of test backends, and set up the fake RLS server to return
// these as targets in the RLS response, based on request keys.
// Start a couple of test backends, and set up the fake RLS server to return
// these as targets in the RLS response, based on request keys.
backendCh1, backendAddress1 := startBackend(t)
backendCh2, backendAddress2 := startBackend(t)
rlsServer.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
if req.KeyMap["k1"] == "v1" {
return &rlstest.RouteLookupResponse{Resp: &rlspb.RouteLookupResponse{Targets: []string{backendAddress1}}}
}
if req.KeyMap["k2"] == "v2" {
return &rlstest.RouteLookupResponse{Resp: &rlspb.RouteLookupResponse{Targets: []string{backendAddress2}}}
}
return &rlstest.RouteLookupResponse{Err: errors.New("no keys in request metadata")}
})
// Register a manual resolver and push the RLS service config through it.
r := manual.NewBuilderWithScheme("rls-e2e")
headers := `
[
{
"key": "k1",
"names": [
"n1"
]
},
{
"key": "k2",
"names": [
"n2"
]
}
]
`
configJSON := `
{
"loadBalancingConfig": [
{
"%s": {
"routeLookupConfig": {
"grpcKeybuilders": [{
"names": [{"service": "grpc.testing.TestService"}],
"headers": %s
}],
"lookupService": "%s",
"cacheSizeBytes": %d
},
"childPolicy": [{"%s": {}}],
"childPolicyConfigTargetFieldName": "Backend"
}
}
]
}`
scJSON := fmt.Sprintf(configJSON, topLevelBalancerName, headers, rlsServer.Address, 1000, childPolicyName)
sc := internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(scJSON)
r.InitialState(resolver.State{ServiceConfig: sc})
cc, err := grpc.NewClient(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("create grpc.NewClient() failed: %v", err)
}
defer cc.Close()
cc.Connect()
<-clientConnUpdateDone
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
// Make an RPC call with empty metadata, which will eventually throw
// the error as no metadata will match from rlsServer response
// callback defined above. This will cause the control channel to
// throw the error and cause the item to get into backoff.
makeTestRPCAndVerifyError(ctx, t, cc, codes.Unavailable, nil)
ctxOutgoing := metadata.AppendToOutgoingContext(ctx, "n1", "v1")
makeTestRPCAndExpectItToReachBackend(ctxOutgoing, t, cc, backendCh1)
verifyRLSRequest(t, rlsReqCh, true)
ctxOutgoing = metadata.AppendToOutgoingContext(ctx, "n2", "v2")
makeTestRPCAndExpectItToReachBackend(ctxOutgoing, t, cc, backendCh2)
verifyRLSRequest(t, rlsReqCh, true)
initialStateCnt := len(ccWrapper.getStates())
// Setting the size to 1 will cause the entries to be
// evicted.
scJSON1 := fmt.Sprintf(`
{
"loadBalancingConfig": [
{
"%s": {
"routeLookupConfig": {
"grpcKeybuilders": [{
"names": [{"service": "grpc.testing.TestService"}],
"headers": %s
}],
"lookupService": "%s",
"cacheSizeBytes": 2
},
"childPolicy": [{"%s": {}}],
"childPolicyConfigTargetFieldName": "Backend"
}
}
]
}`, topLevelBalancerName, headers, rlsServer.Address, childPolicyName)
sc1 := internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(scJSON1)
r.UpdateState(resolver.State{ServiceConfig: sc1})
<-clientConnUpdateDone
finalStateCnt := len(ccWrapper.getStates())
if finalStateCnt != initialStateCnt+1 {
t.Errorf("Unexpected balancer state count: got %v, want %v", finalStateCnt, initialStateCnt)
}
}
// TestDataCachePurging verifies that the LB policy periodically evicts expired
// entries from the data cache.
func (s) TestDataCachePurging(t *testing.T) {
@ -861,16 +683,16 @@ func (s) TestDataCachePurging(t *testing.T) {
// Start a test backend, and set up the fake RLS server to return this as a
// target in the RLS response.
backendCh, backendAddress := startBackend(t)
rlsServer.SetResponseCallback(func(_ context.Context, _ *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
rlsServer.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
return &rlstest.RouteLookupResponse{Resp: &rlspb.RouteLookupResponse{Targets: []string{backendAddress}}}
})
// Register a manual resolver and push the RLS service config through it.
r := startManualResolverWithConfig(t, rlsConfig)
cc, err := grpc.NewClient(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
cc, err := grpc.Dial(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("Failed to create gRPC client: %v", err)
t.Fatalf("grpc.Dial() failed: %v", err)
}
defer cc.Close()
@ -952,16 +774,16 @@ func (s) TestControlChannelConnectivityStateMonitoring(t *testing.T) {
// Start a test backend, and set up the fake RLS server to return this as a
// target in the RLS response.
backendCh, backendAddress := startBackend(t)
rlsServer.SetResponseCallback(func(_ context.Context, _ *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
rlsServer.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
return &rlstest.RouteLookupResponse{Resp: &rlspb.RouteLookupResponse{Targets: []string{backendAddress}}}
})
// Register a manual resolver and push the RLS service config through it.
r := startManualResolverWithConfig(t, rlsConfig)
cc, err := grpc.NewClient(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
cc, err := grpc.Dial(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("Failed to create gRPC client: %v", err)
t.Fatalf("grpc.Dial() failed: %v", err)
}
defer cc.Close()
@ -1016,31 +838,118 @@ func (s) TestControlChannelConnectivityStateMonitoring(t *testing.T) {
verifyRLSRequest(t, rlsReqCh, true)
}
// testCCWrapper wraps a balancer.ClientConn and overrides UpdateState and
// stores all state updates pushed by the RLS LB policy.
type testCCWrapper struct {
const wrappingTopLevelBalancerName = "wrapping-top-level-balancer"
const multipleUpdateStateChildBalancerName = "multiple-update-state-child-balancer"
type wrappingTopLevelBalancerBuilder struct {
balCh chan balancer.Balancer
}
func (w *wrappingTopLevelBalancerBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
tlb := &wrappingTopLevelBalancer{ClientConn: cc}
tlb.Balancer = balancer.Get(Name).Build(tlb, balancer.BuildOptions{})
w.balCh <- tlb
return tlb
}
func (w *wrappingTopLevelBalancerBuilder) Name() string {
return wrappingTopLevelBalancerName
}
func (w *wrappingTopLevelBalancerBuilder) ParseConfig(sc json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
parser := balancer.Get(Name).(balancer.ConfigParser)
return parser.ParseConfig(sc)
}
// wrappingTopLevelBalancer acts as the top-level LB policy on the channel and
// wraps an RLS LB policy. It forwards all balancer API calls unmodified to the
// underlying RLS LB policy. It overrides the UpdateState method on the
// balancer.ClientConn passed to the RLS LB policy and stores all state updates
// pushed by the latter.
type wrappingTopLevelBalancer struct {
balancer.ClientConn
balancer.Balancer
mu sync.Mutex
states []balancer.State
}
func (t *testCCWrapper) UpdateState(bs balancer.State) {
t.mu.Lock()
t.states = append(t.states, bs)
t.mu.Unlock()
t.ClientConn.UpdateState(bs)
func (w *wrappingTopLevelBalancer) UpdateState(bs balancer.State) {
w.mu.Lock()
w.states = append(w.states, bs)
w.mu.Unlock()
w.ClientConn.UpdateState(bs)
}
func (t *testCCWrapper) getStates() []balancer.State {
t.mu.Lock()
defer t.mu.Unlock()
func (w *wrappingTopLevelBalancer) getStates() []balancer.State {
w.mu.Lock()
defer w.mu.Unlock()
states := make([]balancer.State, len(t.states))
copy(states, t.states)
states := make([]balancer.State, len(w.states))
copy(states, w.states)
return states
}
// wrappedPickFirstBalancerBuilder builds a balancer which wraps a pickfirst
// balancer. The wrapping balancing receives addresses to be passed to the
// underlying pickfirst balancer as part of its configuration.
type wrappedPickFirstBalancerBuilder struct{}
func (wrappedPickFirstBalancerBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
builder := balancer.Get(grpc.PickFirstBalancerName)
wpfb := &wrappedPickFirstBalancer{
ClientConn: cc,
}
pf := builder.Build(wpfb, opts)
wpfb.Balancer = pf
return wpfb
}
func (wrappedPickFirstBalancerBuilder) Name() string {
return multipleUpdateStateChildBalancerName
}
type WrappedPickFirstBalancerConfig struct {
serviceconfig.LoadBalancingConfig
Backend string // The target for which this child policy was created.
}
func (wbb *wrappedPickFirstBalancerBuilder) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
cfg := &WrappedPickFirstBalancerConfig{}
if err := json.Unmarshal(c, cfg); err != nil {
return nil, err
}
return cfg, nil
}
// wrappedPickFirstBalancer wraps a pickfirst balancer and makes multiple calls
// to UpdateState when handling a config update in UpdateClientConnState. When
// this policy is used as a child policy of the RLS LB policy, it is expected
// that the latter suppress these updates and push a single picker update on the
// channel (after the config has been processed by all child policies).
type wrappedPickFirstBalancer struct {
balancer.Balancer
balancer.ClientConn
}
func (wb *wrappedPickFirstBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error {
wb.ClientConn.UpdateState(balancer.State{ConnectivityState: connectivity.Idle, Picker: &testutils.TestConstPicker{Err: balancer.ErrNoSubConnAvailable}})
wb.ClientConn.UpdateState(balancer.State{ConnectivityState: connectivity.Connecting, Picker: &testutils.TestConstPicker{Err: balancer.ErrNoSubConnAvailable}})
cfg := ccs.BalancerConfig.(*WrappedPickFirstBalancerConfig)
return wb.Balancer.UpdateClientConnState(balancer.ClientConnState{
ResolverState: resolver.State{Addresses: []resolver.Address{{Addr: cfg.Backend}}},
})
}
func (wb *wrappedPickFirstBalancer) UpdateState(state balancer.State) {
// Eat it if IDLE - allows it to switch over only on a READY SubConn.
if state.ConnectivityState == connectivity.Idle {
return
}
wb.ClientConn.UpdateState(state)
}
// TestUpdateStatePauses tests the scenario where a config update received by
// the RLS LB policy results in multiple UpdateState calls from the child
// policies. This test verifies that picker updates are paused when the config
@ -1063,60 +972,8 @@ func (s) TestUpdateStatePauses(t *testing.T) {
defer func() { clientConnUpdateHook = origClientConnUpdateHook }()
// Register the top-level wrapping balancer which forwards calls to RLS.
topLevelBalancerName := t.Name() + "top-level"
var ccWrapper *testCCWrapper
stub.Register(topLevelBalancerName, stub.BalancerFuncs{
Init: func(bd *stub.BalancerData) {
ccWrapper = &testCCWrapper{ClientConn: bd.ClientConn}
bd.ChildBalancer = balancer.Get(Name).Build(ccWrapper, bd.BuildOptions)
},
ParseConfig: func(sc json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
parser := balancer.Get(Name).(balancer.ConfigParser)
return parser.ParseConfig(sc)
},
UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
return bd.ChildBalancer.UpdateClientConnState(ccs)
},
Close: func(bd *stub.BalancerData) {
bd.ChildBalancer.Close()
},
})
// Register a child policy that wraps a pickfirst balancer and makes multiple calls
// to UpdateState when handling a config update in UpdateClientConnState. When
// this policy is used as a child policy of the RLS LB policy, it is expected
// that the latter suppress these updates and push a single picker update on the
// channel (after the config has been processed by all child policies).
childPolicyName := t.Name() + "child"
type childPolicyConfig struct {
serviceconfig.LoadBalancingConfig
Backend string // `json:"backend,omitempty"`
}
stub.Register(childPolicyName, stub.BalancerFuncs{
Init: func(bd *stub.BalancerData) {
bd.ChildBalancer = balancer.Get(pickfirst.Name).Build(bd.ClientConn, bd.BuildOptions)
},
Close: func(bd *stub.BalancerData) {
bd.ChildBalancer.Close()
},
ParseConfig: func(sc json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
cfg := &childPolicyConfig{}
if err := json.Unmarshal(sc, cfg); err != nil {
return nil, err
}
return cfg, nil
},
UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
bal := bd.ChildBalancer
bd.ClientConn.UpdateState(balancer.State{ConnectivityState: connectivity.Idle, Picker: &testutils.TestConstPicker{Err: balancer.ErrNoSubConnAvailable}})
bd.ClientConn.UpdateState(balancer.State{ConnectivityState: connectivity.Connecting, Picker: &testutils.TestConstPicker{Err: balancer.ErrNoSubConnAvailable}})
cfg := ccs.BalancerConfig.(*childPolicyConfig)
return bal.UpdateClientConnState(balancer.ClientConnState{
ResolverState: resolver.State{Addresses: []resolver.Address{{Addr: cfg.Backend}}},
})
},
})
bb := &wrappingTopLevelBalancerBuilder{balCh: make(chan balancer.Balancer, 1)}
balancer.Register(bb)
// Start an RLS server and set the throttler to never throttle requests.
rlsServer, rlsReqCh := rlstest.SetupFakeRLSServer(t, nil)
@ -1124,10 +981,14 @@ func (s) TestUpdateStatePauses(t *testing.T) {
// Start a test backend and set the RLS server to respond with it.
testBackendCh, testBackendAddress := startBackend(t)
rlsServer.SetResponseCallback(func(_ context.Context, _ *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
rlsServer.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
return &rlstest.RouteLookupResponse{Resp: &rlspb.RouteLookupResponse{Targets: []string{testBackendAddress}}}
})
// Register a child policy which wraps a pickfirst balancer and receives the
// backend address as part of its configuration.
balancer.Register(&wrappedPickFirstBalancerBuilder{})
// Register a manual resolver and push the RLS service config through it.
r := manual.NewBuilderWithScheme("rls-e2e")
scJSON := fmt.Sprintf(`
@ -1147,16 +1008,15 @@ func (s) TestUpdateStatePauses(t *testing.T) {
}
}
]
}`, topLevelBalancerName, rlsServer.Address, childPolicyName)
}`, wrappingTopLevelBalancerName, rlsServer.Address, multipleUpdateStateChildBalancerName)
sc := internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(scJSON)
r.InitialState(resolver.State{ServiceConfig: sc})
cc, err := grpc.NewClient(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
cc, err := grpc.Dial(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("grpc.NewClient() failed: %v", err)
t.Fatalf("grpc.Dial() failed: %v", err)
}
defer cc.Close()
cc.Connect()
// Wait for the clientconn update to be processed by the RLS LB policy.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
@ -1166,6 +1026,16 @@ func (s) TestUpdateStatePauses(t *testing.T) {
case <-clientConnUpdateDone:
}
// Get the top-level LB policy configured on the channel, to be able to read
// the state updates pushed by its child (the RLS LB policy.)
var wb *wrappingTopLevelBalancer
select {
case <-ctx.Done():
t.Fatal("Timeout when waiting for state update on the top-level LB policy")
case b := <-bb.balCh:
wb = b.(*wrappingTopLevelBalancer)
}
// It is important to note that at this point no child policies have been
// created because we have not attempted any RPC so far. When we attempt an
// RPC (below), child policies will be created and their configs will be
@ -1209,10 +1079,14 @@ func (s) TestUpdateStatePauses(t *testing.T) {
// the test would fail. Waiting for the channel to become READY here
// ensures that the test does not flake because of this rare sequence of
// events.
testutils.AwaitState(ctx, t, cc, connectivity.Ready)
for s := cc.GetState(); s != connectivity.Ready; s = cc.GetState() {
if !cc.WaitForStateChange(ctx, s) {
t.Fatal("Timeout when waiting for connectivity state to reach READY")
}
}
// Cache the state changes seen up to this point.
states0 := ccWrapper.getStates()
states0 := wb.getStates()
// Push an updated service config. As mentioned earlier, the previous config
// updates on the child policies did not happen in the context of a config
@ -1239,7 +1113,7 @@ func (s) TestUpdateStatePauses(t *testing.T) {
}
}
]
}`, topLevelBalancerName, rlsServer.Address, childPolicyName)
}`, wrappingTopLevelBalancerName, rlsServer.Address, multipleUpdateStateChildBalancerName)
sc = internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(scJSON)
r.UpdateState(resolver.State{ServiceConfig: sc})
@ -1253,7 +1127,7 @@ func (s) TestUpdateStatePauses(t *testing.T) {
// UpdateState as part of handling their configs, we expect the RLS policy
// to inhibit picker updates during this time frame, and send a single
// picker once the config update is completely handled.
states1 := ccWrapper.getStates()
states1 := wb.getStates()
if len(states1) != len(states0)+1 {
t.Fatalf("more than one state update seen. before %v, after %v", states0, states1)
}

View File

@ -22,8 +22,6 @@ import (
"container/list"
"time"
"github.com/google/uuid"
estats "google.golang.org/grpc/experimental/stats"
"google.golang.org/grpc/internal/backoff"
internalgrpclog "google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/grpcsync"
@ -49,7 +47,7 @@ type cacheEntry struct {
// headerData is received in the RLS response and is to be sent in the
// X-Google-RLS-Data header for matching RPCs.
headerData string
// expiryTime is the absolute time at which this cache entry stops
// expiryTime is the absolute time at which this cache entry entry stops
// being valid. When an RLS request succeeds, this is set to the current
// time plus the max_age field from the LB policy config.
expiryTime time.Time
@ -165,39 +163,24 @@ func (l *lru) getLeastRecentlyUsed() cacheKey {
//
// It is not safe for concurrent access.
type dataCache struct {
maxSize int64 // Maximum allowed size.
currentSize int64 // Current size.
keys *lru // Cache keys maintained in lru order.
entries map[cacheKey]*cacheEntry
logger *internalgrpclog.PrefixLogger
shutdown *grpcsync.Event
rlsServerTarget string
// Read only after initialization.
grpcTarget string
uuid string
metricsRecorder estats.MetricsRecorder
maxSize int64 // Maximum allowed size.
currentSize int64 // Current size.
keys *lru // Cache keys maintained in lru order.
entries map[cacheKey]*cacheEntry
logger *internalgrpclog.PrefixLogger
shutdown *grpcsync.Event
}
func newDataCache(size int64, logger *internalgrpclog.PrefixLogger, metricsRecorder estats.MetricsRecorder, grpcTarget string) *dataCache {
func newDataCache(size int64, logger *internalgrpclog.PrefixLogger) *dataCache {
return &dataCache{
maxSize: size,
keys: newLRU(),
entries: make(map[cacheKey]*cacheEntry),
logger: logger,
shutdown: grpcsync.NewEvent(),
grpcTarget: grpcTarget,
uuid: uuid.New().String(),
metricsRecorder: metricsRecorder,
maxSize: size,
keys: newLRU(),
entries: make(map[cacheKey]*cacheEntry),
logger: logger,
shutdown: grpcsync.NewEvent(),
}
}
// updateRLSServerTarget updates the RLS Server Target the RLS Balancer is
// configured with.
func (dc *dataCache) updateRLSServerTarget(rlsServerTarget string) {
dc.rlsServerTarget = rlsServerTarget
}
// resize changes the maximum allowed size of the data cache.
//
// The return value indicates if an entry with a valid backoff timer was
@ -240,7 +223,7 @@ func (dc *dataCache) resize(size int64) (backoffCancelled bool) {
backoffCancelled = true
}
}
dc.deleteAndCleanup(key, entry)
dc.deleteAndcleanup(key, entry)
}
dc.maxSize = size
return backoffCancelled
@ -266,7 +249,7 @@ func (dc *dataCache) evictExpiredEntries() bool {
if entry.expiryTime.After(now) || entry.backoffExpiryTime.After(now) {
continue
}
dc.deleteAndCleanup(key, entry)
dc.deleteAndcleanup(key, entry)
evicted = true
}
return evicted
@ -327,8 +310,6 @@ func (dc *dataCache) addEntry(key cacheKey, entry *cacheEntry) (backoffCancelled
if dc.currentSize > dc.maxSize {
backoffCancelled = dc.resize(dc.maxSize)
}
cacheSizeMetric.Record(dc.metricsRecorder, dc.currentSize, dc.grpcTarget, dc.rlsServerTarget, dc.uuid)
cacheEntriesMetric.Record(dc.metricsRecorder, int64(len(dc.entries)), dc.grpcTarget, dc.rlsServerTarget, dc.uuid)
return backoffCancelled, true
}
@ -338,7 +319,6 @@ func (dc *dataCache) updateEntrySize(entry *cacheEntry, newSize int64) {
dc.currentSize -= entry.size
entry.size = newSize
dc.currentSize += entry.size
cacheSizeMetric.Record(dc.metricsRecorder, dc.currentSize, dc.grpcTarget, dc.rlsServerTarget, dc.uuid)
}
func (dc *dataCache) getEntry(key cacheKey) *cacheEntry {
@ -359,7 +339,7 @@ func (dc *dataCache) removeEntryForTesting(key cacheKey) {
if !ok {
return
}
dc.deleteAndCleanup(key, entry)
dc.deleteAndcleanup(key, entry)
}
// deleteAndCleanup performs actions required at the time of deleting an entry
@ -367,17 +347,15 @@ func (dc *dataCache) removeEntryForTesting(key cacheKey) {
// - the entry is removed from the map of entries
// - current size of the data cache is update
// - the key is removed from the LRU
func (dc *dataCache) deleteAndCleanup(key cacheKey, entry *cacheEntry) {
func (dc *dataCache) deleteAndcleanup(key cacheKey, entry *cacheEntry) {
delete(dc.entries, key)
dc.currentSize -= entry.size
dc.keys.removeEntry(key)
cacheSizeMetric.Record(dc.metricsRecorder, dc.currentSize, dc.grpcTarget, dc.rlsServerTarget, dc.uuid)
cacheEntriesMetric.Record(dc.metricsRecorder, int64(len(dc.entries)), dc.grpcTarget, dc.rlsServerTarget, dc.uuid)
}
func (dc *dataCache) stop() {
for key, entry := range dc.entries {
dc.deleteAndCleanup(key, entry)
dc.deleteAndcleanup(key, entry)
}
dc.shutdown.Fire()
}

View File

@ -25,7 +25,6 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"google.golang.org/grpc/internal/backoff"
"google.golang.org/grpc/internal/testutils/stats"
)
var (
@ -120,7 +119,7 @@ func (s) TestLRU_BasicOperations(t *testing.T) {
func (s) TestDataCache_BasicOperations(t *testing.T) {
initCacheEntries()
dc := newDataCache(5, nil, &stats.NoopMetricsRecorder{}, "")
dc := newDataCache(5, nil)
for i, k := range cacheKeys {
dc.addEntry(k, cacheEntries[i])
}
@ -134,7 +133,7 @@ func (s) TestDataCache_BasicOperations(t *testing.T) {
func (s) TestDataCache_AddForcesResize(t *testing.T) {
initCacheEntries()
dc := newDataCache(1, nil, &stats.NoopMetricsRecorder{}, "")
dc := newDataCache(1, nil)
// The first entry in cacheEntries has a minimum expiry time in the future.
// This entry would stop the resize operation since we do not evict entries
@ -163,7 +162,7 @@ func (s) TestDataCache_AddForcesResize(t *testing.T) {
func (s) TestDataCache_Resize(t *testing.T) {
initCacheEntries()
dc := newDataCache(5, nil, &stats.NoopMetricsRecorder{}, "")
dc := newDataCache(5, nil)
for i, k := range cacheKeys {
dc.addEntry(k, cacheEntries[i])
}
@ -194,7 +193,7 @@ func (s) TestDataCache_Resize(t *testing.T) {
func (s) TestDataCache_EvictExpiredEntries(t *testing.T) {
initCacheEntries()
dc := newDataCache(5, nil, &stats.NoopMetricsRecorder{}, "")
dc := newDataCache(5, nil)
for i, k := range cacheKeys {
dc.addEntry(k, cacheEntries[i])
}
@ -221,7 +220,7 @@ func (s) TestDataCache_ResetBackoffState(t *testing.T) {
}
initCacheEntries()
dc := newDataCache(5, nil, &stats.NoopMetricsRecorder{}, "")
dc := newDataCache(5, nil)
for i, k := range cacheKeys {
dc.addEntry(k, cacheEntries[i])
}
@ -242,61 +241,3 @@ func (s) TestDataCache_ResetBackoffState(t *testing.T) {
t.Fatalf("unexpected diff in backoffState for cache entry after dataCache.resetBackoffState(): %s", diff)
}
}
func (s) TestDataCache_Metrics(t *testing.T) {
cacheEntriesMetricsTests := []*cacheEntry{
{size: 1},
{size: 2},
{size: 3},
{size: 4},
{size: 5},
}
tmr := stats.NewTestMetricsRecorder()
dc := newDataCache(50, nil, tmr, "")
dc.updateRLSServerTarget("rls-server-target")
for i, k := range cacheKeys {
dc.addEntry(k, cacheEntriesMetricsTests[i])
}
const cacheEntriesKey = "grpc.lb.rls.cache_entries"
const cacheSizeKey = "grpc.lb.rls.cache_size"
// 5 total entries which add up to 15 size, so should record that.
if got, _ := tmr.Metric(cacheEntriesKey); got != 5 {
t.Fatalf("Unexpected data for metric %v, got: %v, want: %v", cacheEntriesKey, got, 5)
}
if got, _ := tmr.Metric(cacheSizeKey); got != 15 {
t.Fatalf("Unexpected data for metric %v, got: %v, want: %v", cacheSizeKey, got, 15)
}
// Resize down the cache to 2 entries (deterministic as based of LRU).
dc.resize(9)
if got, _ := tmr.Metric(cacheEntriesKey); got != 2 {
t.Fatalf("Unexpected data for metric %v, got: %v, want: %v", cacheEntriesKey, got, 2)
}
if got, _ := tmr.Metric(cacheSizeKey); got != 9 {
t.Fatalf("Unexpected data for metric %v, got: %v, want: %v", cacheSizeKey, got, 9)
}
// Update an entry to have size 6. This should reflect in the size metrics,
// which will increase by 1 to 11, while the number of cache entries should
// stay same. This write is deterministic and writes to the last one.
dc.updateEntrySize(cacheEntriesMetricsTests[4], 6)
if got, _ := tmr.Metric(cacheEntriesKey); got != 2 {
t.Fatalf("Unexpected data for metric %v, got: %v, want: %v", cacheEntriesKey, got, 2)
}
if got, _ := tmr.Metric(cacheSizeKey); got != 10 {
t.Fatalf("Unexpected data for metric %v, got: %v, want: %v", cacheSizeKey, got, 10)
}
// Delete this scaled up cache key. This should scale down the cache to 1
// entries, and remove 6 size so cache size should be 4.
dc.deleteAndCleanup(cacheKeys[4], cacheEntriesMetricsTests[4])
if got, _ := tmr.Metric(cacheEntriesKey); got != 1 {
t.Fatalf("Unexpected data for metric %v, got: %v, want: %v", cacheEntriesKey, got, 1)
}
if got, _ := tmr.Metric(cacheSizeKey); got != 4 {
t.Fatalf("Unexpected data for metric %v, got: %v, want: %v", cacheSizeKey, got, 4)
}
}

View File

@ -25,6 +25,8 @@ import (
"net/url"
"time"
"github.com/golang/protobuf/ptypes"
durationpb "github.com/golang/protobuf/ptypes/duration"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/rls/internal/keys"
"google.golang.org/grpc/internal"
@ -33,7 +35,6 @@ import (
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/durationpb"
)
const (
@ -143,10 +144,7 @@ type lbConfigJSON struct {
// - childPolicyConfigTargetFieldName:
// - must be set and non-empty
func (rlsBB) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
if logger.V(2) {
logger.Infof("Received JSON service config: %v", pretty.ToJSON(c))
}
logger.Infof("Received JSON service config: %v", pretty.ToJSON(c))
cfgJSON := &lbConfigJSON{}
if err := json.Unmarshal(c, cfgJSON); err != nil {
return nil, fmt.Errorf("rls: json unmarshal failed for service config %+v: %v", string(c), err)
@ -220,43 +218,27 @@ func parseRLSProto(rlsProto *rlspb.RouteLookupConfig) (*lbConfig, error) {
// Validations performed here:
// - if `max_age` > 5m, it should be set to 5 minutes
// only if stale age is not set
// - if `stale_age` > `max_age`, ignore it
// - if `stale_age` is set, then `max_age` must also be set
maxAgeSet := false
maxAge, err := convertDuration(rlsProto.GetMaxAge())
if err != nil {
return nil, fmt.Errorf("rls: failed to parse max_age in route lookup config %+v: %v", rlsProto, err)
}
if maxAge == 0 {
maxAge = maxMaxAge
} else {
maxAgeSet = true
}
staleAgeSet := false
staleAge, err := convertDuration(rlsProto.GetStaleAge())
if err != nil {
return nil, fmt.Errorf("rls: failed to parse staleAge in route lookup config %+v: %v", rlsProto, err)
}
if staleAge == 0 {
staleAge = maxMaxAge
} else {
staleAgeSet = true
}
if staleAgeSet && !maxAgeSet {
if staleAge != 0 && maxAge == 0 {
return nil, fmt.Errorf("rls: stale_age is set, but max_age is not in route lookup config %+v", rlsProto)
}
if staleAge > maxMaxAge {
staleAge = maxMaxAge
if staleAge >= maxAge {
logger.Infof("rls: stale_age %v is not less than max_age %v, ignoring it", staleAge, maxAge)
staleAge = 0
}
if !staleAgeSet && maxAge > maxMaxAge {
if maxAge == 0 || maxAge > maxMaxAge {
logger.Infof("rls: max_age in route lookup config is %v, using %v", maxAge, maxMaxAge)
maxAge = maxMaxAge
}
if staleAge > maxAge {
staleAge = maxAge
}
// `cache_size_bytes` field must have a value greater than 0, and if its
// value is greater than 5M, we cap it at 5M
@ -326,5 +308,5 @@ func convertDuration(d *durationpb.Duration) (time.Duration, error) {
if d == nil {
return 0, nil
}
return d.AsDuration(), d.CheckValid()
return ptypes.Duration(d)
}

View File

@ -60,8 +60,8 @@ func (s) TestParseConfig(t *testing.T) {
// - A top-level unknown field should not fail.
// - An unknown field in routeLookupConfig proto should not fail.
// - lookupServiceTimeout is set to its default value, since it is not specified in the input.
// - maxAge is clamped to maxMaxAge if staleAge is not set.
// - staleAge is ignored because it is higher than maxAge in the input.
// - maxAge is set to maxMaxAge since the value is too large in the input.
// - staleAge is ignore because it is higher than maxAge in the input.
// - cacheSizeBytes is greater than the hard upper limit of 5MB
desc: "with transformations 1",
input: []byte(`{
@ -87,9 +87,9 @@ func (s) TestParseConfig(t *testing.T) {
}`),
wantCfg: &lbConfig{
lookupService: ":///target",
lookupServiceTimeout: 10 * time.Second, // This is the default value.
maxAge: 500 * time.Second, // Max age is not clamped when stale age is set.
staleAge: 300 * time.Second, // StaleAge is clamped because it was higher than maxMaxAge.
lookupServiceTimeout: 10 * time.Second, // This is the default value.
maxAge: 5 * time.Minute, // This is max maxAge.
staleAge: time.Duration(0), // StaleAge is ignore because it was higher than maxAge.
cacheSizeBytes: maxCacheSize,
defaultTarget: "passthrough:///default",
childPolicyName: "grpclb",
@ -100,69 +100,6 @@ func (s) TestParseConfig(t *testing.T) {
},
},
},
{
desc: "maxAge not clamped when staleAge is set",
input: []byte(`{
"routeLookupConfig": {
"grpcKeybuilders": [{
"names": [{"service": "service", "method": "method"}],
"headers": [{"key": "k1", "names": ["v1"]}]
}],
"lookupService": ":///target",
"maxAge" : "500s",
"staleAge": "200s",
"cacheSizeBytes": 100000000
},
"childPolicy": [
{"grpclb": {"childPolicy": [{"pickfirst": {}}]}}
],
"childPolicyConfigTargetFieldName": "serviceName"
}`),
wantCfg: &lbConfig{
lookupService: ":///target",
lookupServiceTimeout: 10 * time.Second, // This is the default value.
maxAge: 500 * time.Second, // Max age is not clamped when stale age is set.
staleAge: 200 * time.Second, // This is stale age within maxMaxAge.
cacheSizeBytes: maxCacheSize,
childPolicyName: "grpclb",
childPolicyTargetField: "serviceName",
childPolicyConfig: map[string]json.RawMessage{
"childPolicy": json.RawMessage(`[{"pickfirst": {}}]`),
"serviceName": json.RawMessage(childPolicyTargetFieldVal),
},
},
},
{
desc: "maxAge clamped when staleAge is not set",
input: []byte(`{
"routeLookupConfig": {
"grpcKeybuilders": [{
"names": [{"service": "service", "method": "method"}],
"headers": [{"key": "k1", "names": ["v1"]}]
}],
"lookupService": ":///target",
"maxAge" : "500s",
"cacheSizeBytes": 100000000
},
"childPolicy": [
{"grpclb": {"childPolicy": [{"pickfirst": {}}]}}
],
"childPolicyConfigTargetFieldName": "serviceName"
}`),
wantCfg: &lbConfig{
lookupService: ":///target",
lookupServiceTimeout: 10 * time.Second, // This is the default value.
maxAge: 300 * time.Second, // Max age is clamped when stale age is not set.
staleAge: 300 * time.Second,
cacheSizeBytes: maxCacheSize,
childPolicyName: "grpclb",
childPolicyTargetField: "serviceName",
childPolicyConfig: map[string]json.RawMessage{
"childPolicy": json.RawMessage(`[{"pickfirst": {}}]`),
"serviceName": json.RawMessage(childPolicyTargetFieldVal),
},
},
},
{
desc: "without transformations",
input: []byte(`{
@ -385,7 +322,7 @@ func (s) TestParseConfigErrors(t *testing.T) {
"childPolicy": [{"grpclb": {"childPolicy": [{"pickfirst": {}}]}}],
"childPolicyConfigTargetFieldName": "serviceName"
}`),
wantErr: "no supported policies found in config",
wantErr: "invalid loadBalancingConfig: no supported policies found",
},
{
desc: "no child policy",

View File

@ -29,9 +29,7 @@ import (
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/buffer"
internalgrpclog "google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/pretty"
rlsgrpc "google.golang.org/grpc/internal/proto/grpc_lookup_v1"
rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1"
@ -57,12 +55,9 @@ type controlChannel struct {
// hammering the RLS service while it is overloaded or down.
throttler adaptiveThrottler
cc *grpc.ClientConn
client rlsgrpc.RouteLookupServiceClient
logger *internalgrpclog.PrefixLogger
connectivityStateCh *buffer.Unbounded
unsubscribe func()
monitorDoneCh chan struct{}
cc *grpc.ClientConn
client rlsgrpc.RouteLookupServiceClient
logger *internalgrpclog.PrefixLogger
}
// newControlChannel creates a controlChannel to rlsServerName and uses
@ -70,11 +65,9 @@ type controlChannel struct {
// gRPC channel.
func newControlChannel(rlsServerName, serviceConfig string, rpcTimeout time.Duration, bOpts balancer.BuildOptions, backToReadyFunc func()) (*controlChannel, error) {
ctrlCh := &controlChannel{
rpcTimeout: rpcTimeout,
backToReadyFunc: backToReadyFunc,
throttler: newAdaptiveThrottler(),
connectivityStateCh: buffer.NewUnbounded(),
monitorDoneCh: make(chan struct{}),
rpcTimeout: rpcTimeout,
backToReadyFunc: backToReadyFunc,
throttler: newAdaptiveThrottler(),
}
ctrlCh.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf("[rls-control-channel %p] ", ctrlCh))
@ -82,28 +75,17 @@ func newControlChannel(rlsServerName, serviceConfig string, rpcTimeout time.Dura
if err != nil {
return nil, err
}
ctrlCh.cc, err = grpc.NewClient(rlsServerName, dopts...)
ctrlCh.cc, err = grpc.Dial(rlsServerName, dopts...)
if err != nil {
return nil, err
}
// Subscribe to connectivity state before connecting to avoid missing initial
// updates, which are only delivered to active subscribers.
ctrlCh.unsubscribe = internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(ctrlCh.cc, ctrlCh)
ctrlCh.cc.Connect()
ctrlCh.client = rlsgrpc.NewRouteLookupServiceClient(ctrlCh.cc)
ctrlCh.logger.Infof("Control channel created to RLS server at: %v", rlsServerName)
go ctrlCh.monitorConnectivityState()
return ctrlCh, nil
}
func (cc *controlChannel) OnMessage(msg any) {
st, ok := msg.(connectivity.State)
if !ok {
panic(fmt.Sprintf("Unexpected message type %T , wanted connectectivity.State type", msg))
}
cc.connectivityStateCh.Put(st)
}
// dialOpts constructs the dial options for the control plane channel.
func (cc *controlChannel) dialOpts(bOpts balancer.BuildOptions, serviceConfig string) ([]grpc.DialOption, error) {
// The control plane channel will use the same authority as the parent
@ -115,6 +97,7 @@ func (cc *controlChannel) dialOpts(bOpts balancer.BuildOptions, serviceConfig st
if bOpts.Dialer != nil {
dopts = append(dopts, grpc.WithContextDialer(bOpts.Dialer))
}
// The control channel will use the channel credentials from the parent
// channel, including any call creds associated with the channel creds.
var credsOpt grpc.DialOption
@ -150,8 +133,6 @@ func (cc *controlChannel) dialOpts(bOpts balancer.BuildOptions, serviceConfig st
func (cc *controlChannel) monitorConnectivityState() {
cc.logger.Infof("Starting connectivity state monitoring goroutine")
defer close(cc.monitorDoneCh)
// Since we use two mechanisms to deal with RLS server being down:
// - adaptive throttling for the channel as a whole
// - exponential backoff on a per-request basis
@ -173,45 +154,39 @@ func (cc *controlChannel) monitorConnectivityState() {
// returning only one new picker, regardless of how many backoff timers are
// cancelled.
// Wait for the control channel to become READY for the first time.
for s, ok := <-cc.connectivityStateCh.Get(); s != connectivity.Ready; s, ok = <-cc.connectivityStateCh.Get() {
if !ok {
return
}
cc.connectivityStateCh.Load()
if s == connectivity.Shutdown {
return
}
}
cc.connectivityStateCh.Load()
cc.logger.Infof("Connectivity state is READY")
// Using the background context is fine here since we check for the ClientConn
// entering SHUTDOWN and return early in that case.
ctx := context.Background()
first := true
for {
s, ok := <-cc.connectivityStateCh.Get()
if !ok {
return
// Wait for the control channel to become READY.
for s := cc.cc.GetState(); s != connectivity.Ready; s = cc.cc.GetState() {
if s == connectivity.Shutdown {
return
}
cc.cc.WaitForStateChange(ctx, s)
}
cc.connectivityStateCh.Load()
cc.logger.Infof("Connectivity state is READY")
if s == connectivity.Shutdown {
return
}
if s == connectivity.Ready {
if !first {
cc.logger.Infof("Control channel back to READY")
cc.backToReadyFunc()
}
first = false
cc.logger.Infof("Connectivity state is %s", s)
// Wait for the control channel to move out of READY.
cc.cc.WaitForStateChange(ctx, connectivity.Ready)
if cc.cc.GetState() == connectivity.Shutdown {
return
}
cc.logger.Infof("Connectivity state is %s", cc.cc.GetState())
}
}
func (cc *controlChannel) close() {
cc.unsubscribe()
cc.connectivityStateCh.Close()
<-cc.monitorDoneCh
cc.logger.Infof("Closing control channel")
cc.cc.Close()
cc.logger.Infof("Shutdown")
}
type lookupCallback func(targets []string, headerData string, err error)
@ -234,9 +209,7 @@ func (cc *controlChannel) lookup(reqKeys map[string]string, reason rlspb.RouteLo
Reason: reason,
StaleHeaderData: staleHeaders,
}
if cc.logger.V(2) {
cc.logger.Infof("Sending RLS request %+v", pretty.ToJSON(req))
}
cc.logger.Infof("Sending RLS request %+v", pretty.ToJSON(req))
ctx, cancel := context.WithTimeout(context.Background(), cc.rpcTimeout)
defer cancel()

View File

@ -62,7 +62,7 @@ func (s) TestControlChannelThrottled(t *testing.T) {
select {
case <-rlsReqCh:
t.Fatal("RouteLookup RPC invoked when control channel is throttled")
t.Fatal("RouteLookup RPC invoked when control channel is throtlled")
case <-time.After(defaultTestShortTimeout):
}
}
@ -74,7 +74,7 @@ func (s) TestLookupFailure(t *testing.T) {
overrideAdaptiveThrottler(t, neverThrottlingThrottler())
// Setup the RLS server to respond with errors.
rlsServer.SetResponseCallback(func(context.Context, *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
rlsServer.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
return &rlstest.RouteLookupResponse{Err: errors.New("rls failure")}
})
@ -109,7 +109,7 @@ func (s) TestLookupFailure(t *testing.T) {
// respond within the configured rpc timeout.
func (s) TestLookupDeadlineExceeded(t *testing.T) {
// A unary interceptor which returns a status error with DeadlineExceeded.
interceptor := func(context.Context, any, *grpc.UnaryServerInfo, grpc.UnaryHandler) (resp any, err error) {
interceptor := func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
return nil, status.Error(codes.DeadlineExceeded, "deadline exceeded")
}
@ -191,7 +191,7 @@ func (f *testPerRPCCredentials) RequireTransportSecurity() bool {
// Unary server interceptor which validates if the RPC contains call credentials
// which match `perRPCCredsData
func callCredsValidatingServerInterceptor(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
func callCredsValidatingServerInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, status.Error(codes.PermissionDenied, "didn't find metadata in context")
@ -260,7 +260,7 @@ func testControlChannelCredsSuccess(t *testing.T, sopts []grpc.ServerOption, bop
overrideAdaptiveThrottler(t, neverThrottlingThrottler())
// Setup the RLS server to respond with a valid response.
rlsServer.SetResponseCallback(func(context.Context, *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
rlsServer.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
return lookupResponse
})

View File

@ -28,6 +28,7 @@ import (
"google.golang.org/grpc/balancer/rls/internal/test/e2e"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/balancergroup"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/grpctest"
rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1"
@ -47,6 +48,10 @@ const (
defaultTestShortTimeout = 100 * time.Millisecond
)
func init() {
balancergroup.DefaultSubBalancerCloseTimeout = time.Millisecond
}
type s struct {
grpctest.Tester
}
@ -61,7 +66,7 @@ type fakeBackoffStrategy struct {
backoff time.Duration
}
func (f *fakeBackoffStrategy) Backoff(int) time.Duration {
func (f *fakeBackoffStrategy) Backoff(retries int) time.Duration {
return f.backoff
}
@ -171,7 +176,7 @@ func startBackend(t *testing.T, sopts ...grpc.ServerOption) (rpcCh chan struct{}
rpcCh = make(chan struct{}, 1)
backend := &stubserver.StubServer{
EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
select {
case rpcCh <- struct{}{}:
default:

View File

@ -20,15 +20,16 @@
package adaptive
import (
rand "math/rand/v2"
"sync"
"time"
"google.golang.org/grpc/internal/grpcrand"
)
// For overriding in unittests.
var (
timeNowFunc = time.Now
randFunc = rand.Float64
timeNowFunc = func() time.Time { return time.Now() }
randFunc = func() float64 { return grpcrand.Float64() }
)
const (

View File

@ -25,13 +25,13 @@ import (
)
// stats returns a tuple with accepts, throttles for the current time.
func (t *Throttler) stats() (int64, int64) {
func (th *Throttler) stats() (int64, int64) {
now := timeNowFunc()
t.mu.Lock()
a, th := t.accepts.sum(now), t.throttles.sum(now)
t.mu.Unlock()
return a, th
th.mu.Lock()
a, t := th.accepts.sum(now), th.throttles.sum(now)
th.mu.Unlock()
return a, t
}
// Enums for responses.

View File

@ -82,3 +82,10 @@ func (l *lookback) advance(t time.Time) int64 {
l.head = nh
return nh
}
func min(x int64, y int64) int64 {
if x < y {
return x
}
return y
}

View File

@ -189,7 +189,7 @@ func (b builder) Equal(a builder) bool {
// Protobuf serialization maintains the order of repeated fields. Matchers
// are specified as a repeated field inside the KeyBuilder proto. If the
// order changes, it means that the order in the protobuf changed. We report
// this case as not being equal even though the builders could possibly be
// this case as not being equal even though the builders could possible be
// functionally equal.
for i, bMatcher := range b.headerKeys {
aMatcher := a.headerKeys[i]
@ -218,7 +218,7 @@ type matcher struct {
names []string
}
// Equal reports if m and a are equivalent headerKeys.
// Equal reports if m and are are equivalent headerKeys.
func (m matcher) Equal(a matcher) bool {
if m.key != a.key {
return false

View File

@ -23,8 +23,8 @@ import (
"errors"
"fmt"
"google.golang.org/grpc"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/pickfirst"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig"
@ -68,7 +68,7 @@ type bb struct {
func (bb bb) Name() string { return bb.name }
func (bb bb) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
pf := balancer.Get(pickfirst.Name)
pf := balancer.Get(grpc.PickFirstBalancerName)
b := &bal{
Balancer: pf.Build(cc, opts),
bf: bb.bf,
@ -125,7 +125,7 @@ func (b *bal) Close() {
// run is a dummy goroutine to make sure that child policies are closed at the
// end of tests. If they are not closed, these goroutines will be picked up by
// the leak checker and tests will fail.
// the leakcheker and tests will fail.
func (b *bal) run() {
<-b.done.Done()
}

View File

@ -1,367 +0,0 @@
/*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package rls
import (
"context"
"math/rand"
"testing"
"github.com/google/uuid"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/sdk/metric"
"go.opentelemetry.io/otel/sdk/metric/metricdata"
"go.opentelemetry.io/otel/sdk/metric/metricdata/metricdatatest"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1"
"google.golang.org/grpc/internal/stubserver"
rlstest "google.golang.org/grpc/internal/testutils/rls"
testgrpc "google.golang.org/grpc/interop/grpc_testing"
testpb "google.golang.org/grpc/interop/grpc_testing"
"google.golang.org/grpc/stats/opentelemetry"
)
func metricsDataFromReader(ctx context.Context, reader *metric.ManualReader) map[string]metricdata.Metrics {
rm := &metricdata.ResourceMetrics{}
reader.Collect(ctx, rm)
gotMetrics := map[string]metricdata.Metrics{}
for _, sm := range rm.ScopeMetrics {
for _, m := range sm.Metrics {
gotMetrics[m.Name] = m
}
}
return gotMetrics
}
// TestRLSTargetPickMetric tests RLS Metrics in the case an RLS Balancer picks a
// target from an RLS Response for a RPC. This should emit a
// "grpc.lb.rls.target_picks" with certain labels and cache metrics with certain
// labels.
func (s) TestRLSTargetPickMetric(t *testing.T) {
// Overwrite the uuid random number generator to be deterministic.
uuid.SetRand(rand.New(rand.NewSource(1)))
defer uuid.SetRand(nil)
rlsServer, _ := rlstest.SetupFakeRLSServer(t, nil)
rlsConfig := buildBasicRLSConfigWithChildPolicy(t, t.Name(), rlsServer.Address)
backend := &stubserver.StubServer{
EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
return &testpb.Empty{}, nil
},
}
if err := backend.StartServer(); err != nil {
t.Fatalf("Failed to start backend: %v", err)
}
t.Logf("Started TestService backend at: %q", backend.Address)
defer backend.Stop()
rlsServer.SetResponseCallback(func(context.Context, *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
return &rlstest.RouteLookupResponse{Resp: &rlspb.RouteLookupResponse{Targets: []string{backend.Address}}}
})
r := startManualResolverWithConfig(t, rlsConfig)
reader := metric.NewManualReader()
provider := metric.NewMeterProvider(metric.WithReader(reader))
mo := opentelemetry.MetricsOptions{
MeterProvider: provider,
Metrics: opentelemetry.DefaultMetrics().Add("grpc.lb.rls.cache_entries", "grpc.lb.rls.cache_size", "grpc.lb.rls.default_target_picks", "grpc.lb.rls.target_picks", "grpc.lb.rls.failed_picks"),
}
grpcTarget := r.Scheme() + ":///"
cc, err := grpc.NewClient(grpcTarget, grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()), opentelemetry.DialOption(opentelemetry.Options{MetricsOptions: mo}))
if err != nil {
t.Fatalf("Failed to dial local test server: %v", err)
}
defer cc.Close()
wantMetrics := []metricdata.Metrics{
{
Name: "grpc.lb.rls.target_picks",
Description: "EXPERIMENTAL. Number of LB picks sent to each RLS target. Note that if the default target is also returned by the RLS server, RPCs sent to that target from the cache will be counted in this metric, not in grpc.rls.default_target_picks.",
Unit: "{pick}",
Data: metricdata.Sum[int64]{
DataPoints: []metricdata.DataPoint[int64]{
{
Attributes: attribute.NewSet(attribute.String("grpc.target", grpcTarget), attribute.String("grpc.lb.rls.server_target", rlsServer.Address), attribute.String("grpc.lb.rls.data_plane_target", backend.Address), attribute.String("grpc.lb.pick_result", "complete")),
Value: 1,
},
},
Temporality: metricdata.CumulativeTemporality,
IsMonotonic: true,
},
},
// Receives an empty RLS Response, so a single cache entry with no size.
{
Name: "grpc.lb.rls.cache_entries",
Description: "EXPERIMENTAL. Number of entries in the RLS cache.",
Unit: "{entry}",
Data: metricdata.Gauge[int64]{
DataPoints: []metricdata.DataPoint[int64]{
{
Attributes: attribute.NewSet(attribute.String("grpc.target", grpcTarget), attribute.String("grpc.lb.rls.server_target", rlsServer.Address), attribute.String("grpc.lb.rls.instance_uuid", "52fdfc07-2182-454f-963f-5f0f9a621d72")),
Value: 1,
},
},
},
},
{
Name: "grpc.lb.rls.cache_size",
Description: "EXPERIMENTAL. The current size of the RLS cache.",
Unit: "By",
Data: metricdata.Gauge[int64]{
DataPoints: []metricdata.DataPoint[int64]{
{
Attributes: attribute.NewSet(attribute.String("grpc.target", grpcTarget), attribute.String("grpc.lb.rls.server_target", rlsServer.Address), attribute.String("grpc.lb.rls.instance_uuid", "52fdfc07-2182-454f-963f-5f0f9a621d72")),
Value: 35,
},
},
},
},
}
client := testgrpc.NewTestServiceClient(cc)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
_, err = client.EmptyCall(ctx, &testpb.Empty{})
if err != nil {
t.Fatalf("client.EmptyCall failed with error: %v", err)
}
gotMetrics := metricsDataFromReader(ctx, reader)
for _, metric := range wantMetrics {
val, ok := gotMetrics[metric.Name]
if !ok {
t.Fatalf("Metric %v not present in recorded metrics", metric.Name)
}
if !metricdatatest.AssertEqual(t, metric, val, metricdatatest.IgnoreTimestamp(), metricdatatest.IgnoreExemplars()) {
t.Fatalf("Metrics data type not equal for metric: %v", metric.Name)
}
}
// Only one pick was made, which was a target pick, so no default target
// pick or failed pick metric should emit.
for _, metric := range []string{"grpc.lb.rls.default_target_picks", "grpc.lb.rls.failed_picks"} {
if _, ok := gotMetrics[metric]; ok {
t.Fatalf("Metric %v present in recorded metrics", metric)
}
}
}
// TestRLSDefaultTargetPickMetric tests RLS Metrics in the case an RLS Balancer
// falls back to the default target for an RPC. This should emit a
// "grpc.lb.rls.default_target_picks" with certain labels and cache metrics with
// certain labels.
func (s) TestRLSDefaultTargetPickMetric(t *testing.T) {
// Overwrite the uuid random number generator to be deterministic.
uuid.SetRand(rand.New(rand.NewSource(1)))
defer uuid.SetRand(nil)
rlsServer, _ := rlstest.SetupFakeRLSServer(t, nil)
// Build RLS service config with a default target.
rlsConfig := buildBasicRLSConfigWithChildPolicy(t, t.Name(), rlsServer.Address)
backend := &stubserver.StubServer{
EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
return &testpb.Empty{}, nil
},
}
if err := backend.StartServer(); err != nil {
t.Fatalf("Failed to start backend: %v", err)
}
t.Logf("Started TestService backend at: %q", backend.Address)
defer backend.Stop()
rlsConfig.RouteLookupConfig.DefaultTarget = backend.Address
r := startManualResolverWithConfig(t, rlsConfig)
reader := metric.NewManualReader()
provider := metric.NewMeterProvider(metric.WithReader(reader))
mo := opentelemetry.MetricsOptions{
MeterProvider: provider,
Metrics: opentelemetry.DefaultMetrics().Add("grpc.lb.rls.cache_entries", "grpc.lb.rls.cache_size", "grpc.lb.rls.default_target_picks", "grpc.lb.rls.target_picks", "grpc.lb.rls.failed_picks"),
}
grpcTarget := r.Scheme() + ":///"
cc, err := grpc.NewClient(grpcTarget, grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()), opentelemetry.DialOption(opentelemetry.Options{MetricsOptions: mo}))
if err != nil {
t.Fatalf("Failed to dial local test server: %v", err)
}
defer cc.Close()
wantMetrics := []metricdata.Metrics{
{
Name: "grpc.lb.rls.default_target_picks",
Description: "EXPERIMENTAL. Number of LB picks sent to the default target.",
Unit: "{pick}",
Data: metricdata.Sum[int64]{
DataPoints: []metricdata.DataPoint[int64]{
{
Attributes: attribute.NewSet(attribute.String("grpc.target", grpcTarget), attribute.String("grpc.lb.rls.server_target", rlsServer.Address), attribute.String("grpc.lb.rls.data_plane_target", backend.Address), attribute.String("grpc.lb.pick_result", "complete")),
Value: 1,
},
},
Temporality: metricdata.CumulativeTemporality,
IsMonotonic: true,
},
},
// Receives a RLS Response with target information, so a single cache
// entry with a certain size.
{
Name: "grpc.lb.rls.cache_entries",
Description: "EXPERIMENTAL. Number of entries in the RLS cache.",
Unit: "{entry}",
Data: metricdata.Gauge[int64]{
DataPoints: []metricdata.DataPoint[int64]{
{
Attributes: attribute.NewSet(attribute.String("grpc.target", grpcTarget), attribute.String("grpc.lb.rls.server_target", rlsServer.Address), attribute.String("grpc.lb.rls.instance_uuid", "52fdfc07-2182-454f-963f-5f0f9a621d72")),
Value: 1,
},
},
},
},
{
Name: "grpc.lb.rls.cache_size",
Description: "EXPERIMENTAL. The current size of the RLS cache.",
Unit: "By",
Data: metricdata.Gauge[int64]{
DataPoints: []metricdata.DataPoint[int64]{
{
Attributes: attribute.NewSet(attribute.String("grpc.target", grpcTarget), attribute.String("grpc.lb.rls.server_target", rlsServer.Address), attribute.String("grpc.lb.rls.instance_uuid", "52fdfc07-2182-454f-963f-5f0f9a621d72")),
Value: 0,
},
},
},
},
}
client := testgrpc.NewTestServiceClient(cc)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err = client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
t.Fatalf("client.EmptyCall failed with error: %v", err)
}
gotMetrics := metricsDataFromReader(ctx, reader)
for _, metric := range wantMetrics {
val, ok := gotMetrics[metric.Name]
if !ok {
t.Fatalf("Metric %v not present in recorded metrics", metric.Name)
}
if !metricdatatest.AssertEqual(t, metric, val, metricdatatest.IgnoreTimestamp(), metricdatatest.IgnoreExemplars()) {
t.Fatalf("Metrics data type not equal for metric: %v", metric.Name)
}
}
// No target picks and failed pick metrics should be emitted, as the test
// made only one RPC which recorded as a default target pick.
for _, metric := range []string{"grpc.lb.rls.target_picks", "grpc.lb.rls.failed_picks"} {
if _, ok := gotMetrics[metric]; ok {
t.Fatalf("Metric %v present in recorded metrics", metric)
}
}
}
// TestRLSFailedRPCMetric tests RLS Metrics in the case an RLS Balancer fails an
// RPC due to an RLS failure. This should emit a
// "grpc.lb.rls.default_target_picks" with certain labels and cache metrics with
// certain labels.
func (s) TestRLSFailedRPCMetric(t *testing.T) {
// Overwrite the uuid random number generator to be deterministic.
uuid.SetRand(rand.New(rand.NewSource(1)))
defer uuid.SetRand(nil)
rlsServer, _ := rlstest.SetupFakeRLSServer(t, nil)
// Build an RLS config without a default target.
rlsConfig := buildBasicRLSConfigWithChildPolicy(t, t.Name(), rlsServer.Address)
// Register a manual resolver and push the RLS service config through it.
r := startManualResolverWithConfig(t, rlsConfig)
reader := metric.NewManualReader()
provider := metric.NewMeterProvider(metric.WithReader(reader))
mo := opentelemetry.MetricsOptions{
MeterProvider: provider,
Metrics: opentelemetry.DefaultMetrics().Add("grpc.lb.rls.cache_entries", "grpc.lb.rls.cache_size", "grpc.lb.rls.default_target_picks", "grpc.lb.rls.target_picks", "grpc.lb.rls.failed_picks"),
}
grpcTarget := r.Scheme() + ":///"
cc, err := grpc.NewClient(grpcTarget, grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()), opentelemetry.DialOption(opentelemetry.Options{MetricsOptions: mo}))
if err != nil {
t.Fatalf("Failed to dial local test server: %v", err)
}
defer cc.Close()
wantMetrics := []metricdata.Metrics{
{
Name: "grpc.lb.rls.failed_picks",
Description: "EXPERIMENTAL. Number of LB picks failed due to either a failed RLS request or the RLS channel being throttled.",
Unit: "{pick}",
Data: metricdata.Sum[int64]{
DataPoints: []metricdata.DataPoint[int64]{
{
Attributes: attribute.NewSet(attribute.String("grpc.target", grpcTarget), attribute.String("grpc.lb.rls.server_target", rlsServer.Address)),
Value: 1,
},
},
Temporality: metricdata.CumulativeTemporality,
IsMonotonic: true,
},
},
// Receives an empty RLS Response, so a single cache entry with no size.
{
Name: "grpc.lb.rls.cache_entries",
Description: "EXPERIMENTAL. Number of entries in the RLS cache.",
Unit: "{entry}",
Data: metricdata.Gauge[int64]{
DataPoints: []metricdata.DataPoint[int64]{
{
Attributes: attribute.NewSet(attribute.String("grpc.target", grpcTarget), attribute.String("grpc.lb.rls.server_target", rlsServer.Address), attribute.String("grpc.lb.rls.instance_uuid", "52fdfc07-2182-454f-963f-5f0f9a621d72")),
Value: 1,
},
},
},
},
{
Name: "grpc.lb.rls.cache_size",
Description: "EXPERIMENTAL. The current size of the RLS cache.",
Unit: "By",
Data: metricdata.Gauge[int64]{
DataPoints: []metricdata.DataPoint[int64]{
{
Attributes: attribute.NewSet(attribute.String("grpc.target", grpcTarget), attribute.String("grpc.lb.rls.server_target", rlsServer.Address), attribute.String("grpc.lb.rls.instance_uuid", "52fdfc07-2182-454f-963f-5f0f9a621d72")),
Value: 0,
},
},
},
},
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
client := testgrpc.NewTestServiceClient(cc)
if _, err = client.EmptyCall(ctx, &testpb.Empty{}); err == nil {
t.Fatalf("client.EmptyCall error = %v, expected a non nil error", err)
}
gotMetrics := metricsDataFromReader(ctx, reader)
for _, metric := range wantMetrics {
val, ok := gotMetrics[metric.Name]
if !ok {
t.Fatalf("Metric %v not present in recorded metrics", metric.Name)
}
if !metricdatatest.AssertEqual(t, metric, val, metricdatatest.IgnoreTimestamp(), metricdatatest.IgnoreExemplars()) {
t.Fatalf("Metrics data type not equal for metric: %v", metric.Name)
}
}
// Only one RPC was made, which was a failed pick due to an RLS failure, so
// no metrics for target picks or default target picks should have emitted.
for _, metric := range []string{"grpc.lb.rls.target_picks", "grpc.lb.rls.default_target_picks"} {
if _, ok := gotMetrics[metric]; ok {
t.Fatalf("Metric %v present in recorded metrics", metric)
}
}
}

View File

@ -29,7 +29,6 @@ import (
"google.golang.org/grpc/balancer/rls/internal/keys"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
estats "google.golang.org/grpc/experimental/stats"
internalgrpclog "google.golang.org/grpc/internal/grpclog"
rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1"
"google.golang.org/grpc/metadata"
@ -62,15 +61,12 @@ type rlsPicker struct {
// The picker is given its own copy of the below fields from the RLS LB policy
// to avoid having to grab the mutex on the latter.
rlsServerTarget string
grpcTarget string
metricsRecorder estats.MetricsRecorder
defaultPolicy *childPolicyWrapper // Child policy for the default target.
ctrlCh *controlChannel // Control channel to the RLS server.
maxAge time.Duration // Cache max age from LB config.
staleAge time.Duration // Cache stale age from LB config.
bg exitIdler
logger *internalgrpclog.PrefixLogger
defaultPolicy *childPolicyWrapper // Child policy for the default target.
ctrlCh *controlChannel // Control channel to the RLS server.
maxAge time.Duration // Cache max age from LB config.
staleAge time.Duration // Cache stale age from LB config.
bg exitIdler
logger *internalgrpclog.PrefixLogger
}
// isFullMethodNameValid return true if name is of the form `/service/method`.
@ -89,17 +85,7 @@ func (p *rlsPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
reqKeys := p.kbm.RLSKey(md, p.origEndpoint, info.FullMethodName)
p.lb.cacheMu.Lock()
var pr balancer.PickResult
var err error
// Record metrics without the cache mutex held, to prevent lock contention
// between concurrent RPC's and their Pick calls. Metrics Recording can
// potentially be expensive.
metricsCallback := func() {}
defer func() {
p.lb.cacheMu.Unlock()
metricsCallback()
}()
defer p.lb.cacheMu.Unlock()
// Lookup data cache and pending request map using request path and keys.
cacheKey := cacheKey{path: info.FullMethodName, keys: reqKeys.Str}
@ -112,8 +98,7 @@ func (p *rlsPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
case dcEntry == nil && pendingEntry == nil:
throttled := p.sendRouteLookupRequestLocked(cacheKey, &backoffState{bs: defaultBackoffStrategy}, reqKeys.Map, rlspb.RouteLookupRequest_REASON_MISS, "")
if throttled {
pr, metricsCallback, err = p.useDefaultPickIfPossible(info, errRLSThrottled)
return pr, err
return p.useDefaultPickIfPossible(info, errRLSThrottled)
}
return balancer.PickResult{}, balancer.ErrNoSubConnAvailable
@ -128,8 +113,8 @@ func (p *rlsPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
p.sendRouteLookupRequestLocked(cacheKey, dcEntry.backoffState, reqKeys.Map, rlspb.RouteLookupRequest_REASON_STALE, dcEntry.headerData)
}
// Delegate to child policies.
pr, metricsCallback, err = p.delegateToChildPoliciesLocked(dcEntry, info)
return pr, err
res, err := p.delegateToChildPoliciesLocked(dcEntry, info)
return res, err
}
// We get here only if the data cache entry has expired. If entry is in
@ -141,108 +126,67 @@ func (p *rlsPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
// message received from the control plane is still fine, as it could be
// useful for debugging purposes.
st := dcEntry.status
pr, metricsCallback, err = p.useDefaultPickIfPossible(info, status.Error(codes.Unavailable, fmt.Sprintf("most recent error from RLS server: %v", st.Error())))
return pr, err
return p.useDefaultPickIfPossible(info, status.Error(codes.Unavailable, fmt.Sprintf("most recent error from RLS server: %v", st.Error())))
}
// We get here only if the entry has expired and is not in backoff.
throttled := p.sendRouteLookupRequestLocked(cacheKey, dcEntry.backoffState, reqKeys.Map, rlspb.RouteLookupRequest_REASON_MISS, "")
if throttled {
pr, metricsCallback, err = p.useDefaultPickIfPossible(info, errRLSThrottled)
return pr, err
return p.useDefaultPickIfPossible(info, errRLSThrottled)
}
return balancer.PickResult{}, balancer.ErrNoSubConnAvailable
// Data cache hit. Pending request exists.
default:
if dcEntry.expiryTime.After(now) {
pr, metricsCallback, err = p.delegateToChildPoliciesLocked(dcEntry, info)
return pr, err
res, err := p.delegateToChildPoliciesLocked(dcEntry, info)
return res, err
}
// Data cache entry has expired and pending request exists. Queue pick.
return balancer.PickResult{}, balancer.ErrNoSubConnAvailable
}
}
// errToPickResult is a helper function which converts the error value returned
// by Pick() to a string that represents the pick result.
func errToPickResult(err error) string {
if err == nil {
return "complete"
}
if errors.Is(err, balancer.ErrNoSubConnAvailable) {
return "queue"
}
if _, ok := status.FromError(err); ok {
return "drop"
}
return "fail"
}
// delegateToChildPoliciesLocked is a helper function which iterates through the
// list of child policy wrappers in a cache entry and attempts to find a child
// policy to which this RPC can be routed to. If all child policies are in
// TRANSIENT_FAILURE, we delegate to the last child policy arbitrarily. Returns
// a function to be invoked to record metrics.
func (p *rlsPicker) delegateToChildPoliciesLocked(dcEntry *cacheEntry, info balancer.PickInfo) (balancer.PickResult, func(), error) {
// TRANSIENT_FAILURE, we delegate to the last child policy arbitrarily.
func (p *rlsPicker) delegateToChildPoliciesLocked(dcEntry *cacheEntry, info balancer.PickInfo) (balancer.PickResult, error) {
const rlsDataHeaderName = "x-google-rls-data"
for i, cpw := range dcEntry.childPolicyWrappers {
state := (*balancer.State)(atomic.LoadPointer(&cpw.state))
// Delegate to the child policy if it is not in TRANSIENT_FAILURE, or if
// it is the last one (which handles the case of delegating to the last
// child picker if all child policies are in TRANSIENT_FAILURE).
// child picker if all child polcies are in TRANSIENT_FAILURE).
if state.ConnectivityState != connectivity.TransientFailure || i == len(dcEntry.childPolicyWrappers)-1 {
// Any header data received from the RLS server is stored in the
// cache entry and needs to be sent to the actual backend in the
// X-Google-RLS-Data header.
res, err := state.Picker.Pick(info)
if err != nil {
pr := errToPickResult(err)
return res, func() {
if pr == "queue" {
// Don't record metrics for queued Picks.
return
}
targetPicksMetric.Record(p.metricsRecorder, 1, p.grpcTarget, p.rlsServerTarget, cpw.target, pr)
}, err
return res, err
}
if res.Metadata == nil {
res.Metadata = metadata.Pairs(rlsDataHeaderName, dcEntry.headerData)
} else {
res.Metadata.Append(rlsDataHeaderName, dcEntry.headerData)
}
return res, func() {
targetPicksMetric.Record(p.metricsRecorder, 1, p.grpcTarget, p.rlsServerTarget, cpw.target, "complete")
}, nil
return res, nil
}
}
// In the unlikely event that we have a cache entry with no targets, we end up
// queueing the RPC.
return balancer.PickResult{}, func() {}, balancer.ErrNoSubConnAvailable
return balancer.PickResult{}, balancer.ErrNoSubConnAvailable
}
// useDefaultPickIfPossible is a helper method which delegates to the default
// target if one is configured, or fails the pick with the given error. Returns
// a function to be invoked to record metrics.
func (p *rlsPicker) useDefaultPickIfPossible(info balancer.PickInfo, errOnNoDefault error) (balancer.PickResult, func(), error) {
// target if one is configured, or fails the pick with the given error.
func (p *rlsPicker) useDefaultPickIfPossible(info balancer.PickInfo, errOnNoDefault error) (balancer.PickResult, error) {
if p.defaultPolicy != nil {
state := (*balancer.State)(atomic.LoadPointer(&p.defaultPolicy.state))
res, err := state.Picker.Pick(info)
pr := errToPickResult(err)
return res, func() {
if pr == "queue" {
// Don't record metrics for queued Picks.
return
}
defaultTargetPicksMetric.Record(p.metricsRecorder, 1, p.grpcTarget, p.rlsServerTarget, p.defaultPolicy.target, pr)
}, err
return state.Picker.Pick(info)
}
return balancer.PickResult{}, func() {
failedPicksMetric.Record(p.metricsRecorder, 1, p.grpcTarget, p.rlsServerTarget)
}, errOnNoDefault
return balancer.PickResult{}, errOnNoDefault
}
// sendRouteLookupRequestLocked adds an entry to the pending request map and
@ -308,16 +252,6 @@ func (p *rlsPicker) handleRouteLookupResponse(cacheKey cacheKey, targets []strin
// entry would be used until expiration, and a new picker would be sent upon
// backoff expiry.
now := time.Now()
// "An RLS request is considered to have failed if it returns a non-OK
// status or the RLS response's targets list is non-empty." - RLS LB Policy
// design.
if len(targets) == 0 && err == nil {
err = fmt.Errorf("RLS response's target list does not contain any entries for key %+v", cacheKey)
// If err is set, rpc error from the control plane and no control plane
// configuration is why no targets were passed into this helper, no need
// to specify and tell the user this information.
}
if err != nil {
dcEntry.status = err
pendingEntry := p.lb.pendingMap[cacheKey]

View File

@ -20,19 +20,16 @@ package rls
import (
"context"
"errors"
"fmt"
"testing"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/stubserver"
rlstest "google.golang.org/grpc/internal/testutils/rls"
"google.golang.org/grpc/internal/testutils/stats"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb"
@ -42,40 +39,6 @@ import (
testpb "google.golang.org/grpc/interop/grpc_testing"
)
// TestNoNonEmptyTargetsReturnsError tests the case where the RLS Server returns
// a response with no non empty targets. This should be treated as an Control
// Plane RPC failure, and thus fail Data Plane RPC's with an error with the
// appropriate information specifying data plane sent a response with no non
// empty targets.
func (s) TestNoNonEmptyTargetsReturnsError(t *testing.T) {
// Setup RLS Server to return a response with an empty target string.
rlsServer, rlsReqCh := rlstest.SetupFakeRLSServer(t, nil)
rlsServer.SetResponseCallback(func(context.Context, *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
return &rlstest.RouteLookupResponse{Resp: &rlspb.RouteLookupResponse{}}
})
// Register a manual resolver and push the RLS service config through it.
rlsConfig := buildBasicRLSConfigWithChildPolicy(t, t.Name(), rlsServer.Address)
r := startManualResolverWithConfig(t, rlsConfig)
// Create new client.
cc, err := grpc.NewClient(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("Failed to create gRPC client: %v", err)
}
defer cc.Close()
// Make an RPC and expect it to fail with an error specifying RLS response's
// target list does not contain any non empty entries.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
makeTestRPCAndVerifyError(ctx, t, cc, codes.Unavailable, errors.New("RLS response's target list does not contain any entries for key"))
// Make sure an RLS request is sent out. Even though the RLS Server will
// return no targets, the request should still hit the server.
verifyRLSRequest(t, rlsReqCh, true)
}
// Test verifies the scenario where there is no matching entry in the data cache
// and no pending request either, and the ensuing RLS request is throttled.
func (s) TestPick_DataCacheMiss_NoPendingEntry_ThrottledWithDefaultTarget(t *testing.T) {
@ -91,9 +54,9 @@ func (s) TestPick_DataCacheMiss_NoPendingEntry_ThrottledWithDefaultTarget(t *tes
// Register a manual resolver and push the RLS service config through it.
r := startManualResolverWithConfig(t, rlsConfig)
cc, err := grpc.NewClient(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
cc, err := grpc.Dial(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("Failed to create gRPC client: %v", err)
t.Fatalf("grpc.Dial() failed: %v", err)
}
defer cc.Close()
@ -121,10 +84,10 @@ func (s) TestPick_DataCacheMiss_NoPendingEntry_ThrottledWithoutDefaultTarget(t *
// Register a manual resolver and push the RLS service config through it.
r := startManualResolverWithConfig(t, rlsConfig)
// Create new client.
cc, err := grpc.NewClient(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
// Dial the backend.
cc, err := grpc.Dial(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("Failed to create gRPC client: %v", err)
t.Fatalf("grpc.Dial() failed: %v", err)
}
defer cc.Close()
@ -140,7 +103,7 @@ func (s) TestPick_DataCacheMiss_NoPendingEntry_ThrottledWithoutDefaultTarget(t *
// Test verifies the scenario where there is no matching entry in the data cache
// and no pending request either, and the ensuing RLS request is not throttled.
// The RLS response does not contain any backends, so the RPC fails with a
// unavailable error.
// deadline exceeded error.
func (s) TestPick_DataCacheMiss_NoPendingEntry_NotThrottled(t *testing.T) {
// Start an RLS server and set the throttler to never throttle requests.
rlsServer, rlsReqCh := rlstest.SetupFakeRLSServer(t, nil)
@ -152,10 +115,10 @@ func (s) TestPick_DataCacheMiss_NoPendingEntry_NotThrottled(t *testing.T) {
// Register a manual resolver and push the RLS service config through it.
r := startManualResolverWithConfig(t, rlsConfig)
// Create new client.
cc, err := grpc.NewClient(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
// Dial the backend.
cc, err := grpc.Dial(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("Failed to create gRPC client: %v", err)
t.Fatalf("grpc.Dial() failed: %v", err)
}
defer cc.Close()
@ -163,7 +126,7 @@ func (s) TestPick_DataCacheMiss_NoPendingEntry_NotThrottled(t *testing.T) {
// smaller timeout to ensure that the test doesn't run very long.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
defer cancel()
makeTestRPCAndVerifyError(ctx, t, cc, codes.Unavailable, errors.New("RLS response's target list does not contain any entries for key"))
makeTestRPCAndVerifyError(ctx, t, cc, codes.DeadlineExceeded, context.DeadlineExceeded)
// Make sure an RLS request is sent out.
verifyRLSRequest(t, rlsReqCh, true)
@ -195,7 +158,7 @@ func (s) TestPick_DataCacheMiss_PendingEntryExists(t *testing.T) {
// also lead to creation of a pending entry, and further RPCs by the
// client should not result in RLS requests being sent out.
rlsReqCh := make(chan struct{}, 1)
interceptor := func(ctx context.Context, _ any, _ *grpc.UnaryServerInfo, _ grpc.UnaryHandler) (resp any, err error) {
interceptor := func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
rlsReqCh <- struct{}{}
<-ctx.Done()
return nil, ctx.Err()
@ -216,10 +179,10 @@ func (s) TestPick_DataCacheMiss_PendingEntryExists(t *testing.T) {
// through it.
r := startManualResolverWithConfig(t, rlsConfig)
// Create new client.
cc, err := grpc.NewClient(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
// Dial the backend.
cc, err := grpc.Dial(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("Failed to create gRPC client: %v", err)
t.Fatalf("grpc.Dial() failed: %v", err)
}
defer cc.Close()
@ -248,133 +211,6 @@ func (s) TestPick_DataCacheMiss_PendingEntryExists(t *testing.T) {
}
}
// Test_RLSDefaultTargetPicksMetric tests the default target picks metric. It
// configures an RLS Balancer which specifies to route to the default target in
// the RLS Configuration, and makes an RPC on a Channel containing this RLS
// Balancer. This test then asserts a default target picks metric is emitted,
// and target pick or failed pick metric is not emitted.
func (s) Test_RLSDefaultTargetPicksMetric(t *testing.T) {
// Start an RLS server and set the throttler to always throttle requests.
rlsServer, _ := rlstest.SetupFakeRLSServer(t, nil)
overrideAdaptiveThrottler(t, alwaysThrottlingThrottler())
// Build RLS service config with a default target.
rlsConfig := buildBasicRLSConfigWithChildPolicy(t, t.Name(), rlsServer.Address)
defBackendCh, defBackendAddress := startBackend(t)
rlsConfig.RouteLookupConfig.DefaultTarget = defBackendAddress
// Register a manual resolver and push the RLS service config through it.
r := startManualResolverWithConfig(t, rlsConfig)
tmr := stats.NewTestMetricsRecorder()
cc, err := grpc.NewClient(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithStatsHandler(tmr))
if err != nil {
t.Fatalf("grpc.NewClient() failed: %v", err)
}
defer cc.Close()
// Make an RPC and ensure it gets routed to the default target.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
makeTestRPCAndExpectItToReachBackend(ctx, t, cc, defBackendCh)
if got, _ := tmr.Metric("grpc.lb.rls.default_target_picks"); got != 1 {
t.Fatalf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.rls.default_target_picks", got, 1)
}
if _, ok := tmr.Metric("grpc.lb.rls.target_picks"); ok {
t.Fatalf("Data is present for metric %v", "grpc.lb.rls.target_picks")
}
if _, ok := tmr.Metric("grpc.lb.rls.failed_picks"); ok {
t.Fatalf("Data is present for metric %v", "grpc.lb.rls.failed_picks")
}
}
// Test_RLSTargetPicksMetric tests the target picks metric. It configures an RLS
// Balancer which specifies to route to a target through a RouteLookupResponse,
// and makes an RPC on a Channel containing this RLS Balancer. This test then
// asserts a target picks metric is emitted, and default target pick or failed
// pick metric is not emitted.
func (s) Test_RLSTargetPicksMetric(t *testing.T) {
// Start an RLS server and set the throttler to never throttle requests.
rlsServer, _ := rlstest.SetupFakeRLSServer(t, nil)
overrideAdaptiveThrottler(t, neverThrottlingThrottler())
// Build the RLS config without a default target.
rlsConfig := buildBasicRLSConfigWithChildPolicy(t, t.Name(), rlsServer.Address)
// Start a test backend, and setup the fake RLS server to return this as a
// target in the RLS response.
testBackendCh, testBackendAddress := startBackend(t)
rlsServer.SetResponseCallback(func(context.Context, *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
return &rlstest.RouteLookupResponse{Resp: &rlspb.RouteLookupResponse{Targets: []string{testBackendAddress}}}
})
// Register a manual resolver and push the RLS service config through it.
r := startManualResolverWithConfig(t, rlsConfig)
tmr := stats.NewTestMetricsRecorder()
// Dial the backend.
cc, err := grpc.NewClient(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithStatsHandler(tmr))
if err != nil {
t.Fatalf("grpc.NewClient() failed: %v", err)
}
defer cc.Close()
// Make an RPC and ensure it gets routed to the test backend.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
makeTestRPCAndExpectItToReachBackend(ctx, t, cc, testBackendCh)
if got, _ := tmr.Metric("grpc.lb.rls.target_picks"); got != 1 {
t.Fatalf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.rls.target_picks", got, 1)
}
if _, ok := tmr.Metric("grpc.lb.rls.default_target_picks"); ok {
t.Fatalf("Data is present for metric %v", "grpc.lb.rls.default_target_picks")
}
if _, ok := tmr.Metric("grpc.lb.rls.failed_picks"); ok {
t.Fatalf("Data is present for metric %v", "grpc.lb.rls.failed_picks")
}
}
// Test_RLSFailedPicksMetric tests the failed picks metric. It configures an RLS
// Balancer to fail a pick with unavailable, and makes an RPC on a Channel
// containing this RLS Balancer. This test then asserts a failed picks metric is
// emitted, and default target pick or target pick metric is not emitted.
func (s) Test_RLSFailedPicksMetric(t *testing.T) {
// Start an RLS server and set the throttler to never throttle requests.
rlsServer, _ := rlstest.SetupFakeRLSServer(t, nil)
overrideAdaptiveThrottler(t, neverThrottlingThrottler())
// Build an RLS config without a default target.
rlsConfig := buildBasicRLSConfigWithChildPolicy(t, t.Name(), rlsServer.Address)
// Register a manual resolver and push the RLS service config through it.
r := startManualResolverWithConfig(t, rlsConfig)
tmr := stats.NewTestMetricsRecorder()
// Dial the backend.
cc, err := grpc.NewClient(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithStatsHandler(tmr))
if err != nil {
t.Fatalf("grpc.NewClient() failed: %v", err)
}
defer cc.Close()
// Make an RPC and expect it to fail with deadline exceeded error. We use a
// smaller timeout to ensure that the test doesn't run very long.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
defer cancel()
makeTestRPCAndVerifyError(ctx, t, cc, codes.Unavailable, errors.New("RLS response's target list does not contain any entries for key"))
if got, _ := tmr.Metric("grpc.lb.rls.failed_picks"); got != 1 {
t.Fatalf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.rls.failed_picks", got, 1)
}
if _, ok := tmr.Metric("grpc.lb.rls.target_picks"); ok {
t.Fatalf("Data is present for metric %v", "grpc.lb.rls.target_picks")
}
if _, ok := tmr.Metric("grpc.lb.rls.default_target_picks"); ok {
t.Fatalf("Data is present for metric %v", "grpc.lb.rls.default_target_picks")
}
}
// Test verifies the scenario where there is a matching entry in the data cache
// which is valid and there is no pending request. The pick is expected to be
// delegated to the child policy.
@ -385,20 +221,21 @@ func (s) TestPick_DataCacheHit_NoPendingEntry_ValidEntry(t *testing.T) {
// Build the RLS config without a default target.
rlsConfig := buildBasicRLSConfigWithChildPolicy(t, t.Name(), rlsServer.Address)
// Start a test backend, and setup the fake RLS server to return this as a
// target in the RLS response.
testBackendCh, testBackendAddress := startBackend(t)
rlsServer.SetResponseCallback(func(context.Context, *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
rlsServer.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
return &rlstest.RouteLookupResponse{Resp: &rlspb.RouteLookupResponse{Targets: []string{testBackendAddress}}}
})
// Register a manual resolver and push the RLS service config through it.
r := startManualResolverWithConfig(t, rlsConfig)
// Create new client.
cc, err := grpc.NewClient(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
// Dial the backend.
cc, err := grpc.Dial(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("Failed to create gRPC client: %v", err)
t.Fatalf("grpc.Dial() failed: %v", err)
}
defer cc.Close()
@ -432,7 +269,7 @@ func (s) TestPick_DataCacheHit_NoPendingEntry_ValidEntry_WithHeaderData(t *testi
// RLS server to be part of RPC metadata as X-Google-RLS-Data header.
const headerDataContents = "foo,bar,baz"
backend := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
gotHeaderData := metadata.ValueFromIncomingContext(ctx, "x-google-rls-data")
if len(gotHeaderData) != 1 || gotHeaderData[0] != headerDataContents {
return nil, fmt.Errorf("got metadata in `X-Google-RLS-Data` is %v, want %s", gotHeaderData, headerDataContents)
@ -448,7 +285,7 @@ func (s) TestPick_DataCacheHit_NoPendingEntry_ValidEntry_WithHeaderData(t *testi
// Setup the fake RLS server to return the above backend as a target in the
// RLS response. Also, populate the header data field in the response.
rlsServer.SetResponseCallback(func(context.Context, *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
rlsServer.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
return &rlstest.RouteLookupResponse{Resp: &rlspb.RouteLookupResponse{
Targets: []string{backend.Address},
HeaderData: headerDataContents,
@ -458,10 +295,10 @@ func (s) TestPick_DataCacheHit_NoPendingEntry_ValidEntry_WithHeaderData(t *testi
// Register a manual resolver and push the RLS service config through it.
r := startManualResolverWithConfig(t, rlsConfig)
// Create new client.
cc, err := grpc.NewClient(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
// Dial the backend.
cc, err := grpc.Dial(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("Failed to create gRPC client: %v", err)
t.Fatalf("grpc.Dial() failed: %v", err)
}
defer cc.Close()
@ -517,7 +354,7 @@ func (s) TestPick_DataCacheHit_NoPendingEntry_StaleEntry(t *testing.T) {
// Start a test backend, and setup the fake RLS server to return
// this as a target in the RLS response.
testBackendCh, testBackendAddress := startBackend(t)
rlsServer.SetResponseCallback(func(context.Context, *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
rlsServer.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
return &rlstest.RouteLookupResponse{Resp: &rlspb.RouteLookupResponse{Targets: []string{testBackendAddress}}}
})
@ -525,10 +362,10 @@ func (s) TestPick_DataCacheHit_NoPendingEntry_StaleEntry(t *testing.T) {
// through it.
r := startManualResolverWithConfig(t, rlsConfig)
// Create new client.
cc, err := grpc.NewClient(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
// Dial the backend.
cc, err := grpc.Dial(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("Failed to create gRPC client: %v", err)
t.Fatalf("grpc.Dial() failed: %v", err)
}
defer cc.Close()
@ -626,7 +463,7 @@ func (s) TestPick_DataCacheHit_NoPendingEntry_ExpiredEntry(t *testing.T) {
// Start a test backend, and setup the fake RLS server to return
// this as a target in the RLS response.
testBackendCh, testBackendAddress := startBackend(t)
rlsServer.SetResponseCallback(func(context.Context, *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
rlsServer.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
return &rlstest.RouteLookupResponse{Resp: &rlspb.RouteLookupResponse{Targets: []string{testBackendAddress}}}
})
@ -634,10 +471,10 @@ func (s) TestPick_DataCacheHit_NoPendingEntry_ExpiredEntry(t *testing.T) {
// through it.
r := startManualResolverWithConfig(t, rlsConfig)
// Create new client.
cc, err := grpc.NewClient(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
// Dial the backend.
cc, err := grpc.Dial(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("Failed to create gRPC client: %v", err)
t.Fatalf("grpc.Dial() failed: %v", err)
}
defer cc.Close()
@ -725,17 +562,17 @@ func (s) TestPick_DataCacheHit_NoPendingEntry_ExpiredEntryInBackoff(t *testing.T
// Start a test backend, and set up the fake RLS server to return this as
// a target in the RLS response.
testBackendCh, testBackendAddress := startBackend(t)
rlsServer.SetResponseCallback(func(context.Context, *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
rlsServer.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
return &rlstest.RouteLookupResponse{Resp: &rlspb.RouteLookupResponse{Targets: []string{testBackendAddress}}}
})
// Register a manual resolver and push the RLS service config through it.
r := startManualResolverWithConfig(t, rlsConfig)
// Create new client.
cc, err := grpc.NewClient(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
// Dial the backend.
cc, err := grpc.Dial(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("Failed to create gRPC client: %v", err)
t.Fatalf("grpc.Dial() failed: %v", err)
}
defer cc.Close()
@ -750,7 +587,7 @@ func (s) TestPick_DataCacheHit_NoPendingEntry_ExpiredEntryInBackoff(t *testing.T
// Set up the fake RLS server to return errors. This will push the cache
// entry into backoff.
var rlsLastErr = status.Error(codes.DeadlineExceeded, "last RLS request failed")
rlsServer.SetResponseCallback(func(context.Context, *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
rlsServer.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
return &rlstest.RouteLookupResponse{Err: rlsLastErr}
})
@ -796,7 +633,7 @@ func (s) TestPick_DataCacheHit_PendingEntryExists_StaleEntry(t *testing.T) {
// expired entry and a pending entry in the cache.
rlsReqCh := make(chan struct{}, 1)
firstRPCDone := grpcsync.NewEvent()
interceptor := func(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
interceptor := func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
select {
case rlsReqCh <- struct{}{}:
default:
@ -826,7 +663,7 @@ func (s) TestPick_DataCacheHit_PendingEntryExists_StaleEntry(t *testing.T) {
// Start a test backend, and setup the fake RLS server to return
// this as a target in the RLS response.
testBackendCh, testBackendAddress := startBackend(t)
rlsServer.SetResponseCallback(func(context.Context, *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
rlsServer.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
return &rlstest.RouteLookupResponse{Resp: &rlspb.RouteLookupResponse{Targets: []string{testBackendAddress}}}
})
@ -834,10 +671,10 @@ func (s) TestPick_DataCacheHit_PendingEntryExists_StaleEntry(t *testing.T) {
// through it.
r := startManualResolverWithConfig(t, rlsConfig)
// Create new client.
cc, err := grpc.NewClient(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
// Dial the backend.
cc, err := grpc.Dial(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("Failed to create gRPC client: %v", err)
t.Fatalf("grpc.Dial() failed: %v", err)
}
defer cc.Close()
@ -896,7 +733,7 @@ func (s) TestPick_DataCacheHit_PendingEntryExists_ExpiredEntry(t *testing.T) {
// expired entry and a pending entry in the cache.
rlsReqCh := make(chan struct{}, 1)
firstRPCDone := grpcsync.NewEvent()
interceptor := func(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
interceptor := func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
select {
case rlsReqCh <- struct{}{}:
default:
@ -924,7 +761,7 @@ func (s) TestPick_DataCacheHit_PendingEntryExists_ExpiredEntry(t *testing.T) {
// Start a test backend, and setup the fake RLS server to return
// this as a target in the RLS response.
testBackendCh, testBackendAddress := startBackend(t)
rlsServer.SetResponseCallback(func(context.Context, *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
rlsServer.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
return &rlstest.RouteLookupResponse{Resp: &rlspb.RouteLookupResponse{Targets: []string{testBackendAddress}}}
})
@ -932,10 +769,10 @@ func (s) TestPick_DataCacheHit_PendingEntryExists_ExpiredEntry(t *testing.T) {
// through it.
r := startManualResolverWithConfig(t, rlsConfig)
// Create new client.
cc, err := grpc.NewClient(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
// Dial the backend.
cc, err := grpc.Dial(r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("Failed to create gRPC client: %v", err)
t.Fatalf("grpc.Dial() failed: %v", err)
}
defer cc.Close()
@ -1009,41 +846,3 @@ func TestIsFullMethodNameValid(t *testing.T) {
})
}
}
// Tests the conversion of the child pickers error to the pick result attribute.
func (s) TestChildPickResultError(t *testing.T) {
tests := []struct {
name string
err error
want string
}{
{
name: "nil",
err: nil,
want: "complete",
},
{
name: "errNoSubConnAvailable",
err: balancer.ErrNoSubConnAvailable,
want: "queue",
},
{
name: "status error",
err: status.Error(codes.Unimplemented, "unimplemented"),
want: "drop",
},
{
name: "other error",
err: errors.New("some error"),
want: "fail",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if got := errToPickResult(test.err); got != test.want {
t.Fatalf("errToPickResult(%q) = %v, want %v", test.err, got, test.want)
}
})
}
}

View File

@ -22,13 +22,12 @@
package roundrobin
import (
"fmt"
"sync/atomic"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/endpointsharding"
"google.golang.org/grpc/balancer/pickfirst/pickfirstleaf"
"google.golang.org/grpc/balancer/base"
"google.golang.org/grpc/grpclog"
internalgrpclog "google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/grpcrand"
)
// Name is the name of round_robin balancer.
@ -36,37 +35,47 @@ const Name = "round_robin"
var logger = grpclog.Component("roundrobin")
// newBuilder creates a new roundrobin balancer builder.
func newBuilder() balancer.Builder {
return base.NewBalancerBuilder(Name, &rrPickerBuilder{}, base.Config{HealthCheck: true})
}
func init() {
balancer.Register(builder{})
balancer.Register(newBuilder())
}
type builder struct{}
type rrPickerBuilder struct{}
func (bb builder) Name() string {
return Name
}
func (bb builder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
childBuilder := balancer.Get(pickfirstleaf.Name).Build
bal := &rrBalancer{
cc: cc,
Balancer: endpointsharding.NewBalancer(cc, opts, childBuilder, endpointsharding.Options{}),
func (*rrPickerBuilder) Build(info base.PickerBuildInfo) balancer.Picker {
logger.Infof("roundrobinPicker: Build called with info: %v", info)
if len(info.ReadySCs) == 0 {
return base.NewErrPicker(balancer.ErrNoSubConnAvailable)
}
scs := make([]balancer.SubConn, 0, len(info.ReadySCs))
for sc := range info.ReadySCs {
scs = append(scs, sc)
}
return &rrPicker{
subConns: scs,
// Start at a random index, as the same RR balancer rebuilds a new
// picker when SubConn states change, and we don't want to apply excess
// load to the first server in the list.
next: uint32(grpcrand.Intn(len(scs))),
}
bal.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf("[%p] ", bal))
bal.logger.Infof("Created")
return bal
}
type rrBalancer struct {
balancer.Balancer
cc balancer.ClientConn
logger *internalgrpclog.PrefixLogger
type rrPicker struct {
// subConns is the snapshot of the roundrobin balancer when this picker was
// created. The slice is immutable. Each Get() will do a round robin
// selection from it and return the selected SubConn.
subConns []balancer.SubConn
next uint32
}
func (b *rrBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error {
return b.Balancer.UpdateClientConnState(balancer.ClientConnState{
// Enable the health listener in pickfirst children for client side health
// checks and outlier detection, if configured.
ResolverState: pickfirstleaf.EnableHealthListener(ccs.ResolverState),
})
func (p *rrPicker) Pick(balancer.PickInfo) (balancer.PickResult, error) {
subConnsLen := uint32(len(p.subConns))
nextIndex := atomic.AddUint32(&p.next, 1)
sc := p.subConns[nextIndex%subConnsLen]
return balancer.PickResult{SubConn: sc}, nil
}

View File

@ -1,134 +0,0 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package balancer
import (
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/resolver"
)
// A SubConn represents a single connection to a gRPC backend service.
//
// All SubConns start in IDLE, and will not try to connect. To trigger a
// connection attempt, Balancers must call Connect.
//
// If the connection attempt fails, the SubConn will transition to
// TRANSIENT_FAILURE for a backoff period, and then return to IDLE. If the
// connection attempt succeeds, it will transition to READY.
//
// If a READY SubConn becomes disconnected, the SubConn will transition to IDLE.
//
// If a connection re-enters IDLE, Balancers must call Connect again to trigger
// a new connection attempt.
//
// Each SubConn contains a list of addresses. gRPC will try to connect to the
// addresses in sequence, and stop trying the remainder once the first
// connection is successful. However, this behavior is deprecated. SubConns
// should only use a single address.
//
// NOTICE: This interface is intended to be implemented by gRPC, or intercepted
// by custom load balancing polices. Users should not need their own complete
// implementation of this interface -- they should always delegate to a SubConn
// returned by ClientConn.NewSubConn() by embedding it in their implementations.
// An embedded SubConn must never be nil, or runtime panics will occur.
type SubConn interface {
// UpdateAddresses updates the addresses used in this SubConn.
// gRPC checks if currently-connected address is still in the new list.
// If it's in the list, the connection will be kept.
// If it's not in the list, the connection will gracefully close, and
// a new connection will be created.
//
// This will trigger a state transition for the SubConn.
//
// Deprecated: this method will be removed. Create new SubConns for new
// addresses instead.
UpdateAddresses([]resolver.Address)
// Connect starts the connecting for this SubConn.
Connect()
// GetOrBuildProducer returns a reference to the existing Producer for this
// ProducerBuilder in this SubConn, or, if one does not currently exist,
// creates a new one and returns it. Returns a close function which may be
// called when the Producer is no longer needed. Otherwise the producer
// will automatically be closed upon connection loss or subchannel close.
// Should only be called on a SubConn in state Ready. Otherwise the
// producer will be unable to create streams.
GetOrBuildProducer(ProducerBuilder) (p Producer, close func())
// Shutdown shuts down the SubConn gracefully. Any started RPCs will be
// allowed to complete. No future calls should be made on the SubConn.
// One final state update will be delivered to the StateListener (or
// UpdateSubConnState; deprecated) with ConnectivityState of Shutdown to
// indicate the shutdown operation. This may be delivered before
// in-progress RPCs are complete and the actual connection is closed.
Shutdown()
// RegisterHealthListener registers a health listener that receives health
// updates for a Ready SubConn. Only one health listener can be registered
// at a time. A health listener should be registered each time the SubConn's
// connectivity state changes to READY. Registering a health listener when
// the connectivity state is not READY may result in undefined behaviour.
// This method must not be called synchronously while handling an update
// from a previously registered health listener.
RegisterHealthListener(func(SubConnState))
// EnforceSubConnEmbedding is included to force implementers to embed
// another implementation of this interface, allowing gRPC to add methods
// without breaking users.
internal.EnforceSubConnEmbedding
}
// A ProducerBuilder is a simple constructor for a Producer. It is used by the
// SubConn to create producers when needed.
type ProducerBuilder interface {
// Build creates a Producer. The first parameter is always a
// grpc.ClientConnInterface (a type to allow creating RPCs/streams on the
// associated SubConn), but is declared as `any` to avoid a dependency
// cycle. Build also returns a close function that will be called when all
// references to the Producer have been given up for a SubConn, or when a
// connectivity state change occurs on the SubConn. The close function
// should always block until all asynchronous cleanup work is completed.
Build(grpcClientConnInterface any) (p Producer, close func())
}
// SubConnState describes the state of a SubConn.
type SubConnState struct {
// ConnectivityState is the connectivity state of the SubConn.
ConnectivityState connectivity.State
// ConnectionError is set if the ConnectivityState is TransientFailure,
// describing the reason the SubConn failed. Otherwise, it is nil.
ConnectionError error
// connectedAddr contains the connected address when ConnectivityState is
// Ready. Otherwise, it is indeterminate.
connectedAddress resolver.Address
}
// connectedAddress returns the connected address for a SubConnState. The
// address is only valid if the state is READY.
func connectedAddress(scs SubConnState) resolver.Address {
return scs.connectedAddress
}
// setConnectedAddress sets the connected address for a SubConnState.
func setConnectedAddress(scs *SubConnState, addr resolver.Address) {
scs.connectedAddress = addr
}
// A Producer is a type shared among potentially many consumers. It is
// associated with a SubConn, and an implementation will typically contain
// other methods to provide additional functionality, e.g. configuration or
// subscription registration.
type Producer any

View File

@ -16,35 +16,24 @@
*
*/
// Package weightedroundrobin provides an implementation of the weighted round
// robin LB policy, as defined in [gRFC A58].
//
// # Experimental
//
// Notice: This package is EXPERIMENTAL and may be changed or removed in a
// later release.
//
// [gRFC A58]: https://github.com/grpc/proposal/blob/master/A58-client-side-weighted-round-robin-lb-policy.md
package weightedroundrobin
import (
"context"
"encoding/json"
"errors"
"fmt"
rand "math/rand/v2"
"sync"
"sync/atomic"
"time"
"unsafe"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/endpointsharding"
"google.golang.org/grpc/balancer/pickfirst/pickfirstleaf"
"google.golang.org/grpc/balancer/base"
"google.golang.org/grpc/balancer/weightedroundrobin/internal"
"google.golang.org/grpc/balancer/weightedtarget"
"google.golang.org/grpc/connectivity"
estats "google.golang.org/grpc/experimental/stats"
"google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/grpcrand"
iserviceconfig "google.golang.org/grpc/internal/serviceconfig"
"google.golang.org/grpc/orca"
"google.golang.org/grpc/resolver"
@ -54,44 +43,7 @@ import (
)
// Name is the name of the weighted round robin balancer.
const Name = "weighted_round_robin"
var (
rrFallbackMetric = estats.RegisterInt64Count(estats.MetricDescriptor{
Name: "grpc.lb.wrr.rr_fallback",
Description: "EXPERIMENTAL. Number of scheduler updates in which there were not enough endpoints with valid weight, which caused the WRR policy to fall back to RR behavior.",
Unit: "{update}",
Labels: []string{"grpc.target"},
OptionalLabels: []string{"grpc.lb.locality"},
Default: false,
})
endpointWeightNotYetUsableMetric = estats.RegisterInt64Count(estats.MetricDescriptor{
Name: "grpc.lb.wrr.endpoint_weight_not_yet_usable",
Description: "EXPERIMENTAL. Number of endpoints from each scheduler update that don't yet have usable weight information (i.e., either the load report has not yet been received, or it is within the blackout period).",
Unit: "{endpoint}",
Labels: []string{"grpc.target"},
OptionalLabels: []string{"grpc.lb.locality"},
Default: false,
})
endpointWeightStaleMetric = estats.RegisterInt64Count(estats.MetricDescriptor{
Name: "grpc.lb.wrr.endpoint_weight_stale",
Description: "EXPERIMENTAL. Number of endpoints from each scheduler update whose latest weight is older than the expiration period.",
Unit: "{endpoint}",
Labels: []string{"grpc.target"},
OptionalLabels: []string{"grpc.lb.locality"},
Default: false,
})
endpointWeightsMetric = estats.RegisterFloat64Histo(estats.MetricDescriptor{
Name: "grpc.lb.wrr.endpoint_weights",
Description: "EXPERIMENTAL. Weight of each endpoint, recorded on every scheduler update. Endpoints without usable weights will be recorded as weight 0.",
Unit: "{endpoint}",
Labels: []string{"grpc.target"},
OptionalLabels: []string{"grpc.lb.locality"},
Default: false,
})
)
const Name = "weighted_round_robin_experimental"
func init() {
balancer.Register(bb{})
@ -101,15 +53,12 @@ type bb struct{}
func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer {
b := &wrrBalancer{
ClientConn: cc,
target: bOpts.Target.String(),
metricsRecorder: cc.MetricsRecorder(),
addressWeights: resolver.NewAddressMapV2[*endpointWeight](),
endpointToWeight: resolver.NewEndpointMap[*endpointWeight](),
scToWeight: make(map[balancer.SubConn]*endpointWeight),
cc: cc,
subConns: resolver.NewAddressMap(),
csEvltr: &balancer.ConnectivityStateEvaluator{},
scMap: make(map[balancer.SubConn]*weightedSubConn),
connectivityState: connectivity.Connecting,
}
b.child = endpointsharding.NewBalancer(b, bOpts, balancer.Get(pickfirstleaf.Name).Build, endpointsharding.Options{})
b.logger = prefixLogger(b)
b.logger.Infof("Created")
return b
@ -150,312 +99,248 @@ func (bb) Name() string {
return Name
}
// updateEndpointsLocked updates endpoint weight state based off new update, by
// starting and clearing any endpoint weights needed.
//
// Caller must hold b.mu.
func (b *wrrBalancer) updateEndpointsLocked(endpoints []resolver.Endpoint) {
endpointSet := resolver.NewEndpointMap[*endpointWeight]()
addressSet := resolver.NewAddressMapV2[*endpointWeight]()
for _, endpoint := range endpoints {
endpointSet.Set(endpoint, nil)
for _, addr := range endpoint.Addresses {
addressSet.Set(addr, nil)
}
ew, ok := b.endpointToWeight.Get(endpoint)
if !ok {
ew = &endpointWeight{
logger: b.logger,
connectivityState: connectivity.Connecting,
// Initially, we set load reports to off, because they are not
// running upon initial endpointWeight creation.
cfg: &lbConfig{EnableOOBLoadReport: false},
metricsRecorder: b.metricsRecorder,
target: b.target,
locality: b.locality,
}
for _, addr := range endpoint.Addresses {
b.addressWeights.Set(addr, ew)
}
b.endpointToWeight.Set(endpoint, ew)
}
ew.updateConfig(b.cfg)
}
for _, endpoint := range b.endpointToWeight.Keys() {
if _, ok := endpointSet.Get(endpoint); ok {
// Existing endpoint also in new endpoint list; skip.
continue
}
b.endpointToWeight.Delete(endpoint)
for _, addr := range endpoint.Addresses {
if _, ok := addressSet.Get(addr); !ok { // old endpoints to be deleted can share addresses with new endpoints, so only delete if necessary
b.addressWeights.Delete(addr)
}
}
// SubConn map will get handled in updateSubConnState
// when receives SHUTDOWN signal.
}
}
// wrrBalancer implements the weighted round robin LB policy.
type wrrBalancer struct {
// The following fields are set at initialization time and read only after that,
// so they do not need to be protected by a mutex.
child balancer.Balancer
balancer.ClientConn // Embed to intercept NewSubConn operation
logger *grpclog.PrefixLogger
target string
metricsRecorder estats.MetricsRecorder
cc balancer.ClientConn
logger *grpclog.PrefixLogger
mu sync.Mutex
cfg *lbConfig // active config
locality string
stopPicker *grpcsync.Event
addressWeights *resolver.AddressMapV2[*endpointWeight]
endpointToWeight *resolver.EndpointMap[*endpointWeight]
scToWeight map[balancer.SubConn]*endpointWeight
// The following fields are only accessed on calls into the LB policy, and
// do not need a mutex.
cfg *lbConfig // active config
subConns *resolver.AddressMap // active weightedSubConns mapped by address
scMap map[balancer.SubConn]*weightedSubConn
connectivityState connectivity.State // aggregate state
csEvltr *balancer.ConnectivityStateEvaluator
resolverErr error // the last error reported by the resolver; cleared on successful resolution
connErr error // the last connection error; cleared upon leaving TransientFailure
stopPicker func()
}
func (b *wrrBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error {
if b.logger.V(2) {
b.logger.Infof("UpdateCCS: %v", ccs)
}
b.logger.Infof("UpdateCCS: %v", ccs)
b.resolverErr = nil
cfg, ok := ccs.BalancerConfig.(*lbConfig)
if !ok {
return fmt.Errorf("wrr: received nil or illegal BalancerConfig (type %T): %v", ccs.BalancerConfig, ccs.BalancerConfig)
}
// Note: empty endpoints and duplicate addresses across endpoints won't
// explicitly error but will have undefined behavior.
b.mu.Lock()
b.cfg = cfg
b.locality = weightedtarget.LocalityFromResolverState(ccs.ResolverState)
b.updateEndpointsLocked(ccs.ResolverState.Endpoints)
b.mu.Unlock()
b.updateAddresses(ccs.ResolverState.Addresses)
// This causes child to update picker inline and will thus cause inline
// picker update.
return b.child.UpdateClientConnState(balancer.ClientConnState{
// Make pickfirst children use health listeners for outlier detection to
// work.
ResolverState: pickfirstleaf.EnableHealthListener(ccs.ResolverState),
})
}
func (b *wrrBalancer) UpdateState(state balancer.State) {
b.mu.Lock()
defer b.mu.Unlock()
if b.stopPicker != nil {
b.stopPicker.Fire()
b.stopPicker = nil
if len(ccs.ResolverState.Addresses) == 0 {
b.ResolverError(errors.New("resolver produced zero addresses")) // will call regeneratePicker
return balancer.ErrBadResolverState
}
childStates := endpointsharding.ChildStatesFromPicker(state.Picker)
b.regeneratePicker()
var readyPickersWeight []pickerWeightedEndpoint
return nil
}
for _, childState := range childStates {
if childState.State.ConnectivityState == connectivity.Ready {
ew, ok := b.endpointToWeight.Get(childState.Endpoint)
if !ok {
// Should never happen, simply continue and ignore this endpoint
// for READY pickers.
func (b *wrrBalancer) updateAddresses(addrs []resolver.Address) {
addrsSet := resolver.NewAddressMap()
// Loop through new address list and create subconns for any new addresses.
for _, addr := range addrs {
if _, ok := addrsSet.Get(addr); ok {
// Redundant address; skip.
continue
}
addrsSet.Set(addr, nil)
var wsc *weightedSubConn
wsci, ok := b.subConns.Get(addr)
if ok {
wsc = wsci.(*weightedSubConn)
} else {
// addr is a new address (not existing in b.subConns).
sc, err := b.cc.NewSubConn([]resolver.Address{addr}, balancer.NewSubConnOptions{})
if err != nil {
b.logger.Warningf("Failed to create new SubConn for address %v: %v", addr, err)
continue
}
readyPickersWeight = append(readyPickersWeight, pickerWeightedEndpoint{
picker: childState.State.Picker,
weightedEndpoint: ew,
})
wsc = &weightedSubConn{
SubConn: sc,
logger: b.logger,
connectivityState: connectivity.Idle,
// Initially, we set load reports to off, because they are not
// running upon initial weightedSubConn creation.
cfg: &lbConfig{EnableOOBLoadReport: false},
}
b.subConns.Set(addr, wsc)
b.scMap[sc] = wsc
b.csEvltr.RecordTransition(connectivity.Shutdown, connectivity.Idle)
sc.Connect()
}
}
// If no ready pickers are present, simply defer to the round robin picker
// from endpoint sharding, which will round robin across the most relevant
// pick first children in the highest precedence connectivity state.
if len(readyPickersWeight) == 0 {
b.ClientConn.UpdateState(balancer.State{
ConnectivityState: state.ConnectivityState,
Picker: state.Picker,
})
return
// Update config for existing weightedSubConn or send update for first
// time to new one. Ensures an OOB listener is running if needed
// (and stops the existing one if applicable).
wsc.updateConfig(b.cfg)
}
p := &picker{
v: rand.Uint32(), // start the scheduler at a random point
cfg: b.cfg,
weightedPickers: readyPickersWeight,
metricsRecorder: b.metricsRecorder,
locality: b.locality,
target: b.target,
// Loop through existing subconns and remove ones that are not in addrs.
for _, addr := range b.subConns.Keys() {
if _, ok := addrsSet.Get(addr); ok {
// Existing address also in new address list; skip.
continue
}
// addr was removed by resolver. Remove.
wsci, _ := b.subConns.Get(addr)
wsc := wsci.(*weightedSubConn)
b.cc.RemoveSubConn(wsc.SubConn)
b.subConns.Delete(addr)
}
b.stopPicker = grpcsync.NewEvent()
p.start(b.stopPicker)
b.ClientConn.UpdateState(balancer.State{
ConnectivityState: state.ConnectivityState,
Picker: p,
})
}
type pickerWeightedEndpoint struct {
picker balancer.Picker
weightedEndpoint *endpointWeight
}
func (b *wrrBalancer) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) {
addr := addrs[0] // The new pick first policy for DualStack will only ever create a SubConn with one address.
var sc balancer.SubConn
oldListener := opts.StateListener
opts.StateListener = func(state balancer.SubConnState) {
b.updateSubConnState(sc, state)
oldListener(state)
}
b.mu.Lock()
defer b.mu.Unlock()
ewi, ok := b.addressWeights.Get(addr)
if !ok {
// SubConn state updates can come in for a no longer relevant endpoint
// weight (from the old system after a new config update is applied).
return nil, fmt.Errorf("balancer is being closed; no new SubConns allowed")
}
sc, err := b.ClientConn.NewSubConn([]resolver.Address{addr}, opts)
if err != nil {
return nil, err
}
b.scToWeight[sc] = ewi
return sc, nil
}
func (b *wrrBalancer) ResolverError(err error) {
// Will cause inline picker update from endpoint sharding.
b.child.ResolverError(err)
b.resolverErr = err
if b.subConns.Len() == 0 {
b.connectivityState = connectivity.TransientFailure
}
if b.connectivityState != connectivity.TransientFailure {
// No need to update the picker since no error is being returned.
return
}
b.regeneratePicker()
}
func (b *wrrBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) {
b.logger.Errorf("UpdateSubConnState(%v, %+v) called unexpectedly", sc, state)
}
func (b *wrrBalancer) updateSubConnState(sc balancer.SubConn, state balancer.SubConnState) {
b.mu.Lock()
ew := b.scToWeight[sc]
// updates from a no longer relevant SubConn update, nothing to do here but
// forward state to state listener, which happens in wrapped listener. Will
// eventually get cleared from scMap once receives Shutdown signal.
if ew == nil {
b.mu.Unlock()
wsc := b.scMap[sc]
if wsc == nil {
b.logger.Errorf("UpdateSubConnState called with an unknown SubConn: %p, %v", sc, state)
return
}
if state.ConnectivityState == connectivity.Shutdown {
delete(b.scToWeight, sc)
}
b.mu.Unlock()
// On the first READY SubConn/Transition for an endpoint, set pickedSC,
// clear endpoint tracking weight state, and potentially start an OOB watch.
if state.ConnectivityState == connectivity.Ready && ew.pickedSC == nil {
ew.pickedSC = sc
ew.mu.Lock()
ew.nonEmptySince = time.Time{}
ew.lastUpdated = time.Time{}
cfg := ew.cfg
ew.mu.Unlock()
ew.updateORCAListener(cfg)
return
if b.logger.V(2) {
logger.Infof("UpdateSubConnState(%+v, %+v)", sc, state)
}
// If the pickedSC (the one pick first uses for an endpoint) transitions out
// of READY, stop OOB listener if needed and clear pickedSC so the next
// created SubConn for the endpoint that goes READY will be chosen for
// endpoint as the active SubConn.
if state.ConnectivityState != connectivity.Ready && ew.pickedSC == sc {
// The first SubConn that goes READY for an endpoint is what pick first
// will pick. Only once that SubConn goes not ready will pick first
// restart this cycle of creating SubConns and using the first READY
// one. The lower level endpoint sharding will ping the Pick First once
// this occurs to ExitIdle which will trigger a connection attempt.
if ew.stopORCAListener != nil {
ew.stopORCAListener()
}
ew.pickedSC = nil
cs := state.ConnectivityState
if cs == connectivity.TransientFailure {
// Save error to be reported via picker.
b.connErr = state.ConnectionError
}
if cs == connectivity.Shutdown {
delete(b.scMap, sc)
// The subconn was removed from b.subConns when the address was removed
// in updateAddresses.
}
oldCS := wsc.updateConnectivityState(cs)
b.connectivityState = b.csEvltr.RecordTransition(oldCS, cs)
// Regenerate picker when one of the following happens:
// - this sc entered or left ready
// - the aggregated state of balancer is TransientFailure
// (may need to update error message)
if (cs == connectivity.Ready) != (oldCS == connectivity.Ready) ||
b.connectivityState == connectivity.TransientFailure {
b.regeneratePicker()
}
}
// Close stops the balancer. It cancels any ongoing scheduler updates and
// stops any ORCA listeners.
func (b *wrrBalancer) Close() {
b.mu.Lock()
if b.stopPicker != nil {
b.stopPicker.Fire()
b.stopPicker()
b.stopPicker = nil
}
b.mu.Unlock()
// Ensure any lingering OOB watchers are stopped.
for _, ew := range b.endpointToWeight.Values() {
if ew.stopORCAListener != nil {
ew.stopORCAListener()
}
for _, wsc := range b.scMap {
// Ensure any lingering OOB watchers are stopped.
wsc.updateConnectivityState(connectivity.Shutdown)
}
b.child.Close()
}
func (b *wrrBalancer) ExitIdle() {
b.child.ExitIdle()
// ExitIdle is ignored; we always connect to all backends.
func (b *wrrBalancer) ExitIdle() {}
func (b *wrrBalancer) readySubConns() []*weightedSubConn {
var ret []*weightedSubConn
for _, v := range b.subConns.Values() {
wsc := v.(*weightedSubConn)
if wsc.connectivityState == connectivity.Ready {
ret = append(ret, wsc)
}
}
return ret
}
// mergeErrors builds an error from the last connection error and the last
// resolver error. Must only be called if b.connectivityState is
// TransientFailure.
func (b *wrrBalancer) mergeErrors() error {
// connErr must always be non-nil unless there are no SubConns, in which
// case resolverErr must be non-nil.
if b.connErr == nil {
return fmt.Errorf("last resolver error: %v", b.resolverErr)
}
if b.resolverErr == nil {
return fmt.Errorf("last connection error: %v", b.connErr)
}
return fmt.Errorf("last connection error: %v; last resolver error: %v", b.connErr, b.resolverErr)
}
func (b *wrrBalancer) regeneratePicker() {
if b.stopPicker != nil {
b.stopPicker()
b.stopPicker = nil
}
switch b.connectivityState {
case connectivity.TransientFailure:
b.cc.UpdateState(balancer.State{
ConnectivityState: connectivity.TransientFailure,
Picker: base.NewErrPicker(b.mergeErrors()),
})
return
case connectivity.Connecting, connectivity.Idle:
// Idle could happen very briefly if all subconns are Idle and we've
// asked them to connect but they haven't reported Connecting yet.
// Report the same as Connecting since this is temporary.
b.cc.UpdateState(balancer.State{
ConnectivityState: connectivity.Connecting,
Picker: base.NewErrPicker(balancer.ErrNoSubConnAvailable),
})
return
case connectivity.Ready:
b.connErr = nil
}
p := &picker{
v: grpcrand.Uint32(), // start the scheduler at a random point
cfg: b.cfg,
subConns: b.readySubConns(),
}
var ctx context.Context
ctx, b.stopPicker = context.WithCancel(context.Background())
p.start(ctx)
b.cc.UpdateState(balancer.State{
ConnectivityState: b.connectivityState,
Picker: p,
})
}
// picker is the WRR policy's picker. It uses live-updating backend weights to
// update the scheduler periodically and ensure picks are routed proportional
// to those weights.
type picker struct {
scheduler unsafe.Pointer // *scheduler; accessed atomically
v uint32 // incrementing value used by the scheduler; accessed atomically
cfg *lbConfig // active config when picker created
weightedPickers []pickerWeightedEndpoint // all READY pickers
// The following fields are immutable.
target string
locality string
metricsRecorder estats.MetricsRecorder
scheduler unsafe.Pointer // *scheduler; accessed atomically
v uint32 // incrementing value used by the scheduler; accessed atomically
cfg *lbConfig // active config when picker created
subConns []*weightedSubConn // all READY subconns
}
func (p *picker) endpointWeights(recordMetrics bool) []float64 {
wp := make([]float64, len(p.weightedPickers))
// scWeights returns a slice containing the weights from p.subConns in the same
// order as p.subConns.
func (p *picker) scWeights() []float64 {
ws := make([]float64, len(p.subConns))
now := internal.TimeNow()
for i, wpi := range p.weightedPickers {
wp[i] = wpi.weightedEndpoint.weight(now, time.Duration(p.cfg.WeightExpirationPeriod), time.Duration(p.cfg.BlackoutPeriod), recordMetrics)
for i, wsc := range p.subConns {
ws[i] = wsc.weight(now, time.Duration(p.cfg.WeightExpirationPeriod), time.Duration(p.cfg.BlackoutPeriod))
}
return wp
}
func (p *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
// Read the scheduler atomically. All scheduler operations are threadsafe,
// and if the scheduler is replaced during this usage, we want to use the
// scheduler that was live when the pick started.
sched := *(*scheduler)(atomic.LoadPointer(&p.scheduler))
pickedPicker := p.weightedPickers[sched.nextIndex()]
pr, err := pickedPicker.picker.Pick(info)
if err != nil {
logger.Errorf("ready picker returned error: %v", err)
return balancer.PickResult{}, err
}
if !p.cfg.EnableOOBLoadReport {
oldDone := pr.Done
pr.Done = func(info balancer.DoneInfo) {
if load, ok := info.ServerLoad.(*v3orcapb.OrcaLoadReport); ok && load != nil {
pickedPicker.weightedEndpoint.OnLoadReport(load)
}
if oldDone != nil {
oldDone(info)
}
}
}
return pr, nil
return ws
}
func (p *picker) inc() uint32 {
@ -463,23 +348,21 @@ func (p *picker) inc() uint32 {
}
func (p *picker) regenerateScheduler() {
s := p.newScheduler(true)
s := newScheduler(p.scWeights(), p.inc)
atomic.StorePointer(&p.scheduler, unsafe.Pointer(&s))
}
func (p *picker) start(stopPicker *grpcsync.Event) {
func (p *picker) start(ctx context.Context) {
p.regenerateScheduler()
if len(p.weightedPickers) == 1 {
if len(p.subConns) == 1 {
// No need to regenerate weights with only one backend.
return
}
go func() {
ticker := time.NewTicker(time.Duration(p.cfg.WeightUpdatePeriod))
defer ticker.Stop()
for {
select {
case <-stopPicker.Done():
case <-ctx.Done():
return
case <-ticker.C:
p.regenerateScheduler()
@ -488,27 +371,36 @@ func (p *picker) start(stopPicker *grpcsync.Event) {
}()
}
// endpointWeight is the weight for an endpoint. It tracks the SubConn that will
// be picked for the endpoint, and other parameters relevant to computing the
// effective weight. When needed, it also tracks connectivity state, listens for
// metrics updates by implementing the orca.OOBListener interface and manages
// that listener.
type endpointWeight struct {
// The following fields are immutable.
logger *grpclog.PrefixLogger
target string
metricsRecorder estats.MetricsRecorder
locality string
func (p *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
// Read the scheduler atomically. All scheduler operations are threadsafe,
// and if the scheduler is replaced during this usage, we want to use the
// scheduler that was live when the pick started.
sched := *(*scheduler)(atomic.LoadPointer(&p.scheduler))
pickedSC := p.subConns[sched.nextIndex()]
pr := balancer.PickResult{SubConn: pickedSC.SubConn}
if !p.cfg.EnableOOBLoadReport {
pr.Done = func(info balancer.DoneInfo) {
if load, ok := info.ServerLoad.(*v3orcapb.OrcaLoadReport); ok && load != nil {
pickedSC.OnLoadReport(load)
}
}
}
return pr, nil
}
// weightedSubConn is the wrapper of a subconn that holds the subconn and its
// weight (and other parameters relevant to computing the effective weight).
// When needed, it also tracks connectivity state, listens for metrics updates
// by implementing the orca.OOBListener interface and manages that listener.
type weightedSubConn struct {
balancer.SubConn
logger *grpclog.PrefixLogger
// The following fields are only accessed on calls into the LB policy, and
// do not need a mutex.
connectivityState connectivity.State
stopORCAListener func()
// The first SubConn for the endpoint that goes READY when endpoint has no
// READY SubConns yet, cleared on that sc disconnecting (i.e. going out of
// READY). Represents what pick first will use as it's picked SubConn for
// this endpoint.
pickedSC balancer.SubConn
// The following fields are accessed asynchronously and are protected by
// mu. Note that mu may not be held when calling into the stopORCAListener
@ -522,18 +414,18 @@ type endpointWeight struct {
cfg *lbConfig
}
func (w *endpointWeight) OnLoadReport(load *v3orcapb.OrcaLoadReport) {
func (w *weightedSubConn) OnLoadReport(load *v3orcapb.OrcaLoadReport) {
if w.logger.V(2) {
w.logger.Infof("Received load report for subchannel %v: %v", w.pickedSC, load)
w.logger.Infof("Received load report for subchannel %v: %v", w.SubConn, load)
}
// Update weights of this endpoint according to the reported load.
// Update weights of this subchannel according to the reported load
utilization := load.ApplicationUtilization
if utilization == 0 {
utilization = load.CpuUtilization
}
if utilization == 0 || load.RpsFractional == 0 {
if w.logger.V(2) {
w.logger.Infof("Ignoring empty load report for subchannel %v", w.pickedSC)
w.logger.Infof("Ignoring empty load report for subchannel %v", w.SubConn)
}
return
}
@ -544,36 +436,34 @@ func (w *endpointWeight) OnLoadReport(load *v3orcapb.OrcaLoadReport) {
errorRate := load.Eps / load.RpsFractional
w.weightVal = load.RpsFractional / (utilization + errorRate*w.cfg.ErrorUtilizationPenalty)
if w.logger.V(2) {
w.logger.Infof("New weight for subchannel %v: %v", w.pickedSC, w.weightVal)
w.logger.Infof("New weight for subchannel %v: %v", w.SubConn, w.weightVal)
}
w.lastUpdated = internal.TimeNow()
if w.nonEmptySince.Equal(time.Time{}) {
if w.nonEmptySince == (time.Time{}) {
w.nonEmptySince = w.lastUpdated
}
}
// updateConfig updates the parameters of the WRR policy and
// stops/starts/restarts the ORCA OOB listener.
func (w *endpointWeight) updateConfig(cfg *lbConfig) {
func (w *weightedSubConn) updateConfig(cfg *lbConfig) {
w.mu.Lock()
oldCfg := w.cfg
w.cfg = cfg
w.mu.Unlock()
newPeriod := cfg.OOBReportingPeriod
if cfg.EnableOOBLoadReport == oldCfg.EnableOOBLoadReport &&
cfg.OOBReportingPeriod == oldCfg.OOBReportingPeriod {
newPeriod == oldCfg.OOBReportingPeriod {
// Load reporting wasn't enabled before or after, or load reporting was
// enabled before and after, and had the same period. (Note that with
// load reporting disabled, OOBReportingPeriod is always 0.)
return
}
// (Re)start the listener to use the new config's settings for OOB
// reporting.
w.updateORCAListener(cfg)
}
// (Optionally stop and) start the listener to use the new config's
// settings for OOB reporting.
func (w *endpointWeight) updateORCAListener(cfg *lbConfig) {
if w.stopORCAListener != nil {
w.stopORCAListener()
}
@ -581,56 +471,67 @@ func (w *endpointWeight) updateORCAListener(cfg *lbConfig) {
w.stopORCAListener = nil
return
}
if w.pickedSC == nil { // No picked SC for this endpoint yet, nothing to listen on.
return
}
if w.logger.V(2) {
w.logger.Infof("Registering ORCA listener for %v with interval %v", w.pickedSC, cfg.OOBReportingPeriod)
w.logger.Infof("Registering ORCA listener for %v with interval %v", w.SubConn, newPeriod)
}
opts := orca.OOBListenerOptions{ReportInterval: time.Duration(cfg.OOBReportingPeriod)}
w.stopORCAListener = orca.RegisterOOBListener(w.pickedSC, w, opts)
opts := orca.OOBListenerOptions{ReportInterval: time.Duration(newPeriod)}
w.stopORCAListener = orca.RegisterOOBListener(w.SubConn, w, opts)
}
// weight returns the current effective weight of the endpoint, taking into
func (w *weightedSubConn) updateConnectivityState(cs connectivity.State) connectivity.State {
switch cs {
case connectivity.Idle:
// Always reconnect when idle.
w.SubConn.Connect()
case connectivity.Ready:
// If we transition back to READY state, reset nonEmptySince so that we
// apply the blackout period after we start receiving load data. Note
// that we cannot guarantee that we will never receive lingering
// callbacks for backend metric reports from the previous connection
// after the new connection has been established, but they should be
// masked by new backend metric reports from the new connection by the
// time the blackout period ends.
w.mu.Lock()
w.nonEmptySince = time.Time{}
w.mu.Unlock()
case connectivity.Shutdown:
if w.stopORCAListener != nil {
w.stopORCAListener()
}
}
oldCS := w.connectivityState
if oldCS == connectivity.TransientFailure &&
(cs == connectivity.Connecting || cs == connectivity.Idle) {
// Once a subconn enters TRANSIENT_FAILURE, ignore subsequent IDLE or
// CONNECTING transitions to prevent the aggregated state from being
// always CONNECTING when many backends exist but are all down.
return oldCS
}
w.connectivityState = cs
return oldCS
}
// weight returns the current effective weight of the subconn, taking into
// account the parameters. Returns 0 for blacked out or expired data, which
// will cause the backend weight to be treated as the mean of the weights of the
// other backends. If forScheduler is set to true, this function will emit
// metrics through the metrics registry.
func (w *endpointWeight) weight(now time.Time, weightExpirationPeriod, blackoutPeriod time.Duration, recordMetrics bool) (weight float64) {
// will cause the backend weight to be treated as the mean of the weights of
// the other backends.
func (w *weightedSubConn) weight(now time.Time, weightExpirationPeriod, blackoutPeriod time.Duration) float64 {
w.mu.Lock()
defer w.mu.Unlock()
if recordMetrics {
defer func() {
endpointWeightsMetric.Record(w.metricsRecorder, weight, w.target, w.locality)
}()
}
// The endpoint has not received a load report (i.e. just turned READY with
// no load report).
if w.lastUpdated.Equal(time.Time{}) {
endpointWeightNotYetUsableMetric.Record(w.metricsRecorder, 1, w.target, w.locality)
return 0
}
// If the most recent update was longer ago than the expiration period,
// reset nonEmptySince so that we apply the blackout period again if we
// start getting data again in the future, and return 0.
if now.Sub(w.lastUpdated) >= weightExpirationPeriod {
if recordMetrics {
endpointWeightStaleMetric.Record(w.metricsRecorder, 1, w.target, w.locality)
}
w.nonEmptySince = time.Time{}
return 0
}
// If we don't have at least blackoutPeriod worth of data, return 0.
if blackoutPeriod != 0 && (w.nonEmptySince.Equal(time.Time{}) || now.Sub(w.nonEmptySince) < blackoutPeriod) {
if recordMetrics {
endpointWeightNotYetUsableMetric.Record(w.metricsRecorder, 1, w.target, w.locality)
}
if blackoutPeriod != 0 && (w.nonEmptySince == (time.Time{}) || now.Sub(w.nonEmptySince) < blackoutPeriod) {
return 0
}
return w.weightVal
}

View File

@ -32,7 +32,6 @@ import (
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/internal/testutils/roundrobin"
"google.golang.org/grpc/internal/testutils/stats"
"google.golang.org/grpc/orca"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/resolver"
@ -82,14 +81,6 @@ var (
WeightUpdatePeriod: stringp(".050s"),
ErrorUtilizationPenalty: float64p(0),
}
testMetricsConfig = iwrr.LBConfig{
EnableOOBLoadReport: boolp(false),
OOBReportingPeriod: stringp("0.005s"),
BlackoutPeriod: stringp("0s"),
WeightExpirationPeriod: stringp("60s"),
WeightUpdatePeriod: stringp("30s"),
ErrorUtilizationPenalty: float64p(0),
}
)
type testServer struct {
@ -115,7 +106,7 @@ func startServer(t *testing.T, r reportType) *testServer {
cmr := orca.NewServerMetricsRecorder().(orca.CallMetricsRecorder)
ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
if r := orca.CallMetricsRecorderFromContext(ctx); r != nil {
// Copy metrics from what the test set in cmr into r.
sm := cmr.(orca.ServerMetricsProvider).ServerMetrics()
@ -138,7 +129,7 @@ func startServer(t *testing.T, r reportType) *testServer {
MinReportingInterval: 10 * time.Millisecond,
}
internal.ORCAAllowAnyMinReportingInterval.(func(so *orca.ServiceOptions))(&oso)
sopts = append(sopts, stubserver.RegisterServiceServerOption(func(s grpc.ServiceRegistrar) {
sopts = append(sopts, stubserver.RegisterServiceServerOption(func(s *grpc.Server) {
if err := orca.Register(s, oso); err != nil {
t.Fatalf("Failed to register orca service: %v", err)
}
@ -205,51 +196,6 @@ func (s) TestBalancer_OneAddress(t *testing.T) {
}
}
// TestWRRMetricsBasic tests metrics emitted from the WRR balancer. It
// configures a weighted round robin balancer as the top level balancer of a
// ClientConn, and configures a fake stats handler on the ClientConn to receive
// metrics. It verifies stats emitted from the Weighted Round Robin Balancer on
// balancer startup case which triggers the first picker and scheduler update
// before any load reports are received.
//
// Note that this test and others, metrics emission assertions are a snapshot
// of the most recently emitted metrics. This is due to the nondeterminism of
// scheduler updates with respect to test bodies, so the assertions made are
// from the most recently synced state of the system (picker/scheduler) from the
// test body.
func (s) TestWRRMetricsBasic(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
srv := startServer(t, reportCall)
sc := svcConfig(t, testMetricsConfig)
tmr := stats.NewTestMetricsRecorder()
if err := srv.StartClient(grpc.WithDefaultServiceConfig(sc), grpc.WithStatsHandler(tmr)); err != nil {
t.Fatalf("Error starting client: %v", err)
}
srv.callMetrics.SetQPS(float64(1))
if _, err := srv.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
t.Fatalf("Error from EmptyCall: %v", err)
}
if got, _ := tmr.Metric("grpc.lb.wrr.rr_fallback"); got != 1 {
t.Fatalf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.wrr.rr_fallback", got, 1)
}
if got, _ := tmr.Metric("grpc.lb.wrr.endpoint_weight_stale"); got != 0 {
t.Fatalf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.wrr.endpoint_weight_stale", got, 0)
}
if got, _ := tmr.Metric("grpc.lb.wrr.endpoint_weight_not_yet_usable"); got != 1 {
t.Fatalf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.wrr.endpoint_weight_not_yet_usable", got, 1)
}
// Unusable, so no endpoint weight. Due to only one SubConn, this will never
// update the weight. Thus, this will stay 0.
if got, _ := tmr.Metric("grpc.lb.wrr.endpoint_weight_stale"); got != 0 {
t.Fatalf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.wrr.endpoint_weight_stale", got, 0)
}
}
// Tests two addresses with ORCA reporting disabled (should fall back to pure
// RR).
func (s) TestBalancer_TwoAddresses_ReportingDisabled(t *testing.T) {
@ -449,7 +395,7 @@ func (s) TestBalancer_TwoAddresses_OOBThenPerCall(t *testing.T) {
// Update to per-call weights.
c := svcConfig(t, perCallConfig)
parsedCfg := srv1.R.CC().ParseServiceConfig(c)
parsedCfg := srv1.R.CC.ParseServiceConfig(c)
if parsedCfg.Err != nil {
panic(fmt.Sprintf("Error parsing config %q: %v", c, parsedCfg.Err))
}
@ -460,65 +406,6 @@ func (s) TestBalancer_TwoAddresses_OOBThenPerCall(t *testing.T) {
checkWeights(ctx, t, srvWeight{srv1, 10}, srvWeight{srv2, 1})
}
// TestEndpoints_SharedAddress tests the case where two endpoints have the same
// address. The expected behavior is undefined, however the program should not
// crash.
func (s) TestEndpoints_SharedAddress(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
srv := startServer(t, reportCall)
sc := svcConfig(t, perCallConfig)
if err := srv.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil {
t.Fatalf("Error starting client: %v", err)
}
endpointsSharedAddress := []resolver.Endpoint{{Addresses: []resolver.Address{{Addr: srv.Address}}}, {Addresses: []resolver.Address{{Addr: srv.Address}}}}
srv.R.UpdateState(resolver.State{Endpoints: endpointsSharedAddress})
// Make some RPC's and make sure doesn't crash. It should go to one of the
// endpoints addresses, it's undefined which one it will choose and the load
// reporting might not work, but it should be able to make an RPC.
for i := 0; i < 10; i++ {
if _, err := srv.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
t.Fatalf("EmptyCall failed with err: %v", err)
}
}
}
// TestEndpoints_MultipleAddresses tests WRR on endpoints with numerous
// addresses. It configures WRR with two endpoints with one bad address followed
// by a good address. It configures two backends that each report per call
// metrics, each corresponding to the two endpoints good address. It then
// asserts load is distributed as expected corresponding to the call metrics
// received.
func (s) TestEndpoints_MultipleAddresses(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
srv1 := startServer(t, reportCall)
srv2 := startServer(t, reportCall)
srv1.callMetrics.SetQPS(10.0)
srv1.callMetrics.SetApplicationUtilization(.1)
srv2.callMetrics.SetQPS(10.0)
srv2.callMetrics.SetApplicationUtilization(1.0)
sc := svcConfig(t, perCallConfig)
if err := srv1.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil {
t.Fatalf("Error starting client: %v", err)
}
twoEndpoints := []resolver.Endpoint{{Addresses: []resolver.Address{{Addr: "bad-address-1"}, {Addr: srv1.Address}}}, {Addresses: []resolver.Address{{Addr: "bad-address-2"}, {Addr: srv2.Address}}}}
srv1.R.UpdateState(resolver.State{Endpoints: twoEndpoints})
// Call each backend once to ensure the weights have been received.
ensureReached(ctx, t, srv1.Client, 2)
// Wait for the weight update period to allow the new weights to be processed.
time.Sleep(weightUpdatePeriod)
checkWeights(ctx, t, srvWeight{srv1, 10}, srvWeight{srv2, 1})
}
// Tests two addresses with OOB ORCA reporting enabled and a non-zero error
// penalty applied.
func (s) TestBalancer_TwoAddresses_ErrorPenalty(t *testing.T) {
@ -563,7 +450,7 @@ func (s) TestBalancer_TwoAddresses_ErrorPenalty(t *testing.T) {
newCfg := oobConfig
newCfg.ErrorUtilizationPenalty = float64p(0.9)
c := svcConfig(t, newCfg)
parsedCfg := srv1.R.CC().ParseServiceConfig(c)
parsedCfg := srv1.R.CC.ParseServiceConfig(c)
if parsedCfg.Err != nil {
panic(fmt.Sprintf("Error parsing config %q: %v", c, parsedCfg.Err))
}
@ -806,7 +693,7 @@ type srvWeight struct {
const rrIterations = 100
// checkWeights does rrIterations RPCs and expects the different backends to be
// routed in a ratio as determined by the srvWeights passed in. Allows for
// routed in a ratio as deterimined by the srvWeights passed in. Allows for
// some variance (+/- 2 RPCs per backend).
func checkWeights(ctx context.Context, t *testing.T, sws ...srvWeight) {
t.Helper()

View File

@ -1,173 +0,0 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package weightedroundrobin
import (
"testing"
"time"
"google.golang.org/grpc/internal/grpctest"
iserviceconfig "google.golang.org/grpc/internal/serviceconfig"
"google.golang.org/grpc/internal/testutils/stats"
)
type s struct {
grpctest.Tester
}
func Test(t *testing.T) {
grpctest.RunSubTests(t, s{})
}
// TestWRR_Metrics_SubConnWeight tests different scenarios for the weight call
// on a weighted SubConn, and expects certain metrics for each of these
// scenarios.
func (s) TestWRR_Metrics_SubConnWeight(t *testing.T) {
tests := []struct {
name string
weightExpirationPeriod time.Duration
blackoutPeriod time.Duration
lastUpdated time.Time
nonEmpty time.Time
nowTime time.Time
endpointWeightStaleWant float64
endpointWeightNotYetUsableWant float64
endpointWeightWant float64
}{
// The weighted SubConn's lastUpdated field hasn't been set, so this
// SubConn's weight is not yet usable. Thus, should emit that endpoint
// weight is not yet usable, and 0 for weight.
{
name: "no weight set",
weightExpirationPeriod: time.Second,
blackoutPeriod: time.Second,
nowTime: time.Now(),
endpointWeightStaleWant: 0,
endpointWeightNotYetUsableWant: 1,
endpointWeightWant: 0,
},
{
name: "weight expiration",
lastUpdated: time.Now(),
weightExpirationPeriod: 2 * time.Second,
blackoutPeriod: time.Second,
nowTime: time.Now().Add(100 * time.Second),
endpointWeightStaleWant: 1,
endpointWeightNotYetUsableWant: 0,
endpointWeightWant: 0,
},
{
name: "in blackout period",
lastUpdated: time.Now(),
weightExpirationPeriod: time.Minute,
blackoutPeriod: 10 * time.Second,
nowTime: time.Now(),
endpointWeightStaleWant: 0,
endpointWeightNotYetUsableWant: 1,
endpointWeightWant: 0,
},
{
name: "normal weight",
lastUpdated: time.Now(),
nonEmpty: time.Now(),
weightExpirationPeriod: time.Minute,
blackoutPeriod: time.Second,
nowTime: time.Now().Add(10 * time.Second),
endpointWeightStaleWant: 0,
endpointWeightNotYetUsableWant: 0,
endpointWeightWant: 3,
},
{
name: "weight expiration takes precdedence over blackout",
lastUpdated: time.Now(),
nonEmpty: time.Now(),
weightExpirationPeriod: time.Second,
blackoutPeriod: time.Minute,
nowTime: time.Now().Add(10 * time.Second),
endpointWeightStaleWant: 1,
endpointWeightNotYetUsableWant: 0,
endpointWeightWant: 0,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
tmr := stats.NewTestMetricsRecorder()
wsc := &endpointWeight{
metricsRecorder: tmr,
weightVal: 3,
lastUpdated: test.lastUpdated,
nonEmptySince: test.nonEmpty,
}
wsc.weight(test.nowTime, test.weightExpirationPeriod, test.blackoutPeriod, true)
if got, _ := tmr.Metric("grpc.lb.wrr.endpoint_weight_stale"); got != test.endpointWeightStaleWant {
t.Fatalf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.wrr.endpoint_weight_stale", got, test.endpointWeightStaleWant)
}
if got, _ := tmr.Metric("grpc.lb.wrr.endpoint_weight_not_yet_usable"); got != test.endpointWeightNotYetUsableWant {
t.Fatalf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.wrr.endpoint_weight_not_yet_usable", got, test.endpointWeightNotYetUsableWant)
}
if got, _ := tmr.Metric("grpc.lb.wrr.endpoint_weight_stale"); got != test.endpointWeightStaleWant {
t.Fatalf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.wrr.endpoint_weight_stale", got, test.endpointWeightStaleWant)
}
})
}
}
// TestWRR_Metrics_Scheduler_RR_Fallback tests the round robin fallback metric
// for scheduler updates. It tests the case with one SubConn, and two SubConns
// with no weights. Both of these should emit a count metric for round robin
// fallback.
func (s) TestWRR_Metrics_Scheduler_RR_Fallback(t *testing.T) {
tmr := stats.NewTestMetricsRecorder()
ew := &endpointWeight{
metricsRecorder: tmr,
weightVal: 0,
}
p := &picker{
cfg: &lbConfig{
BlackoutPeriod: iserviceconfig.Duration(10 * time.Second),
WeightExpirationPeriod: iserviceconfig.Duration(3 * time.Minute),
},
weightedPickers: []pickerWeightedEndpoint{{weightedEndpoint: ew}},
metricsRecorder: tmr,
}
// There is only one SubConn, so no matter if the SubConn has a weight or
// not will fallback to round robin.
p.regenerateScheduler()
if got, _ := tmr.Metric("grpc.lb.wrr.rr_fallback"); got != 1 {
t.Fatalf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.wrr.rr_fallback", got, 1)
}
tmr.ClearMetrics()
// With two SubConns, if neither of them have weights, it will also fallback
// to round robin.
ew2 := &endpointWeight{
target: "target",
metricsRecorder: tmr,
weightVal: 0,
}
p.weightedPickers = append(p.weightedPickers, pickerWeightedEndpoint{weightedEndpoint: ew2})
p.regenerateScheduler()
if got, _ := tmr.Metric("grpc.lb.wrr.rr_fallback"); got != 1 {
t.Fatalf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.wrr.rr_fallback", got, 1)
}
}

View File

@ -26,27 +26,23 @@ type scheduler interface {
nextIndex() int
}
// newScheduler uses scWeights to create a new scheduler for selecting endpoints
// newScheduler uses scWeights to create a new scheduler for selecting subconns
// in a picker. It will return a round robin implementation if at least
// len(scWeights)-1 are zero or there is only a single endpoint, otherwise it
// len(scWeights)-1 are zero or there is only a single subconn, otherwise it
// will return an Earliest Deadline First (EDF) scheduler implementation that
// selects the endpoints according to their weights.
func (p *picker) newScheduler(recordMetrics bool) scheduler {
epWeights := p.endpointWeights(recordMetrics)
n := len(epWeights)
// selects the subchannels according to their weights.
func newScheduler(scWeights []float64, inc func() uint32) scheduler {
n := len(scWeights)
if n == 0 {
return nil
}
if n == 1 {
if recordMetrics {
rrFallbackMetric.Record(p.metricsRecorder, 1, p.target, p.locality)
}
return &rrScheduler{numSCs: 1, inc: p.inc}
return &rrScheduler{numSCs: 1, inc: inc}
}
sum := float64(0)
numZero := 0
max := float64(0)
for _, w := range epWeights {
for _, w := range scWeights {
sum += w
if w > max {
max = w
@ -55,12 +51,8 @@ func (p *picker) newScheduler(recordMetrics bool) scheduler {
numZero++
}
}
if numZero >= n-1 {
if recordMetrics {
rrFallbackMetric.Record(p.metricsRecorder, 1, p.target, p.locality)
}
return &rrScheduler{numSCs: uint32(n), inc: p.inc}
return &rrScheduler{numSCs: uint32(n), inc: inc}
}
unscaledMean := sum / float64(n-numZero)
scalingFactor := maxWeight / max
@ -68,7 +60,7 @@ func (p *picker) newScheduler(recordMetrics bool) scheduler {
weights := make([]uint16, n)
allEqual := true
for i, w := range epWeights {
for i, w := range scWeights {
if w == 0 {
// Backends with weight = 0 use the mean.
weights[i] = mean
@ -82,11 +74,11 @@ func (p *picker) newScheduler(recordMetrics bool) scheduler {
}
if allEqual {
return &rrScheduler{numSCs: uint32(n), inc: p.inc}
return &rrScheduler{numSCs: uint32(n), inc: inc}
}
logger.Infof("using edf scheduler with weights: %v", weights)
return &edfScheduler{weights: weights, inc: p.inc}
return &edfScheduler{weights: weights, inc: inc}
}
const maxWeight = math.MaxUint16
@ -133,7 +125,7 @@ func (s *edfScheduler) nextIndex() int {
}
// A simple RR scheduler to use for fallback when fewer than two backends have
// non-zero weights, or all backends have the same weight, or when only one
// non-zero weights, or all backends have the the same weight, or when only one
// subconn exists.
type rrScheduler struct {
inc func() uint32

View File

@ -0,0 +1,69 @@
/*
*
* Copyright 2019 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package weightedroundrobin provides an implementation of the weighted round
// robin LB policy, as defined in [gRFC A58].
//
// # Experimental
//
// Notice: This package is EXPERIMENTAL and may be changed or removed in a
// later release.
//
// [gRFC A58]: https://github.com/grpc/proposal/blob/master/A58-client-side-weighted-round-robin-lb-policy.md
package weightedroundrobin
import (
"fmt"
"google.golang.org/grpc/resolver"
)
// attributeKey is the type used as the key to store AddrInfo in the
// BalancerAttributes field of resolver.Address.
type attributeKey struct{}
// AddrInfo will be stored in the BalancerAttributes field of Address in order
// to use weighted roundrobin balancer.
type AddrInfo struct {
Weight uint32
}
// Equal allows the values to be compared by Attributes.Equal.
func (a AddrInfo) Equal(o interface{}) bool {
oa, ok := o.(AddrInfo)
return ok && oa.Weight == a.Weight
}
// SetAddrInfo returns a copy of addr in which the BalancerAttributes field is
// updated with addrInfo.
func SetAddrInfo(addr resolver.Address, addrInfo AddrInfo) resolver.Address {
addr.BalancerAttributes = addr.BalancerAttributes.WithValue(attributeKey{}, addrInfo)
return addr
}
// GetAddrInfo returns the AddrInfo stored in the BalancerAttributes field of
// addr.
func GetAddrInfo(addr resolver.Address) AddrInfo {
v := addr.BalancerAttributes.Value(attributeKey{})
ai, _ := v.(AddrInfo)
return ai
}
func (a AddrInfo) String() string {
return fmt.Sprintf("Weight: %d", a.Weight)
}

View File

@ -0,0 +1,82 @@
/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package weightedroundrobin
import (
"testing"
"github.com/google/go-cmp/cmp"
"google.golang.org/grpc/attributes"
"google.golang.org/grpc/resolver"
)
func TestAddrInfoToAndFromAttributes(t *testing.T) {
tests := []struct {
desc string
inputAddrInfo AddrInfo
inputAttributes *attributes.Attributes
wantAddrInfo AddrInfo
}{
{
desc: "empty attributes",
inputAddrInfo: AddrInfo{Weight: 100},
inputAttributes: nil,
wantAddrInfo: AddrInfo{Weight: 100},
},
{
desc: "non-empty attributes",
inputAddrInfo: AddrInfo{Weight: 100},
inputAttributes: attributes.New("foo", "bar"),
wantAddrInfo: AddrInfo{Weight: 100},
},
{
desc: "addrInfo not present in empty attributes",
inputAddrInfo: AddrInfo{},
inputAttributes: nil,
wantAddrInfo: AddrInfo{},
},
{
desc: "addrInfo not present in non-empty attributes",
inputAddrInfo: AddrInfo{},
inputAttributes: attributes.New("foo", "bar"),
wantAddrInfo: AddrInfo{},
},
}
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
addr := resolver.Address{Attributes: test.inputAttributes}
addr = SetAddrInfo(addr, test.inputAddrInfo)
gotAddrInfo := GetAddrInfo(addr)
if !cmp.Equal(gotAddrInfo, test.wantAddrInfo) {
t.Errorf("gotAddrInfo: %v, wantAddrInfo: %v", gotAddrInfo, test.wantAddrInfo)
}
})
}
}
func TestGetAddInfoEmpty(t *testing.T) {
addr := resolver.Address{}
gotAddrInfo := GetAddrInfo(addr)
wantAddrInfo := AddrInfo{}
if !cmp.Equal(gotAddrInfo, wantAddrInfo) {
t.Errorf("gotAddrInfo: %v, wantAddrInfo: %v", gotAddrInfo, wantAddrInfo)
}
}

View File

@ -26,7 +26,6 @@
package weightedaggregator
import (
"errors"
"fmt"
"sync"
@ -90,7 +89,7 @@ func New(cc balancer.ClientConn, logger *grpclog.PrefixLogger, newWRR func() wrr
}
// Start starts the aggregator. It can be called after Stop to restart the
// aggregator.
// aggretator.
func (wbsa *Aggregator) Start() {
wbsa.mu.Lock()
defer wbsa.mu.Unlock()
@ -252,14 +251,6 @@ func (wbsa *Aggregator) buildAndUpdateLocked() {
func (wbsa *Aggregator) build() balancer.State {
wbsa.logger.Infof("Child pickers with config: %+v", wbsa.idToPickerState)
if len(wbsa.idToPickerState) == 0 {
// This is the case when all sub-balancers are removed.
return balancer.State{
ConnectivityState: connectivity.TransientFailure,
Picker: base.NewErrPicker(errors.New("weighted-target: no targets to pick from")),
}
}
// Make sure picker's return error is consistent with the aggregatedState.
pickers := make([]weightedPickerState, 0, len(wbsa.idToPickerState))

View File

@ -24,7 +24,6 @@ package weightedtarget
import (
"encoding/json"
"fmt"
"time"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/weightedtarget/weightedaggregator"
@ -55,13 +54,8 @@ func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Ba
b.logger = prefixLogger(b)
b.stateAggregator = weightedaggregator.New(cc, b.logger, NewRandomWRR)
b.stateAggregator.Start()
b.bg = balancergroup.New(balancergroup.Options{
CC: cc,
BuildOpts: bOpts,
StateAggregator: b.stateAggregator,
Logger: b.logger,
SubBalancerCloseTimeout: time.Duration(0), // Disable caching of removed child policies
})
b.bg = balancergroup.New(cc, bOpts, b.stateAggregator, b.logger)
b.bg.Start()
b.logger.Infof("Created")
return b
}
@ -83,31 +77,16 @@ type weightedTargetBalancer struct {
targets map[string]Target
}
type localityKeyType string
const localityKey = localityKeyType("locality")
// LocalityFromResolverState returns the locality from the resolver.State
// provided, or an empty string if not present.
func LocalityFromResolverState(state resolver.State) string {
locality, _ := state.Attributes.Value(localityKey).(string)
return locality
}
// UpdateClientConnState takes the new targets in balancer group,
// creates/deletes sub-balancers and sends them update. addresses are split into
// groups based on hierarchy path.
func (b *weightedTargetBalancer) UpdateClientConnState(s balancer.ClientConnState) error {
if b.logger.V(2) {
b.logger.Infof("Received update from resolver, balancer config: %+v", pretty.ToJSON(s.BalancerConfig))
}
b.logger.Infof("Received update from resolver, balancer config: %+v", pretty.ToJSON(s.BalancerConfig))
newConfig, ok := s.BalancerConfig.(*LBConfig)
if !ok {
return fmt.Errorf("unexpected balancer config with type: %T", s.BalancerConfig)
}
addressesSplit := hierarchy.Group(s.ResolverState.Addresses)
endpointsSplit := hierarchy.GroupEndpoints(s.ResolverState.Endpoints)
b.stateAggregator.PauseStateUpdates()
defer b.stateAggregator.ResumeStateUpdates()
@ -155,9 +134,8 @@ func (b *weightedTargetBalancer) UpdateClientConnState(s balancer.ClientConnStat
_ = b.bg.UpdateClientConnState(name, balancer.ClientConnState{
ResolverState: resolver.State{
Addresses: addressesSplit[name],
Endpoints: endpointsSplit[name],
ServiceConfig: s.ResolverState.ServiceConfig,
Attributes: s.ResolverState.Attributes.WithValue(localityKey, name),
Attributes: s.ResolverState.Attributes,
},
BalancerConfig: newT.ChildPolicy.Config,
})
@ -185,7 +163,7 @@ func (b *weightedTargetBalancer) ResolverError(err error) {
}
func (b *weightedTargetBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) {
b.logger.Errorf("UpdateSubConnState(%v, %+v) called unexpectedly", sc, state)
b.bg.UpdateSubConnState(sc, state)
}
func (b *weightedTargetBalancer) Close() {

File diff suppressed because it is too large Load Diff

459
balancer_conn_wrappers.go Normal file
View File

@ -0,0 +1,459 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc
import (
"context"
"fmt"
"strings"
"sync"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/internal/balancer/gracefulswitch"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/resolver"
)
type ccbMode int
const (
ccbModeActive = iota
ccbModeIdle
ccbModeClosed
ccbModeExitingIdle
)
// ccBalancerWrapper sits between the ClientConn and the Balancer.
//
// ccBalancerWrapper implements methods corresponding to the ones on the
// balancer.Balancer interface. The ClientConn is free to call these methods
// concurrently and the ccBalancerWrapper ensures that calls from the ClientConn
// to the Balancer happen synchronously and in order.
//
// ccBalancerWrapper also implements the balancer.ClientConn interface and is
// passed to the Balancer implementations. It invokes unexported methods on the
// ClientConn to handle these calls from the Balancer.
//
// It uses the gracefulswitch.Balancer internally to ensure that balancer
// switches happen in a graceful manner.
type ccBalancerWrapper struct {
// The following fields are initialized when the wrapper is created and are
// read-only afterwards, and therefore can be accessed without a mutex.
cc *ClientConn
opts balancer.BuildOptions
// Outgoing (gRPC --> balancer) calls are guaranteed to execute in a
// mutually exclusive manner as they are scheduled in the serializer. Fields
// accessed *only* in these serializer callbacks, can therefore be accessed
// without a mutex.
balancer *gracefulswitch.Balancer
curBalancerName string
// mu guards access to the below fields. Access to the serializer and its
// cancel function needs to be mutex protected because they are overwritten
// when the wrapper exits idle mode.
mu sync.Mutex
serializer *grpcsync.CallbackSerializer // To serialize all outoing calls.
serializerCancel context.CancelFunc // To close the seralizer at close/enterIdle time.
mode ccbMode // Tracks the current mode of the wrapper.
}
// newCCBalancerWrapper creates a new balancer wrapper. The underlying balancer
// is not created until the switchTo() method is invoked.
func newCCBalancerWrapper(cc *ClientConn, bopts balancer.BuildOptions) *ccBalancerWrapper {
ctx, cancel := context.WithCancel(context.Background())
ccb := &ccBalancerWrapper{
cc: cc,
opts: bopts,
serializer: grpcsync.NewCallbackSerializer(ctx),
serializerCancel: cancel,
}
ccb.balancer = gracefulswitch.NewBalancer(ccb, bopts)
return ccb
}
// updateClientConnState is invoked by grpc to push a ClientConnState update to
// the underlying balancer.
func (ccb *ccBalancerWrapper) updateClientConnState(ccs *balancer.ClientConnState) error {
ccb.mu.Lock()
errCh := make(chan error, 1)
// Here and everywhere else where Schedule() is called, it is done with the
// lock held. But the lock guards only the scheduling part. The actual
// callback is called asynchronously without the lock being held.
ok := ccb.serializer.Schedule(func(_ context.Context) {
// If the addresses specified in the update contain addresses of type
// "grpclb" and the selected LB policy is not "grpclb", these addresses
// will be filtered out and ccs will be modified with the updated
// address list.
if ccb.curBalancerName != grpclbName {
var addrs []resolver.Address
for _, addr := range ccs.ResolverState.Addresses {
if addr.Type == resolver.GRPCLB {
continue
}
addrs = append(addrs, addr)
}
ccs.ResolverState.Addresses = addrs
}
errCh <- ccb.balancer.UpdateClientConnState(*ccs)
})
if !ok {
// If we are unable to schedule a function with the serializer, it
// indicates that it has been closed. A serializer is only closed when
// the wrapper is closed or is in idle.
ccb.mu.Unlock()
return fmt.Errorf("grpc: cannot send state update to a closed or idle balancer")
}
ccb.mu.Unlock()
// We get here only if the above call to Schedule succeeds, in which case it
// is guaranteed that the scheduled function will run. Therefore it is safe
// to block on this channel.
err := <-errCh
if logger.V(2) && err != nil {
logger.Infof("error from balancer.UpdateClientConnState: %v", err)
}
return err
}
// updateSubConnState is invoked by grpc to push a subConn state update to the
// underlying balancer.
func (ccb *ccBalancerWrapper) updateSubConnState(sc balancer.SubConn, s connectivity.State, err error) {
ccb.mu.Lock()
ccb.serializer.Schedule(func(_ context.Context) {
ccb.balancer.UpdateSubConnState(sc, balancer.SubConnState{ConnectivityState: s, ConnectionError: err})
})
ccb.mu.Unlock()
}
func (ccb *ccBalancerWrapper) resolverError(err error) {
ccb.mu.Lock()
ccb.serializer.Schedule(func(_ context.Context) {
ccb.balancer.ResolverError(err)
})
ccb.mu.Unlock()
}
// switchTo is invoked by grpc to instruct the balancer wrapper to switch to the
// LB policy identified by name.
//
// ClientConn calls newCCBalancerWrapper() at creation time. Upon receipt of the
// first good update from the name resolver, it determines the LB policy to use
// and invokes the switchTo() method. Upon receipt of every subsequent update
// from the name resolver, it invokes this method.
//
// the ccBalancerWrapper keeps track of the current LB policy name, and skips
// the graceful balancer switching process if the name does not change.
func (ccb *ccBalancerWrapper) switchTo(name string) {
ccb.mu.Lock()
ccb.serializer.Schedule(func(_ context.Context) {
// TODO: Other languages use case-sensitive balancer registries. We should
// switch as well. See: https://github.com/grpc/grpc-go/issues/5288.
if strings.EqualFold(ccb.curBalancerName, name) {
return
}
ccb.buildLoadBalancingPolicy(name)
})
ccb.mu.Unlock()
}
// buildLoadBalancingPolicy performs the following:
// - retrieve a balancer builder for the given name. Use the default LB
// policy, pick_first, if no LB policy with name is found in the registry.
// - instruct the gracefulswitch balancer to switch to the above builder. This
// will actually build the new balancer.
// - update the `curBalancerName` field
//
// Must be called from a serializer callback.
func (ccb *ccBalancerWrapper) buildLoadBalancingPolicy(name string) {
builder := balancer.Get(name)
if builder == nil {
channelz.Warningf(logger, ccb.cc.channelzID, "Channel switches to new LB policy %q, since the specified LB policy %q was not registered", PickFirstBalancerName, name)
builder = newPickfirstBuilder()
} else {
channelz.Infof(logger, ccb.cc.channelzID, "Channel switches to new LB policy %q", name)
}
if err := ccb.balancer.SwitchTo(builder); err != nil {
channelz.Errorf(logger, ccb.cc.channelzID, "Channel failed to build new LB policy %q: %v", name, err)
return
}
ccb.curBalancerName = builder.Name()
}
func (ccb *ccBalancerWrapper) close() {
channelz.Info(logger, ccb.cc.channelzID, "ccBalancerWrapper: closing")
ccb.closeBalancer(ccbModeClosed)
}
// enterIdleMode is invoked by grpc when the channel enters idle mode upon
// expiry of idle_timeout. This call blocks until the balancer is closed.
func (ccb *ccBalancerWrapper) enterIdleMode() {
channelz.Info(logger, ccb.cc.channelzID, "ccBalancerWrapper: entering idle mode")
ccb.closeBalancer(ccbModeIdle)
}
// closeBalancer is invoked when the channel is being closed or when it enters
// idle mode upon expiry of idle_timeout.
func (ccb *ccBalancerWrapper) closeBalancer(m ccbMode) {
ccb.mu.Lock()
if ccb.mode == ccbModeClosed || ccb.mode == ccbModeIdle {
ccb.mu.Unlock()
return
}
ccb.mode = m
done := ccb.serializer.Done
b := ccb.balancer
ok := ccb.serializer.Schedule(func(_ context.Context) {
// Close the serializer to ensure that no more calls from gRPC are sent
// to the balancer.
ccb.serializerCancel()
// Empty the current balancer name because we don't have a balancer
// anymore and also so that we act on the next call to switchTo by
// creating a new balancer specified by the new resolver.
ccb.curBalancerName = ""
})
if !ok {
ccb.mu.Unlock()
return
}
ccb.mu.Unlock()
// Give enqueued callbacks a chance to finish.
<-done
// Spawn a goroutine to close the balancer (since it may block trying to
// cleanup all allocated resources) and return early.
go b.Close()
}
// exitIdleMode is invoked by grpc when the channel exits idle mode either
// because of an RPC or because of an invocation of the Connect() API. This
// recreates the balancer that was closed previously when entering idle mode.
//
// If the channel is not in idle mode, we know for a fact that we are here as a
// result of the user calling the Connect() method on the ClientConn. In this
// case, we can simply forward the call to the underlying balancer, instructing
// it to reconnect to the backends.
func (ccb *ccBalancerWrapper) exitIdleMode() {
ccb.mu.Lock()
if ccb.mode == ccbModeClosed {
// Request to exit idle is a no-op when wrapper is already closed.
ccb.mu.Unlock()
return
}
if ccb.mode == ccbModeIdle {
// Recreate the serializer which was closed when we entered idle.
ctx, cancel := context.WithCancel(context.Background())
ccb.serializer = grpcsync.NewCallbackSerializer(ctx)
ccb.serializerCancel = cancel
}
// The ClientConn guarantees that mutual exclusion between close() and
// exitIdleMode(), and since we just created a new serializer, we can be
// sure that the below function will be scheduled.
done := make(chan struct{})
ccb.serializer.Schedule(func(_ context.Context) {
defer close(done)
ccb.mu.Lock()
defer ccb.mu.Unlock()
if ccb.mode != ccbModeIdle {
ccb.balancer.ExitIdle()
return
}
// Gracefulswitch balancer does not support a switchTo operation after
// being closed. Hence we need to create a new one here.
ccb.balancer = gracefulswitch.NewBalancer(ccb, ccb.opts)
ccb.mode = ccbModeActive
channelz.Info(logger, ccb.cc.channelzID, "ccBalancerWrapper: exiting idle mode")
})
ccb.mu.Unlock()
<-done
}
func (ccb *ccBalancerWrapper) isIdleOrClosed() bool {
ccb.mu.Lock()
defer ccb.mu.Unlock()
return ccb.mode == ccbModeIdle || ccb.mode == ccbModeClosed
}
func (ccb *ccBalancerWrapper) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) {
if ccb.isIdleOrClosed() {
return nil, fmt.Errorf("grpc: cannot create SubConn when balancer is closed or idle")
}
if len(addrs) == 0 {
return nil, fmt.Errorf("grpc: cannot create SubConn with empty address list")
}
ac, err := ccb.cc.newAddrConn(addrs, opts)
if err != nil {
channelz.Warningf(logger, ccb.cc.channelzID, "acBalancerWrapper: NewSubConn: failed to newAddrConn: %v", err)
return nil, err
}
acbw := &acBalancerWrapper{ac: ac, producers: make(map[balancer.ProducerBuilder]*refCountedProducer)}
ac.acbw = acbw
return acbw, nil
}
func (ccb *ccBalancerWrapper) RemoveSubConn(sc balancer.SubConn) {
if ccb.isIdleOrClosed() {
// It it safe to ignore this call when the balancer is closed or in idle
// because the ClientConn takes care of closing the connections.
//
// Not returning early from here when the balancer is closed or in idle
// leads to a deadlock though, because of the following sequence of
// calls when holding cc.mu:
// cc.exitIdleMode --> ccb.enterIdleMode --> gsw.Close -->
// ccb.RemoveAddrConn --> cc.removeAddrConn
return
}
acbw, ok := sc.(*acBalancerWrapper)
if !ok {
return
}
ccb.cc.removeAddrConn(acbw.ac, errConnDrain)
}
func (ccb *ccBalancerWrapper) UpdateAddresses(sc balancer.SubConn, addrs []resolver.Address) {
if ccb.isIdleOrClosed() {
return
}
acbw, ok := sc.(*acBalancerWrapper)
if !ok {
return
}
acbw.UpdateAddresses(addrs)
}
func (ccb *ccBalancerWrapper) UpdateState(s balancer.State) {
if ccb.isIdleOrClosed() {
return
}
// Update picker before updating state. Even though the ordering here does
// not matter, it can lead to multiple calls of Pick in the common start-up
// case where we wait for ready and then perform an RPC. If the picker is
// updated later, we could call the "connecting" picker when the state is
// updated, and then call the "ready" picker after the picker gets updated.
ccb.cc.blockingpicker.updatePicker(s.Picker)
ccb.cc.csMgr.updateState(s.ConnectivityState)
}
func (ccb *ccBalancerWrapper) ResolveNow(o resolver.ResolveNowOptions) {
if ccb.isIdleOrClosed() {
return
}
ccb.cc.resolveNow(o)
}
func (ccb *ccBalancerWrapper) Target() string {
return ccb.cc.target
}
// acBalancerWrapper is a wrapper on top of ac for balancers.
// It implements balancer.SubConn interface.
type acBalancerWrapper struct {
ac *addrConn // read-only
mu sync.Mutex
producers map[balancer.ProducerBuilder]*refCountedProducer
}
func (acbw *acBalancerWrapper) String() string {
return fmt.Sprintf("SubConn(id:%d)", acbw.ac.channelzID.Int())
}
func (acbw *acBalancerWrapper) UpdateAddresses(addrs []resolver.Address) {
acbw.ac.updateAddrs(addrs)
}
func (acbw *acBalancerWrapper) Connect() {
go acbw.ac.connect()
}
// NewStream begins a streaming RPC on the addrConn. If the addrConn is not
// ready, blocks until it is or ctx expires. Returns an error when the context
// expires or the addrConn is shut down.
func (acbw *acBalancerWrapper) NewStream(ctx context.Context, desc *StreamDesc, method string, opts ...CallOption) (ClientStream, error) {
transport, err := acbw.ac.getTransport(ctx)
if err != nil {
return nil, err
}
return newNonRetryClientStream(ctx, desc, method, transport, acbw.ac, opts...)
}
// Invoke performs a unary RPC. If the addrConn is not ready, returns
// errSubConnNotReady.
func (acbw *acBalancerWrapper) Invoke(ctx context.Context, method string, args interface{}, reply interface{}, opts ...CallOption) error {
cs, err := acbw.NewStream(ctx, unaryStreamDesc, method, opts...)
if err != nil {
return err
}
if err := cs.SendMsg(args); err != nil {
return err
}
return cs.RecvMsg(reply)
}
type refCountedProducer struct {
producer balancer.Producer
refs int // number of current refs to the producer
close func() // underlying producer's close function
}
func (acbw *acBalancerWrapper) GetOrBuildProducer(pb balancer.ProducerBuilder) (balancer.Producer, func()) {
acbw.mu.Lock()
defer acbw.mu.Unlock()
// Look up existing producer from this builder.
pData := acbw.producers[pb]
if pData == nil {
// Not found; create a new one and add it to the producers map.
p, close := pb.Build(acbw)
pData = &refCountedProducer{producer: p, close: close}
acbw.producers[pb] = pData
}
// Account for this new reference.
pData.refs++
// Return a cleanup function wrapped in a OnceFunc to remove this reference
// and delete the refCountedProducer from the map if the total reference
// count goes to zero.
unref := func() {
acbw.mu.Lock()
pData.refs--
if pData.refs == 0 {
defer pData.close() // Run outside the acbw mutex
delete(acbw.producers, pb)
}
acbw.mu.Unlock()
}
return pData.producer, grpcsync.OnceFunc(unref)
}

View File

@ -1,520 +0,0 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc
import (
"context"
"fmt"
"sync"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/experimental/stats"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/balancer/gracefulswitch"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/status"
)
var (
setConnectedAddress = internal.SetConnectedAddress.(func(*balancer.SubConnState, resolver.Address))
// noOpRegisterHealthListenerFn is used when client side health checking is
// disabled. It sends a single READY update on the registered listener.
noOpRegisterHealthListenerFn = func(_ context.Context, listener func(balancer.SubConnState)) func() {
listener(balancer.SubConnState{ConnectivityState: connectivity.Ready})
return func() {}
}
)
// ccBalancerWrapper sits between the ClientConn and the Balancer.
//
// ccBalancerWrapper implements methods corresponding to the ones on the
// balancer.Balancer interface. The ClientConn is free to call these methods
// concurrently and the ccBalancerWrapper ensures that calls from the ClientConn
// to the Balancer happen in order by performing them in the serializer, without
// any mutexes held.
//
// ccBalancerWrapper also implements the balancer.ClientConn interface and is
// passed to the Balancer implementations. It invokes unexported methods on the
// ClientConn to handle these calls from the Balancer.
//
// It uses the gracefulswitch.Balancer internally to ensure that balancer
// switches happen in a graceful manner.
type ccBalancerWrapper struct {
internal.EnforceClientConnEmbedding
// The following fields are initialized when the wrapper is created and are
// read-only afterwards, and therefore can be accessed without a mutex.
cc *ClientConn
opts balancer.BuildOptions
serializer *grpcsync.CallbackSerializer
serializerCancel context.CancelFunc
// The following fields are only accessed within the serializer or during
// initialization.
curBalancerName string
balancer *gracefulswitch.Balancer
// The following field is protected by mu. Caller must take cc.mu before
// taking mu.
mu sync.Mutex
closed bool
}
// newCCBalancerWrapper creates a new balancer wrapper in idle state. The
// underlying balancer is not created until the updateClientConnState() method
// is invoked.
func newCCBalancerWrapper(cc *ClientConn) *ccBalancerWrapper {
ctx, cancel := context.WithCancel(cc.ctx)
ccb := &ccBalancerWrapper{
cc: cc,
opts: balancer.BuildOptions{
DialCreds: cc.dopts.copts.TransportCredentials,
CredsBundle: cc.dopts.copts.CredsBundle,
Dialer: cc.dopts.copts.Dialer,
Authority: cc.authority,
CustomUserAgent: cc.dopts.copts.UserAgent,
ChannelzParent: cc.channelz,
Target: cc.parsedTarget,
},
serializer: grpcsync.NewCallbackSerializer(ctx),
serializerCancel: cancel,
}
ccb.balancer = gracefulswitch.NewBalancer(ccb, ccb.opts)
return ccb
}
func (ccb *ccBalancerWrapper) MetricsRecorder() stats.MetricsRecorder {
return ccb.cc.metricsRecorderList
}
// updateClientConnState is invoked by grpc to push a ClientConnState update to
// the underlying balancer. This is always executed from the serializer, so
// it is safe to call into the balancer here.
func (ccb *ccBalancerWrapper) updateClientConnState(ccs *balancer.ClientConnState) error {
errCh := make(chan error)
uccs := func(ctx context.Context) {
defer close(errCh)
if ctx.Err() != nil || ccb.balancer == nil {
return
}
name := gracefulswitch.ChildName(ccs.BalancerConfig)
if ccb.curBalancerName != name {
ccb.curBalancerName = name
channelz.Infof(logger, ccb.cc.channelz, "Channel switches to new LB policy %q", name)
}
err := ccb.balancer.UpdateClientConnState(*ccs)
if logger.V(2) && err != nil {
logger.Infof("error from balancer.UpdateClientConnState: %v", err)
}
errCh <- err
}
onFailure := func() { close(errCh) }
// UpdateClientConnState can race with Close, and when the latter wins, the
// serializer is closed, and the attempt to schedule the callback will fail.
// It is acceptable to ignore this failure. But since we want to handle the
// state update in a blocking fashion (when we successfully schedule the
// callback), we have to use the ScheduleOr method and not the MaybeSchedule
// method on the serializer.
ccb.serializer.ScheduleOr(uccs, onFailure)
return <-errCh
}
// resolverError is invoked by grpc to push a resolver error to the underlying
// balancer. The call to the balancer is executed from the serializer.
func (ccb *ccBalancerWrapper) resolverError(err error) {
ccb.serializer.TrySchedule(func(ctx context.Context) {
if ctx.Err() != nil || ccb.balancer == nil {
return
}
ccb.balancer.ResolverError(err)
})
}
// close initiates async shutdown of the wrapper. cc.mu must be held when
// calling this function. To determine the wrapper has finished shutting down,
// the channel should block on ccb.serializer.Done() without cc.mu held.
func (ccb *ccBalancerWrapper) close() {
ccb.mu.Lock()
ccb.closed = true
ccb.mu.Unlock()
channelz.Info(logger, ccb.cc.channelz, "ccBalancerWrapper: closing")
ccb.serializer.TrySchedule(func(context.Context) {
if ccb.balancer == nil {
return
}
ccb.balancer.Close()
ccb.balancer = nil
})
ccb.serializerCancel()
}
// exitIdle invokes the balancer's exitIdle method in the serializer.
func (ccb *ccBalancerWrapper) exitIdle() {
ccb.serializer.TrySchedule(func(ctx context.Context) {
if ctx.Err() != nil || ccb.balancer == nil {
return
}
ccb.balancer.ExitIdle()
})
}
func (ccb *ccBalancerWrapper) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) {
ccb.cc.mu.Lock()
defer ccb.cc.mu.Unlock()
ccb.mu.Lock()
if ccb.closed {
ccb.mu.Unlock()
return nil, fmt.Errorf("balancer is being closed; no new SubConns allowed")
}
ccb.mu.Unlock()
if len(addrs) == 0 {
return nil, fmt.Errorf("grpc: cannot create SubConn with empty address list")
}
ac, err := ccb.cc.newAddrConnLocked(addrs, opts)
if err != nil {
channelz.Warningf(logger, ccb.cc.channelz, "acBalancerWrapper: NewSubConn: failed to newAddrConn: %v", err)
return nil, err
}
acbw := &acBalancerWrapper{
ccb: ccb,
ac: ac,
producers: make(map[balancer.ProducerBuilder]*refCountedProducer),
stateListener: opts.StateListener,
healthData: newHealthData(connectivity.Idle),
}
ac.acbw = acbw
return acbw, nil
}
func (ccb *ccBalancerWrapper) RemoveSubConn(balancer.SubConn) {
// The graceful switch balancer will never call this.
logger.Errorf("ccb RemoveSubConn(%v) called unexpectedly, sc")
}
func (ccb *ccBalancerWrapper) UpdateAddresses(sc balancer.SubConn, addrs []resolver.Address) {
acbw, ok := sc.(*acBalancerWrapper)
if !ok {
return
}
acbw.UpdateAddresses(addrs)
}
func (ccb *ccBalancerWrapper) UpdateState(s balancer.State) {
ccb.cc.mu.Lock()
defer ccb.cc.mu.Unlock()
if ccb.cc.conns == nil {
// The CC has been closed; ignore this update.
return
}
ccb.mu.Lock()
if ccb.closed {
ccb.mu.Unlock()
return
}
ccb.mu.Unlock()
// Update picker before updating state. Even though the ordering here does
// not matter, it can lead to multiple calls of Pick in the common start-up
// case where we wait for ready and then perform an RPC. If the picker is
// updated later, we could call the "connecting" picker when the state is
// updated, and then call the "ready" picker after the picker gets updated.
// Note that there is no need to check if the balancer wrapper was closed,
// as we know the graceful switch LB policy will not call cc if it has been
// closed.
ccb.cc.pickerWrapper.updatePicker(s.Picker)
ccb.cc.csMgr.updateState(s.ConnectivityState)
}
func (ccb *ccBalancerWrapper) ResolveNow(o resolver.ResolveNowOptions) {
ccb.cc.mu.RLock()
defer ccb.cc.mu.RUnlock()
ccb.mu.Lock()
if ccb.closed {
ccb.mu.Unlock()
return
}
ccb.mu.Unlock()
ccb.cc.resolveNowLocked(o)
}
func (ccb *ccBalancerWrapper) Target() string {
return ccb.cc.target
}
// acBalancerWrapper is a wrapper on top of ac for balancers.
// It implements balancer.SubConn interface.
type acBalancerWrapper struct {
internal.EnforceSubConnEmbedding
ac *addrConn // read-only
ccb *ccBalancerWrapper // read-only
stateListener func(balancer.SubConnState)
producersMu sync.Mutex
producers map[balancer.ProducerBuilder]*refCountedProducer
// Access to healthData is protected by healthMu.
healthMu sync.Mutex
// healthData is stored as a pointer to detect when the health listener is
// dropped or updated. This is required as closures can't be compared for
// equality.
healthData *healthData
}
// healthData holds data related to health state reporting.
type healthData struct {
// connectivityState stores the most recent connectivity state delivered
// to the LB policy. This is stored to avoid sending updates when the
// SubConn has already exited connectivity state READY.
connectivityState connectivity.State
// closeHealthProducer stores function to close the ref counted health
// producer. The health producer is automatically closed when the SubConn
// state changes.
closeHealthProducer func()
}
func newHealthData(s connectivity.State) *healthData {
return &healthData{
connectivityState: s,
closeHealthProducer: func() {},
}
}
// updateState is invoked by grpc to push a subConn state update to the
// underlying balancer.
func (acbw *acBalancerWrapper) updateState(s connectivity.State, curAddr resolver.Address, err error) {
acbw.ccb.serializer.TrySchedule(func(ctx context.Context) {
if ctx.Err() != nil || acbw.ccb.balancer == nil {
return
}
// Invalidate all producers on any state change.
acbw.closeProducers()
// Even though it is optional for balancers, gracefulswitch ensures
// opts.StateListener is set, so this cannot ever be nil.
// TODO: delete this comment when UpdateSubConnState is removed.
scs := balancer.SubConnState{ConnectivityState: s, ConnectionError: err}
if s == connectivity.Ready {
setConnectedAddress(&scs, curAddr)
}
// Invalidate the health listener by updating the healthData.
acbw.healthMu.Lock()
// A race may occur if a health listener is registered soon after the
// connectivity state is set but before the stateListener is called.
// Two cases may arise:
// 1. The new state is not READY: RegisterHealthListener has checks to
// ensure no updates are sent when the connectivity state is not
// READY.
// 2. The new state is READY: This means that the old state wasn't Ready.
// The RegisterHealthListener API mentions that a health listener
// must not be registered when a SubConn is not ready to avoid such
// races. When this happens, the LB policy would get health updates
// on the old listener. When the LB policy registers a new listener
// on receiving the connectivity update, the health updates will be
// sent to the new health listener.
acbw.healthData = newHealthData(scs.ConnectivityState)
acbw.healthMu.Unlock()
acbw.stateListener(scs)
})
}
func (acbw *acBalancerWrapper) String() string {
return fmt.Sprintf("SubConn(id:%d)", acbw.ac.channelz.ID)
}
func (acbw *acBalancerWrapper) UpdateAddresses(addrs []resolver.Address) {
acbw.ac.updateAddrs(addrs)
}
func (acbw *acBalancerWrapper) Connect() {
go acbw.ac.connect()
}
func (acbw *acBalancerWrapper) Shutdown() {
acbw.closeProducers()
acbw.ccb.cc.removeAddrConn(acbw.ac, errConnDrain)
}
// NewStream begins a streaming RPC on the addrConn. If the addrConn is not
// ready, blocks until it is or ctx expires. Returns an error when the context
// expires or the addrConn is shut down.
func (acbw *acBalancerWrapper) NewStream(ctx context.Context, desc *StreamDesc, method string, opts ...CallOption) (ClientStream, error) {
transport := acbw.ac.getReadyTransport()
if transport == nil {
return nil, status.Errorf(codes.Unavailable, "SubConn state is not Ready")
}
return newNonRetryClientStream(ctx, desc, method, transport, acbw.ac, opts...)
}
// Invoke performs a unary RPC. If the addrConn is not ready, returns
// errSubConnNotReady.
func (acbw *acBalancerWrapper) Invoke(ctx context.Context, method string, args any, reply any, opts ...CallOption) error {
cs, err := acbw.NewStream(ctx, unaryStreamDesc, method, opts...)
if err != nil {
return err
}
if err := cs.SendMsg(args); err != nil {
return err
}
return cs.RecvMsg(reply)
}
type refCountedProducer struct {
producer balancer.Producer
refs int // number of current refs to the producer
close func() // underlying producer's close function
}
func (acbw *acBalancerWrapper) GetOrBuildProducer(pb balancer.ProducerBuilder) (balancer.Producer, func()) {
acbw.producersMu.Lock()
defer acbw.producersMu.Unlock()
// Look up existing producer from this builder.
pData := acbw.producers[pb]
if pData == nil {
// Not found; create a new one and add it to the producers map.
p, closeFn := pb.Build(acbw)
pData = &refCountedProducer{producer: p, close: closeFn}
acbw.producers[pb] = pData
}
// Account for this new reference.
pData.refs++
// Return a cleanup function wrapped in a OnceFunc to remove this reference
// and delete the refCountedProducer from the map if the total reference
// count goes to zero.
unref := func() {
acbw.producersMu.Lock()
// If closeProducers has already closed this producer instance, refs is
// set to 0, so the check after decrementing will never pass, and the
// producer will not be double-closed.
pData.refs--
if pData.refs == 0 {
defer pData.close() // Run outside the acbw mutex
delete(acbw.producers, pb)
}
acbw.producersMu.Unlock()
}
return pData.producer, sync.OnceFunc(unref)
}
func (acbw *acBalancerWrapper) closeProducers() {
acbw.producersMu.Lock()
defer acbw.producersMu.Unlock()
for pb, pData := range acbw.producers {
pData.refs = 0
pData.close()
delete(acbw.producers, pb)
}
}
// healthProducerRegisterFn is a type alias for the health producer's function
// for registering listeners.
type healthProducerRegisterFn = func(context.Context, balancer.SubConn, string, func(balancer.SubConnState)) func()
// healthListenerRegFn returns a function to register a listener for health
// updates. If client side health checks are disabled, the registered listener
// will get a single READY (raw connectivity state) update.
//
// Client side health checking is enabled when all the following
// conditions are satisfied:
// 1. Health checking is not disabled using the dial option.
// 2. The health package is imported.
// 3. The health check config is present in the service config.
func (acbw *acBalancerWrapper) healthListenerRegFn() func(context.Context, func(balancer.SubConnState)) func() {
if acbw.ccb.cc.dopts.disableHealthCheck {
return noOpRegisterHealthListenerFn
}
regHealthLisFn := internal.RegisterClientHealthCheckListener
if regHealthLisFn == nil {
// The health package is not imported.
return noOpRegisterHealthListenerFn
}
cfg := acbw.ac.cc.healthCheckConfig()
if cfg == nil {
return noOpRegisterHealthListenerFn
}
return func(ctx context.Context, listener func(balancer.SubConnState)) func() {
return regHealthLisFn.(healthProducerRegisterFn)(ctx, acbw, cfg.ServiceName, listener)
}
}
// RegisterHealthListener accepts a health listener from the LB policy. It sends
// updates to the health listener as long as the SubConn's connectivity state
// doesn't change and a new health listener is not registered. To invalidate
// the currently registered health listener, acbw updates the healthData. If a
// nil listener is registered, the active health listener is dropped.
func (acbw *acBalancerWrapper) RegisterHealthListener(listener func(balancer.SubConnState)) {
acbw.healthMu.Lock()
defer acbw.healthMu.Unlock()
acbw.healthData.closeHealthProducer()
// listeners should not be registered when the connectivity state
// isn't Ready. This may happen when the balancer registers a listener
// after the connectivityState is updated, but before it is notified
// of the update.
if acbw.healthData.connectivityState != connectivity.Ready {
return
}
// Replace the health data to stop sending updates to any previously
// registered health listeners.
hd := newHealthData(connectivity.Ready)
acbw.healthData = hd
if listener == nil {
return
}
registerFn := acbw.healthListenerRegFn()
acbw.ccb.serializer.TrySchedule(func(ctx context.Context) {
if ctx.Err() != nil || acbw.ccb.balancer == nil {
return
}
// Don't send updates if a new listener is registered.
acbw.healthMu.Lock()
defer acbw.healthMu.Unlock()
if acbw.healthData != hd {
return
}
// Serialize the health updates from the health producer with
// other calls into the LB policy.
listenerWrapper := func(scs balancer.SubConnState) {
acbw.ccb.serializer.TrySchedule(func(ctx context.Context) {
if ctx.Err() != nil || acbw.ccb.balancer == nil {
return
}
acbw.healthMu.Lock()
defer acbw.healthMu.Unlock()
if acbw.healthData != hd {
return
}
listener(scs)
})
}
hd.closeHealthProducer = registerFn(ctx, listenerWrapper)
})
}

View File

@ -1,82 +0,0 @@
/*
*
* Copyright 2023 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc
import (
"fmt"
"strings"
"sync"
"testing"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/internal/balancer/stub"
"google.golang.org/grpc/internal/grpcsync"
)
// TestBalancer_StateListenerBeforeConnect tries to stimulate a race between
// NewSubConn and ClientConn.Close. In no cases should the SubConn's
// StateListener be invoked, because Connect was never called.
func (s) TestBalancer_StateListenerBeforeConnect(t *testing.T) {
// started is fired after cc is set so cc can be used in the balancer.
started := grpcsync.NewEvent()
var cc *ClientConn
wg := sync.WaitGroup{}
wg.Add(2)
// Create a balancer that calls NewSubConn and cc.Close at approximately the
// same time.
bf := stub.BalancerFuncs{
UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
go func() {
// Wait for cc to be valid after the channel is created.
<-started.Done()
// In a goroutine, create the subconn.
go func() {
_, err := bd.ClientConn.NewSubConn(ccs.ResolverState.Addresses, balancer.NewSubConnOptions{
StateListener: func(scs balancer.SubConnState) {
t.Error("Unexpected call to StateListener with:", scs)
},
})
if err != nil && !strings.Contains(err.Error(), "connection is closing") && !strings.Contains(err.Error(), "is deleted") && !strings.Contains(err.Error(), "is closed or idle") && !strings.Contains(err.Error(), "balancer is being closed") {
t.Error("Unexpected error creating subconn:", err)
}
wg.Done()
}()
// At approximately the same time, close the channel.
cc.Close()
wg.Done()
}()
return nil
},
}
stub.Register(t.Name(), bf)
svcCfg := fmt.Sprintf(`{ "loadBalancingConfig": [{%q: {}}] }`, t.Name())
cc, err := NewClient("passthrough:///test.server", WithTransportCredentials(insecure.NewCredentials()), WithDefaultServiceConfig(svcCfg))
if err != nil {
t.Fatalf("grpc.NewClient() failed: %v", err)
}
cc.Connect()
started.Fire()
// Wait for the LB policy to call NewSubConn and cc.Close.
wg.Wait()
}

View File

@ -47,13 +47,11 @@ import (
"fmt"
"io"
"log"
rand "math/rand/v2"
"net"
"os"
"reflect"
"runtime"
"runtime/pprof"
"strconv"
"strings"
"sync"
"sync/atomic"
@ -61,16 +59,14 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/benchmark"
bm "google.golang.org/grpc/benchmark"
"google.golang.org/grpc/benchmark/flags"
"google.golang.org/grpc/benchmark/latency"
"google.golang.org/grpc/benchmark/stats"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/encoding/gzip"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/test/bufconn"
@ -84,8 +80,7 @@ var (
traceMode = flags.StringWithAllowedValues("trace", toggleModeOff,
fmt.Sprintf("Trace mode - One of: %v", strings.Join(allToggleModes, ", ")), allToggleModes)
preloaderMode = flags.StringWithAllowedValues("preloader", toggleModeOff,
fmt.Sprintf("Preloader mode - One of: %v, preloader works only in streaming and unconstrained modes and will be ignored in unary mode",
strings.Join(allToggleModes, ", ")), allToggleModes)
fmt.Sprintf("Preloader mode - One of: %v", strings.Join(allToggleModes, ", ")), allToggleModes)
channelzOn = flags.StringWithAllowedValues("channelz", toggleModeOff,
fmt.Sprintf("Channelz mode - One of: %v", strings.Join(allToggleModes, ", ")), allToggleModes)
compressorMode = flags.StringWithAllowedValues("compression", compModeOff,
@ -110,15 +105,10 @@ var (
useBufconn = flag.Bool("bufconn", false, "Use in-memory connection instead of system network I/O")
enableKeepalive = flag.Bool("enable_keepalive", false, "Enable client keepalive. \n"+
"Keepalive.Time is set to 10s, Keepalive.Timeout is set to 1s, Keepalive.PermitWithoutStream is set to true.")
clientReadBufferSize = flags.IntSlice("clientReadBufferSize", []int{-1}, "Configures the client read buffer size in bytes. If negative, use the default - may be a comma-separated list")
clientWriteBufferSize = flags.IntSlice("clientWriteBufferSize", []int{-1}, "Configures the client write buffer size in bytes. If negative, use the default - may be a comma-separated list")
serverReadBufferSize = flags.IntSlice("serverReadBufferSize", []int{-1}, "Configures the server read buffer size in bytes. If negative, use the default - may be a comma-separated list")
serverWriteBufferSize = flags.IntSlice("serverWriteBufferSize", []int{-1}, "Configures the server write buffer size in bytes. If negative, use the default - may be a comma-separated list")
sleepBetweenRPCs = flags.DurationSlice("sleepBetweenRPCs", []time.Duration{0}, "Configures the maximum amount of time the client should sleep between consecutive RPCs - may be a comma-separated list")
connections = flag.Int("connections", 1, "The number of connections. Each connection will handle maxConcurrentCalls RPC streams")
recvBufferPool = flags.StringWithAllowedValues("recvBufferPool", recvBufferPoolNil, "Configures the shared receive buffer pool. One of: nil, simple, all", allRecvBufferPools)
sharedWriteBuffer = flags.StringWithAllowedValues("sharedWriteBuffer", toggleModeOff,
fmt.Sprintf("Configures both client and server to share write buffer - One of: %v", strings.Join(allToggleModes, ", ")), allToggleModes)
clientReadBufferSize = flags.IntSlice("clientReadBufferSize", []int{-1}, "Configures the client read buffer size in bytes. If negative, use the default - may be a a comma-separated list")
clientWriteBufferSize = flags.IntSlice("clientWriteBufferSize", []int{-1}, "Configures the client write buffer size in bytes. If negative, use the default - may be a a comma-separated list")
serverReadBufferSize = flags.IntSlice("serverReadBufferSize", []int{-1}, "Configures the server read buffer size in bytes. If negative, use the default - may be a a comma-separated list")
serverWriteBufferSize = flags.IntSlice("serverWriteBufferSize", []int{-1}, "Configures the server write buffer size in bytes. If negative, use the default - may be a a comma-separated list")
logger = grpclog.Component("benchmark")
)
@ -143,49 +133,17 @@ const (
networkModeLAN = "LAN"
networkModeWAN = "WAN"
networkLongHaul = "Longhaul"
// Shared recv buffer pool
recvBufferPoolNil = "nil"
recvBufferPoolSimple = "simple"
recvBufferPoolAll = "all"
numStatsBuckets = 10
warmupCallCount = 10
warmuptime = time.Second
)
var useNopBufferPool atomic.Bool
type swappableBufferPool struct {
mem.BufferPool
}
func (p swappableBufferPool) Get(length int) *[]byte {
var pool mem.BufferPool
if useNopBufferPool.Load() {
pool = mem.NopBufferPool{}
} else {
pool = p.BufferPool
}
return pool.Get(length)
}
func (p swappableBufferPool) Put(i *[]byte) {
if useNopBufferPool.Load() {
return
}
p.BufferPool.Put(i)
}
func init() {
internal.SetDefaultBufferPoolForTesting.(func(mem.BufferPool))(swappableBufferPool{mem.DefaultBufferPool()})
}
var (
allWorkloads = []string{workloadsUnary, workloadsStreaming, workloadsUnconstrained, workloadsAll}
allCompModes = []string{compModeOff, compModeGzip, compModeNop, compModeAll}
allToggleModes = []string{toggleModeOff, toggleModeOn, toggleModeBoth}
allNetworkModes = []string{networkModeNone, networkModeLocal, networkModeLAN, networkModeWAN, networkLongHaul}
allRecvBufferPools = []string{recvBufferPoolNil, recvBufferPoolSimple, recvBufferPoolAll}
defaultReadLatency = []time.Duration{0, 40 * time.Millisecond} // if non-positive, no delay.
defaultReadKbps = []int{0, 10240} // if non-positive, infinite
defaultReadMTU = []int{0} // if non-positive, infinite
@ -236,9 +194,9 @@ func runModesFromWorkloads(workload string) runModes {
type startFunc func(mode string, bf stats.Features)
type stopFunc func(count uint64)
type ucStopFunc func(req uint64, resp uint64)
type rpcCallFunc func(cn, pos int)
type rpcSendFunc func(cn, pos int)
type rpcRecvFunc func(cn, pos int)
type rpcCallFunc func(pos int)
type rpcSendFunc func(pos int)
type rpcRecvFunc func(pos int)
type rpcCleanupFunc func()
func unaryBenchmark(start startFunc, stop stopFunc, bf stats.Features, s *stats.Stats) {
@ -275,46 +233,40 @@ func unconstrainedStreamBenchmark(start startFunc, stop ucStopFunc, bf stats.Fea
bmEnd := time.Now().Add(bf.BenchTime + warmuptime)
var wg sync.WaitGroup
wg.Add(2 * bf.Connections * bf.MaxConcurrentCalls)
maxSleep := int(bf.SleepBetweenRPCs)
for cn := 0; cn < bf.Connections; cn++ {
for pos := 0; pos < bf.MaxConcurrentCalls; pos++ {
go func(cn, pos int) {
defer wg.Done()
for {
if maxSleep > 0 {
time.Sleep(time.Duration(rand.IntN(maxSleep)))
}
t := time.Now()
if t.After(bmEnd) {
return
}
sender(cn, pos)
atomic.AddUint64(&req, 1)
wg.Add(2 * bf.MaxConcurrentCalls)
for i := 0; i < bf.MaxConcurrentCalls; i++ {
go func(pos int) {
defer wg.Done()
for {
t := time.Now()
if t.After(bmEnd) {
return
}
}(cn, pos)
go func(cn, pos int) {
defer wg.Done()
for {
t := time.Now()
if t.After(bmEnd) {
return
}
recver(cn, pos)
atomic.AddUint64(&resp, 1)
sender(pos)
atomic.AddUint64(&req, 1)
}
}(i)
go func(pos int) {
defer wg.Done()
for {
t := time.Now()
if t.After(bmEnd) {
return
}
}(cn, pos)
}
recver(pos)
atomic.AddUint64(&resp, 1)
}
}(i)
}
wg.Wait()
stop(req, resp)
}
// makeClients returns a gRPC client (or multiple clients) for the grpc.testing.BenchmarkService
// makeClient returns a gRPC client for the grpc.testing.BenchmarkService
// service. The client is configured using the different options in the passed
// 'bf'. Also returns a cleanup function to close the client and release
// resources.
func makeClients(bf stats.Features) ([]testgrpc.BenchmarkServiceClient, func()) {
func makeClient(bf stats.Features) (testgrpc.BenchmarkServiceClient, func()) {
nw := &latency.Network{Kbps: bf.Kbps, Latency: bf.Latency, MTU: bf.MTU}
opts := []grpc.DialOption{}
sopts := []grpc.ServerOption{}
@ -329,8 +281,13 @@ func makeClients(bf stats.Features) ([]testgrpc.BenchmarkServiceClient, func())
)
}
if bf.ModeCompressor == compModeGzip {
sopts = append(sopts,
grpc.RPCCompressor(grpc.NewGZIPCompressor()),
grpc.RPCDecompressor(grpc.NewGZIPDecompressor()),
)
opts = append(opts,
grpc.WithDefaultCallOptions(grpc.UseCompressor(gzip.Name)),
grpc.WithCompressor(grpc.NewGZIPCompressor()),
grpc.WithDecompressor(grpc.NewGZIPDecompressor()),
)
}
if bf.EnableKeepalive {
@ -361,21 +318,9 @@ func makeClients(bf stats.Features) ([]testgrpc.BenchmarkServiceClient, func())
if bf.ServerReadBufferSize >= 0 {
sopts = append(sopts, grpc.ReadBufferSize(bf.ServerReadBufferSize))
}
if bf.SharedWriteBuffer {
opts = append(opts, grpc.WithSharedWriteBuffer(true))
sopts = append(sopts, grpc.SharedWriteBuffer(true))
}
if bf.ServerWriteBufferSize >= 0 {
sopts = append(sopts, grpc.WriteBufferSize(bf.ServerWriteBufferSize))
}
switch bf.RecvBufferPool {
case recvBufferPoolNil:
useNopBufferPool.Store(true)
case recvBufferPoolSimple:
// Do nothing as buffering is enabled by default.
default:
logger.Fatalf("Unknown shared recv buffer pool type: %v", bf.RecvBufferPool)
}
sopts = append(sopts, grpc.MaxConcurrentStreams(uint32(bf.MaxConcurrentCalls+1)))
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
@ -384,7 +329,7 @@ func makeClients(bf stats.Features) ([]testgrpc.BenchmarkServiceClient, func())
if bf.UseBufConn {
bcLis := bufconn.Listen(256 * 1024)
lis = bcLis
opts = append(opts, grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) {
opts = append(opts, grpc.WithContextDialer(func(ctx context.Context, address string) (net.Conn, error) {
return nw.ContextDialer(func(context.Context, string, string) (net.Conn, error) {
return bcLis.Dial()
})(ctx, "", "")
@ -395,30 +340,22 @@ func makeClients(bf stats.Features) ([]testgrpc.BenchmarkServiceClient, func())
if err != nil {
logger.Fatalf("Failed to listen: %v", err)
}
opts = append(opts, grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) {
return nw.ContextDialer((internal.NetDialerWithTCPKeepalive().DialContext))(ctx, "tcp", lis.Addr().String())
opts = append(opts, grpc.WithContextDialer(func(ctx context.Context, address string) (net.Conn, error) {
return nw.ContextDialer((&net.Dialer{}).DialContext)(ctx, "tcp", lis.Addr().String())
}))
}
lis = nw.Listener(lis)
stopper := benchmark.StartServer(benchmark.ServerInfo{Type: "protobuf", Listener: lis}, sopts...)
conns := make([]*grpc.ClientConn, bf.Connections)
clients := make([]testgrpc.BenchmarkServiceClient, bf.Connections)
for cn := 0; cn < bf.Connections; cn++ {
conns[cn] = benchmark.NewClientConn("passthrough://" /* target not used */, opts...)
clients[cn] = testgrpc.NewBenchmarkServiceClient(conns[cn])
}
return clients, func() {
for _, conn := range conns {
conn.Close()
}
stopper := bm.StartServer(bm.ServerInfo{Type: "protobuf", Listener: lis}, sopts...)
conn := bm.NewClientConn("" /* target not used */, opts...)
return testgrpc.NewBenchmarkServiceClient(conn), func() {
conn.Close()
stopper()
}
}
func makeFuncUnary(bf stats.Features) (rpcCallFunc, rpcCleanupFunc) {
clients, cleanup := makeClients(bf)
return func(cn, _ int) {
tc, cleanup := makeClient(bf)
return func(int) {
reqSizeBytes := bf.ReqSizeBytes
respSizeBytes := bf.RespSizeBytes
if bf.ReqPayloadCurve != nil {
@ -427,19 +364,23 @@ func makeFuncUnary(bf stats.Features) (rpcCallFunc, rpcCleanupFunc) {
if bf.RespPayloadCurve != nil {
respSizeBytes = bf.RespPayloadCurve.ChooseRandom()
}
unaryCaller(clients[cn], reqSizeBytes, respSizeBytes)
unaryCaller(tc, reqSizeBytes, respSizeBytes)
}, cleanup
}
func makeFuncStream(bf stats.Features) (rpcCallFunc, rpcCleanupFunc) {
streams, req, cleanup := setupStream(bf, false)
tc, cleanup := makeClient(bf)
var preparedMsg [][]*grpc.PreparedMsg
if bf.EnablePreloader {
preparedMsg = prepareMessages(streams, req)
streams := make([]testgrpc.BenchmarkService_StreamingCallClient, bf.MaxConcurrentCalls)
for i := 0; i < bf.MaxConcurrentCalls; i++ {
stream, err := tc.StreamingCall(context.Background())
if err != nil {
logger.Fatalf("%v.StreamingCall(_) = _, %v", tc, err)
}
streams[i] = stream
}
return func(cn, pos int) {
return func(pos int) {
reqSizeBytes := bf.ReqSizeBytes
respSizeBytes := bf.RespSizeBytes
if bf.ReqPayloadCurve != nil {
@ -448,69 +389,54 @@ func makeFuncStream(bf stats.Features) (rpcCallFunc, rpcCleanupFunc) {
if bf.RespPayloadCurve != nil {
respSizeBytes = bf.RespPayloadCurve.ChooseRandom()
}
var req any
if bf.EnablePreloader {
req = preparedMsg[cn][pos]
} else {
pl := benchmark.NewPayload(testpb.PayloadType_COMPRESSABLE, reqSizeBytes)
req = &testpb.SimpleRequest{
ResponseType: pl.Type,
ResponseSize: int32(respSizeBytes),
Payload: pl,
}
}
streamCaller(streams[cn][pos], req)
streamCaller(streams[pos], reqSizeBytes, respSizeBytes)
}, cleanup
}
func makeFuncUnconstrainedStreamPreloaded(bf stats.Features) (rpcSendFunc, rpcRecvFunc, rpcCleanupFunc) {
streams, req, cleanup := setupStream(bf, true)
streams, req, cleanup := setupUnconstrainedStream(bf)
preparedMsg := prepareMessages(streams, req)
preparedMsg := make([]*grpc.PreparedMsg, len(streams))
for i, stream := range streams {
preparedMsg[i] = &grpc.PreparedMsg{}
err := preparedMsg[i].Encode(stream, req)
if err != nil {
logger.Fatalf("%v.Encode(%v, %v) = %v", preparedMsg[i], req, stream, err)
}
}
return func(cn, pos int) {
streams[cn][pos].SendMsg(preparedMsg[cn][pos])
}, func(cn, pos int) {
streams[cn][pos].Recv()
return func(pos int) {
streams[pos].SendMsg(preparedMsg[pos])
}, func(pos int) {
streams[pos].Recv()
}, cleanup
}
func makeFuncUnconstrainedStream(bf stats.Features) (rpcSendFunc, rpcRecvFunc, rpcCleanupFunc) {
streams, req, cleanup := setupStream(bf, true)
streams, req, cleanup := setupUnconstrainedStream(bf)
return func(cn, pos int) {
streams[cn][pos].Send(req)
}, func(cn, pos int) {
streams[cn][pos].Recv()
return func(pos int) {
streams[pos].Send(req)
}, func(pos int) {
streams[pos].Recv()
}, cleanup
}
func setupStream(bf stats.Features, unconstrained bool) ([][]testgrpc.BenchmarkService_StreamingCallClient, *testpb.SimpleRequest, rpcCleanupFunc) {
clients, cleanup := makeClients(bf)
func setupUnconstrainedStream(bf stats.Features) ([]testgrpc.BenchmarkService_StreamingCallClient, *testpb.SimpleRequest, rpcCleanupFunc) {
tc, cleanup := makeClient(bf)
streams := make([][]testgrpc.BenchmarkService_StreamingCallClient, bf.Connections)
ctx := context.Background()
if unconstrained {
md := metadata.Pairs(benchmark.UnconstrainedStreamingHeader, "1", benchmark.UnconstrainedStreamingDelayHeader, bf.SleepBetweenRPCs.String())
ctx = metadata.NewOutgoingContext(ctx, md)
}
if bf.EnablePreloader {
md := metadata.Pairs(benchmark.PreloadMsgSizeHeader, strconv.Itoa(bf.RespSizeBytes), benchmark.UnconstrainedStreamingDelayHeader, bf.SleepBetweenRPCs.String())
ctx = metadata.NewOutgoingContext(ctx, md)
}
for cn := 0; cn < bf.Connections; cn++ {
tc := clients[cn]
streams[cn] = make([]testgrpc.BenchmarkService_StreamingCallClient, bf.MaxConcurrentCalls)
for pos := 0; pos < bf.MaxConcurrentCalls; pos++ {
stream, err := tc.StreamingCall(ctx)
if err != nil {
logger.Fatalf("%v.StreamingCall(_) = _, %v", tc, err)
}
streams[cn][pos] = stream
streams := make([]testgrpc.BenchmarkService_StreamingCallClient, bf.MaxConcurrentCalls)
md := metadata.Pairs(benchmark.UnconstrainedStreamingHeader, "1")
ctx := metadata.NewOutgoingContext(context.Background(), md)
for i := 0; i < bf.MaxConcurrentCalls; i++ {
stream, err := tc.StreamingCall(ctx)
if err != nil {
logger.Fatalf("%v.StreamingCall(_) = _, %v", tc, err)
}
streams[i] = stream
}
pl := benchmark.NewPayload(testpb.PayloadType_COMPRESSABLE, bf.ReqSizeBytes)
pl := bm.NewPayload(testpb.PayloadType_COMPRESSABLE, bf.ReqSizeBytes)
req := &testpb.SimpleRequest{
ResponseType: pl.Type,
ResponseSize: int32(bf.RespSizeBytes),
@ -520,74 +446,47 @@ func setupStream(bf stats.Features, unconstrained bool) ([][]testgrpc.BenchmarkS
return streams, req, cleanup
}
func prepareMessages(streams [][]testgrpc.BenchmarkService_StreamingCallClient, req *testpb.SimpleRequest) [][]*grpc.PreparedMsg {
preparedMsg := make([][]*grpc.PreparedMsg, len(streams))
for cn, connStreams := range streams {
preparedMsg[cn] = make([]*grpc.PreparedMsg, len(connStreams))
for pos, stream := range connStreams {
preparedMsg[cn][pos] = &grpc.PreparedMsg{}
if err := preparedMsg[cn][pos].Encode(stream, req); err != nil {
logger.Fatalf("%v.Encode(%v, %v) = %v", preparedMsg[cn][pos], req, stream, err)
}
}
}
return preparedMsg
}
// Makes a UnaryCall gRPC request using the given BenchmarkServiceClient and
// request and response sizes.
func unaryCaller(client testgrpc.BenchmarkServiceClient, reqSize, respSize int) {
if err := benchmark.DoUnaryCall(client, reqSize, respSize); err != nil {
if err := bm.DoUnaryCall(client, reqSize, respSize); err != nil {
logger.Fatalf("DoUnaryCall failed: %v", err)
}
}
func streamCaller(stream testgrpc.BenchmarkService_StreamingCallClient, req any) {
if err := benchmark.DoStreamingRoundTripPreloaded(stream, req); err != nil {
func streamCaller(stream testgrpc.BenchmarkService_StreamingCallClient, reqSize, respSize int) {
if err := bm.DoStreamingRoundTrip(stream, reqSize, respSize); err != nil {
logger.Fatalf("DoStreamingRoundTrip failed: %v", err)
}
}
func runBenchmark(caller rpcCallFunc, start startFunc, stop stopFunc, bf stats.Features, s *stats.Stats, mode string) {
// if SleepBetweenRPCs > 0 we skip the warmup because otherwise
// we are going to send a set of simultaneous requests on every connection,
// which is something we are trying to avoid when using SleepBetweenRPCs.
if bf.SleepBetweenRPCs == 0 {
// Warm up connections.
for i := 0; i < warmupCallCount; i++ {
for cn := 0; cn < bf.Connections; cn++ {
caller(cn, 0)
}
}
// Warm up connection.
for i := 0; i < warmupCallCount; i++ {
caller(0)
}
// Run benchmark.
start(mode, bf)
var wg sync.WaitGroup
wg.Add(bf.Connections * bf.MaxConcurrentCalls)
wg.Add(bf.MaxConcurrentCalls)
bmEnd := time.Now().Add(bf.BenchTime)
maxSleep := int(bf.SleepBetweenRPCs)
var count uint64
for cn := 0; cn < bf.Connections; cn++ {
for pos := 0; pos < bf.MaxConcurrentCalls; pos++ {
go func(cn, pos int) {
defer wg.Done()
for {
if maxSleep > 0 {
time.Sleep(time.Duration(rand.IntN(maxSleep)))
}
t := time.Now()
if t.After(bmEnd) {
return
}
start := time.Now()
caller(cn, pos)
elapse := time.Since(start)
atomic.AddUint64(&count, 1)
s.AddDuration(elapse)
for i := 0; i < bf.MaxConcurrentCalls; i++ {
go func(pos int) {
defer wg.Done()
for {
t := time.Now()
if t.After(bmEnd) {
return
}
}(cn, pos)
}
start := time.Now()
caller(pos)
elapse := time.Since(start)
atomic.AddUint64(&count, 1)
s.AddDuration(elapse)
}
}(i)
}
wg.Wait()
stop(count)
@ -605,7 +504,6 @@ type benchOpts struct {
benchmarkResultFile string
useBufconn bool
enableKeepalive bool
connections int
features *featureOpts
}
@ -630,9 +528,6 @@ type featureOpts struct {
clientWriteBufferSize []int
serverReadBufferSize []int
serverWriteBufferSize []int
sleepBetweenRPCs []time.Duration
recvBufferPools []string
sharedWriteBuffer []bool
}
// makeFeaturesNum returns a slice of ints of size 'maxFeatureIndex' where each
@ -677,12 +572,6 @@ func makeFeaturesNum(b *benchOpts) []int {
featuresNum[i] = len(b.features.serverReadBufferSize)
case stats.ServerWriteBufferSize:
featuresNum[i] = len(b.features.serverWriteBufferSize)
case stats.SleepBetweenRPCs:
featuresNum[i] = len(b.features.sleepBetweenRPCs)
case stats.RecvBufferPool:
featuresNum[i] = len(b.features.recvBufferPools)
case stats.SharedWriteBuffer:
featuresNum[i] = len(b.features.sharedWriteBuffer)
default:
log.Fatalf("Unknown feature index %v in generateFeatures. maxFeatureIndex is %v", i, stats.MaxFeatureIndex)
}
@ -736,7 +625,6 @@ func (b *benchOpts) generateFeatures(featuresNum []int) []stats.Features {
UseBufConn: b.useBufconn,
EnableKeepalive: b.enableKeepalive,
BenchTime: b.benchTime,
Connections: b.connections,
// These features can potentially change for each iteration.
EnableTrace: b.features.enableTrace[curPos[stats.EnableTraceIndex]],
Latency: b.features.readLatencies[curPos[stats.ReadLatenciesIndex]],
@ -750,9 +638,6 @@ func (b *benchOpts) generateFeatures(featuresNum []int) []stats.Features {
ClientWriteBufferSize: b.features.clientWriteBufferSize[curPos[stats.ClientWriteBufferSize]],
ServerReadBufferSize: b.features.serverReadBufferSize[curPos[stats.ServerReadBufferSize]],
ServerWriteBufferSize: b.features.serverWriteBufferSize[curPos[stats.ServerWriteBufferSize]],
SleepBetweenRPCs: b.features.sleepBetweenRPCs[curPos[stats.SleepBetweenRPCs]],
RecvBufferPool: b.features.recvBufferPools[curPos[stats.RecvBufferPool]],
SharedWriteBuffer: b.features.sharedWriteBuffer[curPos[stats.SharedWriteBuffer]],
}
if len(b.features.reqPayloadCurves) == 0 {
f.ReqSizeBytes = b.features.reqSizeBytes[curPos[stats.ReqSizeBytesIndex]]
@ -808,7 +693,6 @@ func processFlags() *benchOpts {
benchmarkResultFile: *benchmarkResultFile,
useBufconn: *useBufconn,
enableKeepalive: *enableKeepalive,
connections: *connections,
features: &featureOpts{
enableTrace: setToggleMode(*traceMode),
readLatencies: append([]time.Duration(nil), *readLatency...),
@ -824,9 +708,6 @@ func processFlags() *benchOpts {
clientWriteBufferSize: append([]int(nil), *clientWriteBufferSize...),
serverReadBufferSize: append([]int(nil), *serverReadBufferSize...),
serverWriteBufferSize: append([]int(nil), *serverWriteBufferSize...),
sleepBetweenRPCs: append([]time.Duration(nil), *sleepBetweenRPCs...),
recvBufferPools: setRecvBufferPool(*recvBufferPool),
sharedWriteBuffer: setToggleMode(*sharedWriteBuffer),
},
}
@ -838,9 +719,6 @@ func processFlags() *benchOpts {
if len(opts.features.reqSizeBytes) != 0 {
log.Fatalf("you may not specify -reqPayloadCurveFiles and -reqSizeBytes at the same time")
}
if len(opts.features.enablePreloader) != 0 {
log.Fatalf("you may not specify -reqPayloadCurveFiles and -preloader at the same time")
}
for _, file := range *reqPayloadCurveFiles {
pc, err := stats.NewPayloadCurve(file)
if err != nil {
@ -858,9 +736,6 @@ func processFlags() *benchOpts {
if len(opts.features.respSizeBytes) != 0 {
log.Fatalf("you may not specify -respPayloadCurveFiles and -respSizeBytes at the same time")
}
if len(opts.features.enablePreloader) != 0 {
log.Fatalf("you may not specify -respPayloadCurveFiles and -preloader at the same time")
}
for _, file := range *respPayloadCurveFiles {
pc, err := stats.NewPayloadCurve(file)
if err != nil {
@ -908,19 +783,6 @@ func setCompressorMode(val string) []string {
}
}
func setRecvBufferPool(val string) []string {
switch val {
case recvBufferPoolNil, recvBufferPoolSimple:
return []string{val}
case recvBufferPoolAll:
return []string{recvBufferPoolNil, recvBufferPoolSimple}
default:
// This should never happen because a wrong value passed to this flag would
// be caught during flag.Parse().
return []string{}
}
}
func main() {
opts := processFlags()
before(opts)

Some files were not shown because too many files have changed in this diff Show More