Fix ROC Component (#559)
* Fix ROC component. * fix ROC component. * Follow up on CR comments.
This commit is contained in:
parent
302e93ce99
commit
a23abf85fd
|
|
@ -32,20 +32,29 @@ from tensorflow.python.lib.io import file_io
|
||||||
def main(argv=None):
|
def main(argv=None):
|
||||||
parser = argparse.ArgumentParser(description='ML Trainer')
|
parser = argparse.ArgumentParser(description='ML Trainer')
|
||||||
parser.add_argument('--predictions', type=str, help='GCS path of prediction file pattern.')
|
parser.add_argument('--predictions', type=str, help='GCS path of prediction file pattern.')
|
||||||
parser.add_argument('--trueclass', type=str, help='The name of the class as true value.')
|
parser.add_argument('--trueclass', type=str, default='true',
|
||||||
|
help='The name of the class as true value. If missing, assuming it is ' +
|
||||||
|
'binary classification and default to "true".')
|
||||||
|
parser.add_argument('--true_score_column', type=str, default='true',
|
||||||
|
help='The name of the column for positive prob. If missing, assuming it is ' +
|
||||||
|
'binary classification and defaults to "true".')
|
||||||
parser.add_argument('--target_lambda', type=str,
|
parser.add_argument('--target_lambda', type=str,
|
||||||
help='a lambda function as a string to determine positive or negative.' +
|
help='a lambda function as a string to determine positive or negative.' +
|
||||||
'For example, "lambda x: x[\'a\'] and x[\'b\']". If missing, ' +
|
'For example, "lambda x: x[\'a\'] and x[\'b\']". If missing, ' +
|
||||||
'trueclass must be set and input must have a "target" column.')
|
'input must have a "target" column.')
|
||||||
parser.add_argument('--output', type=str, help='GCS path of the output directory.')
|
parser.add_argument('--output', type=str, help='GCS path of the output directory.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if not args.target_lambda and not args.trueclass:
|
|
||||||
raise ValueError('Either target_lambda or trueclass must be set.')
|
|
||||||
|
|
||||||
schema_file = os.path.join(os.path.dirname(args.predictions), 'schema.json')
|
schema_file = os.path.join(os.path.dirname(args.predictions), 'schema.json')
|
||||||
schema = json.loads(file_io.read_file_to_string(schema_file))
|
schema = json.loads(file_io.read_file_to_string(schema_file))
|
||||||
names = [x['name'] for x in schema]
|
names = [x['name'] for x in schema]
|
||||||
|
|
||||||
|
if not args.target_lambda and 'target' not in names:
|
||||||
|
raise ValueError('There is no "target" column, and target_lambda is not provided.')
|
||||||
|
|
||||||
|
if args.true_score_column not in names:
|
||||||
|
raise ValueError('Cannot find column name "%s"' % args.true_score_column)
|
||||||
|
|
||||||
dfs = []
|
dfs = []
|
||||||
files = file_io.get_matching_files(args.predictions)
|
files = file_io.get_matching_files(args.predictions)
|
||||||
for file in files:
|
for file in files:
|
||||||
|
|
@ -57,8 +66,8 @@ def main(argv=None):
|
||||||
df['target'] = df.apply(eval(args.target_lambda), axis=1)
|
df['target'] = df.apply(eval(args.target_lambda), axis=1)
|
||||||
else:
|
else:
|
||||||
df['target'] = df['target'].apply(lambda x: 1 if x == args.trueclass else 0)
|
df['target'] = df['target'].apply(lambda x: 1 if x == args.trueclass else 0)
|
||||||
fpr, tpr, thresholds = roc_curve(df['target'], df[args.trueclass])
|
fpr, tpr, thresholds = roc_curve(df['target'], df[args.true_score_column])
|
||||||
roc_auc = roc_auc_score(df['target'], df[args.trueclass])
|
roc_auc = roc_auc_score(df['target'], df[args.true_score_column])
|
||||||
df_roc = pd.DataFrame({'fpr': fpr, 'tpr': tpr, 'thresholds': thresholds})
|
df_roc = pd.DataFrame({'fpr': fpr, 'tpr': tpr, 'thresholds': thresholds})
|
||||||
roc_file = os.path.join(args.output, 'roc.csv')
|
roc_file = os.path.join(args.output, 'roc.csv')
|
||||||
with file_io.FileIO(roc_file, 'w') as f:
|
with file_io.FileIO(roc_file, 'w') as f:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue