va: replacing error assertions with errors.As (#5136)

errors.As checks for a specific error in a wrapped error chain
(see https://golang.org/pkg/errors/#As) as opposed to asserting
that an error is of a specific type.

Part of #5010
This commit is contained in:
Samantha 2020-10-30 15:51:29 -07:00 committed by GitHub
parent 4c96164e53
commit 387e94407c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 31 deletions

View File

@ -3,6 +3,7 @@ package va
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"errors"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@ -414,18 +415,10 @@ func fallbackErr(err error) bool {
if err == nil { if err == nil {
return false return false
} }
// Net OpErrors are fallback errs only if the operation was a "dial"
switch err := err.(type) { // All other errs are not fallback errs
case *url.Error: var netOpError *net.OpError
// URL Errors should be unwrapped and tested return errors.As(err, &netOpError) && netOpError.Op == "dial"
return fallbackErr(err.Err)
case *net.OpError:
// Net OpErrors are fallback errs only if the operation was a "dial"
return err.Op == "dial"
default:
// All other errs are not fallback errs
return false
}
} }
// processHTTPValidation performs an HTTP validation for the given host, port // processHTTPValidation performs an HTTP validation for the given host, port

View File

@ -254,38 +254,43 @@ type verificationRequestEvent struct {
// passing through the detailed message. // passing through the detailed message.
func detailedError(err error) *probs.ProblemDetails { func detailedError(err error) *probs.ProblemDetails {
// net/http wraps net.OpError in a url.Error. Unwrap them. // net/http wraps net.OpError in a url.Error. Unwrap them.
if urlErr, ok := err.(*url.Error); ok { var urlErr *url.Error
if errors.As(err, &urlErr) {
prob := detailedError(urlErr.Err) prob := detailedError(urlErr.Err)
prob.Detail = fmt.Sprintf("Fetching %s: %s", urlErr.URL, prob.Detail) prob.Detail = fmt.Sprintf("Fetching %s: %s", urlErr.URL, prob.Detail)
return prob return prob
} }
if tlsErr, ok := err.(tls.RecordHeaderError); ok && bytes.Compare(tlsErr.RecordHeader[:], badTLSHeader) == 0 { var tlsErr tls.RecordHeaderError
if errors.As(err, &tlsErr) && bytes.Compare(tlsErr.RecordHeader[:], badTLSHeader) == 0 {
return probs.Malformed("Server only speaks HTTP, not TLS") return probs.Malformed("Server only speaks HTTP, not TLS")
} }
var netErr *net.OpError var netOpErr *net.OpError
if errors.As(err, &netErr) { if errors.As(err, &netOpErr) {
if fmt.Sprintf("%T", netErr.Err) == "tls.alert" { if fmt.Sprintf("%T", netOpErr.Err) == "tls.alert" {
// All the tls.alert error strings are reasonable to hand back to a // All the tls.alert error strings are reasonable to hand back to a
// user. Confirmed against Go 1.8. // user. Confirmed against Go 1.8.
return probs.TLSError(netErr.Error()) return probs.TLSError(netOpErr.Error())
} else if syscallErr, ok := netErr.Err.(*os.SyscallError); ok && } else if netOpErr.Timeout() && netOpErr.Op == "dial" {
syscallErr.Err == syscall.ECONNREFUSED {
return probs.ConnectionFailure("Connection refused")
} else if syscallErr, ok := netErr.Err.(*os.SyscallError); ok &&
syscallErr.Err == syscall.ENETUNREACH {
return probs.ConnectionFailure("Network unreachable")
} else if syscallErr, ok := netErr.Err.(*os.SyscallError); ok &&
syscallErr.Err == syscall.ECONNRESET {
return probs.ConnectionFailure("Connection reset by peer")
} else if netErr.Timeout() && netErr.Op == "dial" {
return probs.ConnectionFailure("Timeout during connect (likely firewall problem)") return probs.ConnectionFailure("Timeout during connect (likely firewall problem)")
} else if netErr.Timeout() { } else if netOpErr.Timeout() {
return probs.ConnectionFailure(fmt.Sprintf("Timeout during %s (your server may be slow or overloaded)", netErr.Op)) return probs.ConnectionFailure(fmt.Sprintf("Timeout during %s (your server may be slow or overloaded)", netOpErr.Op))
} }
} }
if err, ok := err.(net.Error); ok && err.Timeout() { var syscallErr *os.SyscallError
if errors.As(err, &syscallErr) {
switch syscallErr.Err {
case syscall.ECONNREFUSED:
return probs.ConnectionFailure("Connection refused")
case syscall.ENETUNREACH:
return probs.ConnectionFailure("Network unreachable")
case syscall.ECONNRESET:
return probs.ConnectionFailure("Connection reset by peer")
}
}
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
return probs.ConnectionFailure("Timeout after connect (your server may be slow or overloaded)") return probs.ConnectionFailure("Timeout after connect (your server may be slow or overloaded)")
} }
if berrors.Is(err, berrors.ConnectionFailure) { if berrors.Is(err, berrors.ConnectionFailure) {