pipelines/components/arena/docker/job_generator.py

117 lines
2.8 KiB
Python

import argparse
import datetime
import json
import os
import sys
import logging
import requests
import subprocess
import six
import time
import yaml
from subprocess import Popen,PIPE
from shlex import split
from utils import *
# Generate common options
def generate_options(args):
gpus = args.gpus
cpu = args.cpu
memory = args.memory
tensorboard = args.tensorboard
output_data = args.output_data
data = args.data
env = args.env
tensorboard_image = args.tensorboard_image
tensorboard = str2bool(args.tensorboard)
log_dir = args.log_dir
sync_source = args.sync_source
options = []
if gpus > 0:
options.extend(['--gpus', str(gpus)])
if cpu != '0':
options.extend(['--cpu', str(cpu)])
if memory != '0':
options.extend(['--memory', str(memory)])
if tensorboard_image != "tensorflow/tensorflow:1.12.0":
options.extend(['--tensorboardImage', tensorboard_image])
if tensorboard:
options.append("--tensorboard")
if os.path.isdir(args.log_dir):
options.extend(['--logdir', args.log_dir])
else:
logging.info("skip log dir :{0}".format(args.log_dir))
if len(data) > 0:
for d in data:
if ":" in d:
options.append("--data={0}".format(d))
else:
logging.info("--data={0} is illegal, skip.".format(d))
if len(env) > 0:
for e in env:
if "=" in e:
options.append("--env={0}".format(e))
else:
logging.info("--env={0} is illegal, skip.".format(e))
if len(args.workflow_name) > 0:
options.append("--env=WORKFLOW_NAME={0}".format(args.workflow_name))
if len(args.step_name) > 0:
options.append("--env=STEP_NAME={0}".format(args.step_name))
if len(sync_source) > 0:
if not sync_source.endswith(".git"):
raise ValueError("sync_source must be an http git url")
options.extend(['--sync-mode','git'])
options.extend(['--sync-source',sync_source])
return options
# Generate standalone job
def generate_job_command(args):
name = args.name
image = args.image
commandArray = [
'arena', 'submit', 'tfjob',
'--name={0}'.format(name),
'--image={0}'.format(image),
]
commandArray.extend(generate_options(args))
return commandArray, "tfjob"
# Generate mpi job
def generate_mpjob_command(args):
name = args.name
workers = args.workers
image = args.image
rdma = args.rdma
commandArray = [
'arena', 'submit', 'mpijob',
'--name={0}'.format(name),
'--workers={0}'.format(workers),
'--image={0}'.format(image),
]
if rdma.lower() == "true":
commandArray.append("--rdma")
commandArray.extend(generate_options(args))
return commandArray, "mpijob"