Fix ROC Component (#559)
* Fix ROC component. * fix ROC component. * Follow up on CR comments.
This commit is contained in:
parent
302e93ce99
commit
a23abf85fd
|
|
@ -32,20 +32,29 @@ from tensorflow.python.lib.io import file_io
|
|||
def main(argv=None):
|
||||
parser = argparse.ArgumentParser(description='ML Trainer')
|
||||
parser.add_argument('--predictions', type=str, help='GCS path of prediction file pattern.')
|
||||
parser.add_argument('--trueclass', type=str, help='The name of the class as true value.')
|
||||
parser.add_argument('--trueclass', type=str, default='true',
|
||||
help='The name of the class as true value. If missing, assuming it is ' +
|
||||
'binary classification and default to "true".')
|
||||
parser.add_argument('--true_score_column', type=str, default='true',
|
||||
help='The name of the column for positive prob. If missing, assuming it is ' +
|
||||
'binary classification and defaults to "true".')
|
||||
parser.add_argument('--target_lambda', type=str,
|
||||
help='a lambda function as a string to determine positive or negative.' +
|
||||
'For example, "lambda x: x[\'a\'] and x[\'b\']". If missing, ' +
|
||||
'trueclass must be set and input must have a "target" column.')
|
||||
'input must have a "target" column.')
|
||||
parser.add_argument('--output', type=str, help='GCS path of the output directory.')
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.target_lambda and not args.trueclass:
|
||||
raise ValueError('Either target_lambda or trueclass must be set.')
|
||||
|
||||
schema_file = os.path.join(os.path.dirname(args.predictions), 'schema.json')
|
||||
schema = json.loads(file_io.read_file_to_string(schema_file))
|
||||
names = [x['name'] for x in schema]
|
||||
|
||||
if not args.target_lambda and 'target' not in names:
|
||||
raise ValueError('There is no "target" column, and target_lambda is not provided.')
|
||||
|
||||
if args.true_score_column not in names:
|
||||
raise ValueError('Cannot find column name "%s"' % args.true_score_column)
|
||||
|
||||
dfs = []
|
||||
files = file_io.get_matching_files(args.predictions)
|
||||
for file in files:
|
||||
|
|
@ -57,8 +66,8 @@ def main(argv=None):
|
|||
df['target'] = df.apply(eval(args.target_lambda), axis=1)
|
||||
else:
|
||||
df['target'] = df['target'].apply(lambda x: 1 if x == args.trueclass else 0)
|
||||
fpr, tpr, thresholds = roc_curve(df['target'], df[args.trueclass])
|
||||
roc_auc = roc_auc_score(df['target'], df[args.trueclass])
|
||||
fpr, tpr, thresholds = roc_curve(df['target'], df[args.true_score_column])
|
||||
roc_auc = roc_auc_score(df['target'], df[args.true_score_column])
|
||||
df_roc = pd.DataFrame({'fpr': fpr, 'tpr': tpr, 'thresholds': thresholds})
|
||||
roc_file = os.path.join(args.output, 'roc.csv')
|
||||
with file_io.FileIO(roc_file, 'w') as f:
|
||||
|
|
|
|||
Loading…
Reference in New Issue