feat(components): Adding Trainer component for PyTorch - KFP (#5767)

* Create README.md

Initial Readme

* Create README.md

Initial commit for the PyTorch pipeline examples

* Update README.md

* Adding PyTorch training component

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Adding PyTorch - trainer unit tests

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Fixing UT and adding setup.py

Signed-off-by: ankan94 <ankan@ideas2it.com>

* update setup.py, trainer compopents pyfiles,applied pylyintrc

Signed-off-by: Arvind-Ideas2IT <arvindkumarsingh.gautam@ideas2it.com>

* adressed the review comments,applied pylint, black

Signed-off-by: Arvind-Ideas2IT <arvindkumarsingh.gautam@ideas2it.com>

* added copyright headers in tests files

Signed-off-by: Arvind-Ideas2IT <arvindkumarsingh.gautam@ideas2it.com>

* updated base component

Signed-off-by: Arvind-Ideas2IT <arvindkumarsingh.gautam@ideas2it.com>

* Changing to Apache2 License

Reviewed license change with Legal team

* Switch to Apache2 License

* Add tox for test automation with presubmit script

* Add model archiver to dependencies

* Cleanup setup.py

* Cleanup

* Cleanup

* Cleanup

* Using common fixture for unit tests and adding gpu fix for saving the model

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Changing to Apache2 License

Signed-off-by: Arvind-Ideas2IT <arvindkumarsingh.gautam@ideas2it.com>

* resolve typo

Signed-off-by: Arvind-Ideas2IT <arvindkumarsingh.gautam@ideas2it.com>

* Added OWNERS

* OWENRS will be added as separate PR

* fix detect_version function in setup.py

Signed-off-by: Arvind-Ideas2IT <arvindkumarsingh.gautam@ideas2it.com>

* Cleanup for setup description field warnings

Co-authored-by: Geeta Chauhan <4461127+chauhang@users.noreply.github.com>
Co-authored-by: ankan94 <ankan@ideas2it.com>
Co-authored-by: Arvind-Ideas2IT <arvindkumarsingh.gautam@ideas2it.com>
Co-authored-by: Geeta Chauhan <gchauhan@fb.com>
This commit is contained in:
shrinath-suresh 2021-06-08 09:42:56 +05:30 committed by GitHub
parent b08b29f46e
commit c5325db7d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 2346 additions and 1 deletions

View File

@ -0,0 +1,7 @@
*lightning_logs
build
dist
.tox
*.egg-info
.coverag*
.pytest*

View File

@ -0,0 +1,439 @@
# This Pylint rcfile contains a best-effort configuration to uphold the
# best-practices and style described in the Google Python style guide:
# https://google.github.io/styleguide/pyguide.html
#
# Its canonical open-source location is:
# https://google.github.io/styleguide/pylintrc
[MASTER]
# Files or directories to be skipped. They should be base names, not paths.
ignore=third_party
# Files or directories matching the regex patterns are skipped. The regex
# matches against base names, not paths.
ignore-patterns=
# Pickle collected data for later comparisons.
persistent=no
# List of plugins (as comma separated values of python modules names) to load,
# usually to register additional checkers.
load-plugins=
# Use multiple processes to speed up Pylint.
jobs=4
# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no
[MESSAGES CONTROL]
# Only show warnings with the listed confidence levels. Leave empty to show
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
confidence=
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
#enable=
# Disable the message, report, category or checker with the given id(s). You
# can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration
# file where it should appear only once).You can also use "--disable=all" to
# disable everything first and then reenable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use"--disable=all --enable=classes
# --disable=W"
disable=abstract-method,
apply-builtin,
arguments-differ,
attribute-defined-outside-init,
backtick,
bad-option-value,
basestring-builtin,
buffer-builtin,
c-extension-no-member,
consider-using-enumerate,
cmp-builtin,
cmp-method,
coerce-builtin,
coerce-method,
delslice-method,
div-method,
duplicate-code,
eq-without-hash,
execfile-builtin,
file-builtin,
filter-builtin-not-iterating,
fixme,
getslice-method,
global-statement,
hex-method,
idiv-method,
implicit-str-concat-in-sequence,
import-error,
import-self,
import-star-module-level,
inconsistent-return-statements,
input-builtin,
intern-builtin,
invalid-str-codec,
locally-disabled,
long-builtin,
long-suffix,
map-builtin-not-iterating,
misplaced-comparison-constant,
missing-function-docstring,
metaclass-assignment,
next-method-called,
next-method-defined,
no-absolute-import,
no-else-break,
no-else-continue,
no-else-raise,
no-else-return,
no-init, # added
no-member,
no-name-in-module,
no-self-use,
nonzero-method,
oct-method,
old-division,
old-ne-operator,
old-octal-literal,
old-raise-syntax,
parameter-unpacking,
print-statement,
raising-string,
range-builtin-not-iterating,
raw_input-builtin,
rdiv-method,
reduce-builtin,
relative-import,
reload-builtin,
round-builtin,
setslice-method,
signature-differs,
standarderror-builtin,
suppressed-message,
sys-max-int,
too-few-public-methods,
too-many-ancestors,
too-many-arguments,
too-many-boolean-expressions,
too-many-branches,
too-many-instance-attributes,
too-many-locals,
too-many-nested-blocks,
too-many-public-methods,
too-many-return-statements,
too-many-statements,
trailing-newlines,
unichr-builtin,
unicode-builtin,
unnecessary-pass,
unpacking-in-except,
useless-else-on-loop,
useless-object-inheritance,
useless-suppression,
using-cmp-argument,
wrong-import-order,
xrange-builtin,
zip-builtin-not-iterating,
[REPORTS]
# Set the output format. Available formats are text, parseable, colorized, msvs
# (visual studio) and html. You can also give a reporter class, eg
# mypackage.mymodule.MyReporterClass.
output-format=text
# Put messages in a separate file for each module / package specified on the
# command line instead of printing them on stdout. Reports (if any) will be
# written in a file name "pylint_global.[txt|html]". This option is deprecated
# and it will be removed in Pylint 2.0.
files-output=no
# Tells whether to display a full report or only the messages
reports=no
# Python expression which should return a note less than 10 (10 is the highest
# note). You have access to the variables errors warning, statement which
# respectively contain the number of errors / warnings messages and the total
# number of statements analyzed. This is used by the global evaluation report
# (RP0004).
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
# Template used to display messages. This is a python new-style format string
# used to format the message information. See doc for all details
#msg-template=
[BASIC]
# Good variable names which should always be accepted, separated by a comma
good-names=main,_
# Bad variable names which should always be refused, separated by a comma
bad-names=
# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
name-group=
# Include a hint for the correct naming format with invalid-name
include-naming-hint=no
# List of decorators that produce properties, such as abc.abstractproperty. Add
# to this list to register other decorators that produce valid properties.
property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl
# Regular expression matching correct function names
function-rgx=^(?:(?P<exempt>setUp|tearDown|setUpModule|tearDownModule)|(?P<camel_case>_?[A-Z][a-zA-Z0-9]*)|(?P<snake_case>_?[a-z][a-z0-9_]*))$
# Regular expression matching correct variable names
variable-rgx=^[a-z][a-z0-9_]*$
# Regular expression matching correct constant names
const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
# Regular expression matching correct attribute names
attr-rgx=^_{0,2}[a-z][a-z0-9_]*$
# Regular expression matching correct argument names
argument-rgx=^[a-z][a-z0-9_]*$
# Regular expression matching correct class attribute names
class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
# Regular expression matching correct inline iteration names
inlinevar-rgx=^[a-z][a-z0-9_]*$
# Regular expression matching correct class names
class-rgx=^_?[A-Z][a-zA-Z0-9]*$
# Regular expression matching correct module names
module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$
# Regular expression matching correct method names
method-rgx=(?x)^(?:(?P<exempt>_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P<camel_case>_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P<snake_case>_{0,2}[a-z][a-z0-9_]*))$
# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$
# Minimum line length for functions/classes that require docstrings, shorter
# ones are exempt.
docstring-min-length=10
[TYPECHECK]
# List of decorators that produce context managers, such as
# contextlib.contextmanager. Add to this list to register other decorators that
# produce valid context managers.
contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager
# Tells whether missing members accessed in mixin class should be ignored. A
# mixin class is detected if its name ends with "mixin" (case insensitive).
ignore-mixin-members=yes
# List of module names for which member attributes should not be checked
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis. It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=
# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local
# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
# expressions are accepted.
generated-members=
[FORMAT]
# Maximum number of characters on a single line.
max-line-length=80
# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt
# lines made too long by directives to pytype.
# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=(?x)(
^\s*(\#\ )?<?https?://\S+>?$|
^\s*(from\s+\S+\s+)?import\s+.+$)
# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=yes
# List of optional constructs for which whitespace checking is disabled. `dict-
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
# `empty-line` allows space-only lines.
no-space-check=
# Maximum number of lines in a module
max-module-lines=99999
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
# tab).
indent-string=' '
# Number of spaces of indent required inside a hanging or continued line.
indent-after-paren=4
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
expected-line-ending-format=
[MISCELLANEOUS]
# List of note tags to take in consideration, separated by a comma.
notes=TODO
[STRING]
# This flag controls whether inconsistent-quotes generates a warning when the
# character used as a quote delimiter is used inconsistently within a module.
check-quote-consistency=yes
[VARIABLES]
# Tells whether we should check for unused import in __init__ files.
init-import=no
# A regular expression matching the name of dummy variables (i.e. expectedly
# not used).
dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_)
# List of additional names supposed to be defined in builtins. Remember that
# you should avoid to define new builtins when possible.
additional-builtins=
# List of strings which can identify a callback function by name. A callback
# name must start or end with one of those strings.
callbacks=cb_,_cb
# List of qualified module names which can have objects that can redefine
# builtins.
redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools
[LOGGING]
# Logging modules to check that the string format arguments are in logging
# function parameter format
logging-modules=logging,absl.logging,tensorflow.io.logging
[SIMILARITIES]
# Minimum lines number of a similarity.
min-similarity-lines=4
# Ignore comments when computing similarities.
ignore-comments=yes
# Ignore docstrings when computing similarities.
ignore-docstrings=yes
# Ignore imports when computing similarities.
ignore-imports=no
[SPELLING]
# Spelling dictionary name. Available dictionaries: none. To make it working
# install python-enchant package.
spelling-dict=
# List of comma separated words that should not be checked.
spelling-ignore-words=
# A path to a file that contains private dictionary; one word per line.
spelling-private-dict-file=
# Tells whether to store unknown words to indicated private dictionary in
# --spelling-private-dict-file option instead of raising a message.
spelling-store-unknown-words=no
[IMPORTS]
# Deprecated modules which should not be used, separated by a comma
deprecated-modules=regsub,
TERMIOS,
Bastion,
rexec,
sets
# Create a graph of every (i.e. internal and external) dependencies in the
# given file (report RP0402 must not be disabled)
import-graph=
# Create a graph of external dependencies in the given file (report RP0402 must
# not be disabled)
ext-import-graph=
# Create a graph of internal dependencies in the given file (report RP0402 must
# not be disabled)
int-import-graph=
# Force import order to recognize a module as part of the standard
# compatibility libraries.
known-standard-library=
# Force import order to recognize a module as part of a third party library.
known-third-party=enchant, absl
# Analyse import fallback blocks. This can be used to support both Python 2 and
# 3 compatible code, which means that the block might have code that exists
# only in one or another interpreter, leading to false positives when analysed.
analyse-fallback-blocks=no
[CLASSES]
# List of method names used to declare (i.e. assign) instance attributes.
defining-attr-methods=__init__,
__new__,
setUp
# List of member names, which should be excluded from the protected access
# warning.
exclude-protected=_asdict,
_fields,
_replace,
_source,
_make
# List of valid names for the first argument in a class method.
valid-classmethod-first-arg=cls,
class_
# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=mcs
[EXCEPTIONS]
# Exceptions that will emit a warning when being caught. Defaults to
# "Exception"
overgeneral-exceptions=StandardError,
Exception,
BaseException

View File

@ -0,0 +1,190 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
Copyright (c) 2021 Facebook, Inc. and its affiliates.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@ -0,0 +1,23 @@
# PyTorch Kubeflow Pipeline Components
PyTorch Kubeflow Pipeline Components provides an SDK and a set of components that lets you build kubeflow pipelines using PyTorch. You can use the predefined components in this repository to build your pipeline using the Kubeflow Pipelines SDK.
## Installation
### Requirements
Python >= 3.6
Kubeflow cluster setuo (on-prem or in any of the Clouds)
### Install latest release
Use the following command to install PyTorch Pipeline Components from PyPI.
```
pip install -U pytorch-kfp-components
```
### Install from source
Use the following commands to install PyTorch Kubeflow Pipeline Components from GitHub.
```
git clone https://github.com/kubeflow/pipelines.git
pip install pipelines/components/PyTorch/pytorch_kfp_components/.
```

View File

@ -0,0 +1,6 @@
[build-system]
requires = [
"setuptools>=42",
"wheel"
]
build-backend = "setuptools.build_meta"

View File

@ -0,0 +1,439 @@
# This Pylint rcfile contains a best-effort configuration to uphold the
# best-practices and style described in the Google Python style guide:
# https://google.github.io/styleguide/pyguide.html
#
# Its canonical open-source location is:
# https://google.github.io/styleguide/pylintrc
[MASTER]
# Files or directories to be skipped. They should be base names, not paths.
ignore=third_party
# Files or directories matching the regex patterns are skipped. The regex
# matches against base names, not paths.
ignore-patterns=
# Pickle collected data for later comparisons.
persistent=no
# List of plugins (as comma separated values of python modules names) to load,
# usually to register additional checkers.
load-plugins=
# Use multiple processes to speed up Pylint.
jobs=4
# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no
[MESSAGES CONTROL]
# Only show warnings with the listed confidence levels. Leave empty to show
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
confidence=
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
#enable=
# Disable the message, report, category or checker with the given id(s). You
# can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration
# file where it should appear only once).You can also use "--disable=all" to
# disable everything first and then reenable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use"--disable=all --enable=classes
# --disable=W"
disable=abstract-method,
apply-builtin,
arguments-differ,
attribute-defined-outside-init,
backtick,
bad-option-value,
basestring-builtin,
buffer-builtin,
c-extension-no-member,
consider-using-enumerate,
cmp-builtin,
cmp-method,
coerce-builtin,
coerce-method,
delslice-method,
div-method,
duplicate-code,
eq-without-hash,
execfile-builtin,
file-builtin,
filter-builtin-not-iterating,
fixme,
getslice-method,
global-statement,
hex-method,
idiv-method,
implicit-str-concat-in-sequence,
import-error,
import-self,
import-star-module-level,
inconsistent-return-statements,
input-builtin,
intern-builtin,
invalid-str-codec,
locally-disabled,
long-builtin,
long-suffix,
map-builtin-not-iterating,
misplaced-comparison-constant,
missing-function-docstring,
metaclass-assignment,
next-method-called,
next-method-defined,
no-absolute-import,
no-else-break,
no-else-continue,
no-else-raise,
no-else-return,
no-init, # added
no-member,
no-name-in-module,
no-self-use,
nonzero-method,
oct-method,
old-division,
old-ne-operator,
old-octal-literal,
old-raise-syntax,
parameter-unpacking,
print-statement,
raising-string,
range-builtin-not-iterating,
raw_input-builtin,
rdiv-method,
reduce-builtin,
relative-import,
reload-builtin,
round-builtin,
setslice-method,
signature-differs,
standarderror-builtin,
suppressed-message,
sys-max-int,
too-few-public-methods,
too-many-ancestors,
too-many-arguments,
too-many-boolean-expressions,
too-many-branches,
too-many-instance-attributes,
too-many-locals,
too-many-nested-blocks,
too-many-public-methods,
too-many-return-statements,
too-many-statements,
trailing-newlines,
unichr-builtin,
unicode-builtin,
unnecessary-pass,
unpacking-in-except,
useless-else-on-loop,
useless-object-inheritance,
useless-suppression,
using-cmp-argument,
wrong-import-order,
xrange-builtin,
zip-builtin-not-iterating,
[REPORTS]
# Set the output format. Available formats are text, parseable, colorized, msvs
# (visual studio) and html. You can also give a reporter class, eg
# mypackage.mymodule.MyReporterClass.
output-format=text
# Put messages in a separate file for each module / package specified on the
# command line instead of printing them on stdout. Reports (if any) will be
# written in a file name "pylint_global.[txt|html]". This option is deprecated
# and it will be removed in Pylint 2.0.
files-output=no
# Tells whether to display a full report or only the messages
reports=no
# Python expression which should return a note less than 10 (10 is the highest
# note). You have access to the variables errors warning, statement which
# respectively contain the number of errors / warnings messages and the total
# number of statements analyzed. This is used by the global evaluation report
# (RP0004).
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
# Template used to display messages. This is a python new-style format string
# used to format the message information. See doc for all details
#msg-template=
[BASIC]
# Good variable names which should always be accepted, separated by a comma
good-names=main,_
# Bad variable names which should always be refused, separated by a comma
bad-names=
# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
name-group=
# Include a hint for the correct naming format with invalid-name
include-naming-hint=no
# List of decorators that produce properties, such as abc.abstractproperty. Add
# to this list to register other decorators that produce valid properties.
property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl
# Regular expression matching correct function names
function-rgx=^(?:(?P<exempt>setUp|tearDown|setUpModule|tearDownModule)|(?P<camel_case>_?[A-Z][a-zA-Z0-9]*)|(?P<snake_case>_?[a-z][a-z0-9_]*))$
# Regular expression matching correct variable names
variable-rgx=^[a-z][a-z0-9_]*$
# Regular expression matching correct constant names
const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
# Regular expression matching correct attribute names
attr-rgx=^_{0,2}[a-z][a-z0-9_]*$
# Regular expression matching correct argument names
argument-rgx=^[a-z][a-z0-9_]*$
# Regular expression matching correct class attribute names
class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
# Regular expression matching correct inline iteration names
inlinevar-rgx=^[a-z][a-z0-9_]*$
# Regular expression matching correct class names
class-rgx=^_?[A-Z][a-zA-Z0-9]*$
# Regular expression matching correct module names
module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$
# Regular expression matching correct method names
method-rgx=(?x)^(?:(?P<exempt>_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P<camel_case>_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P<snake_case>_{0,2}[a-z][a-z0-9_]*))$
# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$
# Minimum line length for functions/classes that require docstrings, shorter
# ones are exempt.
docstring-min-length=10
[TYPECHECK]
# List of decorators that produce context managers, such as
# contextlib.contextmanager. Add to this list to register other decorators that
# produce valid context managers.
contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager
# Tells whether missing members accessed in mixin class should be ignored. A
# mixin class is detected if its name ends with "mixin" (case insensitive).
ignore-mixin-members=yes
# List of module names for which member attributes should not be checked
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis. It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=
# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local
# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
# expressions are accepted.
generated-members=
[FORMAT]
# Maximum number of characters on a single line.
max-line-length=80
# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt
# lines made too long by directives to pytype.
# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=(?x)(
^\s*(\#\ )?<?https?://\S+>?$|
^\s*(from\s+\S+\s+)?import\s+.+$)
# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=yes
# List of optional constructs for which whitespace checking is disabled. `dict-
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
# `empty-line` allows space-only lines.
no-space-check=
# Maximum number of lines in a module
max-module-lines=99999
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
# tab).
indent-string=' '
# Number of spaces of indent required inside a hanging or continued line.
indent-after-paren=4
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
expected-line-ending-format=
[MISCELLANEOUS]
# List of note tags to take in consideration, separated by a comma.
notes=TODO
[STRING]
# This flag controls whether inconsistent-quotes generates a warning when the
# character used as a quote delimiter is used inconsistently within a module.
check-quote-consistency=yes
[VARIABLES]
# Tells whether we should check for unused import in __init__ files.
init-import=no
# A regular expression matching the name of dummy variables (i.e. expectedly
# not used).
dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_)
# List of additional names supposed to be defined in builtins. Remember that
# you should avoid to define new builtins when possible.
additional-builtins=
# List of strings which can identify a callback function by name. A callback
# name must start or end with one of those strings.
callbacks=cb_,_cb
# List of qualified module names which can have objects that can redefine
# builtins.
redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools
[LOGGING]
# Logging modules to check that the string format arguments are in logging
# function parameter format
logging-modules=logging,absl.logging,tensorflow.io.logging
[SIMILARITIES]
# Minimum lines number of a similarity.
min-similarity-lines=4
# Ignore comments when computing similarities.
ignore-comments=yes
# Ignore docstrings when computing similarities.
ignore-docstrings=yes
# Ignore imports when computing similarities.
ignore-imports=no
[SPELLING]
# Spelling dictionary name. Available dictionaries: none. To make it working
# install python-enchant package.
spelling-dict=
# List of comma separated words that should not be checked.
spelling-ignore-words=
# A path to a file that contains private dictionary; one word per line.
spelling-private-dict-file=
# Tells whether to store unknown words to indicated private dictionary in
# --spelling-private-dict-file option instead of raising a message.
spelling-store-unknown-words=no
[IMPORTS]
# Deprecated modules which should not be used, separated by a comma
deprecated-modules=regsub,
TERMIOS,
Bastion,
rexec,
sets
# Create a graph of every (i.e. internal and external) dependencies in the
# given file (report RP0402 must not be disabled)
import-graph=
# Create a graph of external dependencies in the given file (report RP0402 must
# not be disabled)
ext-import-graph=
# Create a graph of internal dependencies in the given file (report RP0402 must
# not be disabled)
int-import-graph=
# Force import order to recognize a module as part of the standard
# compatibility libraries.
known-standard-library=
# Force import order to recognize a module as part of a third party library.
known-third-party=enchant, absl
# Analyse import fallback blocks. This can be used to support both Python 2 and
# 3 compatible code, which means that the block might have code that exists
# only in one or another interpreter, leading to false positives when analysed.
analyse-fallback-blocks=no
[CLASSES]
# List of method names used to declare (i.e. assign) instance attributes.
defining-attr-methods=__init__,
__new__,
setUp
# List of member names, which should be excluded from the protected access
# warning.
exclude-protected=_asdict,
_fields,
_replace,
_source,
_make
# List of valid names for the first argument in a class method.
valid-classmethod-first-arg=cls,
class_
# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=mcs
[EXCEPTIONS]
# Exceptions that will emit a warning when being caught. Defaults to
# "Exception"
overgeneral-exceptions=StandardError,
Exception,
BaseException

View File

@ -0,0 +1,17 @@
#!/usr/bin/env python3
#
# Copyright (c) Facebook, Inc. and its affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Kubeflow Pipeline Root."""
__version__ = "0.1.1dev"

View File

@ -0,0 +1,115 @@
#!/usr/bin/env/python3
#
# Copyright (c) Facebook, Inc. and its affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pipeline Base component class."""
import abc
from six import with_metaclass
from pytorch_kfp_components.types import standard_component_specs
class BaseComponent(with_metaclass(abc.ABCMeta, object)): # pylint: disable=R0903
"""Pipeline Base component class."""
def __init__(self):
pass
@classmethod
def _validate_spec(
cls,
spec: standard_component_specs,
input_dict: dict,
output_dict: dict,
exec_properties: dict,
):
"""validate the specifications 'type'.
Args:
spec: The standard component specifications
input_dict : A dictionary of inputs.
ouput-dict :
exec_properties : A dict of execution properties.
"""
for key, value in input_dict.items():
cls._type_check(
actual_value=value, key=key, spec_dict=spec.INPUT_DICT
)
for key, value in output_dict.items():
cls._type_check(
actual_value=value, key=key, spec_dict=spec.OUTPUT_DICT
)
for key, value in exec_properties.items():
cls._type_check(
actual_value=value,
key=key,
spec_dict=spec.EXECUTION_PROPERTIES
)
@classmethod
def _optional_check(cls, actual_value: any, key: str, spec_dict: dict):
"""Checks for optional specification.
Args:
actual_value : Value of the dictionary.
key: key for the correspondin value.
spec_dict : The dict of specification for validation.
Returns :
is_optional : The optional key.
Raises :
ValueError : If the key is not optional
"""
is_optional = spec_dict[key].optional
if not is_optional and not actual_value:
raise ValueError(
"{key} is not optional. Received value: {actual_value}".format(
key=key, actual_value=actual_value
)
)
return is_optional
@classmethod
def _type_check(cls, actual_value, key, spec_dict):
"""Checks the type of specifactions.
Args:
actual_value : Value of the dictionary.
key: key for the correspondin value.
spec_dict : The dict of specification for validation.
Raises :
TypeError : If key value type does not match expected value type.
"""
if not actual_value:
is_optional = cls._optional_check(
actual_value=actual_value, key=key, spec_dict=spec_dict
)
if is_optional:
return
expected_type = spec_dict[key].type
actual_type = type(actual_value)
if actual_type != expected_type:
raise TypeError(
"{key} must be of type {expected_type} but received as {actual_type}"
.format(
key=key,
expected_type=expected_type,
actual_type=actual_type,
)
)

View File

@ -0,0 +1,43 @@
#!/usr/bin/env/python3
#
# Copyright (c) Facebook, Inc. and its affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pipeline Base Executor class."""
import abc
import logging
from six import with_metaclass
class BaseExecutor(with_metaclass(abc.ABCMeta, object)): # pylint: disable=R0903
"""Pipeline Base Executor abstract class."""
def __init__(self):
pass # pylint: disable=W0107
@abc.abstractmethod
def Do(self, input_dict: dict, output_dict: dict, exec_properties: dict): # pylint: disable=C0103
"""A Do function that does nothing."""
pass # pylint: disable=W0107
def _log_startup(
self, input_dict: dict, output_dict: dict, exec_properties
):
"""Log inputs, outputs, and executor properties in a standard
format."""
class_name = self.__class__.__name__
logging.debug("Starting %s execution.", class_name)
logging.debug("Inputs for %s are: %s .", class_name, input_dict)
logging.debug("Outputs for %s are: %s.", class_name, output_dict)
logging.debug(
"Execution Properties for %s are: %s",
class_name, exec_properties)

View File

@ -0,0 +1,86 @@
#!/usr/bin/env/python3
# Copyright (c) Facebook, Inc. and its affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training Component class."""
from typing import Optional, Dict
from pytorch_kfp_components.components.trainer.executor import Executor
from pytorch_kfp_components.components.base.base_component import BaseComponent
from pytorch_kfp_components.types import standard_component_specs
class Trainer(BaseComponent):
"""Initializes the Trainer class."""
def __init__( # pylint: disable=R0913
self,
module_file: Optional = None,
data_module_file: Optional = None,
data_module_args: Optional[Dict] = None,
module_file_args: Optional[Dict] = None,
trainer_args: Optional[Dict] = None,
):
"""Initializes the PyTorch Lightning training process.
Args:
module_file : Inherit the model class for training.
data_module_file : From this the data module class is inherited.
data_module_args : The arguments of the data module.
module_file_args : The arguments of the model class.
trainer_args : arguments specific to the PTL trainer.
Raises:
NotImplementedError : If mandatory args;
module_file or data_module_file is empty.
"""
super(Trainer, self).__init__() # pylint: disable=R1725
input_dict = {
standard_component_specs.TRAINER_MODULE_FILE: module_file,
standard_component_specs.TRAINER_DATA_MODULE_FILE: data_module_file,
}
output_dict = {}
exec_properties = {
standard_component_specs.TRAINER_DATA_MODULE_ARGS: data_module_args,
standard_component_specs.TRAINER_MODULE_ARGS: module_file_args,
standard_component_specs.PTL_TRAINER_ARGS: trainer_args,
}
spec = standard_component_specs.TrainerSpec()
self._validate_spec(
spec=spec,
input_dict=input_dict,
output_dict=output_dict,
exec_properties=exec_properties,
)
if module_file and data_module_file:
# Both module file and data module file are present
Executor().Do(
input_dict=input_dict,
output_dict=output_dict,
exec_properties=exec_properties,
)
self.ptl_trainer = output_dict.get(
standard_component_specs.PTL_TRAINER_OBJ, "None"
)
self.output_dict = output_dict
else:
raise NotImplementedError(
"Module file and Datamodule file are mandatory. "
"Custom training methods are yet to be implemented"
)

View File

@ -0,0 +1,122 @@
#!/usr/bin/env/python3
#
# Copyright (c) Facebook, Inc. and its affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training Executor class."""
import os
from argparse import Namespace
import pytorch_lightning as pl
import torch
from pytorch_kfp_components.components.trainer.generic_executor import (
GenericExecutor,
)
from pytorch_kfp_components.types import standard_component_specs
class Executor(GenericExecutor):
"""The Training Executor class."""
def __init__(self): # pylint:disable=useless-super-delegation
super().__init__()
def Do(self, input_dict: dict, output_dict: dict, exec_properties: dict):
"""This function of the Executor invokes the PyTorch Lightning training
loop.
Args:
input_dict : The dictionary of inputs.Example
: model file, data module file
output_dict :
exec_properties : A dict of execution properties
including data_module_args,
trainer_args, module_file_args
Returns:
trainer : The object of PyTorch-Lightning Trainer.
Raises:
ValueError : If both of module_file_arfs or trainer_args are empty.
TypeError : If the type of trainer_args is not dict.
NotImplementedError : If mandatory args;
module_file or data_module_file is empty.
"""
self._log_startup(
input_dict=input_dict,
output_dict=output_dict,
exec_properties=exec_properties,
)
(
module_file,
data_module_file,
trainer_args,
module_file_args,
data_module_args,
) = self._GetFnArgs(
input_dict=input_dict,
output_dict=output_dict,
execution_properties=exec_properties,
)
(
model_class,
data_module_class,
) = self.derive_model_and_data_module_class(
module_file=module_file, data_module_file=data_module_file
)
if data_module_class:
data_module = data_module_class(
**data_module_args if data_module_args else {}
)
data_module.prepare_data()
data_module.setup(stage="fit")
model = model_class(**module_file_args if module_file_args else {})
if (not module_file_args) and (not trainer_args):
raise ValueError("Module file & trainer args can't be empty")
if not isinstance(trainer_args, dict):
raise TypeError("trainer_args must be a dict")
trainer_args.update(module_file_args)
parser = Namespace(**trainer_args)
trainer = pl.Trainer.from_argparse_args(parser)
trainer.fit(model, data_module)
trainer.test()
if "checkpoint_dir" in module_file_args:
model_save_path = module_file_args["checkpoint_dir"]
else:
model_save_path = "/tmp"
if "model_name" in module_file_args:
model_name = module_file_args["model_name"]
else:
model_name = "model_state_dict.pth"
model_save_path = os.path.join(model_save_path, model_name)
if trainer.global_rank == 0:
print("Saving model to {}".format(model_save_path))
torch.save(model.state_dict(), model_save_path)
output_dict[standard_component_specs.TRAINER_MODEL_SAVE_PATH
] = model_save_path
output_dict[standard_component_specs.PTL_TRAINER_OBJ] = trainer
else:
raise NotImplementedError(
"Data module class is mandatory. "
"User defined training module is yet to be supported."
)

View File

@ -0,0 +1,115 @@
#!/usr/bin/env/python3
#
# Copyright (c) Facebook, Inc. and its affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generic Executor Class."""
import importlib
import inspect
from pytorch_kfp_components.components.base.base_executor import BaseExecutor
from pytorch_kfp_components.types import standard_component_specs
class GenericExecutor(BaseExecutor):
"""Generic Executor Class that does nothing."""
def Do(self, input_dict: dict, output_dict: dict, exec_properties: dict):
# TODO: Code to train pretrained model
pass
def _GetFnArgs(
self, input_dict: dict, output_dict: dict, execution_properties: dict
):
"""Gets the input/output/execution properties from the dictionary.
Args:
input_dict : The dictionary of inputs.Example :
model file, data module file
output_dict :
exec_properties : A dict of execution properties including
data_module_args,trainer_args, module_file_args
Returns:
module_file : The model file name
data_module_file : A data module file name
trainer_args: A dictionary of trainer args
module_file_args : A dictionary of model specific args
data_module_args : A dictionary of data module args.
"""
module_file = input_dict.get(
standard_component_specs.TRAINER_MODULE_FILE
)
data_module_file = input_dict.get(
standard_component_specs.TRAINER_DATA_MODULE_FILE
)
trainer_args = execution_properties.get(
standard_component_specs.PTL_TRAINER_ARGS
)
module_file_args = execution_properties.get(
standard_component_specs.TRAINER_MODULE_ARGS
)
data_module_args = execution_properties.get(
(standard_component_specs.TRAINER_DATA_MODULE_ARGS)
)
return (
module_file,
data_module_file,
trainer_args,
module_file_args,
data_module_args,
)
def derive_model_and_data_module_class(
self, module_file: str, data_module_file: str
):
"""Derives the model file and data modul file.
Args :
module_file : A model file name (type:str)
data_module_file : A data module file name (type:str)
Returns :
model_class : The model class
data_module_class : The data module class.
Raises :
ValueError: If the model file or data module file is empty.
"""
model_class = None
data_module_class = None
class_module = importlib.import_module(module_file.split(".")[0])
data_module = importlib.import_module(data_module_file.split(".")[0])
for cls in inspect.getmembers(
class_module,
lambda member: inspect.isclass(member) and member.__module__ ==
class_module.__name__,
):
model_class = cls[1]
if not model_class:
raise ValueError(f"Unable to load module_file - {module_file}")
for cls in inspect.getmembers(
data_module,
lambda member: inspect.isclass(member) and member.__module__ ==
data_module.__name__,
):
data_module_class = cls[1]
if not data_module_class:
raise ValueError(
f"Unable to load data_module_file - {data_module_file}"
)
return model_class, data_module_class

View File

@ -42,7 +42,6 @@ MINIO_BUCKET_NAME = "bucket_name"
MINIO_DESTINATION = "destination"
MINIO_ENDPOINT = "endpoint"
class Parameters: # pylint: disable=R0903
"""Parameter class to match the desired type."""

View File

@ -0,0 +1,2 @@
[metadata]
description_file = README.md

View File

@ -0,0 +1,106 @@
#!/usr/bin/env/python3
#
# Copyright (c) Facebook, Inc. and its affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Setup script."""
import importlib
import os
import types
from setuptools import setup, find_packages
def make_required_install_packages():
return [
"kfp>=1.6.1",
"torch>=1.7.1",
"torchserve>=0.3.0",
"torch-model-archiver",
"pytorch-lightning==1.3.2",
]
def make_required_test_packages():
return make_required_install_packages() + [
"mock>=4.0.0",
"flake8>=3.0.0",
"pylint",
"pytest>=6.0.0",
"wget",
"pandas",
"minio"
]
def make_dependency_links():
return []
def detect_version(base_path):
loader = importlib.machinery.SourceFileLoader(
fullname="version",
path=os.path.join(base_path,
"pytorch_kfp_components/__init__.py"),
)
version = types.ModuleType(loader.name)
loader.exec_module(version)
return version.__version__
if __name__ == "__main__":
relative_directory = os.path.relpath(
os.path.dirname(os.path.abspath(__file__)))
version = detect_version(relative_directory)
setup(
name="pytorch-kfp-components",
version=version,
description="PyTorch Kubeflow Pipeline",
url="https://github.com/kubeflow/pipelines/tree/master/components",
author="The PyTorch Kubeflow Pipeline Components authors",
author_email="pytorch-kfp-components@fb.com",
license="Apache License 2.0",
extras_require={"tests": make_required_test_packages()},
include_package_data=True,
python_requires=">=3.6",
install_requires=make_required_install_packages(),
dependency_links=make_dependency_links(),
keywords=[
"Kubeflow",
"ML workflow",
"PyTorch",
],
classifiers=[
"Development Status :: 3 - Alpha",
"Operating System :: Unix",
"Operating System :: MacOS",
"Intended Audience :: Developers",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3 :: Only",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
],
package_dir={
"pytorch_kfp_components":
os.path.join(relative_directory, "pytorch_kfp_components")
},
packages=find_packages(where=relative_directory),
)

View File

@ -0,0 +1,100 @@
#!/usr/bin/env/python3
#
# Copyright (c) Facebook, Inc. and its affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=W0221
# pylint: disable=W0613
# pylint: disable=W0223
from argparse import ArgumentParser
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning.metrics import Accuracy
class IrisClassification(pl.LightningModule):
def __init__(self, **kwargs):
super(IrisClassification, self).__init__()
self.train_acc = Accuracy()
self.val_acc = Accuracy()
self.test_acc = Accuracy()
self.args = kwargs
self.fc1 = nn.Linear(4, 10)
self.fc2 = nn.Linear(10, 10)
self.fc3 = nn.Linear(10, 3)
self.cross_entropy_loss = nn.CrossEntropyLoss()
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
return x
@staticmethod
def add_model_specific_args(parent_parser):
"""
Add model specific arguments like learning rate
:param parent_parser: Application specific parser
:return: Returns the augmented arugument parser
"""
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument(
"--lr",
type=float,
default=0.01,
metavar="LR",
help="learning rate (default: 0.001)",
)
return parser
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), self.args.get("lr", 0.01))
def training_step(self, batch, batch_idx):
x, y = batch
logits = self.forward(x)
_, y_hat = torch.max(logits, dim=1)
loss = self.cross_entropy_loss(logits, y)
self.train_acc(y_hat, y)
self.log(
"train_acc",
self.train_acc.compute(),
on_step=False,
on_epoch=True,
)
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self.forward(x)
_, y_hat = torch.max(logits, dim=1)
loss = F.cross_entropy(logits, y)
self.val_acc(y_hat, y)
self.log("val_acc", self.val_acc.compute())
self.log("val_loss", loss, sync_dist=True)
def test_step(self, batch, batch_idx):
x, y = batch
logits = self.forward(x)
_, y_hat = torch.max(logits, dim=1)
self.test_acc(y_hat, y)
self.log("test_acc", self.test_acc.compute())

View File

@ -0,0 +1,116 @@
#!/usr/bin/env/python3
#
# Copyright (c) Facebook, Inc. and its affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from argparse import ArgumentParser
import pytorch_lightning as pl
import torch
from pytorch_lightning import seed_everything
from sklearn.datasets import load_iris
from torch.utils.data import DataLoader, random_split, TensorDataset
class IrisDataModule(pl.LightningDataModule):
def __init__(self, **kwargs):
"""
Initialization of inherited lightning data module
"""
super(IrisDataModule, self).__init__()
self.train_set = None
self.val_set = None
self.test_set = None
self.args = kwargs
def prepare_data(self):
"""
Implementation of abstract class
"""
def setup(self, stage=None):
"""
Downloads the data, parse it and split the data into train, test, validation data
:param stage: Stage - training or testing
"""
iris = load_iris()
df = iris.data
target = iris["target"]
data = torch.Tensor(df).float()
labels = torch.Tensor(target).long()
RANDOM_SEED = 42
seed_everything(RANDOM_SEED)
data_set = TensorDataset(data, labels)
self.train_set, self.val_set = random_split(data_set, [130, 20])
self.train_set, self.test_set = random_split(self.train_set, [110, 20])
@staticmethod
def add_model_specific_args(parent_parser):
"""
Adds model specific arguments batch size and num workers
:param parent_parser: Application specific parser
:return: Returns the augmented arugument parser
"""
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument(
"--batch-size",
type=int,
default=128,
metavar="N",
help="input batch size for training (default: 16)",
)
parser.add_argument(
"--num-workers",
type=int,
default=3,
metavar="N",
help="number of workers (default: 3)",
)
return parser
def create_data_loader(self, dataset):
"""
Generic data loader function
:param data_set: Input data set
:return: Returns the constructed dataloader
"""
return DataLoader(
dataset,
batch_size=self.args.get("batch_size", 16),
num_workers=self.args.get("num_workers", 3),
)
def train_dataloader(self):
train_loader = self.create_data_loader(dataset=self.train_set)
return train_loader
def val_dataloader(self):
validation_loader = self.create_data_loader(dataset=self.val_set)
return validation_loader
def test_dataloader(self):
test_loader = self.create_data_loader(dataset=self.test_set)
return test_loader
if __name__ == "__main__":
pass

View File

@ -0,0 +1,68 @@
#!/usr/bin/env/python3
#
# Copyright (c) Facebook, Inc. and its affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import logging
import numpy as np
import torch
from ts.torch_handler.base_handler import BaseHandler
logger = logging.getLogger(__name__)
class IRISClassifierHandler(BaseHandler):
"""
IRISClassifier handler class. This handler takes an input tensor and
output the type of iris based on the input
"""
def __init__(self):
super(IRISClassifierHandler, self).__init__()
def preprocess(self, data):
"""
preprocessing step - Reads the input array and converts it to tensor
:param data: Input to be passed through the layers for prediction
:return: output - Preprocessed input
"""
input_data_str = data[0].get("data")
if input_data_str is None:
input_data_str = data[0].get("body")
input_data = input_data_str.decode("utf-8")
input_tensor = torch.Tensor(ast.literal_eval(input_data))
return input_tensor
def postprocess(self, inference_output):
"""
Does postprocess after inference to be returned to user
:param inference_output: Output of inference
:return: output - Output after post processing
"""
predicted_idx = str(np.argmax(inference_output.cpu().detach().numpy()))
if self.mapping:
return [self.mapping[str(predicted_idx)]]
return [predicted_idx]
_service = IRISClassifierHandler()

View File

@ -0,0 +1,66 @@
#!/usr/bin/env/python3
#
# Copyright (c) Facebook, Inc. and its affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from argparse import ArgumentParser
import pytorch_lightning as pl
from pytorch_pipeline.components.trainer.component import Trainer
# Argument parser for user defined paths
parser = ArgumentParser()
parser.add_argument(
"--tensorboard_root",
type=str,
default="output/tensorboard",
help="Tensorboard Root path (default: output/tensorboard)",
)
parser.add_argument(
"--checkpoint_dir",
type=str,
default="output",
help="Path to save model checkpoints (default: output/train/models)",
)
parser.add_argument(
"--model_name",
type=str,
default="iris.pt",
help="Name of the model to be saved as (default: iris.pt)",
)
parser = pl.Trainer.add_argparse_args(parent_parser=parser)
args = vars(parser.parse_args())
if not args["max_epochs"]:
max_epochs = 5
else:
max_epochs = args["max_epochs"]
args["max_epochs"] = max_epochs
trainer_args = {}
# Initiating the training process
trainer = Trainer(
module_file="iris_classification.py",
data_module_file="iris_data_module.py",
module_file_args=args,
data_module_args=None,
trainer_args=trainer_args,
)

View File

@ -0,0 +1,32 @@
#!/bin/bash -ex
#
# Copyright (c) Facebook, Inc. and its affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
source_root=$(pwd)
cd "$source_root/components/PyTorch/pytorch-kfp-components"
# Verify package build correctly
python setup.py bdist_wheel clean
# Verify package can be installed and loaded correctly
WHEEL_FILE=$(find "$source_root/components/PyTorch/pytorch-kfp-components/dist/" -name "pytorch_kfp_components*.whl")
pip3 install --upgrade $WHEEL_FILE
python -c "import pytorch_kfp_components"
echo `pwd`
# Run lint and tests
./tests/run_tests.sh

View File

@ -0,0 +1,2 @@
pip install -U tox virtualenv
tox "$@"

View File

@ -0,0 +1,202 @@
#!/usr/bin/env/python3
#
# Copyright (c) Facebook, Inc. and its affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for trainer component."""
import os
import shutil
import sys
import tempfile
import pytest
import pytorch_lightning
from pytorch_kfp_components.components.trainer.component import Trainer
dirname, filename = os.path.split(os.path.abspath(__file__))
IRIS_DIR = os.path.join(dirname, "iris")
sys.path.insert(0, IRIS_DIR)
MODULE_FILE_ARGS = {"lr": 0.1}
TRAINER_ARGS = {"max_epochs": 5}
DATA_MODULE_ARGS = {"num_workers": 2}
# pylint:disable=redefined-outer-name
@pytest.fixture(scope="class")
def trainer_params():
trainer_params = {
"module_file": "iris_classification.py",
"data_module_file": "iris_data_module.py",
"module_file_args": MODULE_FILE_ARGS,
"data_module_args": DATA_MODULE_ARGS,
"trainer_args": TRAINER_ARGS,
}
return trainer_params
MANDATORY_ARGS = [
"module_file",
"data_module_file",
]
OPTIONAL_ARGS = ["module_file_args", "data_module_args", "trainer_args"]
DEFAULT_MODEL_NAME = "model_state_dict.pth"
DEFAULT_SAVE_PATH = f"/tmp/{DEFAULT_MODEL_NAME}"
def invoke_training(trainer_params): # pylint: disable=W0621
"""This function invokes the training process."""
trainer = Trainer(
module_file=trainer_params["module_file"],
data_module_file=trainer_params["data_module_file"],
module_file_args=trainer_params["module_file_args"],
trainer_args=trainer_params["trainer_args"],
data_module_args=trainer_params["data_module_args"],
)
return trainer
@pytest.mark.parametrize("mandatory_key", MANDATORY_ARGS)
def test_mandatory_keys_type_check(trainer_params, mandatory_key):
"""Tests the uncexpected 'type' of mandatory args.
Args:
mandatory_key : mandatory arguments for inivoking training
"""
test_input = ["input_path"]
trainer_params[mandatory_key] = test_input
expected_exception_msg = (
f"{mandatory_key} must be of type <class 'str'> "
f"but received as {type(test_input)}"
)
with pytest.raises(TypeError, match=expected_exception_msg):
invoke_training(trainer_params=trainer_params)
@pytest.mark.parametrize("optional_key", OPTIONAL_ARGS)
def test_optional_keys_type_check(trainer_params, optional_key):
"""Tests the unexpected 'type' of optional args.
Args:
optional_key: optional arguments for invoking training
"""
test_input = "test_input"
trainer_params[optional_key] = test_input
expected_exception_msg = (
f"{optional_key} must be of type <class 'dict'> "
f"but received as {type(test_input)}"
)
with pytest.raises(TypeError, match=expected_exception_msg):
invoke_training(trainer_params=trainer_params)
@pytest.mark.parametrize("input_key", MANDATORY_ARGS + ["module_file_args"])
def test_mandatory_params(trainer_params, input_key):
"""Test for empty mandatory arguments.
Args:
input_key: name of the mandatory arg for training
"""
trainer_params[input_key] = None
expected_exception_msg = (
f"{input_key} is not optional. "
f"Received value: {trainer_params[input_key]}"
)
with pytest.raises(ValueError, match=expected_exception_msg):
invoke_training(trainer_params=trainer_params)
def test_data_module_args_optional(trainer_params):
"""Test for empty optional argument : data module args"""
trainer_params["data_module_args"] = None
invoke_training(trainer_params=trainer_params)
assert os.path.exists(DEFAULT_SAVE_PATH)
os.remove(DEFAULT_SAVE_PATH)
def test_trainer_args_none(trainer_params):
"""Test for empty trainer specific arguments."""
trainer_params["trainer_args"] = None
expected_exception_msg = r"trainer_args must be a dict"
with pytest.raises(TypeError, match=expected_exception_msg):
invoke_training(trainer_params=trainer_params)
def test_training_success(trainer_params):
"""Test the training success case with all required args."""
trainer = invoke_training(trainer_params=trainer_params)
assert os.path.exists(DEFAULT_SAVE_PATH)
os.remove(DEFAULT_SAVE_PATH)
assert hasattr(trainer, "ptl_trainer")
assert isinstance(
trainer.ptl_trainer, pytorch_lightning.trainer.trainer.Trainer
)
def test_training_success_with_custom_model_name(trainer_params):
"""Test for successful training with custom model name."""
tmp_dir = tempfile.mkdtemp()
trainer_params["module_file_args"]["checkpoint_dir"] = tmp_dir
trainer_params["module_file_args"]["model_name"] = "iris.pth"
invoke_training(trainer_params=trainer_params)
assert "iris.pth" in os.listdir(tmp_dir)
shutil.rmtree(tmp_dir)
trainer_params["module_file_args"].pop("checkpoint_dir")
trainer_params["module_file_args"].pop("model_name")
def test_training_failure_with_empty_module_file_args(trainer_params):
"""Test for successful training with empty module file args."""
trainer_params["module_file_args"] = {}
exception_msg = "module_file_args is not optional. Received value: {}"
with pytest.raises(ValueError, match=exception_msg):
invoke_training(trainer_params=trainer_params)
def test_training_success_with_empty_trainer_args(trainer_params):
"""Test for successful training with empty trainer args."""
tmp_dir = tempfile.mkdtemp()
trainer_params["module_file_args"]["max_epochs"] = 5
trainer_params["module_file_args"]["checkpoint_dir"] = tmp_dir
trainer_params["trainer_args"] = {}
invoke_training(trainer_params=trainer_params)
assert DEFAULT_MODEL_NAME in os.listdir(tmp_dir)
shutil.rmtree(tmp_dir)
def test_training_success_with_empty_data_module_args(trainer_params):
"""Test for successful training with empty data module args."""
tmp_dir = tempfile.mkdtemp()
trainer_params["module_file_args"]["checkpoint_dir"] = tmp_dir
trainer_params["data_module_args"] = None
invoke_training(trainer_params=trainer_params)
assert DEFAULT_MODEL_NAME in os.listdir(tmp_dir)
shutil.rmtree(tmp_dir)
#
def test_trainer_output(trainer_params):
"""Test for successful training with proper saving of training output."""
tmp_dir = tempfile.mkdtemp()
trainer_params["module_file_args"]["checkpoint_dir"] = tmp_dir
trainer = invoke_training(trainer_params=trainer_params)
assert hasattr(trainer, "output_dict")
assert trainer.output_dict is not None
assert trainer.output_dict["model_save_path"] == os.path.join(
tmp_dir, DEFAULT_MODEL_NAME
)
assert isinstance(
trainer.output_dict["ptl_trainer"],
pytorch_lightning.trainer.trainer.Trainer
)

View File

@ -0,0 +1,44 @@
[tox]
envlist = clean,py38
skip_missing_interpreters = true
[flake8]
exclude =
.git,
.tox,
.pytest_cache,
__pycache__,
dist,
build,
*.egg-info,
.pylintrc
[testenv]
usedevelop = True
install_command = pip install -U {opts} {packages}
extras = tests
testpaths = tests
deps =
pytest
pytest-cov
absl-py
sklearn
wget
pandas
minio
depends =
{py38}: clean
report: py38
commands =
flake8 --version
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
pylint --rcfile=tox.ini --exit-zero pytorch_kfp_components
py.test --cov=pytorch-kfp-components --cov-append --cov-report=term-missing -vvv -s {posargs}
[testenv:clean]
deps = coverage
skip_install = true
commands = coverage erase

View File

@ -0,0 +1,6 @@
# PyTorch Pipeline Samples
This folder contains different PyTorch Kubeflow pipeline examples using the PyTorch KFP Components SDK.
1. CFar10 example for Computer Vision
2. BERT example for NLP