examples/xgboost_synthetic/xgboost_util.py

112 lines
3.4 KiB
Python

import logging
import re
from google.cloud import storage
import joblib
from sklearn.metrics import mean_absolute_error
from sklearn.model_selection import train_test_split
from sklearn.impute import SimpleImputer
import pandas as pd
from xgboost import XGBRegressor
def read_input(file_name, test_size=0.25):
"""Read input data and split it into train and test."""
if file_name.startswith("gs://"):
gcs_path = file_name
train_bucket_name, train_path = split_gcs_uri(gcs_path)
storage_client = storage.Client()
train_bucket = storage_client.get_bucket(train_bucket_name)
train_blob = train_bucket.blob(train_path)
file_name = "/tmp/data.csv"
train_blob.download_to_filename(file_name)
data = pd.read_csv(file_name)
data.dropna(axis=0, subset=['SalePrice'], inplace=True)
y = data.SalePrice
X = data.drop(['SalePrice'], axis=1).select_dtypes(exclude=['object'])
train_X, test_X, train_y, test_y = train_test_split(X.values,
y.values,
test_size=test_size,
shuffle=False)
imputer = SimpleImputer()
train_X = imputer.fit_transform(train_X)
test_X = imputer.transform(test_X)
return (train_X, train_y), (test_X, test_y)
def load_model(model_path):
local_model_path = model_path
if model_path.startswith("gs://"):
gcs_path = model_path
train_bucket_name, train_path = split_gcs_uri(gcs_path)
storage_client = storage.Client()
train_bucket = storage_client.get_bucket(train_bucket_name)
train_blob = train_bucket.blob(train_path)
local_model_path = "/tmp/model.dat"
logging.info("Downloading model to %s", local_model_path)
train_blob.download_to_filename(local_model_path)
model = joblib.load(local_model_path)
return model
def train_model(train_X,
train_y,
test_X,
test_y,
n_estimators,
learning_rate):
"""Train the model using XGBRegressor."""
model = XGBRegressor(n_estimators=n_estimators, learning_rate=learning_rate)
model.fit(train_X,
train_y,
early_stopping_rounds=40,
eval_set=[(test_X, test_y)])
logging.info("Best RMSE on eval: %.2f with %d rounds",
model.best_score,
model.best_iteration+1)
return model
def eval_model(model, test_X, test_y):
"""Evaluate the model performance."""
predictions = model.predict(test_X)
logging.info("mean_absolute_error=%.2f", mean_absolute_error(predictions, test_y))
def save_model(model, model_file):
"""Save XGBoost model for serving."""
gcs_path = None
if model_file.startswith("gs://"):
gcs_path = model_file
model_file = "/tmp/model.dat"
joblib.dump(model, model_file)
logging.info("Model export success: %s", model_file)
if gcs_path:
model_bucket_name, model_path = split_gcs_uri(gcs_path)
storage_client = storage.Client()
model_bucket = storage_client.get_bucket(model_bucket_name)
model_blob = model_bucket.blob(model_path)
logging.info("Uploading model to %s", gcs_path)
model_blob.upload_from_filename(model_file)
def split_gcs_uri(gcs_uri):
"""Split a GCS URI into bucket and path."""
GCS_REGEX = re.compile("gs://([^/]*)(/.*)?")
m = GCS_REGEX.match(gcs_uri)
bucket = m.group(1)
path = ""
if m.group(2):
path = m.group(2).lstrip("/")
return bucket, path