mirror of https://github.com/kubeflow/examples.git
Integrate nmslib (#194)
* Integrate NMSLib server with new data file * Integrate UI with query URL of search server
This commit is contained in:
parent
636cf1c3d0
commit
994fdf82c0
|
|
@ -7,9 +7,9 @@ and `nmslib-serve` binaries (see `setup.py`). Use `-h` to get a list
|
|||
of input CLI arguments to both.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import argparse
|
||||
import csv
|
||||
import numpy as np
|
||||
|
||||
from code_search.nmslib.gcs import maybe_download_gcs_file, maybe_upload_gcs_file
|
||||
|
|
@ -21,6 +21,8 @@ def parse_server_args(args):
|
|||
|
||||
parser.add_argument('--tmp-dir', type=str, metavar='', default='/tmp/nmslib',
|
||||
help='Path to temporary data directory')
|
||||
parser.add_argument('--data-file', type=str, required=True,
|
||||
help='Path to CSV file containing human-readable data')
|
||||
parser.add_argument('--index-file', type=str, required=True,
|
||||
help='Path to index file created by nmslib')
|
||||
parser.add_argument('--problem', type=str, required=True,
|
||||
|
|
@ -36,6 +38,7 @@ def parse_server_args(args):
|
|||
|
||||
args = parser.parse_args(args)
|
||||
args.tmp_dir = os.path.expanduser(args.tmp_dir)
|
||||
args.data_file = os.path.expanduser(args.data_file)
|
||||
args.index_file = os.path.expanduser(args.index_file)
|
||||
args.data_dir = os.path.expanduser(args.data_dir)
|
||||
|
||||
|
|
@ -59,32 +62,41 @@ def parse_creator_args(args):
|
|||
|
||||
return args
|
||||
|
||||
def server():
|
||||
args = parse_server_args(sys.argv[1:])
|
||||
def server(argv=None):
|
||||
args = parse_server_args(argv)
|
||||
|
||||
if not os.path.isdir(args.tmp_dir):
|
||||
os.makedirs(args.tmp_dir, exist_ok=True)
|
||||
|
||||
# Download relevant files if needed
|
||||
index_file = maybe_download_gcs_file(args.index_file, args.tmp_dir)
|
||||
data_file = maybe_download_gcs_file(args.data_file, args.tmp_dir)
|
||||
|
||||
search_engine = CodeSearchEngine(args.problem, args.data_dir, args.serving_url,
|
||||
index_file)
|
||||
index_file, data_file)
|
||||
|
||||
search_server = CodeSearchServer(engine=search_engine,
|
||||
host=args.host, port=args.port)
|
||||
search_server.run()
|
||||
|
||||
|
||||
def creator():
|
||||
args = parse_creator_args(sys.argv[1:])
|
||||
def creator(argv=None):
|
||||
args = parse_creator_args(argv)
|
||||
|
||||
if not os.path.isdir(args.tmp_dir):
|
||||
os.makedirs(args.tmp_dir, exist_ok=True)
|
||||
os.makedirs(args.tmp_dir)
|
||||
|
||||
data_file = maybe_download_gcs_file(args.data_file, args.tmp_dir)
|
||||
|
||||
data = np.load(data_file)
|
||||
data = np.empty((0, 128), dtype=np.float32)
|
||||
with open(data_file, 'r') as csv_file:
|
||||
data_reader = csv.reader(csv_file)
|
||||
next(data_reader, None) # Skip the header
|
||||
for row in data_reader:
|
||||
vector_string = row[-1]
|
||||
embedding_vector = [float(value) for value in vector_string.split(',')]
|
||||
np_row = np.expand_dims(embedding_vector, axis=0)
|
||||
data = np.append(data, np_row, axis=0)
|
||||
|
||||
tmp_index_file = os.path.join(args.tmp_dir, os.path.basename(args.index_file))
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import json
|
||||
import csv
|
||||
import requests
|
||||
import nmslib
|
||||
from code_search.t2t.query import get_encoder, encode_query
|
||||
|
|
@ -7,11 +8,14 @@ from code_search.t2t.query import get_encoder, encode_query
|
|||
class CodeSearchEngine:
|
||||
"""This is a utility class which takes an nmslib
|
||||
index file and a data file to return data from"""
|
||||
def __init__(self, problem, data_dir, serving_url, index_file):
|
||||
def __init__(self, problem, data_dir, serving_url, index_file, data_file):
|
||||
self._serving_url = serving_url
|
||||
self._problem = problem
|
||||
self._data_dir = data_dir
|
||||
self._index_file = index_file
|
||||
self._data_file = data_file
|
||||
|
||||
self._data_index = self.read_lookup_data_file(data_file)
|
||||
|
||||
self.index = CodeSearchEngine.nmslib_init()
|
||||
self.index.loadIndex(index_file)
|
||||
|
|
@ -22,7 +26,7 @@ class CodeSearchEngine:
|
|||
This involves encoding the input query
|
||||
for the TF Serving service
|
||||
"""
|
||||
encoder, _ = get_encoder(self._problem, self._data_dir)
|
||||
encoder = get_encoder(self._problem, self._data_dir)
|
||||
encoded_query = encode_query(encoder, query_str)
|
||||
data = {"instances": [{"input": {"b64": encoded_query}}]}
|
||||
|
||||
|
|
@ -36,15 +40,22 @@ class CodeSearchEngine:
|
|||
|
||||
def query(self, query_str, k=2):
|
||||
embedding = self.embed(query_str)
|
||||
idxs, dists = self.index.knnQuery(embedding, k=k)
|
||||
idxs, dists = self.index.knnQuery(embedding['predictions'][0], k=k)
|
||||
|
||||
# TODO(sanyamkapoor): initialize data map and return
|
||||
# list of dicts
|
||||
# [
|
||||
# {'src': self.data_map[idx], 'dist': dist}
|
||||
# for idx, dist in zip(idxs, dists)
|
||||
# ]
|
||||
return idxs, dists
|
||||
result = [self._data_index[id] for id in idxs]
|
||||
for i, dist in enumerate(dists):
|
||||
result[i]['score'] = str(dist)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def read_lookup_data_file(data_file):
|
||||
data_list = []
|
||||
with open(data_file, 'r') as csv_file:
|
||||
dict_reader = csv.DictReader(csv_file)
|
||||
for row in dict_reader:
|
||||
row.pop('function_embedding')
|
||||
data_list.append(row)
|
||||
return data_list
|
||||
|
||||
@staticmethod
|
||||
def nmslib_init():
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from flask import Flask, request, abort, jsonify, make_response
|
||||
from flask_cors import CORS
|
||||
|
||||
|
||||
class CodeSearchServer:
|
||||
|
|
@ -40,4 +41,5 @@ class CodeSearchServer:
|
|||
return make_response(jsonify(result=result))
|
||||
|
||||
def run(self):
|
||||
CORS(self.app)
|
||||
self.app.run(host=self.host, port=self.port)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
astor~=0.6.0
|
||||
apache-beam[gcp]~=2.5.0
|
||||
Flask~=1.0.0
|
||||
flask-cors~=3.0.0
|
||||
google-cloud-storage~=1.10.0
|
||||
nltk~=3.3.0
|
||||
nmslib~=1.7.0
|
||||
|
|
|
|||
|
|
@ -2260,6 +2260,11 @@
|
|||
"resolved": "https://registry.npmjs.org/cookie-signature/-/cookie-signature-1.0.6.tgz",
|
||||
"integrity": "sha1-4wOogrNCzD7oylE6eZmXNNqzriw="
|
||||
},
|
||||
"cookiejar": {
|
||||
"version": "2.1.2",
|
||||
"resolved": "https://registry.npmjs.org/cookiejar/-/cookiejar-2.1.2.tgz",
|
||||
"integrity": "sha512-Mw+adcfzPxcPeI+0WlvRrr/3lGVO0bD75SxX6811cxSh1Wbxx7xZBGK1eVtDf6si8rg2lhnUjsVLMFMfbRIuwA=="
|
||||
},
|
||||
"copy-descriptor": {
|
||||
"version": "0.1.1",
|
||||
"resolved": "https://registry.npmjs.org/copy-descriptor/-/copy-descriptor-0.1.1.tgz",
|
||||
|
|
@ -4044,6 +4049,11 @@
|
|||
"resolved": "https://registry.npmjs.org/format/-/format-0.2.2.tgz",
|
||||
"integrity": "sha1-1hcBB+nv3E7TDJ3DkBbflCtctYs="
|
||||
},
|
||||
"formidable": {
|
||||
"version": "1.2.1",
|
||||
"resolved": "https://registry.npmjs.org/formidable/-/formidable-1.2.1.tgz",
|
||||
"integrity": "sha512-Fs9VRguL0gqGHkXS5GQiMCr1VhZBxz0JnJs4JmMp/2jL18Fmbzvv7vOFRU+U8TBkHEE/CX1qDXzJplVULgsLeg=="
|
||||
},
|
||||
"forwarded": {
|
||||
"version": "0.1.2",
|
||||
"resolved": "https://registry.npmjs.org/forwarded/-/forwarded-0.1.2.tgz",
|
||||
|
|
@ -10519,6 +10529,33 @@
|
|||
"schema-utils": "^0.3.0"
|
||||
}
|
||||
},
|
||||
"superagent": {
|
||||
"version": "3.8.3",
|
||||
"resolved": "https://registry.npmjs.org/superagent/-/superagent-3.8.3.tgz",
|
||||
"integrity": "sha512-GLQtLMCoEIK4eDv6OGtkOoSMt3D+oq0y3dsxMuYuDvaNUvuT8eFBuLmfR0iYYzHC1e8hpzC6ZsxbuP6DIalMFA==",
|
||||
"requires": {
|
||||
"component-emitter": "^1.2.0",
|
||||
"cookiejar": "^2.1.0",
|
||||
"debug": "^3.1.0",
|
||||
"extend": "^3.0.0",
|
||||
"form-data": "^2.3.1",
|
||||
"formidable": "^1.2.0",
|
||||
"methods": "^1.1.1",
|
||||
"mime": "^1.4.1",
|
||||
"qs": "^6.5.1",
|
||||
"readable-stream": "^2.3.5"
|
||||
},
|
||||
"dependencies": {
|
||||
"debug": {
|
||||
"version": "3.1.0",
|
||||
"resolved": "https://registry.npmjs.org/debug/-/debug-3.1.0.tgz",
|
||||
"integrity": "sha512-OX8XqP7/1a9cqkxYw2yXss15f26NKWBpDXQd0/uK/KPqdQhxbPa994hnzjcE2VqQpDslf55723cKPUOGSmMY3g==",
|
||||
"requires": {
|
||||
"ms": "2.0.0"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"supports-color": {
|
||||
"version": "5.4.0",
|
||||
"resolved": "https://registry.npmjs.org/supports-color/-/supports-color-5.4.0.tgz",
|
||||
|
|
|
|||
|
|
@ -8,7 +8,8 @@
|
|||
"react": "^16.4.1",
|
||||
"react-dom": "^16.4.1",
|
||||
"react-scripts": "1.1.4",
|
||||
"react-syntax-highlighter": "^8.0.1"
|
||||
"react-syntax-highlighter": "^8.0.1",
|
||||
"superagent": "^3.8.3"
|
||||
},
|
||||
"scripts": {
|
||||
"start": "react-scripts start",
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ class App extends Component {
|
|||
<div className="Search-Results">
|
||||
<h2 className="Search-Results-Title">Search Results</h2>
|
||||
{
|
||||
codeResults.map((attrs) => <CodeSample {...attrs}/>)
|
||||
codeResults.map((attrs, index) => <CodeSample key={index} {...attrs}/>)
|
||||
}
|
||||
</div>
|
||||
}
|
||||
|
|
@ -72,13 +72,11 @@ class App extends Component {
|
|||
const {queryStr} = this.state;
|
||||
if (queryStr) {
|
||||
this.setState({loading: true});
|
||||
code_search_api(queryStr, (response) => {
|
||||
const {status, results} = response;
|
||||
if (status === 200) {
|
||||
this.setState({codeResults: results, loading: false});
|
||||
} else {
|
||||
this.setState({loading: false});
|
||||
}
|
||||
code_search_api(queryStr).then((res) => {
|
||||
this.setState({codeResults: res.body.result, loading: false});
|
||||
}).catch((err) => {
|
||||
console.log(err);
|
||||
this.setState({loading: false});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import CodeIcon from '@material-ui/icons/Code';
|
|||
|
||||
class CodeSample extends Component {
|
||||
render() {
|
||||
const {nwo, path, function_string, lineno} = this.props;
|
||||
const {nwo, path, original_function, lineno} = this.props;
|
||||
|
||||
const codeUrl = `${nwo}/blob/master/${path}#L${lineno}`;
|
||||
|
||||
|
|
@ -24,7 +24,7 @@ class CodeSample extends Component {
|
|||
</div>
|
||||
|
||||
<SyntaxHighlighter style={docco}>
|
||||
{function_string}
|
||||
{original_function}
|
||||
</SyntaxHighlighter>
|
||||
</div>
|
||||
);
|
||||
|
|
@ -34,8 +34,8 @@ class CodeSample extends Component {
|
|||
CodeSample.propTypes = {
|
||||
nwo: PropTypes.string.isRequired,
|
||||
path: PropTypes.string.isRequired,
|
||||
function_string: PropTypes.string.isRequired,
|
||||
lineno: PropTypes.number.isRequired,
|
||||
original_function: PropTypes.string.isRequired,
|
||||
lineno: PropTypes.string.isRequired,
|
||||
};
|
||||
|
||||
export default CodeSample;
|
||||
|
|
|
|||
|
|
@ -1,35 +1,9 @@
|
|||
const results = [
|
||||
{
|
||||
nwo: 'activatedgeek/torchrl',
|
||||
path: 'torchrl/agents/random_gym_agent.py',
|
||||
lineno: 19,
|
||||
function_string: `
|
||||
def act(self, obs):
|
||||
return [[self.action_space.sample()] for _ in range(len(obs))]
|
||||
`,
|
||||
},
|
||||
{
|
||||
nwo: 'activatedgeek/torchrl',
|
||||
path: 'torchrl/policies/epsilon_greedy.py',
|
||||
lineno: 4,
|
||||
function_string: `
|
||||
distribution = np.ones((len(choices), action_size),
|
||||
dtype=np.float32) * eps / action_size
|
||||
distribution[np.arange(len(choices)), choices] += 1.0 - eps
|
||||
actions = np.array([
|
||||
np.random.choice(np.arange(action_size), p=dist)
|
||||
for dist in distribution
|
||||
])
|
||||
return np.expand_dims(actions, axis=1)
|
||||
`,
|
||||
},
|
||||
];
|
||||
import request from 'superagent';
|
||||
|
||||
function code_search_api(str, callback) {
|
||||
// TODO: make a real request, this is simulated
|
||||
window.setTimeout(() => {
|
||||
callback({status: 200, results: results});
|
||||
}, 2000);
|
||||
const SEARCH_URL='//localhost:8008/query'
|
||||
|
||||
function code_search_api(str) {
|
||||
return request.get(SEARCH_URL).query({'q': str});
|
||||
}
|
||||
|
||||
export default code_search_api;
|
||||
|
|
|
|||
Loading…
Reference in New Issue