mirror of https://github.com/kubeflow/examples.git
Refactor dataflow pipelines (#197)
* Update to a new dataflow package * [WIP] updating docstrings, fixing redundancies * Limit the scope of Github Transform pipeline, make everything unicode * Add ability to start github pipelines from transformed bigquery dataset * Upgrade batch prediction pipeline to be modular * Fix lint errors * Add write disposition to BigQuery transform * Update documentation format * Nicer names for modules * Add unicode encoding to parsed function docstring tuples * Use Apache Beam options parser to expose all CLI arguments
This commit is contained in:
parent
1746820f8f
commit
767c90ff20
|
|
@ -1,2 +1 @@
|
||||||
include requirements.txt
|
include requirements.txt
|
||||||
include files/*
|
|
||||||
|
|
|
||||||
|
|
@ -1,120 +0,0 @@
|
||||||
"""Entrypoint for Dataflow jobs"""
|
|
||||||
|
|
||||||
from __future__ import print_function
|
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
import apache_beam as beam
|
|
||||||
import apache_beam.options.pipeline_options as pipeline_options
|
|
||||||
|
|
||||||
import code_search.transforms.process_github_files as process_github_files
|
|
||||||
import code_search.transforms.code_embed as code_embed
|
|
||||||
|
|
||||||
|
|
||||||
def create_pipeline_opts(args):
|
|
||||||
"""Create standard Pipeline Options for Beam"""
|
|
||||||
|
|
||||||
options = pipeline_options.PipelineOptions()
|
|
||||||
options.view_as(pipeline_options.StandardOptions).runner = args.runner
|
|
||||||
|
|
||||||
google_cloud_options = options.view_as(pipeline_options.GoogleCloudOptions)
|
|
||||||
google_cloud_options.project = args.project
|
|
||||||
if args.runner == 'DataflowRunner':
|
|
||||||
google_cloud_options.job_name = args.job_name
|
|
||||||
google_cloud_options.temp_location = '{}/temp'.format(args.storage_bucket)
|
|
||||||
google_cloud_options.staging_location = '{}/staging'.format(args.storage_bucket)
|
|
||||||
|
|
||||||
worker_options = options.view_as(pipeline_options.WorkerOptions)
|
|
||||||
worker_options.num_workers = args.num_workers
|
|
||||||
worker_options.max_num_workers = args.max_num_workers
|
|
||||||
worker_options.machine_type = args.machine_type
|
|
||||||
|
|
||||||
setup_options = options.view_as(pipeline_options.SetupOptions)
|
|
||||||
setup_options.setup_file = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'setup.py')
|
|
||||||
|
|
||||||
return options
|
|
||||||
|
|
||||||
def parse_arguments(argv):
|
|
||||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
||||||
|
|
||||||
parser.add_argument('-r', '--runner', metavar='', type=str, default='DirectRunner',
|
|
||||||
help='Type of runner - DirectRunner or DataflowRunner')
|
|
||||||
parser.add_argument('-i', '--input', metavar='', type=str, default='',
|
|
||||||
help='Path to input file')
|
|
||||||
parser.add_argument('-o', '--output', metavar='', type=str,
|
|
||||||
help='Output string of the format <dataset>:<table>')
|
|
||||||
|
|
||||||
predict_args_parser = parser.add_argument_group('Batch Prediction Arguments')
|
|
||||||
predict_args_parser.add_argument('--problem', metavar='', type=str,
|
|
||||||
help='Name of the T2T problem')
|
|
||||||
predict_args_parser.add_argument('--data-dir', metavar='', type=str,
|
|
||||||
help='aPath to directory of the T2T problem data')
|
|
||||||
predict_args_parser.add_argument('--saved-model-dir', metavar='', type=str,
|
|
||||||
help='Path to directory containing Tensorflow SavedModel')
|
|
||||||
|
|
||||||
# Dataflow related arguments
|
|
||||||
dataflow_args_parser = parser.add_argument_group('Dataflow Runner Arguments')
|
|
||||||
dataflow_args_parser.add_argument('-p', '--project', metavar='', type=str, default='Project',
|
|
||||||
help='Project ID')
|
|
||||||
dataflow_args_parser.add_argument('-j', '--job-name', metavar='', type=str, default='Beam Job',
|
|
||||||
help='Job name')
|
|
||||||
dataflow_args_parser.add_argument('--storage-bucket', metavar='', type=str, default='gs://bucket',
|
|
||||||
help='Path to Google Storage Bucket')
|
|
||||||
dataflow_args_parser.add_argument('--num-workers', metavar='', type=int, default=1,
|
|
||||||
help='Number of workers')
|
|
||||||
dataflow_args_parser.add_argument('--max-num-workers', metavar='', type=int, default=1,
|
|
||||||
help='Maximum number of workers')
|
|
||||||
dataflow_args_parser.add_argument('--machine-type', metavar='', type=str, default='n1-standard-1',
|
|
||||||
help='Google Cloud Machine Type to use')
|
|
||||||
|
|
||||||
parsed_args = parser.parse_args(argv)
|
|
||||||
return parsed_args
|
|
||||||
|
|
||||||
|
|
||||||
def create_github_pipeline(argv=None):
|
|
||||||
"""Creates the Github source code pre-processing pipeline.
|
|
||||||
|
|
||||||
This pipeline takes an SQL file for BigQuery as an input
|
|
||||||
and puts the results in a file and a new BigQuery table.
|
|
||||||
An SQL file is included with the module.
|
|
||||||
"""
|
|
||||||
args = parse_arguments(argv)
|
|
||||||
|
|
||||||
default_sql_file = os.path.abspath('{}/../../files/select_github_archive.sql'.format(__file__))
|
|
||||||
args.input = args.input or default_sql_file
|
|
||||||
|
|
||||||
pipeline_opts = create_pipeline_opts(args)
|
|
||||||
|
|
||||||
with open(args.input, 'r') as f:
|
|
||||||
query_string = f.read()
|
|
||||||
|
|
||||||
pipeline = beam.Pipeline(options=pipeline_opts)
|
|
||||||
(pipeline #pylint: disable=expression-not-assigned
|
|
||||||
| process_github_files.ProcessGithubFiles(args.project, query_string,
|
|
||||||
args.output, args.storage_bucket)
|
|
||||||
)
|
|
||||||
result = pipeline.run()
|
|
||||||
if args.runner == 'DirectRunner':
|
|
||||||
result.wait_until_finish()
|
|
||||||
|
|
||||||
|
|
||||||
def create_batch_predict_pipeline(argv=None):
|
|
||||||
"""Creates Batch Prediction Pipeline using trained model.
|
|
||||||
|
|
||||||
This pipeline takes in a collection of CSV files returned
|
|
||||||
by the Github Pipeline, embeds the code text using the
|
|
||||||
trained model in a given model directory.
|
|
||||||
"""
|
|
||||||
args = parse_arguments(argv)
|
|
||||||
pipeline_opts = create_pipeline_opts(args)
|
|
||||||
|
|
||||||
pipeline = beam.Pipeline(options=pipeline_opts)
|
|
||||||
(pipeline #pylint: disable=expression-not-assigned
|
|
||||||
| code_embed.GithubBatchPredict(args.project, args.problem,
|
|
||||||
args.data_dir, args.saved_model_dir)
|
|
||||||
)
|
|
||||||
result = pipeline.run()
|
|
||||||
if args.runner == 'DirectRunner':
|
|
||||||
result.wait_until_finish()
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
create_batch_predict_pipeline()
|
|
||||||
|
|
@ -0,0 +1,68 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import apache_beam.options.pipeline_options as pipeline_options
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineCLIOptions(pipeline_options.StandardOptions,
|
||||||
|
pipeline_options.WorkerOptions,
|
||||||
|
pipeline_options.SetupOptions,
|
||||||
|
pipeline_options.GoogleCloudOptions):
|
||||||
|
"""A unified arguments parser.
|
||||||
|
|
||||||
|
This parser directly exposes all the underlying Beam
|
||||||
|
options available to the user (along with some custom
|
||||||
|
arguments). To use, simply pass the arguments list as
|
||||||
|
`PipelineCLIOptions(argv)`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
argv: A list of strings representing CLI options.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _add_argparse_args(cls, parser):
|
||||||
|
add_parser_arguments(parser)
|
||||||
|
|
||||||
|
|
||||||
|
def add_parser_arguments(parser):
|
||||||
|
additional_args_parser = parser.add_argument_group('Custom Arguments')
|
||||||
|
additional_args_parser.add_argument('--target_dataset', metavar='', type=str,
|
||||||
|
help='BigQuery dataset for output results')
|
||||||
|
additional_args_parser.add_argument('--pre_transformed', action='store_true',
|
||||||
|
help='Use a pre-transformed BigQuery dataset')
|
||||||
|
|
||||||
|
predict_args_parser = parser.add_argument_group('Batch Prediction Arguments')
|
||||||
|
predict_args_parser.add_argument('--problem', metavar='', type=str,
|
||||||
|
help='Name of the T2T problem')
|
||||||
|
predict_args_parser.add_argument('--data_dir', metavar='', type=str,
|
||||||
|
help='Path to directory of the T2T problem data')
|
||||||
|
predict_args_parser.add_argument('--saved_model_dir', metavar='', type=str,
|
||||||
|
help='Path to directory containing Tensorflow SavedModel')
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_pipeline_opts(argv=None):
|
||||||
|
"""Prepare pipeline options from CLI arguments.
|
||||||
|
|
||||||
|
This uses the unified PipelineCLIOptions parser
|
||||||
|
and adds modifications on top. It adds a `setup_file`
|
||||||
|
to allow installation of dependencies on Dataflow workers.
|
||||||
|
These implicit changes allow ease-of-use.
|
||||||
|
|
||||||
|
Use `-h` CLI argument to see the list of all possible
|
||||||
|
arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
argv: A list of strings representing the CLI arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A PipelineCLIOptions object whose `_visible_options`
|
||||||
|
contains the parsed Namespace object.
|
||||||
|
"""
|
||||||
|
argv = argv or sys.argv[1:]
|
||||||
|
argv.extend([
|
||||||
|
'--setup_file',
|
||||||
|
os.path.abspath(os.path.join(__file__, '../../../../setup.py')),
|
||||||
|
])
|
||||||
|
|
||||||
|
pipeline_opts = PipelineCLIOptions(flags=argv)
|
||||||
|
|
||||||
|
return pipeline_opts
|
||||||
|
|
@ -0,0 +1,49 @@
|
||||||
|
import apache_beam as beam
|
||||||
|
|
||||||
|
import code_search.dataflow.cli.arguments as arguments
|
||||||
|
import code_search.dataflow.transforms.github_bigquery as gh_bq
|
||||||
|
import code_search.dataflow.transforms.function_embeddings as func_embed
|
||||||
|
import code_search.dataflow.do_fns.dict_to_csv as dict_to_csv
|
||||||
|
|
||||||
|
|
||||||
|
def create_function_embeddings(argv=None):
|
||||||
|
"""Creates Batch Prediction Pipeline using trained model.
|
||||||
|
|
||||||
|
At a high level, this pipeline does the following things:
|
||||||
|
- Read the Processed Github Dataset from BigQuery
|
||||||
|
- Encode the functions using T2T problem
|
||||||
|
- Get function embeddings using `kubeflow_batch_predict.dataflow.batch_prediction`
|
||||||
|
- All results are stored in a BigQuery dataset (`args.target_dataset`)
|
||||||
|
- See `transforms.github_dataset.GithubBatchPredict` for details of tables created
|
||||||
|
- Additionally, store CSV of docstring, original functions and other metadata for
|
||||||
|
reverse index lookup during search engine queries.
|
||||||
|
"""
|
||||||
|
pipeline_opts = arguments.prepare_pipeline_opts(argv)
|
||||||
|
args = pipeline_opts._visible_options # pylint: disable=protected-access
|
||||||
|
|
||||||
|
pipeline = beam.Pipeline(options=pipeline_opts)
|
||||||
|
|
||||||
|
token_pairs = (pipeline
|
||||||
|
| "Read Transformed Github Dataset" >> gh_bq.ReadTransformedGithubDataset(
|
||||||
|
args.project, dataset=args.target_dataset)
|
||||||
|
| "Compute Function Embeddings" >> func_embed.FunctionEmbeddings(args.project,
|
||||||
|
args.target_dataset,
|
||||||
|
args.problem,
|
||||||
|
args.data_dir,
|
||||||
|
args.saved_model_dir)
|
||||||
|
)
|
||||||
|
|
||||||
|
(token_pairs # pylint: disable=expression-not-assigned
|
||||||
|
| "Format for CSV Write" >> beam.ParDo(dict_to_csv.DictToCSVString(
|
||||||
|
['nwo', 'path', 'function_name', 'lineno', 'original_function', 'function_embedding']))
|
||||||
|
| "Write Embeddings to CSV" >> beam.io.WriteToText('{}/func-index'.format(args.data_dir),
|
||||||
|
file_name_suffix='.csv')
|
||||||
|
)
|
||||||
|
|
||||||
|
result = pipeline.run()
|
||||||
|
if args.runner == 'DirectRunner':
|
||||||
|
result.wait_until_finish()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
create_function_embeddings()
|
||||||
|
|
@ -0,0 +1,52 @@
|
||||||
|
import apache_beam as beam
|
||||||
|
|
||||||
|
import code_search.dataflow.cli.arguments as arguments
|
||||||
|
import code_search.dataflow.transforms.github_bigquery as gh_bq
|
||||||
|
import code_search.dataflow.transforms.github_dataset as github_dataset
|
||||||
|
import code_search.dataflow.do_fns.dict_to_csv as dict_to_csv
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_github_dataset(argv=None):
|
||||||
|
"""Apache Beam pipeline for pre-processing Github dataset.
|
||||||
|
|
||||||
|
At a high level, this pipeline does the following things:
|
||||||
|
- Read Github Python files from BigQuery
|
||||||
|
- If Github Python files have already been processed, use the
|
||||||
|
pre-processed table instead (using flag `--pre-transformed`)
|
||||||
|
- Tokenize files into pairs of function definitions and docstrings
|
||||||
|
- All results are stored in a BigQuery dataset (`args.target_dataset`)
|
||||||
|
- See `transforms.github_dataset.TransformGithubDataset` for details of tables created
|
||||||
|
- Additionally, store pairs of docstring and function tokens in a CSV file
|
||||||
|
for training
|
||||||
|
"""
|
||||||
|
pipeline_opts = arguments.prepare_pipeline_opts(argv)
|
||||||
|
args = pipeline_opts._visible_options # pylint: disable=protected-access
|
||||||
|
|
||||||
|
pipeline = beam.Pipeline(options=pipeline_opts)
|
||||||
|
|
||||||
|
if args.pre_transformed:
|
||||||
|
token_pairs = (pipeline
|
||||||
|
| "Read Transformed Github Dataset" >> gh_bq.ReadTransformedGithubDataset(
|
||||||
|
args.project, dataset=args.target_dataset)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
token_pairs = (pipeline
|
||||||
|
| "Read Github Dataset" >> gh_bq.ReadGithubDataset(args.project)
|
||||||
|
| "Transform Github Dataset" >> github_dataset.TransformGithubDataset(args.project,
|
||||||
|
args.target_dataset)
|
||||||
|
)
|
||||||
|
|
||||||
|
(token_pairs # pylint: disable=expression-not-assigned
|
||||||
|
| "Format for CSV Write" >> beam.ParDo(dict_to_csv.DictToCSVString(
|
||||||
|
['docstring_tokens', 'function_tokens']))
|
||||||
|
| "Write CSV" >> beam.io.WriteToText('{}/func-doc-pairs'.format(args.data_dir),
|
||||||
|
file_name_suffix='.csv')
|
||||||
|
)
|
||||||
|
|
||||||
|
result = pipeline.run()
|
||||||
|
if args.runner == 'DirectRunner':
|
||||||
|
result.wait_until_finish()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
preprocess_github_dataset()
|
||||||
|
|
@ -0,0 +1,50 @@
|
||||||
|
import csv
|
||||||
|
import io
|
||||||
|
import apache_beam as beam
|
||||||
|
|
||||||
|
|
||||||
|
class DictToCSVString(beam.DoFn):
|
||||||
|
"""Convert incoming dict to a CSV string.
|
||||||
|
|
||||||
|
This DoFn converts a Python dict into
|
||||||
|
a CSV string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fieldnames: A list of strings representing keys of a dict.
|
||||||
|
"""
|
||||||
|
def __init__(self, fieldnames):
|
||||||
|
super(DictToCSVString, self).__init__()
|
||||||
|
|
||||||
|
self.fieldnames = fieldnames
|
||||||
|
|
||||||
|
def process(self, element, *_args, **_kwargs):
|
||||||
|
"""Convert a Python dict instance into CSV string.
|
||||||
|
|
||||||
|
This routine uses the Python CSV DictReader to
|
||||||
|
robustly convert an input dict to a comma-separated
|
||||||
|
CSV string. This also handles appropriate escaping of
|
||||||
|
characters like the delimiter ",". The dict values
|
||||||
|
must be serializable into a string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
element: A dict mapping string keys to string values.
|
||||||
|
{
|
||||||
|
"key1": "STRING",
|
||||||
|
"key2": "STRING"
|
||||||
|
}
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
A string representing the row in CSV format.
|
||||||
|
"""
|
||||||
|
fieldnames = self.fieldnames
|
||||||
|
filtered_element = {
|
||||||
|
key: value.encode('utf-8')
|
||||||
|
for (key, value) in element.iteritems()
|
||||||
|
if key in fieldnames
|
||||||
|
}
|
||||||
|
with io.BytesIO() as stream:
|
||||||
|
writer = csv.DictWriter(stream, fieldnames)
|
||||||
|
writer.writerow(filtered_element)
|
||||||
|
csv_string = stream.getvalue().strip('\r\n')
|
||||||
|
|
||||||
|
yield csv_string
|
||||||
|
|
@ -0,0 +1,164 @@
|
||||||
|
"""Beam DoFns specific to `code_search.dataflow.transforms.function_embeddings`."""
|
||||||
|
|
||||||
|
import apache_beam as beam
|
||||||
|
|
||||||
|
from code_search.t2t.query import get_encoder, encode_query
|
||||||
|
|
||||||
|
|
||||||
|
class EncodeFunctionTokens(beam.DoFn):
|
||||||
|
"""Encode function tokens.
|
||||||
|
|
||||||
|
This DoFn prepares the function tokens for
|
||||||
|
inference by a SavedModel estimator downstream.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
problem: A string representing the registered Tensor2Tensor Problem.
|
||||||
|
data_dir: A string representing the path to data directory.
|
||||||
|
"""
|
||||||
|
def __init__(self, problem, data_dir):
|
||||||
|
super(EncodeFunctionTokens, self).__init__()
|
||||||
|
|
||||||
|
self.problem = problem
|
||||||
|
self.data_dir = data_dir
|
||||||
|
|
||||||
|
@property
|
||||||
|
def function_tokens_key(self):
|
||||||
|
return u'function_tokens'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def instances_key(self):
|
||||||
|
return u'instances'
|
||||||
|
|
||||||
|
def process(self, element, *_args, **_kwargs):
|
||||||
|
"""Encode the function instance.
|
||||||
|
|
||||||
|
This DoFn takes a tokenized function string and
|
||||||
|
encodes them into a base64 string of TFExample
|
||||||
|
binary format. The "function_tokens" are encoded
|
||||||
|
and stored into the "instances" key in a format
|
||||||
|
ready for consumption by TensorFlow SavedModel
|
||||||
|
estimators. The encoder is provided by a
|
||||||
|
Tensor2Tensor problem as provided in the constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
element: A Python dict of the form,
|
||||||
|
{
|
||||||
|
"nwo": "STRING",
|
||||||
|
"path": "STRING",
|
||||||
|
"function_name": "STRING",
|
||||||
|
"lineno": "STRING",
|
||||||
|
"original_function": "STRING",
|
||||||
|
"function_tokens": "STRING",
|
||||||
|
"docstring_tokens": "STRING",
|
||||||
|
}
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
An updated Python dict of the form
|
||||||
|
{
|
||||||
|
"nwo": "STRING",
|
||||||
|
"path": "STRING",
|
||||||
|
"function_name": "STRING",
|
||||||
|
"lineno": "STRING",
|
||||||
|
"original_function": "STRING",
|
||||||
|
"function_tokens": "STRING",
|
||||||
|
"docstring_tokens": "STRING",
|
||||||
|
"instances": [
|
||||||
|
{
|
||||||
|
"input": {
|
||||||
|
"b64": "STRING",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
encoder = get_encoder(self.problem, self.data_dir)
|
||||||
|
encoded_function = encode_query(encoder, element.get(self.function_tokens_key))
|
||||||
|
|
||||||
|
element[self.instances_key] = [{'input': {'b64': encoded_function}}]
|
||||||
|
yield element
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessFunctionEmbedding(beam.DoFn):
|
||||||
|
"""Process results from PredictionDoFn.
|
||||||
|
|
||||||
|
This is a DoFn for post-processing on inference
|
||||||
|
results from a SavedModel estimator which are
|
||||||
|
returned by the PredictionDoFn.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def function_embedding_key(self):
|
||||||
|
return 'function_embedding'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def predictions_key(self):
|
||||||
|
return 'predictions'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pop_keys(self):
|
||||||
|
return [
|
||||||
|
'predictions',
|
||||||
|
'docstring_tokens',
|
||||||
|
'function_tokens',
|
||||||
|
'instances',
|
||||||
|
]
|
||||||
|
|
||||||
|
def process(self, element, *_args, **_kwargs):
|
||||||
|
"""Post-Process Function embedding.
|
||||||
|
|
||||||
|
This converts the incoming function instance
|
||||||
|
embedding into a serializable string for downstream
|
||||||
|
tasks. It also pops any extraneous keys which are
|
||||||
|
no more required. The "lineno" key is also converted
|
||||||
|
to a string for serializability downstream.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
element: A Python dict of the form,
|
||||||
|
{
|
||||||
|
"nwo": "STRING",
|
||||||
|
"path": "STRING",
|
||||||
|
"function_name": "STRING",
|
||||||
|
"lineno": "STRING",
|
||||||
|
"original_function": "STRING",
|
||||||
|
"function_tokens": "STRING",
|
||||||
|
"docstring_tokens": "STRING",
|
||||||
|
"instances": [
|
||||||
|
{
|
||||||
|
"input": {
|
||||||
|
"b64": "STRING",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"predictions": [
|
||||||
|
{
|
||||||
|
"outputs": [
|
||||||
|
FLOAT,
|
||||||
|
FLOAT,
|
||||||
|
...
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
An update Python dict of the form,
|
||||||
|
{
|
||||||
|
"nwo": "STRING",
|
||||||
|
"path": "STRING",
|
||||||
|
"function_name": "STRING",
|
||||||
|
"lineno": "STRING",
|
||||||
|
"original_function": "STRING",
|
||||||
|
"function_embedding": "STRING",
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
prediction = element.get(self.predictions_key)[0]['outputs']
|
||||||
|
element[self.function_embedding_key] = ','.join([
|
||||||
|
str(val).decode('utf-8') for val in prediction
|
||||||
|
])
|
||||||
|
|
||||||
|
element['lineno'] = str(element['lineno']).decode('utf-8')
|
||||||
|
|
||||||
|
for key in self.pop_keys:
|
||||||
|
element.pop(key)
|
||||||
|
|
||||||
|
yield element
|
||||||
|
|
@ -0,0 +1,126 @@
|
||||||
|
"""Beam DoFns specific to `code_search.dataflow.transforms.github_dataset`."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import apache_beam as beam
|
||||||
|
from apache_beam import pvalue
|
||||||
|
|
||||||
|
|
||||||
|
class SplitRepoPath(beam.DoFn):
|
||||||
|
"""Update element keys to separate repo path and file path.
|
||||||
|
|
||||||
|
This DoFn's only purpose is to be used after
|
||||||
|
`code_search.dataflow.transforms.github_bigquery.ReadGithubDataset`
|
||||||
|
to split the source dictionary key into two target keys.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def source_key(self):
|
||||||
|
return u'repo_path'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def target_keys(self):
|
||||||
|
return [u'nwo', u'path']
|
||||||
|
|
||||||
|
def process(self, element, *_args, **_kwargs):
|
||||||
|
"""Process Python file attributes.
|
||||||
|
|
||||||
|
This simple DoFn splits the `repo_path` into
|
||||||
|
independent properties of owner (`nwo`) and
|
||||||
|
relative file path (`path`). This value is
|
||||||
|
space-delimited and split over the first space
|
||||||
|
is enough.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
element: A Python dict of the form,
|
||||||
|
{
|
||||||
|
"repo_path": "STRING",
|
||||||
|
"content": "STRING",
|
||||||
|
}
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
An updated Python dict of the form,
|
||||||
|
{
|
||||||
|
"nwo": "STRING",
|
||||||
|
"path": "STRING",
|
||||||
|
"content": "STRING",
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
values = element.pop(self.source_key).split(' ', 1)
|
||||||
|
|
||||||
|
for key, value in zip(self.target_keys, values):
|
||||||
|
element[key] = value
|
||||||
|
|
||||||
|
yield element
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizeFunctionDocstrings(beam.DoFn):
|
||||||
|
"""Tokenize function and docstrings.
|
||||||
|
|
||||||
|
This DoFn takes in the rows from BigQuery and tokenizes
|
||||||
|
the file content present in the content key. This
|
||||||
|
yields an updated dictionary with the new tokenized
|
||||||
|
data in the pairs key.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def content_key(self):
|
||||||
|
return 'content'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def info_keys(self):
|
||||||
|
return [
|
||||||
|
u'function_name',
|
||||||
|
u'lineno',
|
||||||
|
u'original_function',
|
||||||
|
u'function_tokens',
|
||||||
|
u'docstring_tokens',
|
||||||
|
]
|
||||||
|
|
||||||
|
def process(self, element, *_args, **_kwargs):
|
||||||
|
"""Get list of Function-Docstring tokens
|
||||||
|
|
||||||
|
This processes each Python file's content
|
||||||
|
and returns a list of metadata for each extracted
|
||||||
|
pair. These contain the tokenized functions and
|
||||||
|
docstrings. In cases where the tokenization fails,
|
||||||
|
a side output is returned. All values are unicode
|
||||||
|
for serialization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
element: A Python dict of the form,
|
||||||
|
{
|
||||||
|
"nwo": "STRING",
|
||||||
|
"path": "STRING",
|
||||||
|
"content": "STRING",
|
||||||
|
}
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
A Python list of the form,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"nwo": "STRING",
|
||||||
|
"path": "STRING",
|
||||||
|
"function_name": "STRING",
|
||||||
|
"lineno": "STRING",
|
||||||
|
"original_function": "STRING",
|
||||||
|
"function_tokens": "STRING",
|
||||||
|
"docstring_tokens": "STRING",
|
||||||
|
},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import code_search.dataflow.utils as utils
|
||||||
|
|
||||||
|
content_blob = element.pop(self.content_key)
|
||||||
|
pairs = utils.get_function_docstring_pairs(content_blob)
|
||||||
|
|
||||||
|
result = [
|
||||||
|
dict(zip(self.info_keys, pair_tuple), **element)
|
||||||
|
for pair_tuple in pairs
|
||||||
|
]
|
||||||
|
|
||||||
|
yield result
|
||||||
|
except Exception as e: # pylint: disable=broad-except
|
||||||
|
logging.warning('Tokenization failed, %s', e.message)
|
||||||
|
yield pvalue.TaggedOutput('err', element)
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import apache_beam as beam
|
import apache_beam as beam
|
||||||
|
import apache_beam.io.gcp.bigquery as bigquery
|
||||||
import apache_beam.io.gcp.internal.clients as clients
|
import apache_beam.io.gcp.internal.clients as clients
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -10,10 +11,12 @@ class BigQueryRead(beam.PTransform):
|
||||||
string.
|
string.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, project):
|
def __init__(self, project, dataset=None, table=None):
|
||||||
super(BigQueryRead, self).__init__()
|
super(BigQueryRead, self).__init__()
|
||||||
|
|
||||||
self.project = project
|
self.project = project
|
||||||
|
self.dataset = dataset
|
||||||
|
self.table = table
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def limit(self):
|
def limit(self):
|
||||||
|
|
@ -47,12 +50,14 @@ class BigQueryWrite(beam.PTransform):
|
||||||
]
|
]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, project, dataset, table, batch_size=500):
|
def __init__(self, project, dataset, table, batch_size=500,
|
||||||
|
write_disposition=bigquery.BigQueryDisposition.WRITE_TRUNCATE):
|
||||||
super(BigQueryWrite, self).__init__()
|
super(BigQueryWrite, self).__init__()
|
||||||
|
|
||||||
self.project = project
|
self.project = project
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.table = table
|
self.table = table
|
||||||
|
self.write_disposition = write_disposition
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
@ -69,7 +74,8 @@ class BigQueryWrite(beam.PTransform):
|
||||||
dataset=self.dataset,
|
dataset=self.dataset,
|
||||||
table=self.table,
|
table=self.table,
|
||||||
schema=self.output_schema,
|
schema=self.output_schema,
|
||||||
batch_size=self.batch_size)
|
batch_size=self.batch_size,
|
||||||
|
write_disposition=self.write_disposition)
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -0,0 +1,46 @@
|
||||||
|
import apache_beam as beam
|
||||||
|
import kubeflow_batch_predict.dataflow.batch_prediction as batch_prediction
|
||||||
|
|
||||||
|
import code_search.dataflow.do_fns.function_embeddings as func_embeddings
|
||||||
|
import code_search.dataflow.transforms.github_bigquery as github_bigquery
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionEmbeddings(beam.PTransform):
|
||||||
|
"""Batch Prediction for Github dataset.
|
||||||
|
|
||||||
|
This Beam pipeline takes in the transformed dataset,
|
||||||
|
prepares each element's function tokens for prediction
|
||||||
|
by encoding it into base64 format and returns an updated
|
||||||
|
dictionary element with the embedding for further processing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, project, target_dataset, problem, data_dir, saved_model_dir):
|
||||||
|
super(FunctionEmbeddings, self).__init__()
|
||||||
|
|
||||||
|
self.project = project
|
||||||
|
self.target_dataset = target_dataset
|
||||||
|
self.problem = problem
|
||||||
|
self.data_dir = data_dir
|
||||||
|
self.saved_model_dir = saved_model_dir
|
||||||
|
|
||||||
|
def expand(self, input_or_inputs):
|
||||||
|
batch_predict = (input_or_inputs
|
||||||
|
| "Encoded Function Tokens" >> beam.ParDo(func_embeddings.EncodeFunctionTokens(
|
||||||
|
self.problem, self.data_dir))
|
||||||
|
| "Compute Function Embeddings" >> beam.ParDo(batch_prediction.PredictionDoFn(),
|
||||||
|
self.saved_model_dir).with_outputs('err',
|
||||||
|
main='main')
|
||||||
|
)
|
||||||
|
|
||||||
|
predictions = batch_predict.main
|
||||||
|
|
||||||
|
formatted_predictions = (predictions
|
||||||
|
| "Process Function Embeddings" >> beam.ParDo(func_embeddings.ProcessFunctionEmbedding())
|
||||||
|
)
|
||||||
|
|
||||||
|
(formatted_predictions # pylint: disable=expression-not-assigned
|
||||||
|
| "Save Function Embeddings" >> github_bigquery.WriteGithubFunctionEmbeddings(
|
||||||
|
self.project, self.target_dataset)
|
||||||
|
)
|
||||||
|
|
||||||
|
return formatted_predictions
|
||||||
|
|
@ -0,0 +1,143 @@
|
||||||
|
import apache_beam.io.gcp.bigquery as bigquery
|
||||||
|
import code_search.dataflow.transforms.bigquery as bq_transform
|
||||||
|
|
||||||
|
|
||||||
|
# Default internal table names
|
||||||
|
PAIRS_TABLE = 'token_pairs'
|
||||||
|
FAILED_TOKENIZE_TABLE = 'failed_tokenize'
|
||||||
|
FUNCTION_EMBEDDINGS_TABLE = 'function_embeddings'
|
||||||
|
|
||||||
|
|
||||||
|
class ReadGithubDataset(bq_transform.BigQueryRead):
|
||||||
|
"""Read original Github files from BigQuery.
|
||||||
|
|
||||||
|
This utility Transform reads Python files
|
||||||
|
from a BigQuery public dump which are smaller
|
||||||
|
than 15k lines of code, contain at least one
|
||||||
|
function definition and its repository has been
|
||||||
|
watched at least twice since 2017.
|
||||||
|
|
||||||
|
NOTE: Make sure to modify the `self.limit` property
|
||||||
|
as desired.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def limit(self):
|
||||||
|
# return 500
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def query_string(self):
|
||||||
|
query = """
|
||||||
|
SELECT
|
||||||
|
MAX(CONCAT(f.repo_name, ' ', f.path)) AS repo_path,
|
||||||
|
c.content
|
||||||
|
FROM
|
||||||
|
`bigquery-public-data.github_repos.files` AS f
|
||||||
|
JOIN
|
||||||
|
`bigquery-public-data.github_repos.contents` AS c
|
||||||
|
ON
|
||||||
|
f.id = c.id
|
||||||
|
JOIN (
|
||||||
|
--this part of the query makes sure repo is watched at least twice since 2017
|
||||||
|
SELECT
|
||||||
|
repo
|
||||||
|
FROM (
|
||||||
|
SELECT
|
||||||
|
repo.name AS repo
|
||||||
|
FROM
|
||||||
|
`githubarchive.year.2017`
|
||||||
|
WHERE
|
||||||
|
type="WatchEvent"
|
||||||
|
UNION ALL
|
||||||
|
SELECT
|
||||||
|
repo.name AS repo
|
||||||
|
FROM
|
||||||
|
`githubarchive.month.2018*`
|
||||||
|
WHERE
|
||||||
|
type="WatchEvent" )
|
||||||
|
GROUP BY
|
||||||
|
1
|
||||||
|
HAVING
|
||||||
|
COUNT(*) >= 2 ) AS r
|
||||||
|
ON
|
||||||
|
f.repo_name = r.repo
|
||||||
|
WHERE
|
||||||
|
f.path LIKE '%.py' AND --with python extension
|
||||||
|
c.size < 15000 AND --get rid of ridiculously long files
|
||||||
|
REGEXP_CONTAINS(c.content, r'def ') --contains function definition
|
||||||
|
GROUP BY
|
||||||
|
c.content
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.limit:
|
||||||
|
query += '\nLIMIT {}'.format(self.limit)
|
||||||
|
return query
|
||||||
|
|
||||||
|
|
||||||
|
class WriteFailedTokenizedData(bq_transform.BigQueryWrite):
|
||||||
|
@property
|
||||||
|
def column_list(self):
|
||||||
|
return [
|
||||||
|
('nwo', 'STRING'),
|
||||||
|
('path', 'STRING'),
|
||||||
|
('content', 'STRING')
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class WriteTokenizedData(bq_transform.BigQueryWrite):
|
||||||
|
@property
|
||||||
|
def column_list(self):
|
||||||
|
return [
|
||||||
|
('nwo', 'STRING'),
|
||||||
|
('path', 'STRING'),
|
||||||
|
('function_name', 'STRING'),
|
||||||
|
('lineno', 'STRING'),
|
||||||
|
('original_function', 'STRING'),
|
||||||
|
('function_tokens', 'STRING'),
|
||||||
|
('docstring_tokens', 'STRING'),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ReadTransformedGithubDataset(bq_transform.BigQueryRead):
|
||||||
|
|
||||||
|
def __init__(self, project, dataset=None, table=PAIRS_TABLE):
|
||||||
|
super(ReadTransformedGithubDataset, self).__init__(project, dataset=dataset, table=table)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def limit(self):
|
||||||
|
# return 500
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def query_string(self):
|
||||||
|
query = """
|
||||||
|
SELECT
|
||||||
|
nwo, path, function_name, lineno, original_function, function_tokens, docstring_tokens
|
||||||
|
FROM
|
||||||
|
{}.{}
|
||||||
|
""".format(self.dataset, self.table)
|
||||||
|
|
||||||
|
if self.limit:
|
||||||
|
query += '\nLIMIT {}'.format(self.limit)
|
||||||
|
return query
|
||||||
|
|
||||||
|
|
||||||
|
class WriteGithubFunctionEmbeddings(bq_transform.BigQueryWrite):
|
||||||
|
|
||||||
|
def __init__(self, project, dataset, table=FUNCTION_EMBEDDINGS_TABLE, batch_size=500,
|
||||||
|
write_disposition=bigquery.BigQueryDisposition.WRITE_TRUNCATE):
|
||||||
|
super(WriteGithubFunctionEmbeddings, self).__init__(project, dataset, table,
|
||||||
|
batch_size=batch_size,
|
||||||
|
write_disposition=write_disposition)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def column_list(self):
|
||||||
|
return [
|
||||||
|
('nwo', 'STRING'),
|
||||||
|
('path', 'STRING'),
|
||||||
|
('function_name', 'STRING'),
|
||||||
|
('lineno', 'STRING'),
|
||||||
|
('original_function', 'STRING'),
|
||||||
|
('function_embedding', 'STRING')
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,61 @@
|
||||||
|
import apache_beam as beam
|
||||||
|
|
||||||
|
import code_search.dataflow.do_fns.github_dataset as gh_do_fns
|
||||||
|
import code_search.dataflow.transforms.github_bigquery as gh_bq
|
||||||
|
|
||||||
|
|
||||||
|
class TransformGithubDataset(beam.PTransform):
|
||||||
|
"""Transform the BigQuery Github Dataset.
|
||||||
|
|
||||||
|
This is a Beam Pipeline which reads the Github Dataset from
|
||||||
|
BigQuery, tokenizes functions and docstrings in Python files,
|
||||||
|
and dumps into a new BigQuery dataset for further processing.
|
||||||
|
All tiny docstrings (smaller than `self.min_docstring_tokens`)
|
||||||
|
are filtered out.
|
||||||
|
|
||||||
|
This transform creates following tables in the `target_dataset`
|
||||||
|
which are defined as properties for easy modification.
|
||||||
|
- `self.failed_tokenize_table`
|
||||||
|
- `self.pairs_table`
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, project, target_dataset,
|
||||||
|
pairs_table=gh_bq.PAIRS_TABLE,
|
||||||
|
failed_tokenize_table=gh_bq.FAILED_TOKENIZE_TABLE):
|
||||||
|
super(TransformGithubDataset, self).__init__()
|
||||||
|
|
||||||
|
self.project = project
|
||||||
|
self.target_dataset = target_dataset
|
||||||
|
self.pairs_table = pairs_table
|
||||||
|
self.failed_tokenize_table = failed_tokenize_table
|
||||||
|
|
||||||
|
@property
|
||||||
|
def min_docstring_tokens(self):
|
||||||
|
return 5
|
||||||
|
|
||||||
|
def expand(self, input_or_inputs):
|
||||||
|
tokenize_result = (input_or_inputs
|
||||||
|
| "Split 'repo_path'" >> beam.ParDo(gh_do_fns.SplitRepoPath())
|
||||||
|
| "Tokenize Code/Docstring Pairs" >> beam.ParDo(
|
||||||
|
gh_do_fns.TokenizeFunctionDocstrings()).with_outputs('err', main='rows')
|
||||||
|
)
|
||||||
|
|
||||||
|
pairs, tokenize_errors = tokenize_result.rows, tokenize_result.err
|
||||||
|
|
||||||
|
(tokenize_errors # pylint: disable=expression-not-assigned
|
||||||
|
| "Failed Tokenization" >> gh_bq.WriteFailedTokenizedData(self.project, self.target_dataset,
|
||||||
|
self.failed_tokenize_table)
|
||||||
|
)
|
||||||
|
|
||||||
|
flat_rows = (pairs
|
||||||
|
| "Flatten Rows" >> beam.FlatMap(lambda x: x)
|
||||||
|
| "Filter Tiny Docstrings" >> beam.Filter(
|
||||||
|
lambda row: len(row['docstring_tokens'].split(' ')) > self.min_docstring_tokens)
|
||||||
|
)
|
||||||
|
|
||||||
|
(flat_rows # pylint: disable=expression-not-assigned
|
||||||
|
| "Save Tokens" >> gh_bq.WriteTokenizedData(self.project, self.target_dataset,
|
||||||
|
self.pairs_table)
|
||||||
|
)
|
||||||
|
|
||||||
|
return flat_rows
|
||||||
|
|
@ -0,0 +1,80 @@
|
||||||
|
import ast
|
||||||
|
import astor
|
||||||
|
import nltk.tokenize as tokenize
|
||||||
|
import spacy
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize_docstring(text):
|
||||||
|
"""Tokenize docstrings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: A docstring to be tokenized.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of strings representing the tokens in the docstring.
|
||||||
|
"""
|
||||||
|
en = spacy.load('en')
|
||||||
|
tokens = en.tokenizer(text.decode('utf8'))
|
||||||
|
return [token.text.lower() for token in tokens if not token.is_space]
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize_code(text):
|
||||||
|
"""Tokenize code strings.
|
||||||
|
|
||||||
|
This simply considers whitespaces as token delimiters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: A code string to be tokenized.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of strings representing the tokens in the code.
|
||||||
|
"""
|
||||||
|
return tokenize.RegexpTokenizer(r'\w+').tokenize(text)
|
||||||
|
|
||||||
|
|
||||||
|
def get_function_docstring_pairs(blob):
|
||||||
|
"""Extract (function/method, docstring) pairs from a given code blob.
|
||||||
|
|
||||||
|
This method reads a string representing a Python file, builds an
|
||||||
|
abstract syntax tree (AST) and returns a list of Docstring and Function
|
||||||
|
pairs along with supporting metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
blob: A string representing the Python file contents.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of tuples of the form:
|
||||||
|
[
|
||||||
|
(
|
||||||
|
function_name,
|
||||||
|
lineno,
|
||||||
|
original_function,
|
||||||
|
function_tokens,
|
||||||
|
docstring_tokens
|
||||||
|
),
|
||||||
|
...
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
pairs = []
|
||||||
|
try:
|
||||||
|
module = ast.parse(blob)
|
||||||
|
classes = [node for node in module.body if isinstance(node, ast.ClassDef)]
|
||||||
|
functions = [node for node in module.body if isinstance(node, ast.FunctionDef)]
|
||||||
|
for _class in classes:
|
||||||
|
functions.extend([node for node in _class.body if isinstance(node, ast.FunctionDef)])
|
||||||
|
|
||||||
|
for f in functions:
|
||||||
|
source = astor.to_source(f)
|
||||||
|
docstring = ast.get_docstring(f) if ast.get_docstring(f) else ''
|
||||||
|
func = source.replace(ast.get_docstring(f, clean=False), '') if docstring else source
|
||||||
|
pair_tuple = (
|
||||||
|
f.name.decode('utf-8'),
|
||||||
|
str(f.lineno).decode('utf-8'),
|
||||||
|
source.decode('utf-8'),
|
||||||
|
' '.join(tokenize_code(func)).decode('utf-8'),
|
||||||
|
' '.join(tokenize_docstring(docstring.split('\n\n')[0])).decode('utf-8'),
|
||||||
|
)
|
||||||
|
pairs.append(pair_tuple)
|
||||||
|
except (AssertionError, MemoryError, SyntaxError, UnicodeEncodeError):
|
||||||
|
pass
|
||||||
|
return pairs
|
||||||
|
|
@ -1,3 +0,0 @@
|
||||||
from code_search.do_fns.github_files import ExtractFuncInfo
|
|
||||||
from code_search.do_fns.github_files import TokenizeCodeDocstring
|
|
||||||
from code_search.do_fns.github_files import SplitRepoPath
|
|
||||||
|
|
@ -1,76 +0,0 @@
|
||||||
"""Beam DoFns for prediction related tasks"""
|
|
||||||
import io
|
|
||||||
import csv
|
|
||||||
from cStringIO import StringIO
|
|
||||||
import apache_beam as beam
|
|
||||||
from code_search.transforms.process_github_files import ProcessGithubFiles
|
|
||||||
from code_search.t2t.query import get_encoder, encode_query
|
|
||||||
|
|
||||||
class GithubCSVToDict(beam.DoFn):
|
|
||||||
"""Split a text row and convert into a dict."""
|
|
||||||
|
|
||||||
def process(self, element): # pylint: disable=no-self-use
|
|
||||||
element = element.encode('utf-8')
|
|
||||||
row = StringIO(element)
|
|
||||||
reader = csv.reader(row, delimiter=',')
|
|
||||||
|
|
||||||
keys = ProcessGithubFiles.get_key_list()
|
|
||||||
values = next(reader) # pylint: disable=stop-iteration-return
|
|
||||||
|
|
||||||
result = dict(zip(keys, values))
|
|
||||||
yield result
|
|
||||||
|
|
||||||
|
|
||||||
class GithubDictToCSV(beam.DoFn):
|
|
||||||
"""Convert dictionary to writable CSV string."""
|
|
||||||
|
|
||||||
def process(self, element): # pylint: disable=no-self-use
|
|
||||||
element['function_embedding'] = ','.join(str(val) for val in element['function_embedding'])
|
|
||||||
|
|
||||||
target_keys = ['nwo', 'path', 'function_name', 'function_embedding']
|
|
||||||
target_values = [element[key].encode('utf-8') for key in target_keys]
|
|
||||||
|
|
||||||
with io.BytesIO() as fs:
|
|
||||||
cw = csv.writer(fs)
|
|
||||||
cw.writerow(target_values)
|
|
||||||
result_str = fs.getvalue().strip('\r\n')
|
|
||||||
|
|
||||||
return result_str
|
|
||||||
|
|
||||||
|
|
||||||
class EncodeExample(beam.DoFn):
|
|
||||||
"""Encode string to integer tokens.
|
|
||||||
|
|
||||||
This is needed so that the data can be sent in
|
|
||||||
for prediction
|
|
||||||
"""
|
|
||||||
def __init__(self, problem, data_dir):
|
|
||||||
super(EncodeExample, self).__init__()
|
|
||||||
|
|
||||||
self.problem = problem
|
|
||||||
self.data_dir = data_dir
|
|
||||||
|
|
||||||
def process(self, element):
|
|
||||||
encoder = get_encoder(self.problem, self.data_dir)
|
|
||||||
encoded_function = encode_query(encoder, element['function_tokens'])
|
|
||||||
|
|
||||||
element['instances'] = [{'input': {'b64': encoded_function}}]
|
|
||||||
yield element
|
|
||||||
|
|
||||||
|
|
||||||
class ProcessPrediction(beam.DoFn):
|
|
||||||
"""Process results from PredictionDoFn.
|
|
||||||
|
|
||||||
This class processes predictions from another
|
|
||||||
DoFn, to make sure it is a correctly formatted dict.
|
|
||||||
"""
|
|
||||||
def process(self, element): # pylint: disable=no-self-use
|
|
||||||
element['function_embedding'] = ','.join([
|
|
||||||
str(val) for val in element['predictions'][0]['outputs']
|
|
||||||
])
|
|
||||||
|
|
||||||
element.pop('function_tokens')
|
|
||||||
element.pop('instances')
|
|
||||||
element.pop('predictions')
|
|
||||||
|
|
||||||
yield element
|
|
||||||
|
|
@ -1,79 +0,0 @@
|
||||||
"""Beam DoFns for Github related tasks"""
|
|
||||||
import time
|
|
||||||
import logging
|
|
||||||
import apache_beam as beam
|
|
||||||
from apache_beam import pvalue
|
|
||||||
from apache_beam.metrics import Metrics
|
|
||||||
|
|
||||||
|
|
||||||
class SplitRepoPath(beam.DoFn):
|
|
||||||
# pylint: disable=abstract-method
|
|
||||||
"""Split the space-delimited file `repo_path` into owner repository (`nwo`)
|
|
||||||
and file path (`path`)"""
|
|
||||||
|
|
||||||
def process(self, element): # pylint: disable=no-self-use
|
|
||||||
nwo, path = element.pop('repo_path').split(' ', 1)
|
|
||||||
element['nwo'] = nwo
|
|
||||||
element['path'] = path
|
|
||||||
yield element
|
|
||||||
|
|
||||||
|
|
||||||
class TokenizeCodeDocstring(beam.DoFn):
|
|
||||||
# pylint: disable=abstract-method
|
|
||||||
"""Compute code/docstring pairs from incoming BigQuery row dict"""
|
|
||||||
def __init__(self):
|
|
||||||
super(TokenizeCodeDocstring, self).__init__()
|
|
||||||
|
|
||||||
self.tokenization_time_ms = Metrics.counter(self.__class__, 'tokenization_time_ms')
|
|
||||||
|
|
||||||
def process(self, element): # pylint: disable=no-self-use
|
|
||||||
try:
|
|
||||||
import code_search.utils as utils
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
element['pairs'] = utils.get_function_docstring_pairs(element.pop('content'))
|
|
||||||
self.tokenization_time_ms.inc(int((time.time() - start_time) * 1000.0))
|
|
||||||
|
|
||||||
yield element
|
|
||||||
except Exception as e: #pylint: disable=broad-except
|
|
||||||
logging.warning('Tokenization failed, %s', e.message)
|
|
||||||
yield pvalue.TaggedOutput('err_rows', element)
|
|
||||||
|
|
||||||
|
|
||||||
class ExtractFuncInfo(beam.DoFn):
|
|
||||||
# pylint: disable=abstract-method
|
|
||||||
"""Convert pair tuples to dict.
|
|
||||||
|
|
||||||
This takes a list of values from `TokenizeCodeDocstring`
|
|
||||||
and converts into a dictionary so that values can be
|
|
||||||
indexed by names instead of indices. `info_keys` is the
|
|
||||||
list of names of those values in order which will become
|
|
||||||
the keys of each new dict.
|
|
||||||
"""
|
|
||||||
def __init__(self, info_keys):
|
|
||||||
super(ExtractFuncInfo, self).__init__()
|
|
||||||
|
|
||||||
self.info_keys = info_keys
|
|
||||||
|
|
||||||
def process(self, element):
|
|
||||||
try:
|
|
||||||
info_rows = [dict(zip(self.info_keys, pair)) for pair in element.pop('pairs')]
|
|
||||||
info_rows = [self.merge_two_dicts(info_dict, element) for info_dict in info_rows]
|
|
||||||
info_rows = map(self.dict_to_unicode, info_rows)
|
|
||||||
yield info_rows
|
|
||||||
except Exception as e: #pylint: disable=broad-except
|
|
||||||
logging.warning('Function Info extraction failed, %s', e.message)
|
|
||||||
yield pvalue.TaggedOutput('err_rows', element)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def merge_two_dicts(dict_a, dict_b):
|
|
||||||
result = dict_a.copy()
|
|
||||||
result.update(dict_b)
|
|
||||||
return result
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def dict_to_unicode(data_dict):
|
|
||||||
for k, v in data_dict.items():
|
|
||||||
if isinstance(v, str):
|
|
||||||
data_dict[k] = v.decode('utf-8', 'ignore')
|
|
||||||
return data_dict
|
|
||||||
|
|
@ -1,54 +0,0 @@
|
||||||
import apache_beam as beam
|
|
||||||
import kubeflow_batch_predict.dataflow.batch_prediction as batch_prediction
|
|
||||||
|
|
||||||
import code_search.do_fns.embeddings as embeddings
|
|
||||||
import code_search.transforms.github_bigquery as github_bigquery
|
|
||||||
|
|
||||||
|
|
||||||
class GithubBatchPredict(beam.PTransform):
|
|
||||||
"""Batch Prediction for Github dataset"""
|
|
||||||
|
|
||||||
def __init__(self, project, problem, data_dir, saved_model_dir):
|
|
||||||
super(GithubBatchPredict, self).__init__()
|
|
||||||
|
|
||||||
self.project = project
|
|
||||||
self.problem = problem
|
|
||||||
self.data_dir = data_dir
|
|
||||||
self.saved_model_dir = saved_model_dir
|
|
||||||
|
|
||||||
##
|
|
||||||
# Target dataset and table to store prediction outputs.
|
|
||||||
# Non-configurable for now.
|
|
||||||
#
|
|
||||||
self.index_dataset = 'code_search'
|
|
||||||
self.index_table = 'search_index'
|
|
||||||
|
|
||||||
self.batch_size = 100
|
|
||||||
|
|
||||||
def expand(self, input_or_inputs):
|
|
||||||
rows = (input_or_inputs
|
|
||||||
| "Read Processed Github Dataset" >> github_bigquery.ReadProcessedGithubData(self.project)
|
|
||||||
)
|
|
||||||
|
|
||||||
batch_predict = (rows
|
|
||||||
| "Prepare Encoded Input" >> beam.ParDo(embeddings.EncodeExample(self.problem,
|
|
||||||
self.data_dir))
|
|
||||||
| "Execute Predictions" >> beam.ParDo(batch_prediction.PredictionDoFn(),
|
|
||||||
self.saved_model_dir).with_outputs("errors",
|
|
||||||
main="main")
|
|
||||||
)
|
|
||||||
|
|
||||||
predictions = batch_predict.main
|
|
||||||
|
|
||||||
formatted_predictions = (predictions
|
|
||||||
| "Process Predictions" >> beam.ParDo(embeddings.ProcessPrediction())
|
|
||||||
)
|
|
||||||
|
|
||||||
(formatted_predictions # pylint: disable=expression-not-assigned
|
|
||||||
| "Save Index Data" >> github_bigquery.WriteGithubIndexData(self.project,
|
|
||||||
self.index_dataset,
|
|
||||||
self.index_table,
|
|
||||||
batch_size=self.batch_size)
|
|
||||||
)
|
|
||||||
|
|
||||||
return formatted_predictions
|
|
||||||
|
|
@ -1,87 +0,0 @@
|
||||||
import code_search.transforms.bigquery as bigquery
|
|
||||||
|
|
||||||
|
|
||||||
class ReadOriginalGithubPythonData(bigquery.BigQueryRead):
|
|
||||||
@property
|
|
||||||
def limit(self):
|
|
||||||
return None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def query_string(self):
|
|
||||||
query = """
|
|
||||||
SELECT
|
|
||||||
MAX(CONCAT(f.repo_name, ' ', f.path)) AS repo_path,
|
|
||||||
c.content
|
|
||||||
FROM
|
|
||||||
`bigquery-public-data.github_repos.files` AS f
|
|
||||||
JOIN
|
|
||||||
`bigquery-public-data.github_repos.contents` AS c
|
|
||||||
ON
|
|
||||||
f.id = c.id
|
|
||||||
JOIN (
|
|
||||||
--this part of the query makes sure repo is watched at least twice since 2017
|
|
||||||
SELECT
|
|
||||||
repo
|
|
||||||
FROM (
|
|
||||||
SELECT
|
|
||||||
repo.name AS repo
|
|
||||||
FROM
|
|
||||||
`githubarchive.year.2017`
|
|
||||||
WHERE
|
|
||||||
type="WatchEvent"
|
|
||||||
UNION ALL
|
|
||||||
SELECT
|
|
||||||
repo.name AS repo
|
|
||||||
FROM
|
|
||||||
`githubarchive.month.2018*`
|
|
||||||
WHERE
|
|
||||||
type="WatchEvent" )
|
|
||||||
GROUP BY
|
|
||||||
1
|
|
||||||
HAVING
|
|
||||||
COUNT(*) >= 2 ) AS r
|
|
||||||
ON
|
|
||||||
f.repo_name = r.repo
|
|
||||||
WHERE
|
|
||||||
f.path LIKE '%.py' AND --with python extension
|
|
||||||
c.size < 15000 AND --get rid of ridiculously long files
|
|
||||||
REGEXP_CONTAINS(c.content, r'def ') --contains function definition
|
|
||||||
GROUP BY
|
|
||||||
c.content
|
|
||||||
"""
|
|
||||||
|
|
||||||
if self.limit:
|
|
||||||
query += '\nLIMIT {}'.format(self.limit)
|
|
||||||
return query
|
|
||||||
|
|
||||||
|
|
||||||
class ReadProcessedGithubData(bigquery.BigQueryRead):
|
|
||||||
@property
|
|
||||||
def limit(self):
|
|
||||||
return 100
|
|
||||||
|
|
||||||
@property
|
|
||||||
def query_string(self):
|
|
||||||
query = """
|
|
||||||
SELECT
|
|
||||||
nwo, path, function_name, lineno, original_function, function_tokens
|
|
||||||
FROM
|
|
||||||
code_search.function_docstrings
|
|
||||||
"""
|
|
||||||
|
|
||||||
if self.limit:
|
|
||||||
query += '\nLIMIT {}'.format(self.limit)
|
|
||||||
return query
|
|
||||||
|
|
||||||
|
|
||||||
class WriteGithubIndexData(bigquery.BigQueryWrite):
|
|
||||||
@property
|
|
||||||
def column_list(self):
|
|
||||||
return [
|
|
||||||
('nwo', 'STRING'),
|
|
||||||
('path', 'STRING'),
|
|
||||||
('function_name', 'STRING'),
|
|
||||||
('lineno', 'INTEGER'),
|
|
||||||
('original_function', 'STRING'),
|
|
||||||
('function_embedding', 'STRING')
|
|
||||||
]
|
|
||||||
|
|
@ -1,133 +0,0 @@
|
||||||
import io
|
|
||||||
import csv
|
|
||||||
import apache_beam as beam
|
|
||||||
import apache_beam.io.gcp.internal.clients as clients
|
|
||||||
|
|
||||||
import code_search.do_fns as do_fns
|
|
||||||
|
|
||||||
|
|
||||||
class ProcessGithubFiles(beam.PTransform):
|
|
||||||
# pylint: disable=too-many-instance-attributes
|
|
||||||
|
|
||||||
"""A collection of `DoFn`s for Pipeline Transform. Reads the Github dataset from BigQuery
|
|
||||||
and writes back the processed code-docstring pairs in a query-friendly format back to BigQuery
|
|
||||||
table.
|
|
||||||
"""
|
|
||||||
data_columns = ['nwo', 'path', 'function_name', 'lineno', 'original_function',
|
|
||||||
'function_tokens', 'docstring_tokens']
|
|
||||||
data_types = ['STRING', 'STRING', 'STRING', 'INTEGER', 'STRING', 'STRING', 'STRING']
|
|
||||||
|
|
||||||
def __init__(self, project, query_string, output_string, storage_bucket):
|
|
||||||
super(ProcessGithubFiles, self).__init__()
|
|
||||||
|
|
||||||
self.project = project
|
|
||||||
self.query_string = query_string
|
|
||||||
self.output_dataset, self.output_table = output_string.split(':')
|
|
||||||
self.storage_bucket = storage_bucket
|
|
||||||
|
|
||||||
self.num_shards = 10
|
|
||||||
|
|
||||||
def expand(self, input_or_inputs):
|
|
||||||
tokenize_result = (input_or_inputs
|
|
||||||
| "Read Github Dataset" >> beam.io.Read(beam.io.BigQuerySource(query=self.query_string,
|
|
||||||
use_standard_sql=True))
|
|
||||||
| "Split 'repo_path'" >> beam.ParDo(do_fns.SplitRepoPath())
|
|
||||||
| "Tokenize Code/Docstring Pairs" >> beam.ParDo(do_fns.TokenizeCodeDocstring())
|
|
||||||
.with_outputs('err_rows', main='rows')
|
|
||||||
)
|
|
||||||
|
|
||||||
#pylint: disable=expression-not-assigned
|
|
||||||
(tokenize_result.err_rows
|
|
||||||
| "Failed Row Tokenization" >> beam.io.WriteToBigQuery(project=self.project,
|
|
||||||
dataset=self.output_dataset,
|
|
||||||
table=self.output_table + '_failed',
|
|
||||||
schema=self.create_failed_output_schema())
|
|
||||||
)
|
|
||||||
# pylint: enable=expression-not-assigned
|
|
||||||
|
|
||||||
|
|
||||||
info_result = (tokenize_result.rows
|
|
||||||
| "Extract Function Info" >> beam.ParDo(do_fns.ExtractFuncInfo(self.data_columns[2:]))
|
|
||||||
.with_outputs('err_rows', main='rows')
|
|
||||||
)
|
|
||||||
|
|
||||||
#pylint: disable=expression-not-assigned
|
|
||||||
(info_result.err_rows
|
|
||||||
| "Failed Function Info" >> beam.io.WriteToBigQuery(project=self.project,
|
|
||||||
dataset=self.output_dataset,
|
|
||||||
table=self.output_table + '_failed',
|
|
||||||
schema=self.create_failed_output_schema())
|
|
||||||
)
|
|
||||||
# pylint: enable=expression-not-assigned
|
|
||||||
|
|
||||||
processed_rows = (info_result.rows | "Flatten Rows" >> beam.FlatMap(lambda x: x))
|
|
||||||
|
|
||||||
# pylint: disable=expression-not-assigned
|
|
||||||
(processed_rows
|
|
||||||
| "Filter Tiny Docstrings" >> beam.Filter(
|
|
||||||
lambda row: len(row['docstring_tokens'].split(' ')) > 5)
|
|
||||||
| "Format For Write" >> beam.Map(self.format_for_write)
|
|
||||||
| "Write To File" >> beam.io.WriteToText('{}/data/pairs'.format(self.storage_bucket),
|
|
||||||
file_name_suffix='.csv',
|
|
||||||
num_shards=self.num_shards))
|
|
||||||
# pylint: enable=expression-not-assigned
|
|
||||||
|
|
||||||
return (processed_rows
|
|
||||||
| "Save Tokens" >> beam.io.WriteToBigQuery(project=self.project,
|
|
||||||
dataset=self.output_dataset,
|
|
||||||
table=self.output_table,
|
|
||||||
schema=self.create_output_schema())
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_key_list():
|
|
||||||
filter_keys = [
|
|
||||||
'original_function',
|
|
||||||
'lineno',
|
|
||||||
]
|
|
||||||
key_list = [col for col in ProcessGithubFiles.data_columns
|
|
||||||
if col not in filter_keys]
|
|
||||||
return key_list
|
|
||||||
|
|
||||||
def format_for_write(self, row):
|
|
||||||
"""This method filters keys that we don't need in the
|
|
||||||
final CSV. It must ensure that there are no multi-line
|
|
||||||
column fields. For instance, 'original_function' is a
|
|
||||||
multi-line string and makes CSV parsing hard for any
|
|
||||||
derived Dataflow steps. This uses the CSV Writer
|
|
||||||
to handle all edge cases like quote escaping."""
|
|
||||||
|
|
||||||
target_keys = self.get_key_list()
|
|
||||||
target_values = [row[key].encode('utf-8') for key in target_keys]
|
|
||||||
|
|
||||||
with io.BytesIO() as fs:
|
|
||||||
cw = csv.writer(fs)
|
|
||||||
cw.writerow(target_values)
|
|
||||||
result_str = fs.getvalue().strip('\r\n')
|
|
||||||
|
|
||||||
return result_str
|
|
||||||
|
|
||||||
def create_output_schema(self):
|
|
||||||
table_schema = clients.bigquery.TableSchema()
|
|
||||||
|
|
||||||
for column, data_type in zip(self.data_columns, self.data_types):
|
|
||||||
field_schema = clients.bigquery.TableFieldSchema()
|
|
||||||
field_schema.name = column
|
|
||||||
field_schema.type = data_type
|
|
||||||
field_schema.mode = 'nullable'
|
|
||||||
table_schema.fields.append(field_schema)
|
|
||||||
|
|
||||||
return table_schema
|
|
||||||
|
|
||||||
def create_failed_output_schema(self):
|
|
||||||
table_schema = clients.bigquery.TableSchema()
|
|
||||||
|
|
||||||
for column, data_type in zip(self.data_columns[:2] + ['content'],
|
|
||||||
self.data_types[:2] + ['STRING']):
|
|
||||||
field_schema = clients.bigquery.TableFieldSchema()
|
|
||||||
field_schema.name = column
|
|
||||||
field_schema.type = data_type
|
|
||||||
field_schema.mode = 'nullable'
|
|
||||||
table_schema.fields.append(field_schema)
|
|
||||||
|
|
||||||
return table_schema
|
|
||||||
|
|
@ -1,38 +0,0 @@
|
||||||
import ast
|
|
||||||
import astor
|
|
||||||
import nltk.tokenize as tokenize
|
|
||||||
import spacy
|
|
||||||
|
|
||||||
|
|
||||||
def tokenize_docstring(text):
|
|
||||||
"""Apply tokenization using spacy to docstrings."""
|
|
||||||
en = spacy.load('en')
|
|
||||||
tokens = en.tokenizer(text.decode('utf8', 'ignore'))
|
|
||||||
return [token.text.lower() for token in tokens if not token.is_space]
|
|
||||||
|
|
||||||
|
|
||||||
def tokenize_code(text):
|
|
||||||
"""A very basic procedure for tokenizing code strings."""
|
|
||||||
return tokenize.RegexpTokenizer(r'\w+').tokenize(text)
|
|
||||||
|
|
||||||
|
|
||||||
def get_function_docstring_pairs(blob):
|
|
||||||
"""Extract (function/method, docstring) pairs from a given code blob."""
|
|
||||||
pairs = []
|
|
||||||
try:
|
|
||||||
module = ast.parse(blob)
|
|
||||||
classes = [node for node in module.body if isinstance(node, ast.ClassDef)]
|
|
||||||
functions = [node for node in module.body if isinstance(node, ast.FunctionDef)]
|
|
||||||
for _class in classes:
|
|
||||||
functions.extend([node for node in _class.body if isinstance(node, ast.FunctionDef)])
|
|
||||||
|
|
||||||
for f in functions:
|
|
||||||
source = astor.to_source(f)
|
|
||||||
docstring = ast.get_docstring(f) if ast.get_docstring(f) else ''
|
|
||||||
func = source.replace(ast.get_docstring(f, clean=False), '') if docstring else source
|
|
||||||
|
|
||||||
pairs.append((f.name, f.lineno, source, ' '.join(tokenize_code(func)),
|
|
||||||
' '.join(tokenize_docstring(docstring.split('\n\n')[0]))))
|
|
||||||
except (AssertionError, MemoryError, SyntaxError, UnicodeEncodeError):
|
|
||||||
pass
|
|
||||||
return pairs
|
|
||||||
|
|
@ -1,39 +0,0 @@
|
||||||
SELECT
|
|
||||||
MAX(CONCAT(f.repo_name, ' ', f.path)) AS repo_path,
|
|
||||||
c.content
|
|
||||||
FROM
|
|
||||||
`bigquery-public-data.github_repos.files` AS f
|
|
||||||
JOIN
|
|
||||||
`bigquery-public-data.github_repos.contents` AS c
|
|
||||||
ON
|
|
||||||
f.id = c.id
|
|
||||||
JOIN (
|
|
||||||
--this part of the query makes sure repo is watched at least twice since 2017
|
|
||||||
SELECT
|
|
||||||
repo
|
|
||||||
FROM (
|
|
||||||
SELECT
|
|
||||||
repo.name AS repo
|
|
||||||
FROM
|
|
||||||
`githubarchive.year.2017`
|
|
||||||
WHERE
|
|
||||||
type="WatchEvent"
|
|
||||||
UNION ALL
|
|
||||||
SELECT
|
|
||||||
repo.name AS repo
|
|
||||||
FROM
|
|
||||||
`githubarchive.month.2018*`
|
|
||||||
WHERE
|
|
||||||
type="WatchEvent" )
|
|
||||||
GROUP BY
|
|
||||||
1
|
|
||||||
HAVING
|
|
||||||
COUNT(*) >= 2 ) AS r
|
|
||||||
ON
|
|
||||||
f.repo_name = r.repo
|
|
||||||
WHERE
|
|
||||||
f.path LIKE '%.py' AND --with python extension
|
|
||||||
c.size < 15000 AND --get rid of ridiculously long files
|
|
||||||
REGEXP_CONTAINS(c.content, r'def ') --contains function definition
|
|
||||||
GROUP BY
|
|
||||||
c.content
|
|
||||||
|
|
@ -60,8 +60,6 @@ setup(name='code-search',
|
||||||
},
|
},
|
||||||
entry_points={
|
entry_points={
|
||||||
'console_scripts': [
|
'console_scripts': [
|
||||||
'code-search-preprocess=code_search.cli:create_github_pipeline',
|
|
||||||
'code-search-predict=code_search.cli:create_batch_predict_pipeline',
|
|
||||||
'nmslib-serve=code_search.nmslib.cli:server',
|
'nmslib-serve=code_search.nmslib.cli:server',
|
||||||
'nmslib-create=code_search.nmslib.cli:creator',
|
'nmslib-create=code_search.nmslib.cli:creator',
|
||||||
]
|
]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue