Fix ROC Component (#559)

* Fix ROC component.

* fix ROC component.

* Follow up on CR comments.
This commit is contained in:
qimingj 2018-12-17 23:03:38 -08:00 committed by Kubernetes Prow Robot
parent 302e93ce99
commit a23abf85fd
1 changed files with 16 additions and 7 deletions

View File

@ -32,20 +32,29 @@ from tensorflow.python.lib.io import file_io
def main(argv=None):
parser = argparse.ArgumentParser(description='ML Trainer')
parser.add_argument('--predictions', type=str, help='GCS path of prediction file pattern.')
parser.add_argument('--trueclass', type=str, help='The name of the class as true value.')
parser.add_argument('--trueclass', type=str, default='true',
help='The name of the class as true value. If missing, assuming it is ' +
'binary classification and default to "true".')
parser.add_argument('--true_score_column', type=str, default='true',
help='The name of the column for positive prob. If missing, assuming it is ' +
'binary classification and defaults to "true".')
parser.add_argument('--target_lambda', type=str,
help='a lambda function as a string to determine positive or negative.' +
'For example, "lambda x: x[\'a\'] and x[\'b\']". If missing, ' +
'trueclass must be set and input must have a "target" column.')
'input must have a "target" column.')
parser.add_argument('--output', type=str, help='GCS path of the output directory.')
args = parser.parse_args()
if not args.target_lambda and not args.trueclass:
raise ValueError('Either target_lambda or trueclass must be set.')
schema_file = os.path.join(os.path.dirname(args.predictions), 'schema.json')
schema = json.loads(file_io.read_file_to_string(schema_file))
names = [x['name'] for x in schema]
if not args.target_lambda and 'target' not in names:
raise ValueError('There is no "target" column, and target_lambda is not provided.')
if args.true_score_column not in names:
raise ValueError('Cannot find column name "%s"' % args.true_score_column)
dfs = []
files = file_io.get_matching_files(args.predictions)
for file in files:
@ -57,8 +66,8 @@ def main(argv=None):
df['target'] = df.apply(eval(args.target_lambda), axis=1)
else:
df['target'] = df['target'].apply(lambda x: 1 if x == args.trueclass else 0)
fpr, tpr, thresholds = roc_curve(df['target'], df[args.trueclass])
roc_auc = roc_auc_score(df['target'], df[args.trueclass])
fpr, tpr, thresholds = roc_curve(df['target'], df[args.true_score_column])
roc_auc = roc_auc_score(df['target'], df[args.true_score_column])
df_roc = pd.DataFrame({'fpr': fpr, 'tpr': tpr, 'thresholds': thresholds})
roc_file = os.path.join(args.output, 'roc.csv')
with file_io.FileIO(roc_file, 'w') as f: