Integrate nmslib (#194)

* Integrate NMSLib server with new data file

* Integrate UI with query URL of search server
This commit is contained in:
Sanyam Kapoor 2018-07-23 17:17:24 -07:00 committed by k8s-ci-robot
parent 636cf1c3d0
commit 994fdf82c0
9 changed files with 98 additions and 62 deletions

View File

@ -7,9 +7,9 @@ and `nmslib-serve` binaries (see `setup.py`). Use `-h` to get a list
of input CLI arguments to both.
"""
import sys
import os
import argparse
import csv
import numpy as np
from code_search.nmslib.gcs import maybe_download_gcs_file, maybe_upload_gcs_file
@ -21,6 +21,8 @@ def parse_server_args(args):
parser.add_argument('--tmp-dir', type=str, metavar='', default='/tmp/nmslib',
help='Path to temporary data directory')
parser.add_argument('--data-file', type=str, required=True,
help='Path to CSV file containing human-readable data')
parser.add_argument('--index-file', type=str, required=True,
help='Path to index file created by nmslib')
parser.add_argument('--problem', type=str, required=True,
@ -36,6 +38,7 @@ def parse_server_args(args):
args = parser.parse_args(args)
args.tmp_dir = os.path.expanduser(args.tmp_dir)
args.data_file = os.path.expanduser(args.data_file)
args.index_file = os.path.expanduser(args.index_file)
args.data_dir = os.path.expanduser(args.data_dir)
@ -59,32 +62,41 @@ def parse_creator_args(args):
return args
def server():
args = parse_server_args(sys.argv[1:])
def server(argv=None):
args = parse_server_args(argv)
if not os.path.isdir(args.tmp_dir):
os.makedirs(args.tmp_dir, exist_ok=True)
# Download relevant files if needed
index_file = maybe_download_gcs_file(args.index_file, args.tmp_dir)
data_file = maybe_download_gcs_file(args.data_file, args.tmp_dir)
search_engine = CodeSearchEngine(args.problem, args.data_dir, args.serving_url,
index_file)
index_file, data_file)
search_server = CodeSearchServer(engine=search_engine,
host=args.host, port=args.port)
search_server.run()
def creator():
args = parse_creator_args(sys.argv[1:])
def creator(argv=None):
args = parse_creator_args(argv)
if not os.path.isdir(args.tmp_dir):
os.makedirs(args.tmp_dir, exist_ok=True)
os.makedirs(args.tmp_dir)
data_file = maybe_download_gcs_file(args.data_file, args.tmp_dir)
data = np.load(data_file)
data = np.empty((0, 128), dtype=np.float32)
with open(data_file, 'r') as csv_file:
data_reader = csv.reader(csv_file)
next(data_reader, None) # Skip the header
for row in data_reader:
vector_string = row[-1]
embedding_vector = [float(value) for value in vector_string.split(',')]
np_row = np.expand_dims(embedding_vector, axis=0)
data = np.append(data, np_row, axis=0)
tmp_index_file = os.path.join(args.tmp_dir, os.path.basename(args.index_file))

View File

@ -1,4 +1,5 @@
import json
import csv
import requests
import nmslib
from code_search.t2t.query import get_encoder, encode_query
@ -7,11 +8,14 @@ from code_search.t2t.query import get_encoder, encode_query
class CodeSearchEngine:
"""This is a utility class which takes an nmslib
index file and a data file to return data from"""
def __init__(self, problem, data_dir, serving_url, index_file):
def __init__(self, problem, data_dir, serving_url, index_file, data_file):
self._serving_url = serving_url
self._problem = problem
self._data_dir = data_dir
self._index_file = index_file
self._data_file = data_file
self._data_index = self.read_lookup_data_file(data_file)
self.index = CodeSearchEngine.nmslib_init()
self.index.loadIndex(index_file)
@ -22,7 +26,7 @@ class CodeSearchEngine:
This involves encoding the input query
for the TF Serving service
"""
encoder, _ = get_encoder(self._problem, self._data_dir)
encoder = get_encoder(self._problem, self._data_dir)
encoded_query = encode_query(encoder, query_str)
data = {"instances": [{"input": {"b64": encoded_query}}]}
@ -36,15 +40,22 @@ class CodeSearchEngine:
def query(self, query_str, k=2):
embedding = self.embed(query_str)
idxs, dists = self.index.knnQuery(embedding, k=k)
idxs, dists = self.index.knnQuery(embedding['predictions'][0], k=k)
# TODO(sanyamkapoor): initialize data map and return
# list of dicts
# [
# {'src': self.data_map[idx], 'dist': dist}
# for idx, dist in zip(idxs, dists)
# ]
return idxs, dists
result = [self._data_index[id] for id in idxs]
for i, dist in enumerate(dists):
result[i]['score'] = str(dist)
return result
@staticmethod
def read_lookup_data_file(data_file):
data_list = []
with open(data_file, 'r') as csv_file:
dict_reader = csv.DictReader(csv_file)
for row in dict_reader:
row.pop('function_embedding')
data_list.append(row)
return data_list
@staticmethod
def nmslib_init():

View File

@ -1,4 +1,5 @@
from flask import Flask, request, abort, jsonify, make_response
from flask_cors import CORS
class CodeSearchServer:
@ -40,4 +41,5 @@ class CodeSearchServer:
return make_response(jsonify(result=result))
def run(self):
CORS(self.app)
self.app.run(host=self.host, port=self.port)

View File

@ -1,6 +1,7 @@
astor~=0.6.0
apache-beam[gcp]~=2.5.0
Flask~=1.0.0
flask-cors~=3.0.0
google-cloud-storage~=1.10.0
nltk~=3.3.0
nmslib~=1.7.0

View File

@ -2260,6 +2260,11 @@
"resolved": "https://registry.npmjs.org/cookie-signature/-/cookie-signature-1.0.6.tgz",
"integrity": "sha1-4wOogrNCzD7oylE6eZmXNNqzriw="
},
"cookiejar": {
"version": "2.1.2",
"resolved": "https://registry.npmjs.org/cookiejar/-/cookiejar-2.1.2.tgz",
"integrity": "sha512-Mw+adcfzPxcPeI+0WlvRrr/3lGVO0bD75SxX6811cxSh1Wbxx7xZBGK1eVtDf6si8rg2lhnUjsVLMFMfbRIuwA=="
},
"copy-descriptor": {
"version": "0.1.1",
"resolved": "https://registry.npmjs.org/copy-descriptor/-/copy-descriptor-0.1.1.tgz",
@ -4044,6 +4049,11 @@
"resolved": "https://registry.npmjs.org/format/-/format-0.2.2.tgz",
"integrity": "sha1-1hcBB+nv3E7TDJ3DkBbflCtctYs="
},
"formidable": {
"version": "1.2.1",
"resolved": "https://registry.npmjs.org/formidable/-/formidable-1.2.1.tgz",
"integrity": "sha512-Fs9VRguL0gqGHkXS5GQiMCr1VhZBxz0JnJs4JmMp/2jL18Fmbzvv7vOFRU+U8TBkHEE/CX1qDXzJplVULgsLeg=="
},
"forwarded": {
"version": "0.1.2",
"resolved": "https://registry.npmjs.org/forwarded/-/forwarded-0.1.2.tgz",
@ -10519,6 +10529,33 @@
"schema-utils": "^0.3.0"
}
},
"superagent": {
"version": "3.8.3",
"resolved": "https://registry.npmjs.org/superagent/-/superagent-3.8.3.tgz",
"integrity": "sha512-GLQtLMCoEIK4eDv6OGtkOoSMt3D+oq0y3dsxMuYuDvaNUvuT8eFBuLmfR0iYYzHC1e8hpzC6ZsxbuP6DIalMFA==",
"requires": {
"component-emitter": "^1.2.0",
"cookiejar": "^2.1.0",
"debug": "^3.1.0",
"extend": "^3.0.0",
"form-data": "^2.3.1",
"formidable": "^1.2.0",
"methods": "^1.1.1",
"mime": "^1.4.1",
"qs": "^6.5.1",
"readable-stream": "^2.3.5"
},
"dependencies": {
"debug": {
"version": "3.1.0",
"resolved": "https://registry.npmjs.org/debug/-/debug-3.1.0.tgz",
"integrity": "sha512-OX8XqP7/1a9cqkxYw2yXss15f26NKWBpDXQd0/uK/KPqdQhxbPa994hnzjcE2VqQpDslf55723cKPUOGSmMY3g==",
"requires": {
"ms": "2.0.0"
}
}
}
},
"supports-color": {
"version": "5.4.0",
"resolved": "https://registry.npmjs.org/supports-color/-/supports-color-5.4.0.tgz",

View File

@ -8,7 +8,8 @@
"react": "^16.4.1",
"react-dom": "^16.4.1",
"react-scripts": "1.1.4",
"react-syntax-highlighter": "^8.0.1"
"react-syntax-highlighter": "^8.0.1",
"superagent": "^3.8.3"
},
"scripts": {
"start": "react-scripts start",

View File

@ -54,7 +54,7 @@ class App extends Component {
<div className="Search-Results">
<h2 className="Search-Results-Title">Search Results</h2>
{
codeResults.map((attrs) => <CodeSample {...attrs}/>)
codeResults.map((attrs, index) => <CodeSample key={index} {...attrs}/>)
}
</div>
}
@ -72,13 +72,11 @@ class App extends Component {
const {queryStr} = this.state;
if (queryStr) {
this.setState({loading: true});
code_search_api(queryStr, (response) => {
const {status, results} = response;
if (status === 200) {
this.setState({codeResults: results, loading: false});
} else {
this.setState({loading: false});
}
code_search_api(queryStr).then((res) => {
this.setState({codeResults: res.body.result, loading: false});
}).catch((err) => {
console.log(err);
this.setState({loading: false});
});
}
};

View File

@ -6,7 +6,7 @@ import CodeIcon from '@material-ui/icons/Code';
class CodeSample extends Component {
render() {
const {nwo, path, function_string, lineno} = this.props;
const {nwo, path, original_function, lineno} = this.props;
const codeUrl = `${nwo}/blob/master/${path}#L${lineno}`;
@ -24,7 +24,7 @@ class CodeSample extends Component {
</div>
<SyntaxHighlighter style={docco}>
{function_string}
{original_function}
</SyntaxHighlighter>
</div>
);
@ -34,8 +34,8 @@ class CodeSample extends Component {
CodeSample.propTypes = {
nwo: PropTypes.string.isRequired,
path: PropTypes.string.isRequired,
function_string: PropTypes.string.isRequired,
lineno: PropTypes.number.isRequired,
original_function: PropTypes.string.isRequired,
lineno: PropTypes.string.isRequired,
};
export default CodeSample;

View File

@ -1,35 +1,9 @@
const results = [
{
nwo: 'activatedgeek/torchrl',
path: 'torchrl/agents/random_gym_agent.py',
lineno: 19,
function_string: `
def act(self, obs):
return [[self.action_space.sample()] for _ in range(len(obs))]
`,
},
{
nwo: 'activatedgeek/torchrl',
path: 'torchrl/policies/epsilon_greedy.py',
lineno: 4,
function_string: `
distribution = np.ones((len(choices), action_size),
dtype=np.float32) * eps / action_size
distribution[np.arange(len(choices)), choices] += 1.0 - eps
actions = np.array([
np.random.choice(np.arange(action_size), p=dist)
for dist in distribution
])
return np.expand_dims(actions, axis=1)
`,
},
];
import request from 'superagent';
function code_search_api(str, callback) {
// TODO: make a real request, this is simulated
window.setTimeout(() => {
callback({status: 200, results: results});
}, 2000);
const SEARCH_URL='//localhost:8008/query'
function code_search_api(str) {
return request.get(SEARCH_URL).query({'q': str});
}
export default code_search_api;