boulder/vendor/github.com/google/safebrowsing/database.go

536 lines
15 KiB
Go

// Copyright 2016 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package safebrowsing
import (
"bytes"
"compress/gzip"
"encoding/gob"
"errors"
"log"
"math/rand"
"os"
"sync"
"time"
pb "github.com/google/safebrowsing/internal/safebrowsing_proto"
)
// jitter is the maximum amount of time that we expect an API list update to
// actually take. We add this time to the update period time to give some
// leeway before declaring the database as stale.
const (
maxRetryDelay = 24 * time.Hour
baseRetryDelay = 15 * time.Minute
jitter = 30 * time.Second
)
// database tracks the state of the threat lists published by the Safe Browsing
// API. Since the global blacklist is constantly changing, the contents of the
// database needs to be periodically synced with the Safe Browsing servers in
// order to provide protection for the latest threats.
//
// The process for updating the database is as follows:
// * At startup, if a database file is provided, then load it. If loaded
// properly (not corrupted and not stale), then set tfu as the contents.
// Otherwise, pull a new threat list from the Safe Browsing API.
// * Periodically, synchronize the database with the Safe Browsing API.
// This uses the State fields to update only parts of the threat list that have
// changed since the last sync.
// * Anytime tfu is updated, generate a new tfl.
//
// The process for querying the database is as follows:
// * Check if the requested full hash matches any partial hash in tfl.
// If a match is found, return a set of ThreatDescriptors with a partial match.
type database struct {
config *Config
// threatsForUpdate maps ThreatDescriptors to lists of partial hashes.
// This data structure is in a format that is easily updated by the API.
// It is also the form that is written to disk.
tfu threatsForUpdate
mu sync.Mutex // Protects tfu
// threatsForLookup maps ThreatDescriptors to sets of partial hashes.
// This data structure is in a format that is easily queried.
tfl threatsForLookup
ml sync.RWMutex // Protects tfl, err, and last
err error // Last error encountered
readyCh chan struct{} // Used for waiting until not in an error state.
last time.Time // Last time the threat list were synced
updateAPIErrors uint // Number of times we attempted to contact the api and failed
log *log.Logger
}
type threatsForUpdate map[ThreatDescriptor]partialHashes
type partialHashes struct {
// Since the Hashes field is only needed when storing to disk and when
// updating, this field is cleared except for when it is in use.
// This is done to reduce memory usage as the contents of this can be
// regenerated from the tfl.
Hashes hashPrefixes
SHA256 []byte // The SHA256 over Hashes
State []byte // Arbitrary binary blob to synchronize state with API
}
type threatsForLookup map[ThreatDescriptor]hashSet
// databaseFormat is a light struct used only for gob encoding and decoding.
// As written to disk, the format of the database file is basically the gzip
// compressed version of the gob encoding of databaseFormat.
type databaseFormat struct {
Table threatsForUpdate
Time time.Time
}
// Init initializes the database from the specified file in config.DBPath.
// It reports true if the database was successfully loaded.
func (db *database) Init(config *Config, logger *log.Logger) bool {
db.mu.Lock()
defer db.mu.Unlock()
db.setError(errors.New("not intialized"))
db.config = config
db.log = logger
if db.config.DBPath == "" {
db.log.Printf("no database file specified")
db.setError(errors.New("no database loaded"))
return false
}
dbf, err := loadDatabase(db.config.DBPath)
if err != nil {
db.log.Printf("load failure: %v", err)
db.setError(err)
return false
}
// Validate that the database threat list stored on disk is not too stale.
if db.isStale(dbf.Time) {
db.log.Printf("database loaded is stale")
db.ml.Lock()
defer db.ml.Unlock()
db.setStale()
return false
}
// Validate that the database threat list stored on disk is at least a
// superset of the specified configuration.
tfuNew := make(threatsForUpdate)
for _, td := range db.config.ThreatLists {
if row, ok := dbf.Table[td]; ok {
tfuNew[td] = row
} else {
db.log.Printf("database configuration mismatch, missing %v", td)
db.setError(errors.New("database configuration mismatch"))
return false
}
}
db.tfu = tfuNew
db.generateThreatsForLookups(dbf.Time)
return true
}
// Status reports the health of the database. The database is considered faulted
// if there was an error during update or if the last update has gone stale. If
// in a faulted state, the db may repair itself on the next Update.
func (db *database) Status() error {
db.ml.RLock()
defer db.ml.RUnlock()
if db.err != nil {
return db.err
}
if db.isStale(db.last) {
db.setStale()
return db.err
}
return nil
}
// UpdateLag reports the amount of time in between when we expected to run
// a database update and the current time
func (db *database) UpdateLag() time.Duration {
lag := db.SinceLastUpdate()
if lag < db.config.UpdatePeriod {
return 0
}
return lag - db.config.UpdatePeriod
}
// SinceLastUpdate gives the duration since the last database update
func (db *database) SinceLastUpdate() time.Duration {
db.ml.RLock()
defer db.ml.RUnlock()
return db.config.now().Sub(db.last)
}
// Ready returns a channel that's closed when the database is ready for queries.
func (db *database) Ready() <-chan struct{} {
return db.readyCh
}
// Update synchronizes the local threat lists with those maintained by the
// global Safe Browsing API servers. If the update is successful, Status should
// report a nil error.
func (db *database) Update(api api) (time.Duration, bool) {
db.mu.Lock()
defer db.mu.Unlock()
// Construct the request.
var numTypes int
var s []*pb.FetchThreatListUpdatesRequest_ListUpdateRequest
for _, td := range db.config.ThreatLists {
var state []byte
if row, ok := db.tfu[td]; ok {
state = row.State
}
s = append(s, &pb.FetchThreatListUpdatesRequest_ListUpdateRequest{
ThreatType: pb.ThreatType(td.ThreatType),
PlatformType: pb.PlatformType(td.PlatformType),
ThreatEntryType: pb.ThreatEntryType(td.ThreatEntryType),
Constraints: &pb.FetchThreatListUpdatesRequest_ListUpdateRequest_Constraints{
SupportedCompressions: db.config.compressionTypes},
State: state,
})
numTypes++
}
req := &pb.FetchThreatListUpdatesRequest{
Client: &pb.ClientInfo{
ClientId: db.config.ID,
ClientVersion: db.config.Version,
},
ListUpdateRequests: s,
}
// Query the API for the threat list and update the database.
last := db.config.now()
resp, err := api.ListUpdate(req)
if err != nil {
db.log.Printf("ListUpdate failure (%d): %v", db.updateAPIErrors+1, err)
db.setError(err)
// backoff strategy: MIN((2**N-1 * 15 minutes) * (RAND + 1), 24 hours)
n := 1 << db.updateAPIErrors
delay := time.Duration(float64(n) * (rand.Float64() + 1) * float64(baseRetryDelay))
if delay > maxRetryDelay {
delay = maxRetryDelay
}
db.updateAPIErrors++
return delay, false
}
db.updateAPIErrors = 0
// add jitter to wait time to avoid all servers lining up
nextUpdateWait := db.config.UpdatePeriod + time.Duration(rand.Int31n(60)-30)*time.Second
if resp.MinimumWaitDuration != nil {
serverMinWait := time.Duration(resp.MinimumWaitDuration.Seconds)*time.Second + time.Duration(resp.MinimumWaitDuration.Nanos)
if serverMinWait > nextUpdateWait {
nextUpdateWait = serverMinWait
db.log.Printf("Server requested next update in %v", nextUpdateWait)
}
}
if len(resp.ListUpdateResponses) != numTypes {
db.setError(errors.New("safebrowsing: threat list count mismatch"))
db.log.Printf("invalid server response: got %d, want %d threat lists",
len(resp.ListUpdateResponses), numTypes)
return nextUpdateWait, false
}
// Update the threat database with the response.
db.generateThreatsForUpdate()
if err := db.tfu.update(resp); err != nil {
db.setError(err)
db.log.Printf("update failure: %v", err)
db.tfu = nil
return nextUpdateWait, false
}
dbf := databaseFormat{make(threatsForUpdate), last}
for td, phs := range db.tfu {
// Copy of partialHashes before generateThreatsForLookups clobbers it.
dbf.Table[td] = phs
}
db.generateThreatsForLookups(last)
// Regenerate the database and store it.
if db.config.DBPath != "" {
// Semantically, we ignore save errors, but we do log them.
if err := saveDatabase(db.config.DBPath, dbf); err != nil {
db.log.Printf("save failure: %v", err)
}
}
return nextUpdateWait, true
}
// Lookup looks up the full hash in the threat list and returns a partial
// hash and a set of ThreatDescriptors that may match the full hash.
func (db *database) Lookup(hash hashPrefix) (h hashPrefix, tds []ThreatDescriptor) {
if !hash.IsFull() {
panic("hash is not full")
}
db.ml.RLock()
for td, hs := range db.tfl {
if n := hs.Lookup(hash); n > 0 {
h = hash[:n]
tds = append(tds, td)
}
}
db.ml.RUnlock()
return h, tds
}
// setError clears the database state and sets the last error to be err.
//
// This assumes that the db.mu lock is already held.
func (db *database) setError(err error) {
db.tfu = nil
db.ml.Lock()
if db.err == nil {
db.readyCh = make(chan struct{})
}
db.tfl, db.err, db.last = nil, err, time.Time{}
db.ml.Unlock()
}
// isStale checks whether the last successful update should be considered stale.
// Staleness is defined as being older than two of the configured update periods
// plus jitter.
func (db *database) isStale(lastUpdate time.Time) bool {
if db.config.now().Sub(lastUpdate) > 2*(db.config.UpdatePeriod+jitter) {
return true
}
return false
}
// setStale sets the error state to a stale message, without clearing
// the database state.
//
// This assumes that the db.ml lock is already held.
func (db *database) setStale() {
if db.err == nil {
db.readyCh = make(chan struct{})
}
db.err = errStale
}
// clearError clears the db error state, and unblocks any callers of
// WaitUntilReady.
//
// This assumes that the db.mu lock is already held.
func (db *database) clearError() {
db.ml.Lock()
defer db.ml.Unlock()
if db.err != nil {
close(db.readyCh)
}
db.err = nil
}
// generateThreatsForUpdate regenerates the threatsForUpdate hashes from
// the threatsForLookup. We do this to avoid holding onto the hash lists for
// a long time, needlessly occupying lots of memory.
//
// This assumes that the db.mu lock is already held.
func (db *database) generateThreatsForUpdate() {
if db.tfu == nil {
db.tfu = make(threatsForUpdate)
}
db.ml.RLock()
for td, hs := range db.tfl {
phs := db.tfu[td]
phs.Hashes = hs.Export()
db.tfu[td] = phs
}
db.ml.RUnlock()
}
// generateThreatsForLookups regenerates the threatsForLookup data structure
// from the threatsForUpdate data structure and stores the last timestamp.
// Since the hashes are effectively stored as a set inside the threatsForLookup,
// we clear out the hashes slice in threatsForUpdate so that it can be GCed.
//
// This assumes that the db.mu lock is already held.
func (db *database) generateThreatsForLookups(last time.Time) {
tfl := make(threatsForLookup)
for td, phs := range db.tfu {
var hs hashSet
hs.Import(phs.Hashes)
tfl[td] = hs
phs.Hashes = nil // Clear hashes to keep memory usage low
db.tfu[td] = phs
}
db.ml.Lock()
wasBad := db.err != nil
db.tfl, db.last = tfl, last
db.ml.Unlock()
if wasBad {
db.clearError()
db.log.Printf("database is now healthy")
}
}
// saveDatabase saves the database threat list to a file.
func saveDatabase(path string, db databaseFormat) (err error) {
var file *os.File
file, err = os.Create(path)
if err != nil {
return err
}
defer func() {
if cerr := file.Close(); err == nil {
err = cerr
}
}()
gz, err := gzip.NewWriterLevel(file, gzip.BestCompression)
if err != nil {
return err
}
defer func() {
if zerr := gz.Close(); err == nil {
err = zerr
}
}()
encoder := gob.NewEncoder(gz)
if err = encoder.Encode(db); err != nil {
return err
}
return nil
}
// loadDatabase loads the database state from a file.
func loadDatabase(path string) (db databaseFormat, err error) {
var file *os.File
file, err = os.Open(path)
if err != nil {
return db, err
}
defer func() {
if cerr := file.Close(); err == nil {
err = cerr
}
}()
gz, err := gzip.NewReader(file)
if err != nil {
return db, err
}
defer func() {
if zerr := gz.Close(); err == nil {
err = zerr
}
}()
decoder := gob.NewDecoder(gz)
if err = decoder.Decode(&db); err != nil {
return db, err
}
for _, dv := range db.Table {
if !bytes.Equal(dv.SHA256, dv.Hashes.SHA256()) {
return db, errors.New("safebrowsing: threat list SHA256 mismatch")
}
}
return db, nil
}
// update updates the threat list according to the API response.
func (tfu threatsForUpdate) update(resp *pb.FetchThreatListUpdatesResponse) error {
// For each update response do the removes and adds
for _, m := range resp.GetListUpdateResponses() {
td := ThreatDescriptor{
PlatformType: PlatformType(m.PlatformType),
ThreatType: ThreatType(m.ThreatType),
ThreatEntryType: ThreatEntryType(m.ThreatEntryType),
}
phs, ok := tfu[td]
switch m.ResponseType {
case pb.FetchThreatListUpdatesResponse_ListUpdateResponse_PARTIAL_UPDATE:
if !ok {
return errors.New("safebrowsing: partial update received for non-existent key")
}
case pb.FetchThreatListUpdatesResponse_ListUpdateResponse_FULL_UPDATE:
if len(m.Removals) > 0 {
return errors.New("safebrowsing: indices to be removed included in a full update")
}
phs = partialHashes{}
default:
return errors.New("safebrowsing: unknown response type")
}
// Hashes must be sorted for removal logic to work properly.
phs.Hashes.Sort()
for _, removal := range m.Removals {
idxs, err := decodeIndices(removal)
if err != nil {
return err
}
for _, i := range idxs {
if i < 0 || i >= int32(len(phs.Hashes)) {
return errors.New("safebrowsing: invalid removal index")
}
phs.Hashes[i] = ""
}
}
// If any removal was performed, compact the list of hashes.
if len(m.Removals) > 0 {
compactHashes := phs.Hashes[:0]
for _, h := range phs.Hashes {
if h != "" {
compactHashes = append(compactHashes, h)
}
}
phs.Hashes = compactHashes
}
for _, addition := range m.Additions {
hashes, err := decodeHashes(addition)
if err != nil {
return err
}
phs.Hashes = append(phs.Hashes, hashes...)
}
// Hashes must be sorted for SHA256 checksum to be correct.
phs.Hashes.Sort()
if err := phs.Hashes.Validate(); err != nil {
return err
}
if cs := m.GetChecksum(); cs != nil {
phs.SHA256 = cs.Sha256
}
if !bytes.Equal(phs.SHA256, phs.Hashes.SHA256()) {
return errors.New("safebrowsing: threat list SHA256 mismatch")
}
phs.State = m.NewClientState
tfu[td] = phs
}
return nil
}