diff --git a/sqlx-core/src/postgres/options/pgpass.rs b/sqlx-core/src/postgres/options/pgpass.rs index 253bc10c82..831020f1c7 100644 --- a/sqlx-core/src/postgres/options/pgpass.rs +++ b/sqlx-core/src/postgres/options/pgpass.rs @@ -56,9 +56,29 @@ fn load_password_from_file( } } - let mut reader = BufReader::new(file); + let reader = BufReader::new(file); + load_password_from_reader(reader, host, port, username, database) +} + +fn load_password_from_reader( + mut reader: impl BufRead, + host: &str, + port: u16, + username: &str, + database: Option<&str>, +) -> Option { let mut line = String::new(); + // https://stackoverflow.com/a/55041833 + fn trim_newline(s: &mut String) { + if s.ends_with('\n') { + s.pop(); + if s.ends_with('\r') { + s.pop(); + } + } + } + while let Ok(n) = reader.read_line(&mut line) { if n == 0 { break; @@ -68,8 +88,8 @@ fn load_password_from_file( // comment, do nothing } else { // try to load password from line - let line = &line[..line.len() - 1]; // trim newline - if let Some(password) = load_password_from_line(line, host, port, username, database) { + trim_newline(&mut line); + if let Some(password) = load_password_from_line(&line, host, port, username, database) { return Some(password); } } @@ -163,7 +183,7 @@ fn find_next_field<'a>(line: &mut &'a str) -> Option> { #[cfg(test)] mod tests { - use super::{find_next_field, load_password_from_line}; + use super::{find_next_field, load_password_from_line, load_password_from_reader}; use std::borrow::Cow; #[test] @@ -263,4 +283,46 @@ mod tests { None ); } + + #[test] + fn test_load_password_from_reader() { + let file = b"\ + localhost:5432:bar:foo:baz\n\ + # mixed line endings (also a comment!)\n\ + *:5432:bar:foo:baz\r\n\ + # trailing space, comment with CRLF! \r\n\ + thishost:5432:bar:foo:baz \n\ + # malformed line \n\ + thathost:5432:foobar:foo\n\ + # missing trailing newline\n\ + localhost:5432:*:foo:baz + "; + + // normal + assert_eq!( + load_password_from_reader(&mut &file[..], "localhost", 5432, "foo", Some("bar")), + Some("baz".to_owned()) + ); + // wildcard + assert_eq!( + load_password_from_reader(&mut &file[..], "localhost", 5432, "foo", Some("foobar")), + Some("baz".to_owned()) + ); + // accept wildcard with missing db + assert_eq!( + load_password_from_reader(&mut &file[..], "localhost", 5432, "foo", None), + Some("baz".to_owned()) + ); + + // doesn't match + assert_eq!( + load_password_from_reader(&mut &file[..], "thathost", 5432, "foo", Some("foobar")), + None + ); + // malformed entry + assert_eq!( + load_password_from_reader(&mut &file[..], "thathost", 5432, "foo", Some("foobar")), + None + ); + } }