Skip to content
This repository has been archived by the owner on Sep 24, 2020. It is now read-only.

Commit

Permalink
Merge branch 'master' into release/cli-0.10.0rc
Browse files Browse the repository at this point in the history
  • Loading branch information
raubitsj authored Sep 11, 2020
2 parents df05bf7 + dc8f20e commit 9c2c2b2
Show file tree
Hide file tree
Showing 18 changed files with 211 additions and 54 deletions.
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def test_settings(test_dir, mocker):
project="test",
console="off",
host="test",
api_key=DUMMY_API_KEY,
run_id=wandb.util.generate_id(),
_start_datetime=datetime.datetime.now())
settings.setdefaults()
Expand Down Expand Up @@ -172,6 +173,7 @@ def runner(monkeypatch, mocker):
'project_name': 'test_model', 'files': ['weights.h5'],
'attach': False, 'team_name': 'Manual Entry'})
monkeypatch.setattr(webbrowser, 'open_new_tab', lambda x: True)
mocker.patch("wandb.lib.apikey.isatty", lambda stream: True)
mocker.patch("wandb.lib.apikey.input", lambda x: 1)
mocker.patch("wandb.lib.apikey.getpass.getpass", lambda x: DUMMY_API_KEY)
return CliRunner()
Expand Down
38 changes: 38 additions & 0 deletions tests/test_public_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ def api(runner):
return Api()


def test_api_auto_login_no_tty(mocker):
with pytest.raises(wandb.UsageError):
Api()


def test_parse_project_path(api):
e, p = api._parse_project_path("user/proj")
assert e == "user"
Expand Down Expand Up @@ -96,6 +101,12 @@ def test_run_history(mock_server, api):
assert run.history(pandas=False)[0] == {'acc': 10, 'loss': 90}


def test_run_history_keys(mock_server, api):
run = api.run("test/test/test")
assert run.history(keys=["acc", "loss"], pandas=False) == [
{"loss": 0, "acc": 100}, {"loss": 1, "acc": 0}]


def test_run_config(mock_server, api):
run = api.run("test/test/test")
assert run.config == {'epochs': 10}
Expand Down Expand Up @@ -264,6 +275,33 @@ def test_artifact_run_logged(runner, mock_server, api):
assert arts[0].name == "abc123"


def test_artifact_manual_use(runner, mock_server, api):
run = api.run("test/test/test")
art = api.artifact("entity/project/mnist:v0", type="dataset")
run.use_artifact(art)
assert True


def test_artifact_manual_log(runner, mock_server, api):
run = api.run("test/test/test")
art = api.artifact("entity/project/mnist:v0", type="dataset")
run.log_artifact(art)
assert True


def test_artifact_manual_error(runner, mock_server, api):
run = api.run("test/test/test")
art = wandb.Artifact("test", type="dataset")
with pytest.raises(wandb.CommError):
run.log_artifact(art)
with pytest.raises(wandb.CommError):
run.use_artifact(art)
with pytest.raises(wandb.CommError):
run.use_artifact("entity/project/mnist:v0")
with pytest.raises(wandb.CommError):
run.log_artifact("entity/project/mnist:v0")


@pytest.mark.skipif(platform.system() == "Windows",
reason="Verify is broken on Windows")
def test_artifact_verify(runner, mock_server, api):
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/mock_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def run(ctx):
}
]
},
"sampledHistory": ['{"loss": 0, "acc": 100}'],
"sampledHistory": [[{"loss": 0, "acc": 100}, {"loss": 1, "acc": 0}]],
"shouldStop": False,
"failed": False,
"stopped": stopped,
Expand Down
6 changes: 3 additions & 3 deletions tests/wandb_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,11 @@ def test_login_key(capsys):
assert wandb.api.api_key == "A" * 40


def test_login_existing_key():
def test_login_invalid_key():
os.environ["WANDB_API_KEY"] = "B" * 40
wandb.ensure_configured()
wandb.login()
assert wandb.api.api_key == "B" * 40
with pytest.raises(wandb.UsageError):
wandb.login()
del os.environ["WANDB_API_KEY"]


Expand Down
2 changes: 1 addition & 1 deletion wandb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
Config = wandb_sdk.Config

from wandb.apis import InternalApi, PublicApi
from wandb.errors.error import CommError
from wandb.errors.error import CommError, UsageError

from wandb.lib import preinit as _preinit
from wandb.lib import lazyloader as _lazyloader
Expand Down
54 changes: 53 additions & 1 deletion wandb/apis/public.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,8 @@ def _sampled_history(self, keys, x_axis="_step", samples=500):
''')

response = self._exec(query, specs=[json.dumps(spec)])
return [line for line in response['project']['run']['sampledHistory']]
# sampledHistory returns one list per spec, we only send one spec
return response['project']['run']['sampledHistory'][0]

def _full_history(self, samples=500, stream="default"):
node = "history" if stream == "default" else "events"
Expand Down Expand Up @@ -1074,6 +1075,55 @@ def logged_artifacts(self, per_page=100):
def used_artifacts(self, per_page=100):
return RunArtifacts(self.client, self, mode="used", per_page=per_page)

@normalize_exceptions
def use_artifact(self, artifact):
""" Declare an artifact as an input to a run.
Args:
artifact (:obj:`Artifact`): An artifact returned from
`wandb.Api().artifact(name)`
Returns:
A :obj:`Artifact` object.
"""
api = InternalApi(default_settings={
"entity": self.entity, "project": self.project})
api.set_current_run_id(self.id)

if isinstance(artifact, Artifact):
api.use_artifact(artifact.id)
return artifact
elif isinstance(artifact, wandb.Artifact):
raise ValueError("Only existing artifacts are accepted by this api. "
"Manually create one with `wandb artifacts put`")
else:
raise ValueError('You must pass a wandb.Api().artifact() to use_artifact')

@normalize_exceptions
def log_artifact(self, artifact, aliases=None):
""" Declare an artifact as output of a run.
Args:
artifact (:obj:`Artifact`): An artifact returned from
`wandb.Api().artifact(name)`
aliases (list, optional): Aliases to apply to this artifact
Returns:
A :obj:`Artifact` object.
"""
api = InternalApi(default_settings={
"entity": self.entity, "project": self.project})
api.set_current_run_id(self.id)

if isinstance(artifact, Artifact):
artifact_collection_name = artifact.name.split(':')[0]
api.create_artifact(artifact.type, artifact_collection_name,
artifact.digest, aliases=aliases)
return artifact
elif isinstance(artifact, wandb.Artifact):
raise ValueError("Only existing artifacts are accepted by this api. "
"Manually create one with `wandb artifacts put`")
else:
raise ValueError('You must pass a wandb.Api().artifact() to use_artifact')

@property
def summary(self):
if self._summary is None:
Expand Down Expand Up @@ -1777,6 +1827,8 @@ def update_variables(self):
self.variables.update({'cursor': self.cursor})

def convert_objects(self):
if self.last_response['project'] is None:
return []
return [ArtifactType(self.client, self.entity, self.project, r["node"]["name"], r["node"])
for r in self.last_response['project']['artifactTypes']['edges']]

Expand Down
2 changes: 1 addition & 1 deletion wandb/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def login(key, host, cloud, relogin, anonymously, no_offline=False):
api.set_setting("base_url", host.strip("/"), globally=True, persist=True)
key = key[0] if len(key) > 0 else None

wandb.login(relogin=relogin, key=key, anonymous=anon_mode)
wandb.login(relogin=relogin, key=key, anonymous=anon_mode, force=True)


@cli.command(
Expand Down
1 change: 1 addition & 0 deletions wandb/errors/term.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@


LOG_STRING = click.style('wandb', fg='blue', bold=True)
LOG_STRING_NOCOLOR = 'wandb'
ERROR_STRING = click.style('ERROR', bg='red', fg='green')
WARN_STRING = click.style('WARNING', fg='yellow')
PRINTED_MESSAGES = set()
Expand Down
2 changes: 2 additions & 0 deletions wandb/internal/internal_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1477,6 +1477,8 @@ def create_artifact(self, artifact_type_name, artifact_collection_name, digest,
if not is_user_created:
run_name = run_name or self.current_run_id

if aliases is None:
aliases = []
response = self.gql(mutation, variable_values={
'entityName': entity_name,
'projectName': project_name,
Expand Down
3 changes: 2 additions & 1 deletion wandb/internal/sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,10 +522,11 @@ def send_artifact(self, data):
is_user_created=artifact.user_created,
)

metadata = json.loads(artifact.metadata) if artifact.metadata else None
saver.save(
type=artifact.type,
name=artifact.name,
metadata=json.loads(artifact.metadata),
metadata=metadata,
description=artifact.description,
aliases=artifact.aliases,
use_after_commit=artifact.use_after_commit,
Expand Down
37 changes: 25 additions & 12 deletions wandb/lib/apikey.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
from six.moves import input
import wandb
from wandb.apis import InternalApi
from wandb.errors.term import LOG_STRING
from wandb.errors import term
from wandb.util import isatty


LOGIN_CHOICE_ANON = "Private W&B dashboard, no account required"
LOGIN_CHOICE_NEW = "Create a W&B account"
LOGIN_CHOICE_EXISTS = "Use an existing W&B account"
LOGIN_CHOICE_DRYRUN = "Don't visualize my results"
LOGIN_CHOICE_NOTTY = "Unconfigured"
LOGIN_CHOICES = [
LOGIN_CHOICE_ANON,
LOGIN_CHOICE_NEW,
Expand All @@ -42,7 +43,7 @@ def _prompt_choice():
return (
int(
input(
"%s: Enter your choice: " % LOG_STRING
"%s: Enter your choice: " % term.LOG_STRING
)
)
- 1 # noqa: W503
Expand All @@ -51,15 +52,24 @@ def _prompt_choice():
return -1


def prompt_api_key(
def prompt_api_key( # noqa: C901
settings,
api=None,
input_callback=None,
browser_callback=None,
no_offline=False,
no_create=False,
local=False,
):
"""Prompt for api key.
Returns:
str - if key is configured
None - if dryrun is selected
False - if unconfigured (notty)
"""
input_callback = input_callback or getpass.getpass
log_string = term.LOG_STRING
api = api or InternalApi()
anon_mode = _fixup_anon_mode(settings.anonymous)
jupyter = settings._jupyter or False
Expand All @@ -71,8 +81,11 @@ def prompt_api_key(
choices.remove(LOGIN_CHOICE_ANON)
if jupyter or no_offline:
choices.remove(LOGIN_CHOICE_DRYRUN)
if jupyter or no_create:
choices.remove(LOGIN_CHOICE_NEW)

if jupyter and 'google.colab' in sys.modules:
log_string = term.LOG_STRING_NOCOLOR
key = wandb.jupyter.attempt_colab_login(api.app_url)
if key is not None:
write_key(settings, key)
Expand All @@ -82,9 +95,11 @@ def prompt_api_key(
result = LOGIN_CHOICE_ANON
# If we're not in an interactive environment, default to dry-run.
elif not jupyter and (not isatty(sys.stdout) or not isatty(sys.stdin)):
result = LOGIN_CHOICE_DRYRUN
result = LOGIN_CHOICE_NOTTY
elif local:
result = LOGIN_CHOICE_EXISTS
elif len(choices) == 1:
result = choices[0]
else:
for i, choice in enumerate(choices):
wandb.termlog("(%i) %s" % (i + 1, choice))
Expand All @@ -97,6 +112,7 @@ def prompt_api_key(
result = choices[idx]
wandb.termlog("You chose '%s'" % result)

api_ask = "%s: Paste an API key from your profile and hit enter: " % log_string
if result == LOGIN_CHOICE_ANON:
key = api.create_anonymous_api_key()

Expand All @@ -109,10 +125,7 @@ def prompt_api_key(
wandb.termlog(
"Create an account here: {}/authorize?signup=true".format(app_url)
)
key = input_callback(
"%s: Paste an API key from your profile and hit enter"
% LOG_STRING
).strip()
key = input_callback(api_ask).strip()

write_key(settings, key)
return key
Expand All @@ -125,12 +138,12 @@ def prompt_api_key(
app_url
)
)
key = input_callback(
"%s: Paste an API key from your profile and hit enter"
% LOG_STRING
).strip()
key = input_callback(api_ask).strip()
write_key(settings, key)
return key
elif result == LOGIN_CHOICE_NOTTY:
# TODO: Needs refactor as this needs to be handled by caller
return False
else:
# Jupyter environments don't have a tty, but we can still try logging in using
# the browser callback if one is supplied.
Expand Down
4 changes: 2 additions & 2 deletions wandb/lib/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ class ServerError(Exception):


class Server(object):
def __init__(self, api=None):
self._api = api or InternalApi()
def __init__(self, api=None, settings=None):
self._api = api or InternalApi(default_settings=settings)
self._error_network = None
self._viewer = {}
self._flags = {}
Expand Down
4 changes: 4 additions & 0 deletions wandb/sdk/wandb_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@ def setup(self, kwargs):
print("Ignored wandb.init() arg %s when running a sweep" % key)
settings.apply_init(kwargs)

login_key = wandb_login._login(_disable_warning=True, _settings=settings)
if not login_key:
settings.mode = "offline"

# TODO(jhr): should this be moved? probably.
d = dict(_start_time=time.time(), _start_datetime=datetime.datetime.now(),)
settings.update(d)
Expand Down
Loading

0 comments on commit 9c2c2b2

Please sign in to comment.