mirror of https://github.com/kubeflow/examples.git
Refactor dataflow pipelines (#197)
* Update to a new dataflow package * [WIP] updating docstrings, fixing redundancies * Limit the scope of Github Transform pipeline, make everything unicode * Add ability to start github pipelines from transformed bigquery dataset * Upgrade batch prediction pipeline to be modular * Fix lint errors * Add write disposition to BigQuery transform * Update documentation format * Nicer names for modules * Add unicode encoding to parsed function docstring tuples * Use Apache Beam options parser to expose all CLI arguments
This commit is contained in:
parent
1746820f8f
commit
767c90ff20
|
|
@ -1,2 +1 @@
|
|||
include requirements.txt
|
||||
include files/*
|
||||
|
|
|
|||
|
|
@ -1,120 +0,0 @@
|
|||
"""Entrypoint for Dataflow jobs"""
|
||||
|
||||
from __future__ import print_function
|
||||
import argparse
|
||||
import os
|
||||
import apache_beam as beam
|
||||
import apache_beam.options.pipeline_options as pipeline_options
|
||||
|
||||
import code_search.transforms.process_github_files as process_github_files
|
||||
import code_search.transforms.code_embed as code_embed
|
||||
|
||||
|
||||
def create_pipeline_opts(args):
|
||||
"""Create standard Pipeline Options for Beam"""
|
||||
|
||||
options = pipeline_options.PipelineOptions()
|
||||
options.view_as(pipeline_options.StandardOptions).runner = args.runner
|
||||
|
||||
google_cloud_options = options.view_as(pipeline_options.GoogleCloudOptions)
|
||||
google_cloud_options.project = args.project
|
||||
if args.runner == 'DataflowRunner':
|
||||
google_cloud_options.job_name = args.job_name
|
||||
google_cloud_options.temp_location = '{}/temp'.format(args.storage_bucket)
|
||||
google_cloud_options.staging_location = '{}/staging'.format(args.storage_bucket)
|
||||
|
||||
worker_options = options.view_as(pipeline_options.WorkerOptions)
|
||||
worker_options.num_workers = args.num_workers
|
||||
worker_options.max_num_workers = args.max_num_workers
|
||||
worker_options.machine_type = args.machine_type
|
||||
|
||||
setup_options = options.view_as(pipeline_options.SetupOptions)
|
||||
setup_options.setup_file = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'setup.py')
|
||||
|
||||
return options
|
||||
|
||||
def parse_arguments(argv):
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
|
||||
parser.add_argument('-r', '--runner', metavar='', type=str, default='DirectRunner',
|
||||
help='Type of runner - DirectRunner or DataflowRunner')
|
||||
parser.add_argument('-i', '--input', metavar='', type=str, default='',
|
||||
help='Path to input file')
|
||||
parser.add_argument('-o', '--output', metavar='', type=str,
|
||||
help='Output string of the format <dataset>:<table>')
|
||||
|
||||
predict_args_parser = parser.add_argument_group('Batch Prediction Arguments')
|
||||
predict_args_parser.add_argument('--problem', metavar='', type=str,
|
||||
help='Name of the T2T problem')
|
||||
predict_args_parser.add_argument('--data-dir', metavar='', type=str,
|
||||
help='aPath to directory of the T2T problem data')
|
||||
predict_args_parser.add_argument('--saved-model-dir', metavar='', type=str,
|
||||
help='Path to directory containing Tensorflow SavedModel')
|
||||
|
||||
# Dataflow related arguments
|
||||
dataflow_args_parser = parser.add_argument_group('Dataflow Runner Arguments')
|
||||
dataflow_args_parser.add_argument('-p', '--project', metavar='', type=str, default='Project',
|
||||
help='Project ID')
|
||||
dataflow_args_parser.add_argument('-j', '--job-name', metavar='', type=str, default='Beam Job',
|
||||
help='Job name')
|
||||
dataflow_args_parser.add_argument('--storage-bucket', metavar='', type=str, default='gs://bucket',
|
||||
help='Path to Google Storage Bucket')
|
||||
dataflow_args_parser.add_argument('--num-workers', metavar='', type=int, default=1,
|
||||
help='Number of workers')
|
||||
dataflow_args_parser.add_argument('--max-num-workers', metavar='', type=int, default=1,
|
||||
help='Maximum number of workers')
|
||||
dataflow_args_parser.add_argument('--machine-type', metavar='', type=str, default='n1-standard-1',
|
||||
help='Google Cloud Machine Type to use')
|
||||
|
||||
parsed_args = parser.parse_args(argv)
|
||||
return parsed_args
|
||||
|
||||
|
||||
def create_github_pipeline(argv=None):
|
||||
"""Creates the Github source code pre-processing pipeline.
|
||||
|
||||
This pipeline takes an SQL file for BigQuery as an input
|
||||
and puts the results in a file and a new BigQuery table.
|
||||
An SQL file is included with the module.
|
||||
"""
|
||||
args = parse_arguments(argv)
|
||||
|
||||
default_sql_file = os.path.abspath('{}/../../files/select_github_archive.sql'.format(__file__))
|
||||
args.input = args.input or default_sql_file
|
||||
|
||||
pipeline_opts = create_pipeline_opts(args)
|
||||
|
||||
with open(args.input, 'r') as f:
|
||||
query_string = f.read()
|
||||
|
||||
pipeline = beam.Pipeline(options=pipeline_opts)
|
||||
(pipeline #pylint: disable=expression-not-assigned
|
||||
| process_github_files.ProcessGithubFiles(args.project, query_string,
|
||||
args.output, args.storage_bucket)
|
||||
)
|
||||
result = pipeline.run()
|
||||
if args.runner == 'DirectRunner':
|
||||
result.wait_until_finish()
|
||||
|
||||
|
||||
def create_batch_predict_pipeline(argv=None):
|
||||
"""Creates Batch Prediction Pipeline using trained model.
|
||||
|
||||
This pipeline takes in a collection of CSV files returned
|
||||
by the Github Pipeline, embeds the code text using the
|
||||
trained model in a given model directory.
|
||||
"""
|
||||
args = parse_arguments(argv)
|
||||
pipeline_opts = create_pipeline_opts(args)
|
||||
|
||||
pipeline = beam.Pipeline(options=pipeline_opts)
|
||||
(pipeline #pylint: disable=expression-not-assigned
|
||||
| code_embed.GithubBatchPredict(args.project, args.problem,
|
||||
args.data_dir, args.saved_model_dir)
|
||||
)
|
||||
result = pipeline.run()
|
||||
if args.runner == 'DirectRunner':
|
||||
result.wait_until_finish()
|
||||
|
||||
if __name__ == '__main__':
|
||||
create_batch_predict_pipeline()
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
import os
|
||||
import sys
|
||||
import apache_beam.options.pipeline_options as pipeline_options
|
||||
|
||||
|
||||
class PipelineCLIOptions(pipeline_options.StandardOptions,
|
||||
pipeline_options.WorkerOptions,
|
||||
pipeline_options.SetupOptions,
|
||||
pipeline_options.GoogleCloudOptions):
|
||||
"""A unified arguments parser.
|
||||
|
||||
This parser directly exposes all the underlying Beam
|
||||
options available to the user (along with some custom
|
||||
arguments). To use, simply pass the arguments list as
|
||||
`PipelineCLIOptions(argv)`.
|
||||
|
||||
Args:
|
||||
argv: A list of strings representing CLI options.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _add_argparse_args(cls, parser):
|
||||
add_parser_arguments(parser)
|
||||
|
||||
|
||||
def add_parser_arguments(parser):
|
||||
additional_args_parser = parser.add_argument_group('Custom Arguments')
|
||||
additional_args_parser.add_argument('--target_dataset', metavar='', type=str,
|
||||
help='BigQuery dataset for output results')
|
||||
additional_args_parser.add_argument('--pre_transformed', action='store_true',
|
||||
help='Use a pre-transformed BigQuery dataset')
|
||||
|
||||
predict_args_parser = parser.add_argument_group('Batch Prediction Arguments')
|
||||
predict_args_parser.add_argument('--problem', metavar='', type=str,
|
||||
help='Name of the T2T problem')
|
||||
predict_args_parser.add_argument('--data_dir', metavar='', type=str,
|
||||
help='Path to directory of the T2T problem data')
|
||||
predict_args_parser.add_argument('--saved_model_dir', metavar='', type=str,
|
||||
help='Path to directory containing Tensorflow SavedModel')
|
||||
|
||||
|
||||
def prepare_pipeline_opts(argv=None):
|
||||
"""Prepare pipeline options from CLI arguments.
|
||||
|
||||
This uses the unified PipelineCLIOptions parser
|
||||
and adds modifications on top. It adds a `setup_file`
|
||||
to allow installation of dependencies on Dataflow workers.
|
||||
These implicit changes allow ease-of-use.
|
||||
|
||||
Use `-h` CLI argument to see the list of all possible
|
||||
arguments.
|
||||
|
||||
Args:
|
||||
argv: A list of strings representing the CLI arguments.
|
||||
|
||||
Returns:
|
||||
A PipelineCLIOptions object whose `_visible_options`
|
||||
contains the parsed Namespace object.
|
||||
"""
|
||||
argv = argv or sys.argv[1:]
|
||||
argv.extend([
|
||||
'--setup_file',
|
||||
os.path.abspath(os.path.join(__file__, '../../../../setup.py')),
|
||||
])
|
||||
|
||||
pipeline_opts = PipelineCLIOptions(flags=argv)
|
||||
|
||||
return pipeline_opts
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
import apache_beam as beam
|
||||
|
||||
import code_search.dataflow.cli.arguments as arguments
|
||||
import code_search.dataflow.transforms.github_bigquery as gh_bq
|
||||
import code_search.dataflow.transforms.function_embeddings as func_embed
|
||||
import code_search.dataflow.do_fns.dict_to_csv as dict_to_csv
|
||||
|
||||
|
||||
def create_function_embeddings(argv=None):
|
||||
"""Creates Batch Prediction Pipeline using trained model.
|
||||
|
||||
At a high level, this pipeline does the following things:
|
||||
- Read the Processed Github Dataset from BigQuery
|
||||
- Encode the functions using T2T problem
|
||||
- Get function embeddings using `kubeflow_batch_predict.dataflow.batch_prediction`
|
||||
- All results are stored in a BigQuery dataset (`args.target_dataset`)
|
||||
- See `transforms.github_dataset.GithubBatchPredict` for details of tables created
|
||||
- Additionally, store CSV of docstring, original functions and other metadata for
|
||||
reverse index lookup during search engine queries.
|
||||
"""
|
||||
pipeline_opts = arguments.prepare_pipeline_opts(argv)
|
||||
args = pipeline_opts._visible_options # pylint: disable=protected-access
|
||||
|
||||
pipeline = beam.Pipeline(options=pipeline_opts)
|
||||
|
||||
token_pairs = (pipeline
|
||||
| "Read Transformed Github Dataset" >> gh_bq.ReadTransformedGithubDataset(
|
||||
args.project, dataset=args.target_dataset)
|
||||
| "Compute Function Embeddings" >> func_embed.FunctionEmbeddings(args.project,
|
||||
args.target_dataset,
|
||||
args.problem,
|
||||
args.data_dir,
|
||||
args.saved_model_dir)
|
||||
)
|
||||
|
||||
(token_pairs # pylint: disable=expression-not-assigned
|
||||
| "Format for CSV Write" >> beam.ParDo(dict_to_csv.DictToCSVString(
|
||||
['nwo', 'path', 'function_name', 'lineno', 'original_function', 'function_embedding']))
|
||||
| "Write Embeddings to CSV" >> beam.io.WriteToText('{}/func-index'.format(args.data_dir),
|
||||
file_name_suffix='.csv')
|
||||
)
|
||||
|
||||
result = pipeline.run()
|
||||
if args.runner == 'DirectRunner':
|
||||
result.wait_until_finish()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
create_function_embeddings()
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
import apache_beam as beam
|
||||
|
||||
import code_search.dataflow.cli.arguments as arguments
|
||||
import code_search.dataflow.transforms.github_bigquery as gh_bq
|
||||
import code_search.dataflow.transforms.github_dataset as github_dataset
|
||||
import code_search.dataflow.do_fns.dict_to_csv as dict_to_csv
|
||||
|
||||
|
||||
def preprocess_github_dataset(argv=None):
|
||||
"""Apache Beam pipeline for pre-processing Github dataset.
|
||||
|
||||
At a high level, this pipeline does the following things:
|
||||
- Read Github Python files from BigQuery
|
||||
- If Github Python files have already been processed, use the
|
||||
pre-processed table instead (using flag `--pre-transformed`)
|
||||
- Tokenize files into pairs of function definitions and docstrings
|
||||
- All results are stored in a BigQuery dataset (`args.target_dataset`)
|
||||
- See `transforms.github_dataset.TransformGithubDataset` for details of tables created
|
||||
- Additionally, store pairs of docstring and function tokens in a CSV file
|
||||
for training
|
||||
"""
|
||||
pipeline_opts = arguments.prepare_pipeline_opts(argv)
|
||||
args = pipeline_opts._visible_options # pylint: disable=protected-access
|
||||
|
||||
pipeline = beam.Pipeline(options=pipeline_opts)
|
||||
|
||||
if args.pre_transformed:
|
||||
token_pairs = (pipeline
|
||||
| "Read Transformed Github Dataset" >> gh_bq.ReadTransformedGithubDataset(
|
||||
args.project, dataset=args.target_dataset)
|
||||
)
|
||||
else:
|
||||
token_pairs = (pipeline
|
||||
| "Read Github Dataset" >> gh_bq.ReadGithubDataset(args.project)
|
||||
| "Transform Github Dataset" >> github_dataset.TransformGithubDataset(args.project,
|
||||
args.target_dataset)
|
||||
)
|
||||
|
||||
(token_pairs # pylint: disable=expression-not-assigned
|
||||
| "Format for CSV Write" >> beam.ParDo(dict_to_csv.DictToCSVString(
|
||||
['docstring_tokens', 'function_tokens']))
|
||||
| "Write CSV" >> beam.io.WriteToText('{}/func-doc-pairs'.format(args.data_dir),
|
||||
file_name_suffix='.csv')
|
||||
)
|
||||
|
||||
result = pipeline.run()
|
||||
if args.runner == 'DirectRunner':
|
||||
result.wait_until_finish()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
preprocess_github_dataset()
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
import csv
|
||||
import io
|
||||
import apache_beam as beam
|
||||
|
||||
|
||||
class DictToCSVString(beam.DoFn):
|
||||
"""Convert incoming dict to a CSV string.
|
||||
|
||||
This DoFn converts a Python dict into
|
||||
a CSV string.
|
||||
|
||||
Args:
|
||||
fieldnames: A list of strings representing keys of a dict.
|
||||
"""
|
||||
def __init__(self, fieldnames):
|
||||
super(DictToCSVString, self).__init__()
|
||||
|
||||
self.fieldnames = fieldnames
|
||||
|
||||
def process(self, element, *_args, **_kwargs):
|
||||
"""Convert a Python dict instance into CSV string.
|
||||
|
||||
This routine uses the Python CSV DictReader to
|
||||
robustly convert an input dict to a comma-separated
|
||||
CSV string. This also handles appropriate escaping of
|
||||
characters like the delimiter ",". The dict values
|
||||
must be serializable into a string.
|
||||
|
||||
Args:
|
||||
element: A dict mapping string keys to string values.
|
||||
{
|
||||
"key1": "STRING",
|
||||
"key2": "STRING"
|
||||
}
|
||||
|
||||
Yields:
|
||||
A string representing the row in CSV format.
|
||||
"""
|
||||
fieldnames = self.fieldnames
|
||||
filtered_element = {
|
||||
key: value.encode('utf-8')
|
||||
for (key, value) in element.iteritems()
|
||||
if key in fieldnames
|
||||
}
|
||||
with io.BytesIO() as stream:
|
||||
writer = csv.DictWriter(stream, fieldnames)
|
||||
writer.writerow(filtered_element)
|
||||
csv_string = stream.getvalue().strip('\r\n')
|
||||
|
||||
yield csv_string
|
||||
|
|
@ -0,0 +1,164 @@
|
|||
"""Beam DoFns specific to `code_search.dataflow.transforms.function_embeddings`."""
|
||||
|
||||
import apache_beam as beam
|
||||
|
||||
from code_search.t2t.query import get_encoder, encode_query
|
||||
|
||||
|
||||
class EncodeFunctionTokens(beam.DoFn):
|
||||
"""Encode function tokens.
|
||||
|
||||
This DoFn prepares the function tokens for
|
||||
inference by a SavedModel estimator downstream.
|
||||
|
||||
Args:
|
||||
problem: A string representing the registered Tensor2Tensor Problem.
|
||||
data_dir: A string representing the path to data directory.
|
||||
"""
|
||||
def __init__(self, problem, data_dir):
|
||||
super(EncodeFunctionTokens, self).__init__()
|
||||
|
||||
self.problem = problem
|
||||
self.data_dir = data_dir
|
||||
|
||||
@property
|
||||
def function_tokens_key(self):
|
||||
return u'function_tokens'
|
||||
|
||||
@property
|
||||
def instances_key(self):
|
||||
return u'instances'
|
||||
|
||||
def process(self, element, *_args, **_kwargs):
|
||||
"""Encode the function instance.
|
||||
|
||||
This DoFn takes a tokenized function string and
|
||||
encodes them into a base64 string of TFExample
|
||||
binary format. The "function_tokens" are encoded
|
||||
and stored into the "instances" key in a format
|
||||
ready for consumption by TensorFlow SavedModel
|
||||
estimators. The encoder is provided by a
|
||||
Tensor2Tensor problem as provided in the constructor.
|
||||
|
||||
Args:
|
||||
element: A Python dict of the form,
|
||||
{
|
||||
"nwo": "STRING",
|
||||
"path": "STRING",
|
||||
"function_name": "STRING",
|
||||
"lineno": "STRING",
|
||||
"original_function": "STRING",
|
||||
"function_tokens": "STRING",
|
||||
"docstring_tokens": "STRING",
|
||||
}
|
||||
|
||||
Yields:
|
||||
An updated Python dict of the form
|
||||
{
|
||||
"nwo": "STRING",
|
||||
"path": "STRING",
|
||||
"function_name": "STRING",
|
||||
"lineno": "STRING",
|
||||
"original_function": "STRING",
|
||||
"function_tokens": "STRING",
|
||||
"docstring_tokens": "STRING",
|
||||
"instances": [
|
||||
{
|
||||
"input": {
|
||||
"b64": "STRING",
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
encoder = get_encoder(self.problem, self.data_dir)
|
||||
encoded_function = encode_query(encoder, element.get(self.function_tokens_key))
|
||||
|
||||
element[self.instances_key] = [{'input': {'b64': encoded_function}}]
|
||||
yield element
|
||||
|
||||
|
||||
class ProcessFunctionEmbedding(beam.DoFn):
|
||||
"""Process results from PredictionDoFn.
|
||||
|
||||
This is a DoFn for post-processing on inference
|
||||
results from a SavedModel estimator which are
|
||||
returned by the PredictionDoFn.
|
||||
"""
|
||||
|
||||
@property
|
||||
def function_embedding_key(self):
|
||||
return 'function_embedding'
|
||||
|
||||
@property
|
||||
def predictions_key(self):
|
||||
return 'predictions'
|
||||
|
||||
@property
|
||||
def pop_keys(self):
|
||||
return [
|
||||
'predictions',
|
||||
'docstring_tokens',
|
||||
'function_tokens',
|
||||
'instances',
|
||||
]
|
||||
|
||||
def process(self, element, *_args, **_kwargs):
|
||||
"""Post-Process Function embedding.
|
||||
|
||||
This converts the incoming function instance
|
||||
embedding into a serializable string for downstream
|
||||
tasks. It also pops any extraneous keys which are
|
||||
no more required. The "lineno" key is also converted
|
||||
to a string for serializability downstream.
|
||||
|
||||
Args:
|
||||
element: A Python dict of the form,
|
||||
{
|
||||
"nwo": "STRING",
|
||||
"path": "STRING",
|
||||
"function_name": "STRING",
|
||||
"lineno": "STRING",
|
||||
"original_function": "STRING",
|
||||
"function_tokens": "STRING",
|
||||
"docstring_tokens": "STRING",
|
||||
"instances": [
|
||||
{
|
||||
"input": {
|
||||
"b64": "STRING",
|
||||
}
|
||||
}
|
||||
],
|
||||
"predictions": [
|
||||
{
|
||||
"outputs": [
|
||||
FLOAT,
|
||||
FLOAT,
|
||||
...
|
||||
]
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
Yields:
|
||||
An update Python dict of the form,
|
||||
{
|
||||
"nwo": "STRING",
|
||||
"path": "STRING",
|
||||
"function_name": "STRING",
|
||||
"lineno": "STRING",
|
||||
"original_function": "STRING",
|
||||
"function_embedding": "STRING",
|
||||
}
|
||||
"""
|
||||
prediction = element.get(self.predictions_key)[0]['outputs']
|
||||
element[self.function_embedding_key] = ','.join([
|
||||
str(val).decode('utf-8') for val in prediction
|
||||
])
|
||||
|
||||
element['lineno'] = str(element['lineno']).decode('utf-8')
|
||||
|
||||
for key in self.pop_keys:
|
||||
element.pop(key)
|
||||
|
||||
yield element
|
||||
|
|
@ -0,0 +1,126 @@
|
|||
"""Beam DoFns specific to `code_search.dataflow.transforms.github_dataset`."""
|
||||
|
||||
import logging
|
||||
import apache_beam as beam
|
||||
from apache_beam import pvalue
|
||||
|
||||
|
||||
class SplitRepoPath(beam.DoFn):
|
||||
"""Update element keys to separate repo path and file path.
|
||||
|
||||
This DoFn's only purpose is to be used after
|
||||
`code_search.dataflow.transforms.github_bigquery.ReadGithubDataset`
|
||||
to split the source dictionary key into two target keys.
|
||||
"""
|
||||
|
||||
@property
|
||||
def source_key(self):
|
||||
return u'repo_path'
|
||||
|
||||
@property
|
||||
def target_keys(self):
|
||||
return [u'nwo', u'path']
|
||||
|
||||
def process(self, element, *_args, **_kwargs):
|
||||
"""Process Python file attributes.
|
||||
|
||||
This simple DoFn splits the `repo_path` into
|
||||
independent properties of owner (`nwo`) and
|
||||
relative file path (`path`). This value is
|
||||
space-delimited and split over the first space
|
||||
is enough.
|
||||
|
||||
Args:
|
||||
element: A Python dict of the form,
|
||||
{
|
||||
"repo_path": "STRING",
|
||||
"content": "STRING",
|
||||
}
|
||||
|
||||
Yields:
|
||||
An updated Python dict of the form,
|
||||
{
|
||||
"nwo": "STRING",
|
||||
"path": "STRING",
|
||||
"content": "STRING",
|
||||
}
|
||||
"""
|
||||
values = element.pop(self.source_key).split(' ', 1)
|
||||
|
||||
for key, value in zip(self.target_keys, values):
|
||||
element[key] = value
|
||||
|
||||
yield element
|
||||
|
||||
|
||||
class TokenizeFunctionDocstrings(beam.DoFn):
|
||||
"""Tokenize function and docstrings.
|
||||
|
||||
This DoFn takes in the rows from BigQuery and tokenizes
|
||||
the file content present in the content key. This
|
||||
yields an updated dictionary with the new tokenized
|
||||
data in the pairs key.
|
||||
"""
|
||||
|
||||
@property
|
||||
def content_key(self):
|
||||
return 'content'
|
||||
|
||||
@property
|
||||
def info_keys(self):
|
||||
return [
|
||||
u'function_name',
|
||||
u'lineno',
|
||||
u'original_function',
|
||||
u'function_tokens',
|
||||
u'docstring_tokens',
|
||||
]
|
||||
|
||||
def process(self, element, *_args, **_kwargs):
|
||||
"""Get list of Function-Docstring tokens
|
||||
|
||||
This processes each Python file's content
|
||||
and returns a list of metadata for each extracted
|
||||
pair. These contain the tokenized functions and
|
||||
docstrings. In cases where the tokenization fails,
|
||||
a side output is returned. All values are unicode
|
||||
for serialization.
|
||||
|
||||
Args:
|
||||
element: A Python dict of the form,
|
||||
{
|
||||
"nwo": "STRING",
|
||||
"path": "STRING",
|
||||
"content": "STRING",
|
||||
}
|
||||
|
||||
Yields:
|
||||
A Python list of the form,
|
||||
[
|
||||
{
|
||||
"nwo": "STRING",
|
||||
"path": "STRING",
|
||||
"function_name": "STRING",
|
||||
"lineno": "STRING",
|
||||
"original_function": "STRING",
|
||||
"function_tokens": "STRING",
|
||||
"docstring_tokens": "STRING",
|
||||
},
|
||||
...
|
||||
]
|
||||
"""
|
||||
try:
|
||||
import code_search.dataflow.utils as utils
|
||||
|
||||
content_blob = element.pop(self.content_key)
|
||||
pairs = utils.get_function_docstring_pairs(content_blob)
|
||||
|
||||
result = [
|
||||
dict(zip(self.info_keys, pair_tuple), **element)
|
||||
for pair_tuple in pairs
|
||||
]
|
||||
|
||||
yield result
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
logging.warning('Tokenization failed, %s', e.message)
|
||||
yield pvalue.TaggedOutput('err', element)
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
import apache_beam as beam
|
||||
import apache_beam.io.gcp.bigquery as bigquery
|
||||
import apache_beam.io.gcp.internal.clients as clients
|
||||
|
||||
|
||||
|
|
@ -10,10 +11,12 @@ class BigQueryRead(beam.PTransform):
|
|||
string.
|
||||
"""
|
||||
|
||||
def __init__(self, project):
|
||||
def __init__(self, project, dataset=None, table=None):
|
||||
super(BigQueryRead, self).__init__()
|
||||
|
||||
self.project = project
|
||||
self.dataset = dataset
|
||||
self.table = table
|
||||
|
||||
@property
|
||||
def limit(self):
|
||||
|
|
@ -47,12 +50,14 @@ class BigQueryWrite(beam.PTransform):
|
|||
]
|
||||
"""
|
||||
|
||||
def __init__(self, project, dataset, table, batch_size=500):
|
||||
def __init__(self, project, dataset, table, batch_size=500,
|
||||
write_disposition=bigquery.BigQueryDisposition.WRITE_TRUNCATE):
|
||||
super(BigQueryWrite, self).__init__()
|
||||
|
||||
self.project = project
|
||||
self.dataset = dataset
|
||||
self.table = table
|
||||
self.write_disposition = write_disposition
|
||||
self.batch_size = batch_size
|
||||
|
||||
@property
|
||||
|
|
@ -69,7 +74,8 @@ class BigQueryWrite(beam.PTransform):
|
|||
dataset=self.dataset,
|
||||
table=self.table,
|
||||
schema=self.output_schema,
|
||||
batch_size=self.batch_size)
|
||||
batch_size=self.batch_size,
|
||||
write_disposition=self.write_disposition)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
import apache_beam as beam
|
||||
import kubeflow_batch_predict.dataflow.batch_prediction as batch_prediction
|
||||
|
||||
import code_search.dataflow.do_fns.function_embeddings as func_embeddings
|
||||
import code_search.dataflow.transforms.github_bigquery as github_bigquery
|
||||
|
||||
|
||||
class FunctionEmbeddings(beam.PTransform):
|
||||
"""Batch Prediction for Github dataset.
|
||||
|
||||
This Beam pipeline takes in the transformed dataset,
|
||||
prepares each element's function tokens for prediction
|
||||
by encoding it into base64 format and returns an updated
|
||||
dictionary element with the embedding for further processing.
|
||||
"""
|
||||
|
||||
def __init__(self, project, target_dataset, problem, data_dir, saved_model_dir):
|
||||
super(FunctionEmbeddings, self).__init__()
|
||||
|
||||
self.project = project
|
||||
self.target_dataset = target_dataset
|
||||
self.problem = problem
|
||||
self.data_dir = data_dir
|
||||
self.saved_model_dir = saved_model_dir
|
||||
|
||||
def expand(self, input_or_inputs):
|
||||
batch_predict = (input_or_inputs
|
||||
| "Encoded Function Tokens" >> beam.ParDo(func_embeddings.EncodeFunctionTokens(
|
||||
self.problem, self.data_dir))
|
||||
| "Compute Function Embeddings" >> beam.ParDo(batch_prediction.PredictionDoFn(),
|
||||
self.saved_model_dir).with_outputs('err',
|
||||
main='main')
|
||||
)
|
||||
|
||||
predictions = batch_predict.main
|
||||
|
||||
formatted_predictions = (predictions
|
||||
| "Process Function Embeddings" >> beam.ParDo(func_embeddings.ProcessFunctionEmbedding())
|
||||
)
|
||||
|
||||
(formatted_predictions # pylint: disable=expression-not-assigned
|
||||
| "Save Function Embeddings" >> github_bigquery.WriteGithubFunctionEmbeddings(
|
||||
self.project, self.target_dataset)
|
||||
)
|
||||
|
||||
return formatted_predictions
|
||||
|
|
@ -0,0 +1,143 @@
|
|||
import apache_beam.io.gcp.bigquery as bigquery
|
||||
import code_search.dataflow.transforms.bigquery as bq_transform
|
||||
|
||||
|
||||
# Default internal table names
|
||||
PAIRS_TABLE = 'token_pairs'
|
||||
FAILED_TOKENIZE_TABLE = 'failed_tokenize'
|
||||
FUNCTION_EMBEDDINGS_TABLE = 'function_embeddings'
|
||||
|
||||
|
||||
class ReadGithubDataset(bq_transform.BigQueryRead):
|
||||
"""Read original Github files from BigQuery.
|
||||
|
||||
This utility Transform reads Python files
|
||||
from a BigQuery public dump which are smaller
|
||||
than 15k lines of code, contain at least one
|
||||
function definition and its repository has been
|
||||
watched at least twice since 2017.
|
||||
|
||||
NOTE: Make sure to modify the `self.limit` property
|
||||
as desired.
|
||||
"""
|
||||
|
||||
@property
|
||||
def limit(self):
|
||||
# return 500
|
||||
return None
|
||||
|
||||
@property
|
||||
def query_string(self):
|
||||
query = """
|
||||
SELECT
|
||||
MAX(CONCAT(f.repo_name, ' ', f.path)) AS repo_path,
|
||||
c.content
|
||||
FROM
|
||||
`bigquery-public-data.github_repos.files` AS f
|
||||
JOIN
|
||||
`bigquery-public-data.github_repos.contents` AS c
|
||||
ON
|
||||
f.id = c.id
|
||||
JOIN (
|
||||
--this part of the query makes sure repo is watched at least twice since 2017
|
||||
SELECT
|
||||
repo
|
||||
FROM (
|
||||
SELECT
|
||||
repo.name AS repo
|
||||
FROM
|
||||
`githubarchive.year.2017`
|
||||
WHERE
|
||||
type="WatchEvent"
|
||||
UNION ALL
|
||||
SELECT
|
||||
repo.name AS repo
|
||||
FROM
|
||||
`githubarchive.month.2018*`
|
||||
WHERE
|
||||
type="WatchEvent" )
|
||||
GROUP BY
|
||||
1
|
||||
HAVING
|
||||
COUNT(*) >= 2 ) AS r
|
||||
ON
|
||||
f.repo_name = r.repo
|
||||
WHERE
|
||||
f.path LIKE '%.py' AND --with python extension
|
||||
c.size < 15000 AND --get rid of ridiculously long files
|
||||
REGEXP_CONTAINS(c.content, r'def ') --contains function definition
|
||||
GROUP BY
|
||||
c.content
|
||||
"""
|
||||
|
||||
if self.limit:
|
||||
query += '\nLIMIT {}'.format(self.limit)
|
||||
return query
|
||||
|
||||
|
||||
class WriteFailedTokenizedData(bq_transform.BigQueryWrite):
|
||||
@property
|
||||
def column_list(self):
|
||||
return [
|
||||
('nwo', 'STRING'),
|
||||
('path', 'STRING'),
|
||||
('content', 'STRING')
|
||||
]
|
||||
|
||||
|
||||
class WriteTokenizedData(bq_transform.BigQueryWrite):
|
||||
@property
|
||||
def column_list(self):
|
||||
return [
|
||||
('nwo', 'STRING'),
|
||||
('path', 'STRING'),
|
||||
('function_name', 'STRING'),
|
||||
('lineno', 'STRING'),
|
||||
('original_function', 'STRING'),
|
||||
('function_tokens', 'STRING'),
|
||||
('docstring_tokens', 'STRING'),
|
||||
]
|
||||
|
||||
|
||||
class ReadTransformedGithubDataset(bq_transform.BigQueryRead):
|
||||
|
||||
def __init__(self, project, dataset=None, table=PAIRS_TABLE):
|
||||
super(ReadTransformedGithubDataset, self).__init__(project, dataset=dataset, table=table)
|
||||
|
||||
@property
|
||||
def limit(self):
|
||||
# return 500
|
||||
return None
|
||||
|
||||
@property
|
||||
def query_string(self):
|
||||
query = """
|
||||
SELECT
|
||||
nwo, path, function_name, lineno, original_function, function_tokens, docstring_tokens
|
||||
FROM
|
||||
{}.{}
|
||||
""".format(self.dataset, self.table)
|
||||
|
||||
if self.limit:
|
||||
query += '\nLIMIT {}'.format(self.limit)
|
||||
return query
|
||||
|
||||
|
||||
class WriteGithubFunctionEmbeddings(bq_transform.BigQueryWrite):
|
||||
|
||||
def __init__(self, project, dataset, table=FUNCTION_EMBEDDINGS_TABLE, batch_size=500,
|
||||
write_disposition=bigquery.BigQueryDisposition.WRITE_TRUNCATE):
|
||||
super(WriteGithubFunctionEmbeddings, self).__init__(project, dataset, table,
|
||||
batch_size=batch_size,
|
||||
write_disposition=write_disposition)
|
||||
|
||||
@property
|
||||
def column_list(self):
|
||||
return [
|
||||
('nwo', 'STRING'),
|
||||
('path', 'STRING'),
|
||||
('function_name', 'STRING'),
|
||||
('lineno', 'STRING'),
|
||||
('original_function', 'STRING'),
|
||||
('function_embedding', 'STRING')
|
||||
]
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
import apache_beam as beam
|
||||
|
||||
import code_search.dataflow.do_fns.github_dataset as gh_do_fns
|
||||
import code_search.dataflow.transforms.github_bigquery as gh_bq
|
||||
|
||||
|
||||
class TransformGithubDataset(beam.PTransform):
|
||||
"""Transform the BigQuery Github Dataset.
|
||||
|
||||
This is a Beam Pipeline which reads the Github Dataset from
|
||||
BigQuery, tokenizes functions and docstrings in Python files,
|
||||
and dumps into a new BigQuery dataset for further processing.
|
||||
All tiny docstrings (smaller than `self.min_docstring_tokens`)
|
||||
are filtered out.
|
||||
|
||||
This transform creates following tables in the `target_dataset`
|
||||
which are defined as properties for easy modification.
|
||||
- `self.failed_tokenize_table`
|
||||
- `self.pairs_table`
|
||||
"""
|
||||
|
||||
def __init__(self, project, target_dataset,
|
||||
pairs_table=gh_bq.PAIRS_TABLE,
|
||||
failed_tokenize_table=gh_bq.FAILED_TOKENIZE_TABLE):
|
||||
super(TransformGithubDataset, self).__init__()
|
||||
|
||||
self.project = project
|
||||
self.target_dataset = target_dataset
|
||||
self.pairs_table = pairs_table
|
||||
self.failed_tokenize_table = failed_tokenize_table
|
||||
|
||||
@property
|
||||
def min_docstring_tokens(self):
|
||||
return 5
|
||||
|
||||
def expand(self, input_or_inputs):
|
||||
tokenize_result = (input_or_inputs
|
||||
| "Split 'repo_path'" >> beam.ParDo(gh_do_fns.SplitRepoPath())
|
||||
| "Tokenize Code/Docstring Pairs" >> beam.ParDo(
|
||||
gh_do_fns.TokenizeFunctionDocstrings()).with_outputs('err', main='rows')
|
||||
)
|
||||
|
||||
pairs, tokenize_errors = tokenize_result.rows, tokenize_result.err
|
||||
|
||||
(tokenize_errors # pylint: disable=expression-not-assigned
|
||||
| "Failed Tokenization" >> gh_bq.WriteFailedTokenizedData(self.project, self.target_dataset,
|
||||
self.failed_tokenize_table)
|
||||
)
|
||||
|
||||
flat_rows = (pairs
|
||||
| "Flatten Rows" >> beam.FlatMap(lambda x: x)
|
||||
| "Filter Tiny Docstrings" >> beam.Filter(
|
||||
lambda row: len(row['docstring_tokens'].split(' ')) > self.min_docstring_tokens)
|
||||
)
|
||||
|
||||
(flat_rows # pylint: disable=expression-not-assigned
|
||||
| "Save Tokens" >> gh_bq.WriteTokenizedData(self.project, self.target_dataset,
|
||||
self.pairs_table)
|
||||
)
|
||||
|
||||
return flat_rows
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
import ast
|
||||
import astor
|
||||
import nltk.tokenize as tokenize
|
||||
import spacy
|
||||
|
||||
|
||||
def tokenize_docstring(text):
|
||||
"""Tokenize docstrings.
|
||||
|
||||
Args:
|
||||
text: A docstring to be tokenized.
|
||||
|
||||
Returns:
|
||||
A list of strings representing the tokens in the docstring.
|
||||
"""
|
||||
en = spacy.load('en')
|
||||
tokens = en.tokenizer(text.decode('utf8'))
|
||||
return [token.text.lower() for token in tokens if not token.is_space]
|
||||
|
||||
|
||||
def tokenize_code(text):
|
||||
"""Tokenize code strings.
|
||||
|
||||
This simply considers whitespaces as token delimiters.
|
||||
|
||||
Args:
|
||||
text: A code string to be tokenized.
|
||||
|
||||
Returns:
|
||||
A list of strings representing the tokens in the code.
|
||||
"""
|
||||
return tokenize.RegexpTokenizer(r'\w+').tokenize(text)
|
||||
|
||||
|
||||
def get_function_docstring_pairs(blob):
|
||||
"""Extract (function/method, docstring) pairs from a given code blob.
|
||||
|
||||
This method reads a string representing a Python file, builds an
|
||||
abstract syntax tree (AST) and returns a list of Docstring and Function
|
||||
pairs along with supporting metadata.
|
||||
|
||||
Args:
|
||||
blob: A string representing the Python file contents.
|
||||
|
||||
Returns:
|
||||
A list of tuples of the form:
|
||||
[
|
||||
(
|
||||
function_name,
|
||||
lineno,
|
||||
original_function,
|
||||
function_tokens,
|
||||
docstring_tokens
|
||||
),
|
||||
...
|
||||
]
|
||||
"""
|
||||
pairs = []
|
||||
try:
|
||||
module = ast.parse(blob)
|
||||
classes = [node for node in module.body if isinstance(node, ast.ClassDef)]
|
||||
functions = [node for node in module.body if isinstance(node, ast.FunctionDef)]
|
||||
for _class in classes:
|
||||
functions.extend([node for node in _class.body if isinstance(node, ast.FunctionDef)])
|
||||
|
||||
for f in functions:
|
||||
source = astor.to_source(f)
|
||||
docstring = ast.get_docstring(f) if ast.get_docstring(f) else ''
|
||||
func = source.replace(ast.get_docstring(f, clean=False), '') if docstring else source
|
||||
pair_tuple = (
|
||||
f.name.decode('utf-8'),
|
||||
str(f.lineno).decode('utf-8'),
|
||||
source.decode('utf-8'),
|
||||
' '.join(tokenize_code(func)).decode('utf-8'),
|
||||
' '.join(tokenize_docstring(docstring.split('\n\n')[0])).decode('utf-8'),
|
||||
)
|
||||
pairs.append(pair_tuple)
|
||||
except (AssertionError, MemoryError, SyntaxError, UnicodeEncodeError):
|
||||
pass
|
||||
return pairs
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
from code_search.do_fns.github_files import ExtractFuncInfo
|
||||
from code_search.do_fns.github_files import TokenizeCodeDocstring
|
||||
from code_search.do_fns.github_files import SplitRepoPath
|
||||
|
|
@ -1,76 +0,0 @@
|
|||
"""Beam DoFns for prediction related tasks"""
|
||||
import io
|
||||
import csv
|
||||
from cStringIO import StringIO
|
||||
import apache_beam as beam
|
||||
from code_search.transforms.process_github_files import ProcessGithubFiles
|
||||
from code_search.t2t.query import get_encoder, encode_query
|
||||
|
||||
class GithubCSVToDict(beam.DoFn):
|
||||
"""Split a text row and convert into a dict."""
|
||||
|
||||
def process(self, element): # pylint: disable=no-self-use
|
||||
element = element.encode('utf-8')
|
||||
row = StringIO(element)
|
||||
reader = csv.reader(row, delimiter=',')
|
||||
|
||||
keys = ProcessGithubFiles.get_key_list()
|
||||
values = next(reader) # pylint: disable=stop-iteration-return
|
||||
|
||||
result = dict(zip(keys, values))
|
||||
yield result
|
||||
|
||||
|
||||
class GithubDictToCSV(beam.DoFn):
|
||||
"""Convert dictionary to writable CSV string."""
|
||||
|
||||
def process(self, element): # pylint: disable=no-self-use
|
||||
element['function_embedding'] = ','.join(str(val) for val in element['function_embedding'])
|
||||
|
||||
target_keys = ['nwo', 'path', 'function_name', 'function_embedding']
|
||||
target_values = [element[key].encode('utf-8') for key in target_keys]
|
||||
|
||||
with io.BytesIO() as fs:
|
||||
cw = csv.writer(fs)
|
||||
cw.writerow(target_values)
|
||||
result_str = fs.getvalue().strip('\r\n')
|
||||
|
||||
return result_str
|
||||
|
||||
|
||||
class EncodeExample(beam.DoFn):
|
||||
"""Encode string to integer tokens.
|
||||
|
||||
This is needed so that the data can be sent in
|
||||
for prediction
|
||||
"""
|
||||
def __init__(self, problem, data_dir):
|
||||
super(EncodeExample, self).__init__()
|
||||
|
||||
self.problem = problem
|
||||
self.data_dir = data_dir
|
||||
|
||||
def process(self, element):
|
||||
encoder = get_encoder(self.problem, self.data_dir)
|
||||
encoded_function = encode_query(encoder, element['function_tokens'])
|
||||
|
||||
element['instances'] = [{'input': {'b64': encoded_function}}]
|
||||
yield element
|
||||
|
||||
|
||||
class ProcessPrediction(beam.DoFn):
|
||||
"""Process results from PredictionDoFn.
|
||||
|
||||
This class processes predictions from another
|
||||
DoFn, to make sure it is a correctly formatted dict.
|
||||
"""
|
||||
def process(self, element): # pylint: disable=no-self-use
|
||||
element['function_embedding'] = ','.join([
|
||||
str(val) for val in element['predictions'][0]['outputs']
|
||||
])
|
||||
|
||||
element.pop('function_tokens')
|
||||
element.pop('instances')
|
||||
element.pop('predictions')
|
||||
|
||||
yield element
|
||||
|
|
@ -1,79 +0,0 @@
|
|||
"""Beam DoFns for Github related tasks"""
|
||||
import time
|
||||
import logging
|
||||
import apache_beam as beam
|
||||
from apache_beam import pvalue
|
||||
from apache_beam.metrics import Metrics
|
||||
|
||||
|
||||
class SplitRepoPath(beam.DoFn):
|
||||
# pylint: disable=abstract-method
|
||||
"""Split the space-delimited file `repo_path` into owner repository (`nwo`)
|
||||
and file path (`path`)"""
|
||||
|
||||
def process(self, element): # pylint: disable=no-self-use
|
||||
nwo, path = element.pop('repo_path').split(' ', 1)
|
||||
element['nwo'] = nwo
|
||||
element['path'] = path
|
||||
yield element
|
||||
|
||||
|
||||
class TokenizeCodeDocstring(beam.DoFn):
|
||||
# pylint: disable=abstract-method
|
||||
"""Compute code/docstring pairs from incoming BigQuery row dict"""
|
||||
def __init__(self):
|
||||
super(TokenizeCodeDocstring, self).__init__()
|
||||
|
||||
self.tokenization_time_ms = Metrics.counter(self.__class__, 'tokenization_time_ms')
|
||||
|
||||
def process(self, element): # pylint: disable=no-self-use
|
||||
try:
|
||||
import code_search.utils as utils
|
||||
|
||||
start_time = time.time()
|
||||
element['pairs'] = utils.get_function_docstring_pairs(element.pop('content'))
|
||||
self.tokenization_time_ms.inc(int((time.time() - start_time) * 1000.0))
|
||||
|
||||
yield element
|
||||
except Exception as e: #pylint: disable=broad-except
|
||||
logging.warning('Tokenization failed, %s', e.message)
|
||||
yield pvalue.TaggedOutput('err_rows', element)
|
||||
|
||||
|
||||
class ExtractFuncInfo(beam.DoFn):
|
||||
# pylint: disable=abstract-method
|
||||
"""Convert pair tuples to dict.
|
||||
|
||||
This takes a list of values from `TokenizeCodeDocstring`
|
||||
and converts into a dictionary so that values can be
|
||||
indexed by names instead of indices. `info_keys` is the
|
||||
list of names of those values in order which will become
|
||||
the keys of each new dict.
|
||||
"""
|
||||
def __init__(self, info_keys):
|
||||
super(ExtractFuncInfo, self).__init__()
|
||||
|
||||
self.info_keys = info_keys
|
||||
|
||||
def process(self, element):
|
||||
try:
|
||||
info_rows = [dict(zip(self.info_keys, pair)) for pair in element.pop('pairs')]
|
||||
info_rows = [self.merge_two_dicts(info_dict, element) for info_dict in info_rows]
|
||||
info_rows = map(self.dict_to_unicode, info_rows)
|
||||
yield info_rows
|
||||
except Exception as e: #pylint: disable=broad-except
|
||||
logging.warning('Function Info extraction failed, %s', e.message)
|
||||
yield pvalue.TaggedOutput('err_rows', element)
|
||||
|
||||
@staticmethod
|
||||
def merge_two_dicts(dict_a, dict_b):
|
||||
result = dict_a.copy()
|
||||
result.update(dict_b)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def dict_to_unicode(data_dict):
|
||||
for k, v in data_dict.items():
|
||||
if isinstance(v, str):
|
||||
data_dict[k] = v.decode('utf-8', 'ignore')
|
||||
return data_dict
|
||||
|
|
@ -1,54 +0,0 @@
|
|||
import apache_beam as beam
|
||||
import kubeflow_batch_predict.dataflow.batch_prediction as batch_prediction
|
||||
|
||||
import code_search.do_fns.embeddings as embeddings
|
||||
import code_search.transforms.github_bigquery as github_bigquery
|
||||
|
||||
|
||||
class GithubBatchPredict(beam.PTransform):
|
||||
"""Batch Prediction for Github dataset"""
|
||||
|
||||
def __init__(self, project, problem, data_dir, saved_model_dir):
|
||||
super(GithubBatchPredict, self).__init__()
|
||||
|
||||
self.project = project
|
||||
self.problem = problem
|
||||
self.data_dir = data_dir
|
||||
self.saved_model_dir = saved_model_dir
|
||||
|
||||
##
|
||||
# Target dataset and table to store prediction outputs.
|
||||
# Non-configurable for now.
|
||||
#
|
||||
self.index_dataset = 'code_search'
|
||||
self.index_table = 'search_index'
|
||||
|
||||
self.batch_size = 100
|
||||
|
||||
def expand(self, input_or_inputs):
|
||||
rows = (input_or_inputs
|
||||
| "Read Processed Github Dataset" >> github_bigquery.ReadProcessedGithubData(self.project)
|
||||
)
|
||||
|
||||
batch_predict = (rows
|
||||
| "Prepare Encoded Input" >> beam.ParDo(embeddings.EncodeExample(self.problem,
|
||||
self.data_dir))
|
||||
| "Execute Predictions" >> beam.ParDo(batch_prediction.PredictionDoFn(),
|
||||
self.saved_model_dir).with_outputs("errors",
|
||||
main="main")
|
||||
)
|
||||
|
||||
predictions = batch_predict.main
|
||||
|
||||
formatted_predictions = (predictions
|
||||
| "Process Predictions" >> beam.ParDo(embeddings.ProcessPrediction())
|
||||
)
|
||||
|
||||
(formatted_predictions # pylint: disable=expression-not-assigned
|
||||
| "Save Index Data" >> github_bigquery.WriteGithubIndexData(self.project,
|
||||
self.index_dataset,
|
||||
self.index_table,
|
||||
batch_size=self.batch_size)
|
||||
)
|
||||
|
||||
return formatted_predictions
|
||||
|
|
@ -1,87 +0,0 @@
|
|||
import code_search.transforms.bigquery as bigquery
|
||||
|
||||
|
||||
class ReadOriginalGithubPythonData(bigquery.BigQueryRead):
|
||||
@property
|
||||
def limit(self):
|
||||
return None
|
||||
|
||||
@property
|
||||
def query_string(self):
|
||||
query = """
|
||||
SELECT
|
||||
MAX(CONCAT(f.repo_name, ' ', f.path)) AS repo_path,
|
||||
c.content
|
||||
FROM
|
||||
`bigquery-public-data.github_repos.files` AS f
|
||||
JOIN
|
||||
`bigquery-public-data.github_repos.contents` AS c
|
||||
ON
|
||||
f.id = c.id
|
||||
JOIN (
|
||||
--this part of the query makes sure repo is watched at least twice since 2017
|
||||
SELECT
|
||||
repo
|
||||
FROM (
|
||||
SELECT
|
||||
repo.name AS repo
|
||||
FROM
|
||||
`githubarchive.year.2017`
|
||||
WHERE
|
||||
type="WatchEvent"
|
||||
UNION ALL
|
||||
SELECT
|
||||
repo.name AS repo
|
||||
FROM
|
||||
`githubarchive.month.2018*`
|
||||
WHERE
|
||||
type="WatchEvent" )
|
||||
GROUP BY
|
||||
1
|
||||
HAVING
|
||||
COUNT(*) >= 2 ) AS r
|
||||
ON
|
||||
f.repo_name = r.repo
|
||||
WHERE
|
||||
f.path LIKE '%.py' AND --with python extension
|
||||
c.size < 15000 AND --get rid of ridiculously long files
|
||||
REGEXP_CONTAINS(c.content, r'def ') --contains function definition
|
||||
GROUP BY
|
||||
c.content
|
||||
"""
|
||||
|
||||
if self.limit:
|
||||
query += '\nLIMIT {}'.format(self.limit)
|
||||
return query
|
||||
|
||||
|
||||
class ReadProcessedGithubData(bigquery.BigQueryRead):
|
||||
@property
|
||||
def limit(self):
|
||||
return 100
|
||||
|
||||
@property
|
||||
def query_string(self):
|
||||
query = """
|
||||
SELECT
|
||||
nwo, path, function_name, lineno, original_function, function_tokens
|
||||
FROM
|
||||
code_search.function_docstrings
|
||||
"""
|
||||
|
||||
if self.limit:
|
||||
query += '\nLIMIT {}'.format(self.limit)
|
||||
return query
|
||||
|
||||
|
||||
class WriteGithubIndexData(bigquery.BigQueryWrite):
|
||||
@property
|
||||
def column_list(self):
|
||||
return [
|
||||
('nwo', 'STRING'),
|
||||
('path', 'STRING'),
|
||||
('function_name', 'STRING'),
|
||||
('lineno', 'INTEGER'),
|
||||
('original_function', 'STRING'),
|
||||
('function_embedding', 'STRING')
|
||||
]
|
||||
|
|
@ -1,133 +0,0 @@
|
|||
import io
|
||||
import csv
|
||||
import apache_beam as beam
|
||||
import apache_beam.io.gcp.internal.clients as clients
|
||||
|
||||
import code_search.do_fns as do_fns
|
||||
|
||||
|
||||
class ProcessGithubFiles(beam.PTransform):
|
||||
# pylint: disable=too-many-instance-attributes
|
||||
|
||||
"""A collection of `DoFn`s for Pipeline Transform. Reads the Github dataset from BigQuery
|
||||
and writes back the processed code-docstring pairs in a query-friendly format back to BigQuery
|
||||
table.
|
||||
"""
|
||||
data_columns = ['nwo', 'path', 'function_name', 'lineno', 'original_function',
|
||||
'function_tokens', 'docstring_tokens']
|
||||
data_types = ['STRING', 'STRING', 'STRING', 'INTEGER', 'STRING', 'STRING', 'STRING']
|
||||
|
||||
def __init__(self, project, query_string, output_string, storage_bucket):
|
||||
super(ProcessGithubFiles, self).__init__()
|
||||
|
||||
self.project = project
|
||||
self.query_string = query_string
|
||||
self.output_dataset, self.output_table = output_string.split(':')
|
||||
self.storage_bucket = storage_bucket
|
||||
|
||||
self.num_shards = 10
|
||||
|
||||
def expand(self, input_or_inputs):
|
||||
tokenize_result = (input_or_inputs
|
||||
| "Read Github Dataset" >> beam.io.Read(beam.io.BigQuerySource(query=self.query_string,
|
||||
use_standard_sql=True))
|
||||
| "Split 'repo_path'" >> beam.ParDo(do_fns.SplitRepoPath())
|
||||
| "Tokenize Code/Docstring Pairs" >> beam.ParDo(do_fns.TokenizeCodeDocstring())
|
||||
.with_outputs('err_rows', main='rows')
|
||||
)
|
||||
|
||||
#pylint: disable=expression-not-assigned
|
||||
(tokenize_result.err_rows
|
||||
| "Failed Row Tokenization" >> beam.io.WriteToBigQuery(project=self.project,
|
||||
dataset=self.output_dataset,
|
||||
table=self.output_table + '_failed',
|
||||
schema=self.create_failed_output_schema())
|
||||
)
|
||||
# pylint: enable=expression-not-assigned
|
||||
|
||||
|
||||
info_result = (tokenize_result.rows
|
||||
| "Extract Function Info" >> beam.ParDo(do_fns.ExtractFuncInfo(self.data_columns[2:]))
|
||||
.with_outputs('err_rows', main='rows')
|
||||
)
|
||||
|
||||
#pylint: disable=expression-not-assigned
|
||||
(info_result.err_rows
|
||||
| "Failed Function Info" >> beam.io.WriteToBigQuery(project=self.project,
|
||||
dataset=self.output_dataset,
|
||||
table=self.output_table + '_failed',
|
||||
schema=self.create_failed_output_schema())
|
||||
)
|
||||
# pylint: enable=expression-not-assigned
|
||||
|
||||
processed_rows = (info_result.rows | "Flatten Rows" >> beam.FlatMap(lambda x: x))
|
||||
|
||||
# pylint: disable=expression-not-assigned
|
||||
(processed_rows
|
||||
| "Filter Tiny Docstrings" >> beam.Filter(
|
||||
lambda row: len(row['docstring_tokens'].split(' ')) > 5)
|
||||
| "Format For Write" >> beam.Map(self.format_for_write)
|
||||
| "Write To File" >> beam.io.WriteToText('{}/data/pairs'.format(self.storage_bucket),
|
||||
file_name_suffix='.csv',
|
||||
num_shards=self.num_shards))
|
||||
# pylint: enable=expression-not-assigned
|
||||
|
||||
return (processed_rows
|
||||
| "Save Tokens" >> beam.io.WriteToBigQuery(project=self.project,
|
||||
dataset=self.output_dataset,
|
||||
table=self.output_table,
|
||||
schema=self.create_output_schema())
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_key_list():
|
||||
filter_keys = [
|
||||
'original_function',
|
||||
'lineno',
|
||||
]
|
||||
key_list = [col for col in ProcessGithubFiles.data_columns
|
||||
if col not in filter_keys]
|
||||
return key_list
|
||||
|
||||
def format_for_write(self, row):
|
||||
"""This method filters keys that we don't need in the
|
||||
final CSV. It must ensure that there are no multi-line
|
||||
column fields. For instance, 'original_function' is a
|
||||
multi-line string and makes CSV parsing hard for any
|
||||
derived Dataflow steps. This uses the CSV Writer
|
||||
to handle all edge cases like quote escaping."""
|
||||
|
||||
target_keys = self.get_key_list()
|
||||
target_values = [row[key].encode('utf-8') for key in target_keys]
|
||||
|
||||
with io.BytesIO() as fs:
|
||||
cw = csv.writer(fs)
|
||||
cw.writerow(target_values)
|
||||
result_str = fs.getvalue().strip('\r\n')
|
||||
|
||||
return result_str
|
||||
|
||||
def create_output_schema(self):
|
||||
table_schema = clients.bigquery.TableSchema()
|
||||
|
||||
for column, data_type in zip(self.data_columns, self.data_types):
|
||||
field_schema = clients.bigquery.TableFieldSchema()
|
||||
field_schema.name = column
|
||||
field_schema.type = data_type
|
||||
field_schema.mode = 'nullable'
|
||||
table_schema.fields.append(field_schema)
|
||||
|
||||
return table_schema
|
||||
|
||||
def create_failed_output_schema(self):
|
||||
table_schema = clients.bigquery.TableSchema()
|
||||
|
||||
for column, data_type in zip(self.data_columns[:2] + ['content'],
|
||||
self.data_types[:2] + ['STRING']):
|
||||
field_schema = clients.bigquery.TableFieldSchema()
|
||||
field_schema.name = column
|
||||
field_schema.type = data_type
|
||||
field_schema.mode = 'nullable'
|
||||
table_schema.fields.append(field_schema)
|
||||
|
||||
return table_schema
|
||||
|
|
@ -1,38 +0,0 @@
|
|||
import ast
|
||||
import astor
|
||||
import nltk.tokenize as tokenize
|
||||
import spacy
|
||||
|
||||
|
||||
def tokenize_docstring(text):
|
||||
"""Apply tokenization using spacy to docstrings."""
|
||||
en = spacy.load('en')
|
||||
tokens = en.tokenizer(text.decode('utf8', 'ignore'))
|
||||
return [token.text.lower() for token in tokens if not token.is_space]
|
||||
|
||||
|
||||
def tokenize_code(text):
|
||||
"""A very basic procedure for tokenizing code strings."""
|
||||
return tokenize.RegexpTokenizer(r'\w+').tokenize(text)
|
||||
|
||||
|
||||
def get_function_docstring_pairs(blob):
|
||||
"""Extract (function/method, docstring) pairs from a given code blob."""
|
||||
pairs = []
|
||||
try:
|
||||
module = ast.parse(blob)
|
||||
classes = [node for node in module.body if isinstance(node, ast.ClassDef)]
|
||||
functions = [node for node in module.body if isinstance(node, ast.FunctionDef)]
|
||||
for _class in classes:
|
||||
functions.extend([node for node in _class.body if isinstance(node, ast.FunctionDef)])
|
||||
|
||||
for f in functions:
|
||||
source = astor.to_source(f)
|
||||
docstring = ast.get_docstring(f) if ast.get_docstring(f) else ''
|
||||
func = source.replace(ast.get_docstring(f, clean=False), '') if docstring else source
|
||||
|
||||
pairs.append((f.name, f.lineno, source, ' '.join(tokenize_code(func)),
|
||||
' '.join(tokenize_docstring(docstring.split('\n\n')[0]))))
|
||||
except (AssertionError, MemoryError, SyntaxError, UnicodeEncodeError):
|
||||
pass
|
||||
return pairs
|
||||
|
|
@ -1,39 +0,0 @@
|
|||
SELECT
|
||||
MAX(CONCAT(f.repo_name, ' ', f.path)) AS repo_path,
|
||||
c.content
|
||||
FROM
|
||||
`bigquery-public-data.github_repos.files` AS f
|
||||
JOIN
|
||||
`bigquery-public-data.github_repos.contents` AS c
|
||||
ON
|
||||
f.id = c.id
|
||||
JOIN (
|
||||
--this part of the query makes sure repo is watched at least twice since 2017
|
||||
SELECT
|
||||
repo
|
||||
FROM (
|
||||
SELECT
|
||||
repo.name AS repo
|
||||
FROM
|
||||
`githubarchive.year.2017`
|
||||
WHERE
|
||||
type="WatchEvent"
|
||||
UNION ALL
|
||||
SELECT
|
||||
repo.name AS repo
|
||||
FROM
|
||||
`githubarchive.month.2018*`
|
||||
WHERE
|
||||
type="WatchEvent" )
|
||||
GROUP BY
|
||||
1
|
||||
HAVING
|
||||
COUNT(*) >= 2 ) AS r
|
||||
ON
|
||||
f.repo_name = r.repo
|
||||
WHERE
|
||||
f.path LIKE '%.py' AND --with python extension
|
||||
c.size < 15000 AND --get rid of ridiculously long files
|
||||
REGEXP_CONTAINS(c.content, r'def ') --contains function definition
|
||||
GROUP BY
|
||||
c.content
|
||||
|
|
@ -60,8 +60,6 @@ setup(name='code-search',
|
|||
},
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'code-search-preprocess=code_search.cli:create_github_pipeline',
|
||||
'code-search-predict=code_search.cli:create_batch_predict_pipeline',
|
||||
'nmslib-serve=code_search.nmslib.cli:server',
|
||||
'nmslib-create=code_search.nmslib.cli:creator',
|
||||
]
|
||||
|
|
|
|||
Loading…
Reference in New Issue