Skip to content

Commit

Permalink
fix(completions): explicitly use utf8 for Windows compat
Browse files Browse the repository at this point in the history
fix(completions): handle symlinked profile files

refactor(completions): write temporary file to a consistent location
  • Loading branch information
daniel-makerx authored and achidlow committed Dec 19, 2022
1 parent 55fe4fc commit 5033f8e
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 71 deletions.
92 changes: 44 additions & 48 deletions src/algokit/cli/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,25 +56,29 @@ def __init__(self, shell: str | None) -> None:

def install(self) -> None:
self._save_source()
self._insert_profile_line()
logger.info(f"AlgoKit completions installed for {self.shell} 🎉")
if self._insert_profile_line():
logger.info(f"AlgoKit completions installed for {self.shell} 🎉")
else:
logger.info(f"{self.profile_path} already contains completion source 🤔")
home_based_profile_path = _get_home_based_path(self.profile_path)
logger.info(f"Restart shell or run `. {home_based_profile_path}` to enable completions")

def uninstall(self) -> None:
self._remove_source()
self._remove_profile_line()
logger.info(f"AlgoKit completions uninstalled for {self.shell} 🎉")
if self._remove_profile_line():
logger.info(f"AlgoKit completions uninstalled for {self.shell} 🎉")
else:
logger.info(f"AlgoKit completions not installed for {self.shell} 🤔")

@property
def source(self) -> str:
completion_class = click.shell_completion.get_completion_class(self.shell)
completion = completion_class(
# class is only instantiated to get source snippet, so don't need to pass a real command
None, # type: ignore
{},
"algokit",
"_ALGOKIT_COMPLETE",
cli=None, # type: ignore
ctx_args={},
prog_name="algokit",
complete_var="_ALGOKIT_COMPLETE",
)
try:
return completion.source()
Expand All @@ -90,65 +94,57 @@ def _save_source(self) -> None:
# grab source before attempting to write file in case it fails
source = self.source
logger.debug(f"Writing source script {self.source_path}")
with open(self.source_path, "w") as source_file:
source_file.write(source)
source_file.flush()
self.source_path.write_text(source, encoding="utf-8")

def _remove_source(self) -> None:
logger.debug(f"Removing source script {self.source_path}")
self.source_path.unlink(missing_ok=True)

def _insert_profile_line(self) -> None:
do_write = True
if self.profile_path.exists():
with open(self.profile_path) as file:
for line in file:
if self.profile_line in line:
logger.debug(f"{self.profile_path} already contains completion source")
# profile already contains source of completion script. nothing to do
do_write = False
break

if do_write:
logger.debug(f"Appending completion source to {self.profile_path}")
# got to end of file, so append profile line
atomic_write([self.profile_line], self.profile_path, "a")

def _remove_profile_line(self) -> None:
if not self.profile_path.exists():
def _insert_profile_line(self) -> bool:
try:
content = self.profile_path.read_text(encoding="utf-8")
except FileNotFoundError:
pass
else:
if self.profile_line in content:
# profile already contains source of completion script. nothing to do
return False

logger.debug(f"Appending completion source to {self.profile_path}")
# got to end of file, so append profile line
atomic_write(self.profile_line, self.profile_path, "a")
return True

def _remove_profile_line(self) -> bool:
try:
content = self.profile_path.read_text(encoding="utf-8")
except FileNotFoundError:
logger.debug(f"{self.profile_path} not found")
# nothing to do
return

return False
# see if profile script contains profile_line, if it does remove it
do_write = False
lines = []
with open(self.profile_path) as file:
for line in file:
if self.profile_line in line:
do_write = True
logger.debug(f"Completion source found in {self.profile_path}")
else:
lines.append(line)
if self.profile_line not in content:
return False
logger.debug(f"Completion source found in {self.profile_path}")
content = content.replace(self.profile_line, "")

if do_write:
logger.debug(f"Removing completion source found in {self.profile_path}")
atomic_write(lines, self.profile_path, "w")
logger.debug(f"Removing completion source found in {self.profile_path}")
atomic_write(content, self.profile_path, "w")
return True


def _get_home_based_path(path: Path) -> Path:
home = Path("~").expanduser()
home = Path.home()
try:
home_based_path = path.relative_to(home)
return "~" / home_based_path
except ValueError:
return path
else:
return "~" / home_based_path


def _get_current_shell() -> str:
try:
shell = shellingham.detect_shell()
shell_name: str = shell[0]
shell_name, *_ = shellingham.detect_shell() # type: tuple[str, str]
except Exception as ex:
logger.debug("Could not determine current shell", exc_info=ex)
logger.warning("Could not determine current shell. Try specifying a supported shell with --shell")
Expand Down
35 changes: 17 additions & 18 deletions src/algokit/core/atomic_write.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,35 @@
import os
import platform
import shutil
import stat
import tempfile
from pathlib import Path
from typing import Literal


# from https://python.plainenglish.io/simple-safe-atomic-writes-in-python3-44b98830a013
def atomic_write(file_contents: list[str], target_file_path: Path, mode: str = "w") -> None:
# Use the same directory as the destination file so replace is atomic
temp_file = tempfile.NamedTemporaryFile(delete=False, dir=target_file_path.parent)
temp_file_path = Path(temp_file.name)
temp_file.close()
def atomic_write(file_contents: str, target_file_path: Path, mode: Literal["a", "w"] = "w") -> None:
# if target path is a symlink, we want to use the real path as the replacement target,
# otherwise we'd just be overwriting the symlink
target_file_path = target_file_path.resolve()
temp_file_path = target_file_path.with_suffix(f"{target_file_path.suffix}.algokit~")
try:
# preserve file metadata if it already exists
if target_file_path.exists():
try:
_copy_with_metadata(target_file_path, temp_file_path)
with open(temp_file_path, mode) as file:
file.writelines(file_contents)
file.flush()
os.fsync(file.fileno())

os.replace(temp_file_path, target_file_path)
except FileNotFoundError:
pass
# write content to new temp file
with temp_file_path.open(mode=mode, encoding="utf-8") as fp:
fp.write(file_contents)
# overwrite destination with the temp file
temp_file_path.replace(target_file_path)
finally:
temp_file_path.unlink(missing_ok=True)


def _copy_with_metadata(source: Path, target: Path) -> None:
# copy content, stat-info (mode too), timestamps...
shutil.copy2(source, target)
os_type = platform.system().lower()
if os_type != "windows":
# try copy owner+group if platform supports it
if hasattr(os, "chown"):
# copy owner and group
st = os.stat(source)
st = source.stat()
os.chown(target, st[stat.ST_UID], st[stat.ST_GID])
4 changes: 3 additions & 1 deletion tests/completions/test_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def run_command(self, command: str, shell: str | None = None) -> ClickInvokeResu
command += f" --shell {shell}"

result = invoke(command, env=self.env)
normalized_output = normalize_path(result.output, self.home.name, "{home}").replace("\\", "/")
normalized_output = normalize_path(result.output, str(self.home_path), "{home}").replace("\\", "/")
return ClickInvokeResult(exit_code=result.exit_code, output=normalized_output)

@property
Expand All @@ -86,6 +86,7 @@ def test_completions_installs_correctly_with_specified_shell(shell: str):
assert result.exit_code == 0
# content of this file is defined by click, so only assert it exists not its content
assert context.source_path.exists()
assert not context.profile_path.with_suffix(".algokit~").exists()
profile = context.profile_contents
verify(get_combined_verify_output(result.output, "profile", profile), options=NamerFactory.with_parameters(shell))

Expand Down Expand Up @@ -120,6 +121,7 @@ def test_completions_uninstalls_correctly(shell: str):
assert result.exit_code == 0
assert not context.source_path.exists()
profile = context.profile_contents
assert not context.profile_path.with_suffix(".algokit~").exists()
assert profile == ORIGINAL_PROFILE_CONTENTS
verify(result.output, options=NamerFactory.with_parameters(shell))

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
DEBUG: Writing source script {home}/.config/algokit/.algokit-completions.bash
DEBUG: {home}/.bashrc already contains completion source
AlgoKit completions installed for bash 🎉
{home}/.bashrc already contains completion source 🤔
Restart shell or run `. ~/.bashrc` to enable completions
----
profile:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
DEBUG: Removing source script {home}/.config/algokit/.algokit-completions.bash
DEBUG: {home}/.bashrc not found
AlgoKit completions uninstalled for bash 🎉
AlgoKit completions not installed for bash 🤔
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
DEBUG: Removing source script {home}/.config/algokit/.algokit-completions.bash
AlgoKit completions uninstalled for bash 🎉
AlgoKit completions not installed for bash 🤔

0 comments on commit 5033f8e

Please sign in to comment.