-
-
Notifications
You must be signed in to change notification settings - Fork 505
/
sh.py
3675 lines (2930 loc) · 124 KB
/
sh.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
https://sh.readthedocs.io/en/latest/
https://github.com/amoffat/sh
"""
# ===============================================================================
# Copyright (C) 2011-2023 by Andrew Moffat
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
# ===============================================================================
import asyncio
from collections import deque
from collections.abc import Mapping
import platform
from importlib import metadata
try:
__version__ = metadata.version("sh")
except metadata.PackageNotFoundError: # pragma: no cover
__version__ = "unknown"
if "windows" in platform.system().lower(): # pragma: no cover
raise ImportError(
f"sh {__version__} is currently only supported on Linux and macOS."
)
import errno
import fcntl
import gc
import getpass
import glob as glob_module
import inspect
import logging
import os
import pty
import pwd
import re
import select
import signal
import stat
import struct
import sys
import termios
import textwrap
import threading
import time
import traceback
import tty
import warnings
import weakref
from asyncio import Queue as AQueue
from contextlib import contextmanager
from functools import partial
from io import BytesIO, StringIO, UnsupportedOperation
from io import open as fdopen
from locale import getpreferredencoding
from queue import Empty, Queue
from shlex import quote as shlex_quote
from types import GeneratorType, ModuleType
from typing import Any, Dict, Type, Union
__project_url__ = "https://github.com/amoffat/sh"
TEE_STDOUT = {True, "out", 1}
TEE_STDERR = {"err", 2}
DEFAULT_ENCODING = getpreferredencoding() or "UTF-8"
IS_MACOS = platform.system() in ("AIX", "Darwin")
SH_LOGGER_NAME = __name__
# normally i would hate this idea of using a global to signify whether we are
# running tests, because it breaks the assumption that what is running in the
# tests is what will run live, but we ONLY use this in a place that has no
# serious side-effects that could change anything. as long as we do that, it
# should be ok
RUNNING_TESTS = bool(int(os.environ.get("SH_TESTS_RUNNING", "0")))
FORCE_USE_SELECT = bool(int(os.environ.get("SH_TESTS_USE_SELECT", "0")))
# a re-entrant lock for pushd. this way, multiple threads that happen to use
# pushd will all see the current working directory for the duration of the
# with-context
PUSHD_LOCK = threading.RLock()
def get_num_args(fn):
return len(inspect.getfullargspec(fn).args)
_unicode_methods = set(dir(""))
HAS_POLL = hasattr(select, "poll")
POLLER_EVENT_READ = 1
POLLER_EVENT_WRITE = 2
POLLER_EVENT_HUP = 4
POLLER_EVENT_ERROR = 8
class PollPoller:
def __init__(self):
self._poll = select.poll()
# file descriptor <-> file object bidirectional maps
self.fd_lookup = {}
self.fo_lookup = {}
def __nonzero__(self):
return len(self.fd_lookup) != 0
def __len__(self):
return len(self.fd_lookup)
def _set_fileobject(self, f):
if hasattr(f, "fileno"):
fd = f.fileno()
self.fd_lookup[fd] = f
self.fo_lookup[f] = fd
else:
self.fd_lookup[f] = f
self.fo_lookup[f] = f
def _remove_fileobject(self, f):
if hasattr(f, "fileno"):
fd = f.fileno()
del self.fd_lookup[fd]
del self.fo_lookup[f]
else:
del self.fd_lookup[f]
del self.fo_lookup[f]
def _get_file_descriptor(self, f):
return self.fo_lookup.get(f)
def _get_file_object(self, fd):
return self.fd_lookup.get(fd)
def _register(self, f, events):
# f can be a file descriptor or file object
self._set_fileobject(f)
fd = self._get_file_descriptor(f)
self._poll.register(fd, events)
def register_read(self, f):
self._register(f, select.POLLIN | select.POLLPRI)
def register_write(self, f):
self._register(f, select.POLLOUT)
def register_error(self, f):
self._register(f, select.POLLERR | select.POLLHUP | select.POLLNVAL)
def unregister(self, f):
fd = self._get_file_descriptor(f)
self._poll.unregister(fd)
self._remove_fileobject(f)
def poll(self, timeout):
if timeout is not None:
# convert from seconds to milliseconds
timeout *= 1000
changes = self._poll.poll(timeout)
results = []
for fd, events in changes:
f = self._get_file_object(fd)
if events & (select.POLLIN | select.POLLPRI):
results.append((f, POLLER_EVENT_READ))
elif events & select.POLLOUT:
results.append((f, POLLER_EVENT_WRITE))
elif events & select.POLLHUP:
results.append((f, POLLER_EVENT_HUP))
elif events & (select.POLLERR | select.POLLNVAL):
results.append((f, POLLER_EVENT_ERROR))
return results
class SelectPoller:
def __init__(self):
self.rlist = []
self.wlist = []
self.xlist = []
def __nonzero__(self):
return len(self.rlist) + len(self.wlist) + len(self.xlist) != 0
def __len__(self):
return len(self.rlist) + len(self.wlist) + len(self.xlist)
@staticmethod
def _register(f, events):
if f not in events:
events.append(f)
@staticmethod
def _unregister(f, events):
if f in events:
events.remove(f)
def register_read(self, f):
self._register(f, self.rlist)
def register_write(self, f):
self._register(f, self.wlist)
def register_error(self, f):
self._register(f, self.xlist)
def unregister(self, f):
self._unregister(f, self.rlist)
self._unregister(f, self.wlist)
self._unregister(f, self.xlist)
def poll(self, timeout):
_in, _out, _err = select.select(self.rlist, self.wlist, self.xlist, timeout)
results = []
for f in _in:
results.append((f, POLLER_EVENT_READ))
for f in _out:
results.append((f, POLLER_EVENT_WRITE))
for f in _err:
results.append((f, POLLER_EVENT_ERROR))
return results
# here we use an use a poller interface that transparently selects the most
# capable poller (out of either select.select or select.poll). this was added
# by zhangyafeikimi when he discovered that if the fds created internally by sh
# numbered > 1024, select.select failed (a limitation of select.select). this
# can happen if your script opens a lot of files
Poller: Union[Type[SelectPoller], Type[PollPoller]] = SelectPoller
if HAS_POLL and not FORCE_USE_SELECT:
Poller = PollPoller
class ForkException(Exception):
def __init__(self, orig_exc):
msg = f"""
Original exception:
===================
{textwrap.indent(orig_exc, " ")}
"""
Exception.__init__(self, msg)
class ErrorReturnCodeMeta(type):
"""a metaclass which provides the ability for an ErrorReturnCode (or
derived) instance, imported from one sh module, to be considered the
subclass of ErrorReturnCode from another module. this is mostly necessary
in the tests, where we do assertRaises, but the ErrorReturnCode that the
program we're testing throws may not be the same class that we pass to
assertRaises
"""
def __subclasscheck__(self, o):
other_bases = {b.__name__ for b in o.__bases__}
return self.__name__ in other_bases or o.__name__ == self.__name__
class ErrorReturnCode(Exception):
__metaclass__ = ErrorReturnCodeMeta
""" base class for all exceptions as a result of a command's exit status
being deemed an error. this base class is dynamically subclassed into
derived classes with the format: ErrorReturnCode_NNN where NNN is the exit
code number. the reason for this is it reduces boiler plate code when
testing error return codes:
try:
some_cmd()
except ErrorReturnCode_12:
print("couldn't do X")
vs:
try:
some_cmd()
except ErrorReturnCode as e:
if e.exit_code == 12:
print("couldn't do X")
it's not much of a savings, but i believe it makes the code easier to read """
truncate_cap = 750
def __reduce__(self):
return self.__class__, (self.full_cmd, self.stdout, self.stderr, self.truncate)
def __init__(self, full_cmd, stdout, stderr, truncate=True):
self.exit_code = self.exit_code # makes pylint happy
self.full_cmd = full_cmd
self.stdout = stdout
self.stderr = stderr
self.truncate = truncate
exc_stdout = self.stdout
if truncate:
exc_stdout = exc_stdout[: self.truncate_cap]
out_delta = len(self.stdout) - len(exc_stdout)
if out_delta:
exc_stdout += (f"... ({out_delta} more, please see e.stdout)").encode()
exc_stderr = self.stderr
if truncate:
exc_stderr = exc_stderr[: self.truncate_cap]
err_delta = len(self.stderr) - len(exc_stderr)
if err_delta:
exc_stderr += (f"... ({err_delta} more, please see e.stderr)").encode()
msg = (
f"\n\n RAN: {self.full_cmd}"
f"\n\n STDOUT:\n{exc_stdout.decode(DEFAULT_ENCODING, 'replace')}"
f"\n\n STDERR:\n{exc_stderr.decode(DEFAULT_ENCODING, 'replace')}"
)
super().__init__(msg)
class SignalException(ErrorReturnCode):
pass
class TimeoutException(Exception):
"""the exception thrown when a command is killed because a specified
timeout (via _timeout or .wait(timeout)) was hit"""
def __init__(self, exit_code, full_cmd):
self.exit_code = exit_code
self.full_cmd = full_cmd
super(Exception, self).__init__()
SIGNALS_THAT_SHOULD_THROW_EXCEPTION = {
signal.SIGABRT,
signal.SIGBUS,
signal.SIGFPE,
signal.SIGILL,
signal.SIGINT,
signal.SIGKILL,
signal.SIGPIPE,
signal.SIGQUIT,
signal.SIGSEGV,
signal.SIGTERM,
signal.SIGSYS,
}
# we subclass AttributeError because:
# https://github.com/ipython/ipython/issues/2577
# https://github.com/amoffat/sh/issues/97#issuecomment-10610629
class CommandNotFound(AttributeError):
pass
rc_exc_regex = re.compile(r"(ErrorReturnCode|SignalException)_((\d+)|SIG[a-zA-Z]+)")
rc_exc_cache: Dict[str, Type[ErrorReturnCode]] = {}
SIGNAL_MAPPING = {
v: k for k, v in signal.__dict__.items() if re.match(r"SIG[a-zA-Z]+", k)
}
def get_exc_from_name(name):
"""takes an exception name, like:
ErrorReturnCode_1
SignalException_9
SignalException_SIGHUP
and returns the corresponding exception. this is primarily used for
importing exceptions from sh into user code, for instance, to capture those
exceptions"""
exc = None
try:
return rc_exc_cache[name]
except KeyError:
m = rc_exc_regex.match(name)
if m:
base = m.group(1)
rc_or_sig_name = m.group(2)
if base == "SignalException":
try:
rc = -int(rc_or_sig_name)
except ValueError:
rc = -getattr(signal, rc_or_sig_name)
else:
rc = int(rc_or_sig_name)
exc = get_rc_exc(rc)
return exc
def get_rc_exc(rc):
"""takes a exit code or negative signal number and produces an exception
that corresponds to that return code. positive return codes yield
ErrorReturnCode exception, negative return codes yield SignalException
we also cache the generated exception so that only one signal of that type
exists, preserving identity"""
try:
return rc_exc_cache[rc]
except KeyError:
pass
if rc >= 0:
name = f"ErrorReturnCode_{rc}"
base = ErrorReturnCode
else:
name = f"SignalException_{SIGNAL_MAPPING[abs(rc)]}"
base = SignalException
exc = ErrorReturnCodeMeta(name, (base,), {"exit_code": rc})
rc_exc_cache[rc] = exc
return exc
# we monkey patch glob. i'm normally generally against monkey patching, but i
# decided to do this really un-intrusive patch because we need a way to detect
# if a list that we pass into an sh command was generated from glob. the reason
# being that glob returns an empty list if a pattern is not found, and so
# commands will treat the empty list as no arguments, which can be a problem,
# ie:
#
# ls(glob("*.ojfawe"))
#
# ^ will show the contents of your home directory, because it's essentially
# running ls([]) which, as a process, is just "ls".
#
# so we subclass list and monkey patch the glob function. nobody should be the
# wiser, but we'll have results that we can make some determinations on
_old_glob = glob_module.glob
class GlobResults(list):
def __init__(self, path, results):
self.path = path
list.__init__(self, results)
def glob(path, *args, **kwargs):
expanded = GlobResults(path, _old_glob(path, *args, **kwargs))
return expanded
glob_module.glob = glob # type: ignore
def canonicalize(path):
return os.path.abspath(os.path.expanduser(path))
def _which(program, paths=None):
"""takes a program name or full path, plus an optional collection of search
paths, and returns the full path of the requested executable. if paths is
specified, it is the entire list of search paths, and the PATH env is not
used at all. otherwise, PATH env is used to look for the program"""
def is_exe(file_path):
return (
os.path.exists(file_path)
and os.access(file_path, os.X_OK)
and os.path.isfile(os.path.realpath(file_path))
)
found_path = None
fpath, fname = os.path.split(program)
# if there's a path component, then we've specified a path to the program,
# and we should just test if that program is executable. if it is, return
if fpath:
program = canonicalize(program)
if is_exe(program):
found_path = program
# otherwise, we've just passed in the program name, and we need to search
# the paths to find where it actually lives
else:
paths_to_search = []
if isinstance(paths, (tuple, list)):
paths_to_search.extend(paths)
else:
env_paths = os.environ.get("PATH", "").split(os.pathsep)
paths_to_search.extend(env_paths)
for path in paths_to_search:
exe_file = os.path.join(canonicalize(path), program)
if is_exe(exe_file):
found_path = exe_file
break
return found_path
def resolve_command_path(program):
path = _which(program)
if not path:
# our actual command might have a dash in it, but we can't call
# that from python (we have to use underscores), so we'll check
# if a dash version of our underscore command exists and use that
# if it does
if "_" in program:
path = _which(program.replace("_", "-"))
if not path:
return None
return path
def resolve_command(name, command_cls, baked_args=None):
path = resolve_command_path(name)
cmd = None
if path:
cmd = command_cls(path)
if baked_args:
cmd = cmd.bake(**baked_args)
return cmd
class Logger:
"""provides a memory-inexpensive logger. a gotcha about python's builtin
logger is that logger objects are never garbage collected. if you create a
thousand loggers with unique names, they'll sit there in memory until your
script is done. with sh, it's easy to create loggers with unique names if
we want our loggers to include our command arguments. for example, these
are all unique loggers:
ls -l
ls -l /tmp
ls /tmp
so instead of creating unique loggers, and without sacrificing logging
output, we use this class, which maintains as part of its state, the logging
"context", which will be the very unique name. this allows us to get a
logger with a very general name, eg: "command", and have a unique name
appended to it via the context, eg: "ls -l /tmp" """
def __init__(self, name, context=None):
self.name = name
self.log = logging.getLogger(f"{SH_LOGGER_NAME}.{name}")
self.context = self.sanitize_context(context)
def _format_msg(self, msg, *a):
if self.context:
msg = f"{self.context}: {msg}"
return msg % a
@staticmethod
def sanitize_context(context):
if context:
context = context.replace("%", "%%")
return context or ""
def get_child(self, name, context):
new_name = self.name + "." + name
new_context = self.context + "." + context
return Logger(new_name, new_context)
def info(self, msg, *a):
self.log.info(self._format_msg(msg, *a))
def debug(self, msg, *a):
self.log.debug(self._format_msg(msg, *a))
def error(self, msg, *a):
self.log.error(self._format_msg(msg, *a))
def exception(self, msg, *a):
self.log.exception(self._format_msg(msg, *a))
def default_logger_str(cmd, call_args, pid=None):
if pid:
s = f"<Command {cmd!r}, pid {pid}>"
else:
s = f"<Command {cmd!r}>"
return s
class RunningCommand:
"""this represents an executing Command object. it is returned as the
result of __call__() being executed on a Command instance. this creates a
reference to a OProc instance, which is a low-level wrapper around the
process that was exec'd
this is the class that gets manipulated the most by user code, and so it
implements various convenience methods and logical mechanisms for the
underlying process. for example, if a user tries to access a
backgrounded-process's stdout/err, the RunningCommand object is smart enough
to know to wait() on the process to finish first. and when the process
finishes, RunningCommand is smart enough to translate exit codes to
exceptions."""
# these are attributes that we allow to pass through to OProc
_OProc_attr_allowlist = {
"signal",
"terminate",
"kill",
"kill_group",
"signal_group",
"pid",
"sid",
"pgid",
"ctty",
"input_thread_exc",
"output_thread_exc",
"bg_thread_exc",
}
def __init__(self, cmd, call_args, stdin, stdout, stderr):
# self.ran is used for auditing what actually ran. for example, in
# exceptions, or if you just want to know what was ran after the
# command ran
self.ran = " ".join([shlex_quote(str(arg)) for arg in cmd])
self.call_args = call_args
self.cmd = cmd
self.process = None
self._waited_until_completion = False
should_wait = True
spawn_process = True
# if we're using an async for loop on this object, we need to put the underlying
# iterable in no-block mode. however, we will only know if we're using an async
# for loop after this object is constructed. so we'll set it to False now, but
# then later set it to True if we need it
self._force_noblock_iter = False
# this event is used when we want to `await` a RunningCommand. see how it gets
# used in self.__await__
try:
asyncio.get_running_loop()
except RuntimeError:
self.aio_output_complete = None
else:
self.aio_output_complete = asyncio.Event()
# this is used to track if we've already raised StopIteration, and if we
# have, raise it immediately again if the user tries to call next() on
# us. https://github.com/amoffat/sh/issues/273
self._stopped_iteration = False
# with contexts shouldn't run at all yet, they prepend
# to every command in the context
if call_args["with"]:
spawn_process = False
get_prepend_stack().append(self)
if call_args["piped"] or call_args["iter"] or call_args["iter_noblock"]:
should_wait = False
if call_args["async"]:
should_wait = False
# we're running in the background, return self and let us lazily
# evaluate
if call_args["bg"]:
should_wait = False
# redirection
if call_args["err_to_out"]:
stderr = OProc.STDOUT
done_callback = call_args["done"]
if done_callback:
call_args["done"] = partial(done_callback, self)
# set up which stream should write to the pipe
# TODO, make pipe None by default and limit the size of the Queue
# in oproc.OProc
pipe = OProc.STDOUT
if call_args["iter"] == "out" or call_args["iter"] is True:
pipe = OProc.STDOUT
elif call_args["iter"] == "err":
pipe = OProc.STDERR
if call_args["iter_noblock"] == "out" or call_args["iter_noblock"] is True:
pipe = OProc.STDOUT
elif call_args["iter_noblock"] == "err":
pipe = OProc.STDERR
# there's currently only one case where we wouldn't spawn a child
# process, and that's if we're using a with-context with our command
self._spawned_and_waited = False
if spawn_process:
log_str_factory = call_args["log_msg"] or default_logger_str
logger_str = log_str_factory(self.ran, call_args)
self.log = Logger("command", logger_str)
self.log.debug("starting process")
if should_wait:
self._spawned_and_waited = True
# this lock is needed because of a race condition where a background
# thread, created in the OProc constructor, may try to access
# self.process, but it has not been assigned yet
process_assign_lock = threading.Lock()
with process_assign_lock:
self.process = OProc(
self,
self.log,
cmd,
stdin,
stdout,
stderr,
self.call_args,
pipe,
process_assign_lock,
)
logger_str = log_str_factory(self.ran, call_args, self.process.pid)
self.log.context = self.log.sanitize_context(logger_str)
self.log.info("process started")
if should_wait:
self.wait()
def wait(self, timeout=None):
"""waits for the running command to finish. this is called on all
running commands, eventually, except for ones that run in the background
if timeout is a number, it is the number of seconds to wait for the process to
resolve. otherwise block on wait.
this function can raise a TimeoutException, either because of a `_timeout` on
the command itself as it was
launched, or because of a timeout passed into this method.
"""
if not self._waited_until_completion:
# if we've been given a timeout, we need to poll is_alive()
if timeout is not None:
waited_for = 0
sleep_amt = 0.1
alive = False
exit_code = None
if timeout < 0:
raise RuntimeError("timeout cannot be negative")
# while we still have time to wait, run this loop
# notice that alive and exit_code are only defined in this loop, but
# the loop is also guaranteed to run, defining them, given the
# constraints that timeout is non-negative
while waited_for <= timeout:
alive, exit_code = self.process.is_alive()
# if we're alive, we need to wait some more, but let's sleep
# before we poll again
if alive:
time.sleep(sleep_amt)
waited_for += sleep_amt
# but if we're not alive, we're done waiting
else:
break
# if we've made it this far, and we're still alive, then it means we
# timed out waiting
if alive:
raise TimeoutException(None, self.ran)
# if we didn't time out, we fall through and let the rest of the code
# handle exit_code. notice that we set _waited_until_completion here,
# only if we didn't time out. this allows us to re-wait again on
# timeout, if we catch the TimeoutException in the parent frame
self._waited_until_completion = True
else:
exit_code = self.process.wait()
self._waited_until_completion = True
if self.process.timed_out:
# if we timed out, our exit code represents a signal, which is
# negative, so let's make it positive to store in our
# TimeoutException
raise TimeoutException(-exit_code, self.ran)
else:
self.handle_command_exit_code(exit_code)
# if an iterable command is using an instance of OProc for its stdin,
# wait on it. the process is probably set to "piped", which means it
# won't be waited on, which means exceptions won't propagate up to the
# main thread. this allows them to bubble up
if self.process._stdin_process:
self.process._stdin_process.command.wait()
self.log.debug("process completed")
return self
def is_alive(self):
"""returns whether or not we're still alive. this call has side-effects on
OProc"""
return self.process.is_alive()[0]
def handle_command_exit_code(self, code):
"""here we determine if we had an exception, or an error code that we
weren't expecting to see. if we did, we create and raise an exception
"""
ca = self.call_args
exc_class = get_exc_exit_code_would_raise(code, ca["ok_code"], ca["piped"])
if exc_class:
exc = exc_class(
self.ran, self.process.stdout, self.process.stderr, ca["truncate_exc"]
)
raise exc
@property
def stdout(self):
self.wait()
return self.process.stdout
@property
def stderr(self):
self.wait()
return self.process.stderr
@property
def exit_code(self):
self.wait()
return self.process.exit_code
def __len__(self):
return len(str(self))
def __enter__(self):
"""we don't actually do anything here because anything that should have
been done would have been done in the Command.__call__ call.
essentially all that has to happen is the command be pushed on the
prepend stack."""
pass
def __iter__(self):
return self
def __next__(self):
"""allow us to iterate over the output of our command"""
if self._stopped_iteration:
raise StopIteration()
pq = self.process._pipe_queue
# the idea with this is, if we're using regular `_iter` (non-asyncio), then we
# want to have blocking be True when we read from the pipe queue, so our cpu
# doesn't spin too fast. however, if we *are* using asyncio (an async for loop),
# then we want non-blocking pipe queue reads, because we'll do an asyncio.sleep,
# in the coroutine that is doing the iteration, this way coroutines have better
# yielding (see queue_connector in __aiter__).
block_pq_read = not self._force_noblock_iter
# we do this because if get blocks, we can't catch a KeyboardInterrupt
# so the slight timeout allows for that.
while True:
try:
chunk = pq.get(block_pq_read, self.call_args["iter_poll_time"])
except Empty:
if self.call_args["iter_noblock"] or self._force_noblock_iter:
return errno.EWOULDBLOCK
else:
if chunk is None:
self.wait()
self._stopped_iteration = True
raise StopIteration()
try:
return chunk.decode(
self.call_args["encoding"], self.call_args["decode_errors"]
)
except UnicodeDecodeError:
return chunk
def __await__(self):
async def wait_for_completion():
await self.aio_output_complete.wait()
return str(self)
return wait_for_completion().__await__()
def __aiter__(self):
# maxsize is critical to making sure our queue_connector function below yields
# when it awaits _aio_queue.put(chunk). if we didn't have a maxsize, our loop
# would happily iterate through `chunk in self` and put onto the queue without
# any blocking, and therefore no yielding, which would prevent other coroutines
# from running.
self._aio_queue = AQueue(maxsize=1)
self._force_noblock_iter = True
# the sole purpose of this coroutine is to connect our pipe_queue (which is
# being populated by a thread) to an asyncio-friendly queue. then, in __anext__,
# we can iterate over that asyncio queue.
async def queue_connector():
try:
# this will spin as fast as possible if there's no data to read,
# thanks to self._force_noblock_iter. so we sleep below.
for chunk in self:
if chunk == errno.EWOULDBLOCK:
# let us have better coroutine yielding.
await asyncio.sleep(0.01)
else:
await self._aio_queue.put(chunk)
finally:
await self._aio_queue.put(None)
task = asyncio.create_task(queue_connector())
self._aio_task = task
return self
async def __anext__(self):
chunk = await self._aio_queue.get()
if chunk is not None:
return chunk
else:
exc = self._aio_task.exception()
if exc is not None:
raise exc
raise StopAsyncIteration
def __exit__(self, exc_type, exc_val, exc_tb):
if self.call_args["with"] and get_prepend_stack():
get_prepend_stack().pop()
def __str__(self):
if self.process and self.stdout:
return self.stdout.decode(
self.call_args["encoding"], self.call_args["decode_errors"]
)
return ""
def __eq__(self, other):
return id(self) == id(other)
def __contains__(self, item):
return item in str(self)
def __getattr__(self, p):
# let these three attributes pass through to the OProc object
if p in self._OProc_attr_allowlist:
if self.process:
return getattr(self.process, p)
else:
raise AttributeError
# see if strings have what we're looking for
if p in _unicode_methods:
return getattr(str(self), p)
raise AttributeError
def __repr__(self):
try:
return str(self)
except UnicodeDecodeError:
if self.process:
if self.stdout:
return repr(self.stdout)
return repr("")
def __long__(self):
return int(str(self).strip())
def __float__(self):
return float(str(self).strip())
def __int__(self):
return int(str(self).strip())
def output_redirect_is_filename(out):
return isinstance(out, str) or hasattr(out, "__fspath__")
def get_prepend_stack():
tl = Command.thread_local
if not hasattr(tl, "_prepend_stack"):
tl._prepend_stack = []
return tl._prepend_stack
def special_kwarg_validator(passed_kwargs, merged_kwargs, invalid_list):
s1 = set(passed_kwargs.keys())
invalid_args = []
for elem in invalid_list: