pipelines/samples/contrib/aws-samples/mnist-kmeans-sagemaker/invoke_endpoint.py

41 lines
1011 B
Python

import json
import io
import boto3
import pickle
import gzip
import numpy
import sys
if len(sys.argv) < 2:
print("Must pass your endpoint name")
exit(1)
ENDPOINT_NAME = sys.argv[1]
# Simple function to create a csv from numpy array
def np2csv(arr):
csv = io.BytesIO()
numpy.savetxt(csv, arr, delimiter=",", fmt="%g")
return csv.getvalue().decode().rstrip()
# Prepare input for the model
# Load the dataset
s3 = boto3.client("s3")
s3.download_file(
"sagemaker-sample-files", "datasets/image/MNIST/mnist.pkl.gz", "mnist.pkl.gz"
)
with gzip.open("mnist.pkl.gz", "rb") as f:
train_set, _, _ = pickle.load(f, encoding="latin1")
payload = np2csv(train_set[0][30:31])
# Run prediction aganist the endpoint created by the pipeline
runtime = boto3.Session(region_name="us-east-1").client("sagemaker-runtime")
response = runtime.invoke_endpoint(
EndpointName=ENDPOINT_NAME, ContentType="text/csv", Body=payload
)
result = json.loads(response["Body"].read().decode())
print(result)