diff --git a/dotenv/main.py b/dotenv/main.py index 349ec069..1a882380 100644 --- a/dotenv/main.py +++ b/dotenv/main.py @@ -2,47 +2,90 @@ from __future__ import absolute_import, print_function, unicode_literals import codecs -import fileinput import io import os import re +import shutil import sys from subprocess import Popen +import tempfile import warnings -from collections import OrderedDict +from collections import OrderedDict, namedtuple +from contextlib import contextmanager from .compat import StringIO, PY2, WIN, text_type -__escape_decoder = codecs.getdecoder('unicode_escape') -__posix_variable = re.compile('\$\{[^\}]*\}') # noqa +__posix_variable = re.compile(r'\$\{[^\}]*\}') +_binding = re.compile( + r""" + ( + \s* # leading whitespace + (?:export\s+)? # export -def decode_escaped(escaped): - return __escape_decoder(escaped)[0] + ( '[^']+' # single-quoted key + | [^=\#\s]+ # or unquoted key + )? + (?: + (?:\s*=\s*) # equal sign -def parse_line(line): - line = line.strip() + ( '(?:\\'|[^'])*' # single-quoted value + | "(?:\\"|[^"])*" # or double-quoted value + | [^\#\r\n]* # or unquoted value + ) + )? - # Ignore lines with `#` or which doesn't have `=` in it. - if not line or line.startswith('#') or '=' not in line: - return None, None + \s* # trailing whitespace + (?:\#[^\r\n]*)? # comment + (?:\r|\n|\r\n)? # newline + ) + """, + re.MULTILINE | re.VERBOSE, +) - k, v = line.split('=', 1) +_escape_sequence = re.compile(r"\\[\\'\"abfnrtv]") - if k.startswith('export '): - (_, _, k) = k.partition('export ') - # Remove any leading and trailing spaces in key, value - k, v = k.strip(), v.strip() +Binding = namedtuple('Binding', 'key value original') - if v: - v = v.encode('unicode-escape').decode('ascii') - quoted = v[0] == v[-1] in ['"', "'"] - if quoted: - v = decode_escaped(v[1:-1]) - return k, v +def decode_escapes(string): + def decode_match(match): + return codecs.decode(match.group(0), 'unicode-escape') + + return _escape_sequence.sub(decode_match, string) + + +def is_surrounded_by(string, char): + return ( + len(string) > 1 + and string[0] == string[-1] == char + ) + + +def parse_binding(string, position): + match = _binding.match(string, position) + (matched, key, value) = match.groups() + if key is None or value is None: + key = None + value = None + else: + value_quoted = is_surrounded_by(value, "'") or is_surrounded_by(value, '"') + if value_quoted: + value = decode_escapes(value[1:-1]) + else: + value = value.strip() + return (Binding(key=key, value=value, original=matched), match.end()) + + +def parse_stream(stream): + string = stream.read() + position = 0 + length = len(string) + while position < length: + (binding, position) = parse_binding(string, position) + yield binding class DotEnv(): @@ -52,19 +95,17 @@ def __init__(self, dotenv_path, verbose=False): self._dict = None self.verbose = verbose + @contextmanager def _get_stream(self): - self._is_file = False if isinstance(self.dotenv_path, StringIO): - return self.dotenv_path - - if os.path.isfile(self.dotenv_path): - self._is_file = True - return io.open(self.dotenv_path) - - if self.verbose: - warnings.warn("File doesn't exist {}".format(self.dotenv_path)) - - return StringIO('') + yield self.dotenv_path + elif os.path.isfile(self.dotenv_path): + with io.open(self.dotenv_path) as stream: + yield stream + else: + if self.verbose: + warnings.warn("File doesn't exist {}".format(self.dotenv_path)) + yield StringIO('') def dict(self): """Return dotenv as dict""" @@ -76,17 +117,10 @@ def dict(self): return self._dict def parse(self): - f = self._get_stream() - - for line in f: - key, value = parse_line(line) - if not key: - continue - - yield key, value - - if self._is_file: - f.close() + with self._get_stream() as stream: + for mapping in parse_stream(stream): + if mapping.key is not None and mapping.value is not None: + yield mapping.key, mapping.value def set_as_environment_variables(self, override=False): """ @@ -126,6 +160,20 @@ def get_key(dotenv_path, key_to_get): return DotEnv(dotenv_path, verbose=True).get(key_to_get) +@contextmanager +def rewrite(path): + try: + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as dest: + with io.open(path) as source: + yield (source, dest) + except BaseException: + if os.path.isfile(dest.name): + os.unlink(dest.name) + raise + else: + shutil.move(dest.name, path) + + def set_key(dotenv_path, key_to_set, value_to_set, quote_mode="always"): """ Adds or Updates a key/value to the given .env @@ -141,20 +189,19 @@ def set_key(dotenv_path, key_to_set, value_to_set, quote_mode="always"): if " " in value_to_set: quote_mode = "always" - line_template = '{}="{}"' if quote_mode == "always" else '{}={}' + line_template = '{}="{}"\n' if quote_mode == "always" else '{}={}\n' line_out = line_template.format(key_to_set, value_to_set) - replaced = False - for line in fileinput.input(dotenv_path, inplace=True): - k, v = parse_line(line) - if k == key_to_set: - replaced = True - line = "{}\n".format(line_out) - print(line, end='') - - if not replaced: - with io.open(dotenv_path, "a") as f: - f.write("{}\n".format(line_out)) + with rewrite(dotenv_path) as (source, dest): + replaced = False + for mapping in parse_stream(source): + if mapping.key == key_to_set: + dest.write(line_out) + replaced = True + else: + dest.write(mapping.original) + if not replaced: + dest.write(line_out) return True, key_to_set, value_to_set @@ -166,18 +213,17 @@ def unset_key(dotenv_path, key_to_unset, quote_mode="always"): If the .env path given doesn't exist, fails If the given key doesn't exist in the .env, fails """ - removed = False - if not os.path.exists(dotenv_path): warnings.warn("can't delete from %s - it doesn't exist." % dotenv_path) return None, key_to_unset - for line in fileinput.input(dotenv_path, inplace=True): - k, v = parse_line(line) - if k == key_to_unset: - removed = True - line = '' - print(line, end='') + removed = False + with rewrite(dotenv_path) as (source, dest): + for mapping in parse_stream(source): + if mapping.key == key_to_unset: + removed = True + else: + dest.write(mapping.original) if not removed: warnings.warn("key %s not removed from %s - key doesn't exist." % (key_to_unset, dotenv_path)) diff --git a/tests/test_cli.py b/tests/test_cli.py index 15c47af8..b594592a 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,12 +1,13 @@ # -*- coding: utf-8 -*- -from os import environ +import os from os.path import dirname, join +import pytest +import sh + import dotenv -from dotenv.version import __version__ from dotenv.cli import cli as dotenv_cli - -import sh +from dotenv.version import __version__ here = dirname(__file__) dotenv_path = join(here, '.env') @@ -38,6 +39,22 @@ def test_set_key(dotenv_file): with open(dotenv_file, 'r') as fp: assert 'HELLO="WORLD 2"\nfoo="bar"' == fp.read().strip() + success, key_to_set, value_to_set = dotenv.set_key(dotenv_file, "HELLO", "WORLD\n3") + + with open(dotenv_file, "r") as fp: + assert 'HELLO="WORLD\n3"\nfoo="bar"' == fp.read().strip() + + +def test_set_key_permission_error(dotenv_file): + os.chmod(dotenv_file, 0o000) + + with pytest.raises(Exception): + dotenv.set_key(dotenv_file, "HELLO", "WORLD") + + os.chmod(dotenv_file, 0o600) + with open(dotenv_file, "r") as fp: + assert fp.read() == "" + def test_list(cli, dotenv_file): success, key_to_set, value_to_set = dotenv.set_key(dotenv_file, 'HELLO', 'WORLD') @@ -59,6 +76,13 @@ def test_list_wo_file(cli): assert 'Invalid value for "-f"' in result.output +def test_empty_value(): + with open(dotenv_path, "w") as f: + f.write("TEST=") + assert dotenv.get_key(dotenv_path, "TEST") == "" + sh.rm(dotenv_path) + + def test_key_value_without_quotes(): with open(dotenv_path, 'w') as f: f.write("TEST = value \n") @@ -95,18 +119,41 @@ def test_value_with_special_characters(): sh.rm(dotenv_path) -def test_unset(): - sh.touch(dotenv_path) - success, key_to_set, value_to_set = dotenv.set_key(dotenv_path, 'HELLO', 'WORLD') - stored_value = dotenv.get_key(dotenv_path, 'HELLO') - assert stored_value == 'WORLD' - success, key_to_unset = dotenv.unset_key(dotenv_path, 'HELLO') - assert success is True - assert dotenv.get_key(dotenv_path, 'HELLO') is None - success, key_to_unset = dotenv.unset_key(dotenv_path, 'RANDOM') - assert success is None +def test_value_with_new_lines(): + with open(dotenv_path, 'w') as f: + f.write('TEST="a\nb"') + assert dotenv.get_key(dotenv_path, 'TEST') == "a\nb" + sh.rm(dotenv_path) + + with open(dotenv_path, 'w') as f: + f.write("TEST='a\nb'") + assert dotenv.get_key(dotenv_path, 'TEST') == "a\nb" + sh.rm(dotenv_path) + + +def test_value_after_comment(): + with open(dotenv_path, "w") as f: + f.write("# comment\nTEST=a") + assert dotenv.get_key(dotenv_path, "TEST") == "a" sh.rm(dotenv_path) - success, key_to_unset = dotenv.unset_key(dotenv_path, 'HELLO') + + +def test_unset_ok(dotenv_file): + with open(dotenv_file, "w") as f: + f.write("a=b\nc=d") + + success, key_to_unset = dotenv.unset_key(dotenv_file, "a") + + assert success is True + assert key_to_unset == "a" + with open(dotenv_file, "r") as f: + assert f.read() == "c=d" + sh.rm(dotenv_file) + + +def test_unset_non_existing_file(): + success, key_to_unset = dotenv.unset_key('/non-existing', 'HELLO') + assert success is None @@ -180,7 +227,7 @@ def test_get_key_with_interpolation(cli): stored_value = dotenv.get_key(dotenv_path, 'BAR') assert stored_value == 'CONCATENATED_WORLD_POSIX_VAR' # test replace from environ taking precedence over file - environ["HELLO"] = "TAKES_PRECEDENCE" + os.environ["HELLO"] = "TAKES_PRECEDENCE" stored_value = dotenv.get_key(dotenv_path, 'FOO') assert stored_value == "TAKES_PRECEDENCE" sh.rm(dotenv_path) @@ -194,10 +241,10 @@ def test_get_key_with_interpolation_of_unset_variable(cli): stored_value = dotenv.get_key(dotenv_path, 'FOO') assert stored_value == '' # unless present in environment - environ['NOT_SET'] = 'BAR' + os.environ['NOT_SET'] = 'BAR' stored_value = dotenv.get_key(dotenv_path, 'FOO') assert stored_value == 'BAR' - del(environ['NOT_SET']) + del(os.environ['NOT_SET']) sh.rm(dotenv_path) diff --git a/tests/test_core.py b/tests/test_core.py index 45a1f86a..bda2e3b7 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -9,7 +9,7 @@ import sh from dotenv import load_dotenv, find_dotenv, set_key, dotenv_values -from dotenv.main import parse_line +from dotenv.main import Binding, parse_stream from dotenv.compat import StringIO from IPython.terminal.embed import InteractiveShellEmbed @@ -25,21 +25,71 @@ def restore_os_environ(): @pytest.mark.parametrize("test_input,expected", [ - ("a=b", ("a", "b")), - (" a = b ", ("a", "b")), - ("export a=b", ("a", "b")), - (" export 'a'=b", ("'a'", "b")), - (" export 'a'=b", ("'a'", "b")), - ("# a=b", (None, None)), - ("# a=b", (None, None)), - ("a=b space ", ('a', 'b space')), - ("a='b space '", ('a', 'b space ')), - ('a="b space "', ('a', 'b space ')), - ("export export_spam=1", ("export_spam", "1")), - ("export port=8000", ("port", "8000")), + ("", []), + ("a=b", [Binding(key="a", value="b", original="a=b")]), + ("'a'=b", [Binding(key="'a'", value="b", original="'a'=b")]), + ("[=b", [Binding(key="[", value="b", original="[=b")]), + (" a = b ", [Binding(key="a", value="b", original=" a = b ")]), + ("export a=b", [Binding(key="a", value="b", original="export a=b")]), + (" export 'a'=b", [Binding(key="'a'", value="b", original=" export 'a'=b")]), + (" export 'a'=b", [Binding(key="'a'", value="b", original=" export 'a'=b")]), + ("# a=b", [Binding(key=None, value=None, original="# a=b")]), + ('a=b # comment', [Binding(key="a", value="b", original="a=b # comment")]), + ("a=b space ", [Binding(key="a", value="b space", original="a=b space ")]), + ("a='b space '", [Binding(key="a", value="b space ", original="a='b space '")]), + ('a="b space "', [Binding(key="a", value="b space ", original='a="b space "')]), + ("export export_a=1", [Binding(key="export_a", value="1", original="export export_a=1")]), + ("export port=8000", [Binding(key="port", value="8000", original="export port=8000")]), + ('a="b\nc"', [Binding(key="a", value="b\nc", original='a="b\nc"')]), + ("a='b\nc'", [Binding(key="a", value="b\nc", original="a='b\nc'")]), + ('a="b\nc"', [Binding(key="a", value="b\nc", original='a="b\nc"')]), + ('a="b\\nc"', [Binding(key="a", value='b\nc', original='a="b\\nc"')]), + ('a="b\\"c"', [Binding(key="a", value='b"c', original='a="b\\"c"')]), + ("a='b\\'c'", [Binding(key="a", value="b'c", original="a='b\\'c'")]), + ("a=à", [Binding(key="a", value="à", original="a=à")]), + ('a="à"', [Binding(key="a", value="à", original='a="à"')]), + ('garbage', [Binding(key=None, value=None, original="garbage")]), + ( + "a=b\nc=d", + [ + Binding(key="a", value="b", original="a=b\n"), + Binding(key="c", value="d", original="c=d"), + ], + ), + ( + "a=b\r\nc=d", + [ + Binding(key="a", value="b", original="a=b\r\n"), + Binding(key="c", value="d", original="c=d"), + ], + ), + ( + 'a="\nb=c', + [ + Binding(key="a", value='"', original='a="\n'), + Binding(key="b", value='c', original="b=c"), + ] + ), + ( + '# comment\na="b\nc"\nd=e\n', + [ + Binding(key=None, value=None, original="# comment\n"), + Binding(key="a", value="b\nc", original='a="b\nc"\n'), + Binding(key="d", value="e", original="d=e\n"), + ], + ), + ( + 'garbage[%$#\na=b', + [ + Binding(key=None, value=None, original="garbage[%$#\n"), + Binding(key="a", value="b", original='a=b'), + ], + ), ]) -def test_parse_line(test_input, expected): - assert parse_line(test_input) == expected +def test_parse_stream(test_input, expected): + result = parse_stream(StringIO(test_input)) + + assert list(result) == expected def test_warns_if_file_does_not_exist():