536 lines
15 KiB
Go
536 lines
15 KiB
Go
// Copyright 2016 Google Inc. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package safebrowsing
|
|
|
|
import (
|
|
"bytes"
|
|
"compress/gzip"
|
|
"encoding/gob"
|
|
"errors"
|
|
"log"
|
|
"math/rand"
|
|
"os"
|
|
"sync"
|
|
"time"
|
|
|
|
pb "github.com/google/safebrowsing/internal/safebrowsing_proto"
|
|
)
|
|
|
|
// jitter is the maximum amount of time that we expect an API list update to
|
|
// actually take. We add this time to the update period time to give some
|
|
// leeway before declaring the database as stale.
|
|
const (
|
|
maxRetryDelay = 24 * time.Hour
|
|
baseRetryDelay = 15 * time.Minute
|
|
jitter = 30 * time.Second
|
|
)
|
|
|
|
// database tracks the state of the threat lists published by the Safe Browsing
|
|
// API. Since the global blacklist is constantly changing, the contents of the
|
|
// database needs to be periodically synced with the Safe Browsing servers in
|
|
// order to provide protection for the latest threats.
|
|
//
|
|
// The process for updating the database is as follows:
|
|
// * At startup, if a database file is provided, then load it. If loaded
|
|
// properly (not corrupted and not stale), then set tfu as the contents.
|
|
// Otherwise, pull a new threat list from the Safe Browsing API.
|
|
// * Periodically, synchronize the database with the Safe Browsing API.
|
|
// This uses the State fields to update only parts of the threat list that have
|
|
// changed since the last sync.
|
|
// * Anytime tfu is updated, generate a new tfl.
|
|
//
|
|
// The process for querying the database is as follows:
|
|
// * Check if the requested full hash matches any partial hash in tfl.
|
|
// If a match is found, return a set of ThreatDescriptors with a partial match.
|
|
type database struct {
|
|
config *Config
|
|
|
|
// threatsForUpdate maps ThreatDescriptors to lists of partial hashes.
|
|
// This data structure is in a format that is easily updated by the API.
|
|
// It is also the form that is written to disk.
|
|
tfu threatsForUpdate
|
|
mu sync.Mutex // Protects tfu
|
|
|
|
// threatsForLookup maps ThreatDescriptors to sets of partial hashes.
|
|
// This data structure is in a format that is easily queried.
|
|
tfl threatsForLookup
|
|
ml sync.RWMutex // Protects tfl, err, and last
|
|
|
|
err error // Last error encountered
|
|
readyCh chan struct{} // Used for waiting until not in an error state.
|
|
last time.Time // Last time the threat list were synced
|
|
updateAPIErrors uint // Number of times we attempted to contact the api and failed
|
|
|
|
log *log.Logger
|
|
}
|
|
|
|
type threatsForUpdate map[ThreatDescriptor]partialHashes
|
|
type partialHashes struct {
|
|
// Since the Hashes field is only needed when storing to disk and when
|
|
// updating, this field is cleared except for when it is in use.
|
|
// This is done to reduce memory usage as the contents of this can be
|
|
// regenerated from the tfl.
|
|
Hashes hashPrefixes
|
|
|
|
SHA256 []byte // The SHA256 over Hashes
|
|
State []byte // Arbitrary binary blob to synchronize state with API
|
|
}
|
|
|
|
type threatsForLookup map[ThreatDescriptor]hashSet
|
|
|
|
// databaseFormat is a light struct used only for gob encoding and decoding.
|
|
// As written to disk, the format of the database file is basically the gzip
|
|
// compressed version of the gob encoding of databaseFormat.
|
|
type databaseFormat struct {
|
|
Table threatsForUpdate
|
|
Time time.Time
|
|
}
|
|
|
|
// Init initializes the database from the specified file in config.DBPath.
|
|
// It reports true if the database was successfully loaded.
|
|
func (db *database) Init(config *Config, logger *log.Logger) bool {
|
|
db.mu.Lock()
|
|
defer db.mu.Unlock()
|
|
db.setError(errors.New("not intialized"))
|
|
db.config = config
|
|
db.log = logger
|
|
if db.config.DBPath == "" {
|
|
db.log.Printf("no database file specified")
|
|
db.setError(errors.New("no database loaded"))
|
|
return false
|
|
}
|
|
dbf, err := loadDatabase(db.config.DBPath)
|
|
if err != nil {
|
|
db.log.Printf("load failure: %v", err)
|
|
db.setError(err)
|
|
return false
|
|
}
|
|
// Validate that the database threat list stored on disk is not too stale.
|
|
if db.isStale(dbf.Time) {
|
|
db.log.Printf("database loaded is stale")
|
|
db.ml.Lock()
|
|
defer db.ml.Unlock()
|
|
db.setStale()
|
|
return false
|
|
}
|
|
// Validate that the database threat list stored on disk is at least a
|
|
// superset of the specified configuration.
|
|
tfuNew := make(threatsForUpdate)
|
|
for _, td := range db.config.ThreatLists {
|
|
if row, ok := dbf.Table[td]; ok {
|
|
tfuNew[td] = row
|
|
} else {
|
|
db.log.Printf("database configuration mismatch, missing %v", td)
|
|
db.setError(errors.New("database configuration mismatch"))
|
|
return false
|
|
}
|
|
}
|
|
db.tfu = tfuNew
|
|
db.generateThreatsForLookups(dbf.Time)
|
|
return true
|
|
}
|
|
|
|
// Status reports the health of the database. The database is considered faulted
|
|
// if there was an error during update or if the last update has gone stale. If
|
|
// in a faulted state, the db may repair itself on the next Update.
|
|
func (db *database) Status() error {
|
|
db.ml.RLock()
|
|
defer db.ml.RUnlock()
|
|
|
|
if db.err != nil {
|
|
return db.err
|
|
}
|
|
if db.isStale(db.last) {
|
|
db.setStale()
|
|
return db.err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpdateLag reports the amount of time in between when we expected to run
|
|
// a database update and the current time
|
|
func (db *database) UpdateLag() time.Duration {
|
|
lag := db.SinceLastUpdate()
|
|
if lag < db.config.UpdatePeriod {
|
|
return 0
|
|
}
|
|
return lag - db.config.UpdatePeriod
|
|
}
|
|
|
|
// SinceLastUpdate gives the duration since the last database update
|
|
func (db *database) SinceLastUpdate() time.Duration {
|
|
db.ml.RLock()
|
|
defer db.ml.RUnlock()
|
|
|
|
return db.config.now().Sub(db.last)
|
|
}
|
|
|
|
// Ready returns a channel that's closed when the database is ready for queries.
|
|
func (db *database) Ready() <-chan struct{} {
|
|
return db.readyCh
|
|
}
|
|
|
|
// Update synchronizes the local threat lists with those maintained by the
|
|
// global Safe Browsing API servers. If the update is successful, Status should
|
|
// report a nil error.
|
|
func (db *database) Update(api api) (time.Duration, bool) {
|
|
db.mu.Lock()
|
|
defer db.mu.Unlock()
|
|
|
|
// Construct the request.
|
|
var numTypes int
|
|
var s []*pb.FetchThreatListUpdatesRequest_ListUpdateRequest
|
|
for _, td := range db.config.ThreatLists {
|
|
var state []byte
|
|
if row, ok := db.tfu[td]; ok {
|
|
state = row.State
|
|
}
|
|
|
|
s = append(s, &pb.FetchThreatListUpdatesRequest_ListUpdateRequest{
|
|
ThreatType: pb.ThreatType(td.ThreatType),
|
|
PlatformType: pb.PlatformType(td.PlatformType),
|
|
ThreatEntryType: pb.ThreatEntryType(td.ThreatEntryType),
|
|
Constraints: &pb.FetchThreatListUpdatesRequest_ListUpdateRequest_Constraints{
|
|
SupportedCompressions: db.config.compressionTypes},
|
|
State: state,
|
|
})
|
|
numTypes++
|
|
}
|
|
req := &pb.FetchThreatListUpdatesRequest{
|
|
Client: &pb.ClientInfo{
|
|
ClientId: db.config.ID,
|
|
ClientVersion: db.config.Version,
|
|
},
|
|
ListUpdateRequests: s,
|
|
}
|
|
|
|
// Query the API for the threat list and update the database.
|
|
last := db.config.now()
|
|
resp, err := api.ListUpdate(req)
|
|
if err != nil {
|
|
db.log.Printf("ListUpdate failure (%d): %v", db.updateAPIErrors+1, err)
|
|
db.setError(err)
|
|
// backoff strategy: MIN((2**N-1 * 15 minutes) * (RAND + 1), 24 hours)
|
|
n := 1 << db.updateAPIErrors
|
|
delay := time.Duration(float64(n) * (rand.Float64() + 1) * float64(baseRetryDelay))
|
|
if delay > maxRetryDelay {
|
|
delay = maxRetryDelay
|
|
}
|
|
db.updateAPIErrors++
|
|
return delay, false
|
|
}
|
|
db.updateAPIErrors = 0
|
|
|
|
// add jitter to wait time to avoid all servers lining up
|
|
nextUpdateWait := db.config.UpdatePeriod + time.Duration(rand.Int31n(60)-30)*time.Second
|
|
if resp.MinimumWaitDuration != nil {
|
|
serverMinWait := time.Duration(resp.MinimumWaitDuration.Seconds)*time.Second + time.Duration(resp.MinimumWaitDuration.Nanos)
|
|
if serverMinWait > nextUpdateWait {
|
|
nextUpdateWait = serverMinWait
|
|
db.log.Printf("Server requested next update in %v", nextUpdateWait)
|
|
}
|
|
}
|
|
if len(resp.ListUpdateResponses) != numTypes {
|
|
db.setError(errors.New("safebrowsing: threat list count mismatch"))
|
|
db.log.Printf("invalid server response: got %d, want %d threat lists",
|
|
len(resp.ListUpdateResponses), numTypes)
|
|
return nextUpdateWait, false
|
|
}
|
|
|
|
// Update the threat database with the response.
|
|
db.generateThreatsForUpdate()
|
|
if err := db.tfu.update(resp); err != nil {
|
|
db.setError(err)
|
|
db.log.Printf("update failure: %v", err)
|
|
db.tfu = nil
|
|
return nextUpdateWait, false
|
|
}
|
|
dbf := databaseFormat{make(threatsForUpdate), last}
|
|
for td, phs := range db.tfu {
|
|
// Copy of partialHashes before generateThreatsForLookups clobbers it.
|
|
dbf.Table[td] = phs
|
|
}
|
|
db.generateThreatsForLookups(last)
|
|
|
|
// Regenerate the database and store it.
|
|
if db.config.DBPath != "" {
|
|
// Semantically, we ignore save errors, but we do log them.
|
|
if err := saveDatabase(db.config.DBPath, dbf); err != nil {
|
|
db.log.Printf("save failure: %v", err)
|
|
}
|
|
}
|
|
|
|
return nextUpdateWait, true
|
|
}
|
|
|
|
// Lookup looks up the full hash in the threat list and returns a partial
|
|
// hash and a set of ThreatDescriptors that may match the full hash.
|
|
func (db *database) Lookup(hash hashPrefix) (h hashPrefix, tds []ThreatDescriptor) {
|
|
if !hash.IsFull() {
|
|
panic("hash is not full")
|
|
}
|
|
|
|
db.ml.RLock()
|
|
for td, hs := range db.tfl {
|
|
if n := hs.Lookup(hash); n > 0 {
|
|
h = hash[:n]
|
|
tds = append(tds, td)
|
|
}
|
|
}
|
|
db.ml.RUnlock()
|
|
return h, tds
|
|
}
|
|
|
|
// setError clears the database state and sets the last error to be err.
|
|
//
|
|
// This assumes that the db.mu lock is already held.
|
|
func (db *database) setError(err error) {
|
|
db.tfu = nil
|
|
|
|
db.ml.Lock()
|
|
if db.err == nil {
|
|
db.readyCh = make(chan struct{})
|
|
}
|
|
db.tfl, db.err, db.last = nil, err, time.Time{}
|
|
db.ml.Unlock()
|
|
}
|
|
|
|
// isStale checks whether the last successful update should be considered stale.
|
|
// Staleness is defined as being older than two of the configured update periods
|
|
// plus jitter.
|
|
func (db *database) isStale(lastUpdate time.Time) bool {
|
|
if db.config.now().Sub(lastUpdate) > 2*(db.config.UpdatePeriod+jitter) {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
// setStale sets the error state to a stale message, without clearing
|
|
// the database state.
|
|
//
|
|
// This assumes that the db.ml lock is already held.
|
|
func (db *database) setStale() {
|
|
if db.err == nil {
|
|
db.readyCh = make(chan struct{})
|
|
}
|
|
db.err = errStale
|
|
}
|
|
|
|
// clearError clears the db error state, and unblocks any callers of
|
|
// WaitUntilReady.
|
|
//
|
|
// This assumes that the db.mu lock is already held.
|
|
func (db *database) clearError() {
|
|
db.ml.Lock()
|
|
defer db.ml.Unlock()
|
|
|
|
if db.err != nil {
|
|
close(db.readyCh)
|
|
}
|
|
db.err = nil
|
|
}
|
|
|
|
// generateThreatsForUpdate regenerates the threatsForUpdate hashes from
|
|
// the threatsForLookup. We do this to avoid holding onto the hash lists for
|
|
// a long time, needlessly occupying lots of memory.
|
|
//
|
|
// This assumes that the db.mu lock is already held.
|
|
func (db *database) generateThreatsForUpdate() {
|
|
if db.tfu == nil {
|
|
db.tfu = make(threatsForUpdate)
|
|
}
|
|
|
|
db.ml.RLock()
|
|
for td, hs := range db.tfl {
|
|
phs := db.tfu[td]
|
|
phs.Hashes = hs.Export()
|
|
db.tfu[td] = phs
|
|
}
|
|
db.ml.RUnlock()
|
|
}
|
|
|
|
// generateThreatsForLookups regenerates the threatsForLookup data structure
|
|
// from the threatsForUpdate data structure and stores the last timestamp.
|
|
// Since the hashes are effectively stored as a set inside the threatsForLookup,
|
|
// we clear out the hashes slice in threatsForUpdate so that it can be GCed.
|
|
//
|
|
// This assumes that the db.mu lock is already held.
|
|
func (db *database) generateThreatsForLookups(last time.Time) {
|
|
tfl := make(threatsForLookup)
|
|
for td, phs := range db.tfu {
|
|
var hs hashSet
|
|
hs.Import(phs.Hashes)
|
|
tfl[td] = hs
|
|
|
|
phs.Hashes = nil // Clear hashes to keep memory usage low
|
|
db.tfu[td] = phs
|
|
}
|
|
|
|
db.ml.Lock()
|
|
wasBad := db.err != nil
|
|
db.tfl, db.last = tfl, last
|
|
db.ml.Unlock()
|
|
|
|
if wasBad {
|
|
db.clearError()
|
|
db.log.Printf("database is now healthy")
|
|
}
|
|
}
|
|
|
|
// saveDatabase saves the database threat list to a file.
|
|
func saveDatabase(path string, db databaseFormat) (err error) {
|
|
var file *os.File
|
|
file, err = os.Create(path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() {
|
|
if cerr := file.Close(); err == nil {
|
|
err = cerr
|
|
}
|
|
}()
|
|
|
|
gz, err := gzip.NewWriterLevel(file, gzip.BestCompression)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() {
|
|
if zerr := gz.Close(); err == nil {
|
|
err = zerr
|
|
}
|
|
}()
|
|
|
|
encoder := gob.NewEncoder(gz)
|
|
if err = encoder.Encode(db); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// loadDatabase loads the database state from a file.
|
|
func loadDatabase(path string) (db databaseFormat, err error) {
|
|
var file *os.File
|
|
file, err = os.Open(path)
|
|
if err != nil {
|
|
return db, err
|
|
}
|
|
defer func() {
|
|
if cerr := file.Close(); err == nil {
|
|
err = cerr
|
|
}
|
|
}()
|
|
|
|
gz, err := gzip.NewReader(file)
|
|
if err != nil {
|
|
return db, err
|
|
}
|
|
defer func() {
|
|
if zerr := gz.Close(); err == nil {
|
|
err = zerr
|
|
}
|
|
}()
|
|
|
|
decoder := gob.NewDecoder(gz)
|
|
if err = decoder.Decode(&db); err != nil {
|
|
return db, err
|
|
}
|
|
for _, dv := range db.Table {
|
|
if !bytes.Equal(dv.SHA256, dv.Hashes.SHA256()) {
|
|
return db, errors.New("safebrowsing: threat list SHA256 mismatch")
|
|
}
|
|
}
|
|
return db, nil
|
|
}
|
|
|
|
// update updates the threat list according to the API response.
|
|
func (tfu threatsForUpdate) update(resp *pb.FetchThreatListUpdatesResponse) error {
|
|
// For each update response do the removes and adds
|
|
for _, m := range resp.GetListUpdateResponses() {
|
|
td := ThreatDescriptor{
|
|
PlatformType: PlatformType(m.PlatformType),
|
|
ThreatType: ThreatType(m.ThreatType),
|
|
ThreatEntryType: ThreatEntryType(m.ThreatEntryType),
|
|
}
|
|
|
|
phs, ok := tfu[td]
|
|
switch m.ResponseType {
|
|
case pb.FetchThreatListUpdatesResponse_ListUpdateResponse_PARTIAL_UPDATE:
|
|
if !ok {
|
|
return errors.New("safebrowsing: partial update received for non-existent key")
|
|
}
|
|
case pb.FetchThreatListUpdatesResponse_ListUpdateResponse_FULL_UPDATE:
|
|
if len(m.Removals) > 0 {
|
|
return errors.New("safebrowsing: indices to be removed included in a full update")
|
|
}
|
|
phs = partialHashes{}
|
|
default:
|
|
return errors.New("safebrowsing: unknown response type")
|
|
}
|
|
|
|
// Hashes must be sorted for removal logic to work properly.
|
|
phs.Hashes.Sort()
|
|
|
|
for _, removal := range m.Removals {
|
|
idxs, err := decodeIndices(removal)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, i := range idxs {
|
|
if i < 0 || i >= int32(len(phs.Hashes)) {
|
|
return errors.New("safebrowsing: invalid removal index")
|
|
}
|
|
phs.Hashes[i] = ""
|
|
}
|
|
}
|
|
|
|
// If any removal was performed, compact the list of hashes.
|
|
if len(m.Removals) > 0 {
|
|
compactHashes := phs.Hashes[:0]
|
|
for _, h := range phs.Hashes {
|
|
if h != "" {
|
|
compactHashes = append(compactHashes, h)
|
|
}
|
|
}
|
|
phs.Hashes = compactHashes
|
|
}
|
|
|
|
for _, addition := range m.Additions {
|
|
hashes, err := decodeHashes(addition)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
phs.Hashes = append(phs.Hashes, hashes...)
|
|
}
|
|
|
|
// Hashes must be sorted for SHA256 checksum to be correct.
|
|
phs.Hashes.Sort()
|
|
if err := phs.Hashes.Validate(); err != nil {
|
|
return err
|
|
}
|
|
|
|
if cs := m.GetChecksum(); cs != nil {
|
|
phs.SHA256 = cs.Sha256
|
|
}
|
|
if !bytes.Equal(phs.SHA256, phs.Hashes.SHA256()) {
|
|
return errors.New("safebrowsing: threat list SHA256 mismatch")
|
|
}
|
|
|
|
phs.State = m.NewClientState
|
|
tfu[td] = phs
|
|
}
|
|
return nil
|
|
}
|