From 95f6c6d954e361960fd1d044250f308968a7b797 Mon Sep 17 00:00:00 2001 From: Carlos Hernandez Date: Sun, 24 Jun 2018 17:35:19 -0700 Subject: [PATCH 1/8] add skorch compat --- osprey/data/torch_skeleton_config.yaml | 38 ++++++++++++++++++++++++++ osprey/eval_scopes.py | 16 ++++++++++- osprey/execute_skeleton.py | 3 +- 3 files changed, 55 insertions(+), 2 deletions(-) create mode 100644 osprey/data/torch_skeleton_config.yaml diff --git a/osprey/data/torch_skeleton_config.yaml b/osprey/data/torch_skeleton_config.yaml new file mode 100644 index 0000000..631f39c --- /dev/null +++ b/osprey/data/torch_skeleton_config.yaml @@ -0,0 +1,38 @@ +estimator: + eval: Pipeline([ + ('scale', RobustScaler()), + ('classifier', NeuralNetClassifier(nn.Sequential(nn.Linear(64, 32), + nn.ReLU(), + nn.Linear(32, 10), + nn.Softmax(dim=1)), + max_epochs=10)), + ]) + eval_scope: ['sklearn', 'torch'] + +scoring: accuracy + +strategy: + name: gp + params: + seeds: 5 + +search_space: + classifier__lr: + min: 1e-3 + max: 1e-1 + num: 10 + type: jump + var_type: float + warp: log + +cv: 5 + +dataset_loader: + name: sklearn_dataset + params: + method: load_digits + +trials: + uri: sqlite:///osprey-trials.db + +random_seed: 42 diff --git a/osprey/eval_scopes.py b/osprey/eval_scopes.py index b55509f..e2af20b 100644 --- a/osprey/eval_scopes.py +++ b/osprey/eval_scopes.py @@ -8,7 +8,7 @@ from sklearn.base import BaseEstimator -__all__ = ['msmbuilder', 'import_all_estimators'] +__all__ = ['msmbuilder', 'torch', 'import_all_estimators'] def msmbuilder(): @@ -22,6 +22,20 @@ def msmbuilder(): return scope +def torch(): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + import torch + from torch import nn + import skorch + from sklearn.pipeline import Pipeline + + scope = import_all_estimators(skorch) + scope.update({'nn': nn}) + scope['Pipeline'] = Pipeline + return scope + + def import_all_estimators(pkg): def estimator_in_module(mod): diff --git a/osprey/execute_skeleton.py b/osprey/execute_skeleton.py index 0b9684c..d15cc53 100644 --- a/osprey/execute_skeleton.py +++ b/osprey/execute_skeleton.py @@ -8,7 +8,8 @@ 'random_example': 'random_example.yaml', 'gp_example': 'sklearn_skeleton_config.yaml', 'grid_example': 'grid_example.yaml', - 'msmb_feat_select': 'msmb_feat_select_skeleton_config.yaml'} + 'msmb_feat_select': 'msmb_feat_select_skeleton_config.yaml', + 'torch': 'torch_skeleton_config.yaml'} def execute(args, parser): From ad335fb79630139106f24df5a20f3eff0ddb4d2b Mon Sep 17 00:00:00 2001 From: Carlos Hernandez Date: Tue, 26 Jun 2018 18:19:36 -0700 Subject: [PATCH 2/8] add test --- devtools/conda-recipe/meta.yaml | 1 + osprey/tests/test_cli_worker_and_dump.py | 31 ++++++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/devtools/conda-recipe/meta.yaml b/devtools/conda-recipe/meta.yaml index b97d318..3f9ec5f 100644 --- a/devtools/conda-recipe/meta.yaml +++ b/devtools/conda-recipe/meta.yaml @@ -33,6 +33,7 @@ test: - nose - nose-timer - gpy + - skorch - msmbuilder - msmb_data - mdtraj diff --git a/osprey/tests/test_cli_worker_and_dump.py b/osprey/tests/test_cli_worker_and_dump.py index c71f105..21416f6 100644 --- a/osprey/tests/test_cli_worker_and_dump.py +++ b/osprey/tests/test_cli_worker_and_dump.py @@ -15,6 +15,13 @@ except: HAVE_MSMBUILDER = False +try: + __import__('skorch') + HAVE_SKORCH = True +except: + HAVE_SKORCH = False + + OSPREY_BIN = find_executable('osprey') @@ -136,6 +143,30 @@ def test_gp_example(): shutil.rmtree(dirname) +@skipif(not HAVE_SKORCH, 'this test requires Skorch') +def test_torch_example(): + assert OSPREY_BIN is not None + cwd = os.path.abspath(os.curdir) + dirname = tempfile.mkdtemp() + + try: + os.chdir(dirname) + subprocess.check_call([OSPREY_BIN, 'skeleton', '-t', 'torch', + '-f', 'config.yaml']) + subprocess.check_call([OSPREY_BIN, 'worker', 'config.yaml', '-n', '1']) + assert os.path.exists('osprey-trials.db') + + subprocess.check_call([OSPREY_BIN, 'current_best', 'config.yaml']) + + yield _test_dump_1 + + yield _test_plot_1 + + finally: + os.chdir(cwd) + shutil.rmtree(dirname) + + def test_grid_example(): assert OSPREY_BIN is not None cwd = os.path.abspath(os.curdir) From 3f02258921438f7bdb4a0e18fbd0359857f1c8af Mon Sep 17 00:00:00 2001 From: Carlos Hernandez Date: Tue, 26 Jun 2018 18:47:24 -0700 Subject: [PATCH 3/8] add py36 --- .travis.yml | 3 ++- devtools/travis-ci/build_docs.sh | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/.travis.yml b/.travis.yml index d41f5b9..616e428 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,6 +5,7 @@ env: matrix: - CONDA_PY=2.7 - CONDA_PY=3.5 + - CONDA_PY=3.6 branches: only: @@ -31,7 +32,7 @@ deploy: local-dir: docs/_deploy/ on: branch: master - condition: "$CONDA_PY = 3.5" + condition: "$CONDA_PY = 3.6" region: us-east-1 detect_encoding: true access_key_id: $AWS_ACCESS_KEY diff --git a/devtools/travis-ci/build_docs.sh b/devtools/travis-ci/build_docs.sh index c3336e2..52c8961 100755 --- a/devtools/travis-ci/build_docs.sh +++ b/devtools/travis-ci/build_docs.sh @@ -8,15 +8,15 @@ conda create --yes -n docenv python=$CONDA_PY source activate docenv conda install -yq --use-local osprey +# Install doc requirements +conda install --yes --file docs/requirements.txt + # We don't use conda for these: # sphinx_rtd_theme's latest releases are not available # neither is msmb_theme # neither is sphinx > 1.3.1 (fix #1892 autodoc problem) pip install -I sphinx==1.3.5 sphinx_rtd_theme==0.1.9 msmb_theme==1.2.0 -# Install doc requirements -conda install --yes --file docs/requirements.txt - # Make docs cd docs && make html && cd - From 6ab9afd6a08cfeb12d7d9d3e325dbb4044a1977a Mon Sep 17 00:00:00 2001 From: Carlos Hernandez Date: Wed, 27 Jun 2018 16:52:53 -0700 Subject: [PATCH 4/8] add py34 --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 616e428..e8e1973 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,7 +3,7 @@ sudo: false env: matrix: - - CONDA_PY=2.7 + - CONDA_PY=3.4 - CONDA_PY=3.5 - CONDA_PY=3.6 From a3c43963b5accfccadc96056aa9796abd3051982 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Hern=C3=A1ndez?= Date: Thu, 28 Jun 2018 11:02:39 -0700 Subject: [PATCH 5/8] Update .travis.yml --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index e8e1973..92071ea 100644 --- a/.travis.yml +++ b/.travis.yml @@ -17,7 +17,7 @@ install: - conda install -yq python-coveralls script: - - conda build --quiet devtools/conda-recipe + - if [ $CONDA_PY == 3.6 ]; then conda build --quiet devtools/conda-recipe; fi - devtools/travis-ci/build_docs.sh after_success: From e39d865fcb7f98bc2595cbff2f4e656eb22fffd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Hern=C3=A1ndez?= Date: Thu, 18 Apr 2019 16:29:33 -0400 Subject: [PATCH 6/8] Update .travis.yml --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 8db0113..0dbb122 100644 --- a/.travis.yml +++ b/.travis.yml @@ -18,7 +18,7 @@ install: - conda install -yq python-coveralls script: - - if [ $CONDA_PY == 3.6 ]; then conda build --quiet devtools/conda-recipe; fi + - conda build --quiet devtools/conda-recipe - devtools/travis-ci/build_docs.sh after_success: From 59de59a7ef707e1d8cf029a3a5907d5ec46e8af2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Hern=C3=A1ndez?= Date: Thu, 18 Apr 2019 18:31:30 -0400 Subject: [PATCH 7/8] Update meta.yaml --- devtools/conda-recipe/meta.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/devtools/conda-recipe/meta.yaml b/devtools/conda-recipe/meta.yaml index 3f9ec5f..3dbbde3 100644 --- a/devtools/conda-recipe/meta.yaml +++ b/devtools/conda-recipe/meta.yaml @@ -34,6 +34,7 @@ test: - nose-timer - gpy - skorch + - torch - msmbuilder - msmb_data - mdtraj From 3b2d24df0871cfb6060be928f754e515250ea11e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Hern=C3=A1ndez?= Date: Thu, 18 Apr 2019 18:46:00 -0400 Subject: [PATCH 8/8] Update meta.yaml --- devtools/conda-recipe/meta.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/devtools/conda-recipe/meta.yaml b/devtools/conda-recipe/meta.yaml index 3dbbde3..b035bca 100644 --- a/devtools/conda-recipe/meta.yaml +++ b/devtools/conda-recipe/meta.yaml @@ -34,7 +34,7 @@ test: - nose-timer - gpy - skorch - - torch + - pytorch - msmbuilder - msmb_data - mdtraj