diff --git a/components/local/roc/src/roc.py b/components/local/roc/src/roc.py index ac0f736d38..9de1edf244 100644 --- a/components/local/roc/src/roc.py +++ b/components/local/roc/src/roc.py @@ -32,20 +32,29 @@ from tensorflow.python.lib.io import file_io def main(argv=None): parser = argparse.ArgumentParser(description='ML Trainer') parser.add_argument('--predictions', type=str, help='GCS path of prediction file pattern.') - parser.add_argument('--trueclass', type=str, help='The name of the class as true value.') + parser.add_argument('--trueclass', type=str, default='true', + help='The name of the class as true value. If missing, assuming it is ' + + 'binary classification and default to "true".') + parser.add_argument('--true_score_column', type=str, default='true', + help='The name of the column for positive prob. If missing, assuming it is ' + + 'binary classification and defaults to "true".') parser.add_argument('--target_lambda', type=str, help='a lambda function as a string to determine positive or negative.' + 'For example, "lambda x: x[\'a\'] and x[\'b\']". If missing, ' + - 'trueclass must be set and input must have a "target" column.') + 'input must have a "target" column.') parser.add_argument('--output', type=str, help='GCS path of the output directory.') args = parser.parse_args() - if not args.target_lambda and not args.trueclass: - raise ValueError('Either target_lambda or trueclass must be set.') - schema_file = os.path.join(os.path.dirname(args.predictions), 'schema.json') schema = json.loads(file_io.read_file_to_string(schema_file)) names = [x['name'] for x in schema] + + if not args.target_lambda and 'target' not in names: + raise ValueError('There is no "target" column, and target_lambda is not provided.') + + if args.true_score_column not in names: + raise ValueError('Cannot find column name "%s"' % args.true_score_column) + dfs = [] files = file_io.get_matching_files(args.predictions) for file in files: @@ -57,8 +66,8 @@ def main(argv=None): df['target'] = df.apply(eval(args.target_lambda), axis=1) else: df['target'] = df['target'].apply(lambda x: 1 if x == args.trueclass else 0) - fpr, tpr, thresholds = roc_curve(df['target'], df[args.trueclass]) - roc_auc = roc_auc_score(df['target'], df[args.trueclass]) + fpr, tpr, thresholds = roc_curve(df['target'], df[args.true_score_column]) + roc_auc = roc_auc_score(df['target'], df[args.true_score_column]) df_roc = pd.DataFrame({'fpr': fpr, 'tpr': tpr, 'thresholds': thresholds}) roc_file = os.path.join(args.output, 'roc.csv') with file_io.FileIO(roc_file, 'w') as f: