mirror of https://github.com/kubeflow/examples.git
Python package for indexing and serving the index (#150)
* Add a utility python package for indexing and serving the index * Add CLI arguments, conditional GCS download * Complete skeleton CLIs for serving and index creation * Fix lint issues
This commit is contained in:
parent
4bd30a1e68
commit
21506ffc51
|
|
@ -0,0 +1,9 @@
|
|||
FROM python:3.6
|
||||
|
||||
ADD . /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN pip install .
|
||||
|
||||
ENTRYPOINT ["sh"]
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
|
||||
PROJECT=${PROJECT:-}
|
||||
BUILD_IMAGE_TAG=${BUILD_IMAGE_TAG:-nmslib:devel}
|
||||
|
||||
# Directory of this script used as docker context
|
||||
_SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
||||
|
||||
pushd "$_SCRIPT_DIR"
|
||||
|
||||
docker build -t ${BUILD_IMAGE_TAG} .
|
||||
|
||||
# Push image to GCR if PROJECT available
|
||||
if [[ ! -z "${PROJECT}" ]]; then
|
||||
docker tag ${BUILD_IMAGE_TAG} gcr.io/${PROJECT}/${BUILD_IMAGE_TAG}
|
||||
docker push gcr.io/${PROJECT}/${BUILD_IMAGE_TAG}
|
||||
fi
|
||||
|
||||
popd
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
import sys
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
from nmslib_flask.gcs import maybe_download_gcs_file, maybe_upload_gcs_file
|
||||
from nmslib_flask.search_engine import CodeSearchEngine
|
||||
from nmslib_flask.search_server import CodeSearchServer
|
||||
|
||||
def parse_server_args(args):
|
||||
parser = argparse.ArgumentParser(prog='nmslib Flask Server')
|
||||
|
||||
parser.add_argument('--index-file', type=str, required=True,
|
||||
help='Path to index file created by nmslib')
|
||||
parser.add_argument('--data-file', type=str, required=True,
|
||||
help='Path to csv file for human-readable data')
|
||||
parser.add_argument('--data-dir', type=str, metavar='', default='/tmp',
|
||||
help='Path to working data directory')
|
||||
parser.add_argument('--host', type=str, metavar='', default='0.0.0.0',
|
||||
help='Host to start server on')
|
||||
parser.add_argument('--port', type=int, metavar='', default=8008,
|
||||
help='Port to bind server to')
|
||||
|
||||
return parser.parse_args(args)
|
||||
|
||||
|
||||
def parse_creator_args(args):
|
||||
parser = argparse.ArgumentParser(prog='nmslib Index Creator')
|
||||
|
||||
parser.add_argument('--data-file', type=str, required=True,
|
||||
help='Path to csv data file for human-readable data')
|
||||
parser.add_argument('--output-file', type=str, metavar='', default='/tmp/index.nmslib',
|
||||
help='Path to output index file')
|
||||
parser.add_argument('--data-dir', type=str, metavar='', default='/tmp',
|
||||
help='Path to working data directory')
|
||||
|
||||
return parser.parse_args(args)
|
||||
|
||||
def server():
|
||||
args = parse_server_args(sys.argv[1:])
|
||||
|
||||
if not os.path.isdir(args.data_dir):
|
||||
os.makedirs(args.data_dir, exist_ok=True)
|
||||
|
||||
# Download relevant files if needed
|
||||
index_file = maybe_download_gcs_file(args.index_file, args.data_dir)
|
||||
data_file = maybe_download_gcs_file(args.data_file, args.data_dir)
|
||||
|
||||
search_engine = CodeSearchEngine(index_file, data_file)
|
||||
|
||||
search_server = CodeSearchServer(engine=search_engine,
|
||||
host=args.host, port=args.port)
|
||||
search_server.run()
|
||||
|
||||
|
||||
def creator():
|
||||
args = parse_creator_args(sys.argv[1:])
|
||||
|
||||
if not os.path.isdir(args.data_dir):
|
||||
os.makedirs(args.data_dir, exist_ok=True)
|
||||
|
||||
data_file = maybe_download_gcs_file(args.data_file, args.data_dir)
|
||||
|
||||
# TODO(sanyamkapoor): parse data file into a numpy array
|
||||
|
||||
data = np.load(data_file)
|
||||
|
||||
tmp_output_file = os.path.join(args.data_dir, os.path.basename(args.output_file))
|
||||
|
||||
CodeSearchEngine.create_index(data, tmp_output_file)
|
||||
|
||||
maybe_upload_gcs_file(tmp_output_file, args.output_file)
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
import re
|
||||
import os
|
||||
from google.cloud import storage
|
||||
|
||||
|
||||
def is_gcs_path(gcs_path_string):
|
||||
"""
|
||||
Checks if strings are of the format
|
||||
"gs://bucket_name" or "gs://bucket_name/file/path"
|
||||
"""
|
||||
return bool(re.match(r'gs://([^/]+)(/.+)?', gcs_path_string))
|
||||
|
||||
def parse_gcs_path(gcs_path_string):
|
||||
"""
|
||||
Get the bucket name and file path from a valid GCS path
|
||||
string (see `is_gcs_path`)
|
||||
"""
|
||||
if not is_gcs_path(gcs_path_string):
|
||||
raise ValueError("{} is not a valid GCS path".format(gcs_path_string))
|
||||
|
||||
_, full_path = gcs_path_string.split('//')
|
||||
bucket_name, bucket_path = full_path.split('/', 1)
|
||||
return bucket_name, bucket_path
|
||||
|
||||
|
||||
def download_gcs_file(src_file, target_file):
|
||||
"""
|
||||
Download a source file to the target file from GCS
|
||||
and return the target file path
|
||||
"""
|
||||
storage_client = storage.Client()
|
||||
bucket_name, bucket_path = parse_gcs_path(src_file)
|
||||
bucket = storage_client.get_bucket(bucket_name)
|
||||
blob = bucket.blob(bucket_path)
|
||||
blob.download_to_filename(target_file)
|
||||
return target_file
|
||||
|
||||
|
||||
def maybe_download_gcs_file(src_file, target_dir):
|
||||
"""Wraps `download_gcs_file` with checks"""
|
||||
if not is_gcs_path(src_file):
|
||||
return src_file
|
||||
|
||||
target_file = os.path.join(target_dir, os.path.basename(src_file))
|
||||
|
||||
return download_gcs_file(src_file, target_file)
|
||||
|
||||
|
||||
def upload_gcs_file(src_file, target_file):
|
||||
"""
|
||||
Upload a source file to the target file in GCS
|
||||
and return the target file path
|
||||
"""
|
||||
storage_client = storage.Client()
|
||||
bucket_name, bucket_path = parse_gcs_path(target_file)
|
||||
bucket = storage_client.get_bucket(bucket_name)
|
||||
blob = bucket.blob(bucket_path)
|
||||
blob.upload_from_filename(src_file)
|
||||
return target_file
|
||||
|
||||
|
||||
def maybe_upload_gcs_file(src_file, target_file):
|
||||
"""Wraps `upload_gcs_file` with checks"""
|
||||
if not is_gcs_path(target_file):
|
||||
return target_file
|
||||
|
||||
return upload_gcs_file(src_file, target_file)
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
import nmslib
|
||||
import numpy as np
|
||||
|
||||
|
||||
class CodeSearchEngine:
|
||||
"""This is a utility class which takes an nmslib
|
||||
index file and a data file to return data from"""
|
||||
def __init__(self, index_file: str, data_file: str):
|
||||
self._index_file = index_file
|
||||
self._data_file = data_file
|
||||
|
||||
self.index = CodeSearchEngine.nmslib_init()
|
||||
self.index.loadIndex(index_file)
|
||||
|
||||
# TODO: load the reverse-index map for actual code data
|
||||
# self.data_map =
|
||||
|
||||
def embed(self, query_str):
|
||||
# TODO load trained model and embed input strings
|
||||
raise NotImplementedError
|
||||
|
||||
def query(self, query_str: str, k=2):
|
||||
embedding = self.embed(query_str)
|
||||
idxs, dists = self.index.knnQuery(embedding, k=k)
|
||||
|
||||
# TODO(sanyamkapoor): initialize data map and return
|
||||
# list of dicts
|
||||
# [
|
||||
# {'src': self.data_map[idx], 'dist': dist}
|
||||
# for idx, dist in zip(idxs, dists)
|
||||
# ]
|
||||
return idxs, dists
|
||||
|
||||
@staticmethod
|
||||
def nmslib_init():
|
||||
"""Initializes an nmslib index object"""
|
||||
index = nmslib.init(method='hnsw', space='cosinesimil')
|
||||
return index
|
||||
|
||||
@staticmethod
|
||||
def create_index(data: np.array, save_path: str):
|
||||
"""Add numpy data to the index and save to path"""
|
||||
index = CodeSearchEngine.nmslib_init()
|
||||
index.addDataPointBatch(data)
|
||||
index.createIndex({'post': 2}, print_progress=True)
|
||||
index.saveIndex(save_path)
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
from flask import Flask, request, abort, jsonify, make_response
|
||||
|
||||
|
||||
class CodeSearchServer:
|
||||
"""This utility class wraps the search engine into
|
||||
an HTTP server based on Flask"""
|
||||
def __init__(self, engine, host='0.0.0.0', port=8008):
|
||||
self.app = Flask(__name__)
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.engine = engine
|
||||
|
||||
def init_routes(self):
|
||||
# pylint: disable=unused-variable
|
||||
|
||||
@self.app.route('/ping')
|
||||
def ping():
|
||||
return make_response(jsonify(status=200), 200)
|
||||
|
||||
@self.app.route('/query')
|
||||
def query():
|
||||
query_str = request.args.get('query')
|
||||
if not query_str:
|
||||
abort(make_response(
|
||||
jsonify(status=400, error="empty query"), 400))
|
||||
|
||||
result = self.engine.search(query_str)
|
||||
return make_response(jsonify(result=result))
|
||||
|
||||
def run(self):
|
||||
self.app.run(host=self.host, port=self.port)
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
Flask~=1.0.0
|
||||
nmslib~=1.7.0
|
||||
numpy~=1.14.0
|
||||
google-cloud-storage~=1.10.0
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
from setuptools import setup, find_packages
|
||||
|
||||
with open('requirements.txt', 'r') as f:
|
||||
install_requires = f.readlines()
|
||||
|
||||
VERSION = '0.1.0'
|
||||
|
||||
setup(name='code-search-index-server',
|
||||
description='Kubeflow Code Search Demo - Index Server',
|
||||
url='https://www.github.com/kubeflow/examples',
|
||||
author='Sanyam Kapoor',
|
||||
author_email='sanyamkapoor@google.com',
|
||||
version=VERSION,
|
||||
license='MIT',
|
||||
packages=find_packages(),
|
||||
install_requires=install_requires,
|
||||
extras_require={},
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'nmslib-serve=nmslib_flask.cli:server',
|
||||
'nmslib-create=nmslib_flask.cli:creator',
|
||||
]
|
||||
})
|
||||
Loading…
Reference in New Issue