Skip to content

Commit

Permalink
add login sessions
Browse files Browse the repository at this point in the history
  • Loading branch information
9001 committed Sep 9, 2024
1 parent 6eee601 commit b540517
Show file tree
Hide file tree
Showing 10 changed files with 206 additions and 19 deletions.
4 changes: 4 additions & 0 deletions copyparty/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,13 +1067,17 @@ def add_cert(ap, cert_path):


def add_auth(ap):
ses_db = os.path.join(E.cfg, "sessions.db")
ap2 = ap.add_argument_group('IdP / identity provider / user authentication options')
ap2.add_argument("--idp-h-usr", metavar="HN", type=u, default="", help="bypass the copyparty authentication checks and assume the request-header \033[33mHN\033[0m contains the username of the requesting user (for use with authentik/oauth/...)\n\033[1;31mWARNING:\033[0m if you enable this, make sure clients are unable to specify this header themselves; must be washed away and replaced by a reverse-proxy")
ap2.add_argument("--idp-h-grp", metavar="HN", type=u, default="", help="assume the request-header \033[33mHN\033[0m contains the groupname of the requesting user; can be referenced in config files for group-based access control")
ap2.add_argument("--idp-h-key", metavar="HN", type=u, default="", help="optional but recommended safeguard; your reverse-proxy will insert a secret header named \033[33mHN\033[0m into all requests, and the other IdP headers will be ignored if this header is not present")
ap2.add_argument("--idp-gsep", metavar="RE", type=u, default="|:;+,", help="if there are multiple groups in \033[33m--idp-h-grp\033[0m, they are separated by one of the characters in \033[33mRE\033[0m")
ap2.add_argument("--no-bauth", action="store_true", help="disable basic-authentication support; do not accept passwords from the 'Authenticate' header at all. NOTE: This breaks support for the android app")
ap2.add_argument("--bauth-last", action="store_true", help="keeps basic-authentication enabled, but only as a last-resort; if a cookie is also provided then the cookie wins")
ap2.add_argument("--ses-db", metavar="PATH", type=u, default=ses_db, help="where to store the sessions database (if you run multiple copyparty instances, make sure they use different DBs)")
ap2.add_argument("--ses-len", metavar="CHARS", type=int, default=20, help="session key length; default is 120 bits ((20//4)*4*6)")
ap2.add_argument("--no-ses", action="store_true", help="disable sessions; use plaintext passwords in cookies")


def add_chpw(ap):
Expand Down
77 changes: 74 additions & 3 deletions copyparty/authsrv.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,8 +840,10 @@ def __init__(

# fwd-decl
self.vfs = VFS(log_func, "", "", AXS(), {})
self.acct: dict[str, str] = {}
self.iacct: dict[str, str] = {}
self.acct: dict[str, str] = {} # uname->pw
self.iacct: dict[str, str] = {} # pw->uname
self.ases: dict[str, str] = {} # uname->session
self.sesa: dict[str, str] = {} # session->uname
self.defpw: dict[str, str] = {}
self.grps: dict[str, list[str]] = {}
self.re_pwd: Optional[re.Pattern] = None
Expand Down Expand Up @@ -2181,8 +2183,11 @@ def _reload(self, verbosity: int = 9) -> None:
self.grps = grps
self.iacct = {v: k for k, v in acct.items()}

self.load_sessions()

self.re_pwd = None
pwds = [re.escape(x) for x in self.iacct.keys()]
pwds.extend(list(self.sesa))
if pwds:
if self.ah.on:
zs = r"(\[H\] pw:.*|[?&]pw=)([^&]+)"
Expand Down Expand Up @@ -2257,6 +2262,72 @@ def _reload(self, verbosity: int = 9) -> None:
cur.close()
db.close()

def load_sessions(self, quiet=False) -> None:
# mutex me
if self.args.no_ses:
self.ases = {}
self.sesa = {}
return

import sqlite3

ases = {}
blen = (self.args.ses_len // 4) * 4 # 3 bytes in 4 chars
blen = (blen * 3) // 4 # bytes needed for ses_len chars

db = sqlite3.connect(self.args.ses_db)
cur = db.cursor()

for uname, sid in cur.execute("select un, si from us"):
if uname in self.acct:
ases[uname] = sid

n = []
q = "insert into us values (?,?,?)"
for uname in self.acct:
if uname not in ases:
sid = ub64enc(os.urandom(blen)).decode("utf-8")
cur.execute(q, (uname, sid, int(time.time())))
ases[uname] = sid
n.append(uname)

if n:
db.commit()

cur.close()
db.close()

self.ases = ases
self.sesa = {v: k for k, v in ases.items()}
if n and not quiet:
t = ", ".join(n[:3])
if len(n) > 3:
t += "..."
self.log("added %d new sessions (%s)" % (len(n), t))

def forget_session(self, broker: Optional["BrokerCli"], uname: str) -> None:
with self.mutex:
self._forget_session(uname)

if broker:
broker.ask("_reload_sessions").get()

def _forget_session(self, uname: str) -> None:
if self.args.no_ses:
return

import sqlite3

db = sqlite3.connect(self.args.ses_db)
cur = db.cursor()
cur.execute("delete from us where un = ?", (uname,))
db.commit()
cur.close()
db.close()

self.sesa.pop(self.ases.get(uname, ""), "")
self.ases.pop(uname, "")

def chpw(self, broker: Optional["BrokerCli"], uname, pw) -> tuple[bool, str]:
if not self.args.chpw:
return False, "feature disabled in server config"
Expand All @@ -2276,7 +2347,7 @@ def chpw(self, broker: Optional["BrokerCli"], uname, pw) -> tuple[bool, str]:
if hpw == self.acct[uname]:
return False, "that's already your password my dude"

if hpw in self.iacct:
if hpw in self.iacct or hpw in self.sesa:
return False, "password is taken"

with self.mutex:
Expand Down
4 changes: 4 additions & 0 deletions copyparty/broker_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ def reload(self) -> None:
for _, proc in enumerate(self.procs):
proc.q_pend.put((0, "reload", []))

def reload_sessions(self) -> None:
for _, proc in enumerate(self.procs):
proc.q_pend.put((0, "reload_sessions", []))

def collector(self, proc: MProcess) -> None:
"""receive message from hub in other process"""
while True:
Expand Down
4 changes: 4 additions & 0 deletions copyparty/broker_mpw.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ def main(self) -> None:
self.asrv.reload()
self.logw("mpw.asrv reloaded")

elif dest == "reload_sessions":
with self.asrv.mutex:
self.asrv.load_sessions()

elif dest == "listen":
self.httpsrv.listen(args[0], args[1])

Expand Down
1 change: 1 addition & 0 deletions copyparty/broker_thr.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(self, hub: "SvcHub") -> None:
self.iphash = HMaccas(os.path.join(self.args.E.cfg, "iphash"), 8)
self.httpsrv = HttpSrv(self, None)
self.reload = self.noop
self.reload_sessions = self.noop

def shutdown(self) -> None:
# self.log("broker", "shutting down")
Expand Down
52 changes: 41 additions & 11 deletions copyparty/httpcli.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ def log(self, msg: str, c: Union[int, str] = 0) -> None:

def unpwd(self, m: Match[str]) -> str:
a, b, c = m.groups()
return "%s\033[7m %s \033[27m%s" % (a, self.asrv.iacct[b], c)
uname = self.asrv.iacct.get(b) or self.asrv.sesa.get(b)
return "%s\033[7m %s \033[27m%s" % (a, uname, c)

def _check_nonfatal(self, ex: Pebkac, post: bool) -> bool:
if post:
Expand Down Expand Up @@ -504,6 +505,8 @@ def run(self) -> bool:
zs = base64.b64decode(zb).decode("utf-8")
# try "pwd", "x:pwd", "pwd:x"
for bauth in [zs] + zs.split(":", 1)[::-1]:
if bauth in self.asrv.sesa:
break
hpw = self.asrv.ah.hash(bauth)
if self.asrv.iacct.get(hpw):
break
Expand Down Expand Up @@ -565,7 +568,11 @@ def run(self) -> bool:
self.uname = "*"
else:
self.pw = uparam.get("pw") or self.headers.get("pw") or bauth or cookie_pw
self.uname = self.asrv.iacct.get(self.asrv.ah.hash(self.pw)) or "*"
self.uname = (
self.asrv.sesa.get(self.pw)
or self.asrv.iacct.get(self.asrv.ah.hash(self.pw))
or "*"
)

self.rvol = self.asrv.vfs.aread[self.uname]
self.wvol = self.asrv.vfs.awrite[self.uname]
Expand Down Expand Up @@ -2088,6 +2095,9 @@ def handle_post_multipart(self) -> bool:
if act == "chpw":
return self.handle_chpw()

if act == "logout":
return self.handle_logout()

raise Pebkac(422, 'invalid action "{}"'.format(act))

def handle_zip_post(self) -> bool:
Expand Down Expand Up @@ -2409,7 +2419,8 @@ def handle_chpw(self) -> bool:
msg = "new password OK"

redir = (self.args.SRS + "?h") if ok else ""
html = self.j2s("msg", h1=msg, h2='<a href="/?h">ack</a>', redir=redir)
h2 = '<a href="' + self.args.SRS + '?h">ack</a>'
html = self.j2s("msg", h1=msg, h2=h2, redir=redir)
self.reply(html.encode("utf-8"))
return True

Expand All @@ -2422,9 +2433,8 @@ def handle_login(self) -> bool:
uhash = ""
self.parser.drop()

self.out_headerlist = [
x for x in self.out_headerlist if x[0] != "Set-Cookie" or "cppw" != x[1][:4]
]
if not pwd:
raise Pebkac(422, "password cannot be blank")

dst = self.args.SRS
if self.vpath:
Expand All @@ -2442,9 +2452,27 @@ def handle_login(self) -> bool:
self.reply(html.encode("utf-8"))
return True

def handle_logout(self) -> bool:
assert self.parser
self.parser.drop()

self.log("logout " + self.uname)
self.asrv.forget_session(self.conn.hsrv.broker, self.uname)
self.get_pwd_cookie("x")

dst = self.args.SRS + "?h"
h2 = '<a href="' + dst + '">ack</a>'
html = self.j2s("msg", h1="ok bye", h2=h2, redir=dst)
self.reply(html.encode("utf-8"))
return True

def get_pwd_cookie(self, pwd: str) -> tuple[bool, str]:
hpwd = self.asrv.ah.hash(pwd)
uname = self.asrv.iacct.get(hpwd)
uname = self.asrv.sesa.get(pwd)
if not uname:
hpwd = self.asrv.ah.hash(pwd)
uname = self.asrv.iacct.get(hpwd)
if uname:
pwd = self.asrv.ases.get(uname) or pwd
if uname:
msg = "hi " + uname
dur = int(60 * 60 * self.args.logout)
Expand All @@ -2456,8 +2484,9 @@ def get_pwd_cookie(self, pwd: str) -> tuple[bool, str]:
zb = hashlib.sha512(pwd.encode("utf-8", "replace")).digest()
logpwd = "%" + base64.b64encode(zb[:12]).decode("utf-8")

self.log("invalid password: {}".format(logpwd), 3)
self.cbonk(self.conn.hsrv.gpwd, pwd, "pw", "invalid passwords")
if pwd != "x":
self.log("invalid password: {}".format(logpwd), 3)
self.cbonk(self.conn.hsrv.gpwd, pwd, "pw", "invalid passwords")

msg = "naw dude"
pwd = "x" # nosec
Expand All @@ -2469,10 +2498,11 @@ def get_pwd_cookie(self, pwd: str) -> tuple[bool, str]:
for k in ("cppwd", "cppws") if self.is_https else ("cppwd",):
ck = gencookie(k, pwd, self.args.R, False)
self.out_headerlist.append(("Set-Cookie", ck))
self.out_headers.pop("Set-Cookie", None) # drop keepalive
else:
k = "cppws" if self.is_https else "cppwd"
ck = gencookie(k, pwd, self.args.R, self.is_https, dur, "; HttpOnly")
self.out_headerlist.append(("Set-Cookie", ck))
self.out_headers["Set-Cookie"] = ck

return dur > 0, msg

Expand Down
68 changes: 67 additions & 1 deletion copyparty/svchub.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,9 @@ def __init__(
noch.update([x for x in zsl if x])
args.chpw_no = noch

if not self.args.no_ses:
self.setup_session_db()

if args.shr:
self.setup_share_db()

Expand Down Expand Up @@ -369,6 +372,64 @@ def __init__(

self.broker = Broker(self)

def setup_session_db(self) -> None:
if not HAVE_SQLITE3:
self.args.no_ses = True
t = "WARNING: sqlite3 not available; disabling sessions, will use plaintext passwords in cookies"
self.log("root", t, 3)
return

import sqlite3

create = True
db_path = self.args.ses_db
self.log("root", "opening sessions-db %s" % (db_path,))
for n in range(2):
try:
db = sqlite3.connect(db_path)
cur = db.cursor()
try:
cur.execute("select count(*) from us").fetchone()
create = False
break
except:
pass
except Exception as ex:
if n:
raise
t = "sessions-db corrupt; deleting and recreating: %r"
self.log("root", t % (ex,), 3)
try:
cur.close() # type: ignore
except:
pass
try:
db.close() # type: ignore
except:
pass
os.unlink(db_path)

sch = [
r"create table kv (k text, v int)",
r"create table us (un text, si text, t0 int)",
# username, session-id, creation-time
r"create index us_un on us(un)",
r"create index us_si on us(si)",
r"create index us_t0 on us(t0)",
r"insert into kv values ('sver', 1)",
]

assert db # type: ignore
assert cur # type: ignore
if create:
for cmd in sch:
cur.execute(cmd)
self.log("root", "created new sessions-db")
db.commit()

cur.close()
db.close()

def setup_share_db(self) -> None:
al = self.args
if not HAVE_SQLITE3:
Expand Down Expand Up @@ -545,7 +606,7 @@ def _feature_test(self) -> None:
fng = []
t_ff = "transcode audio, create spectrograms, video thumbnails"
to_check = [
(HAVE_SQLITE3, "sqlite", "file and media indexing"),
(HAVE_SQLITE3, "sqlite", "sessions and file/media indexing"),
(HAVE_PIL, "pillow", "image thumbnails (plenty fast)"),
(HAVE_VIPS, "vips", "image thumbnails (faster, eats more ram)"),
(HAVE_WEBP, "pillow-webp", "create thumbnails as webp files"),
Expand Down Expand Up @@ -945,6 +1006,11 @@ def _reload_blocking(self, rescan_all_vols: bool = True, up2k: bool = True) -> N

self._reload(rescan_all_vols=rescan_all_vols, up2k=up2k)

def _reload_sessions(self) -> None:
with self.asrv.mutex:
self.asrv.load_sessions(True)
self.broker.reload_sessions()

def stop_thr(self) -> None:
while not self.stop_req:
with self.stop_cond:
Expand Down
Loading

0 comments on commit b540517

Please sign in to comment.