diff --git a/.swift-version b/.swift-version index d0b9234..9f55b2c 100644 --- a/.swift-version +++ b/.swift-version @@ -1 +1 @@ -DEVELOPMENT-SNAPSHOT-2016-05-09-a +3.0 diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..56fbca9 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,17 @@ +notifications: + slack: zewo:VjyVCCQvTOw9yrbzQysZezD1 +os: + - linux + - osx +language: generic +sudo: required +dist: trusty +osx_image: xcode8 +install: + - eval "$(curl -sL https://raw.githubusercontent.com/Zewo/Zewo/master/Scripts/Travis/install.sh)" +script: + - swift build + - swift build --configuration release + # tests require a live psql database + # TODO: setup psql on travis + # - swift test diff --git a/Package.swift b/Package.swift index 598b3fd..43bf7b4 100644 --- a/Package.swift +++ b/Package.swift @@ -3,7 +3,7 @@ import PackageDescription let package = Package( name: "PostgreSQL", dependencies: [ - .Package(url: "https://github.com/Zewo/CLibpq.git", majorVersion: 0, minor: 5), - .Package(url: "https://github.com/Zewo/SQL.git", majorVersion: 0, minor: 0) + .Package(url: "https://github.com/Zewo/CLibpq.git", majorVersion: 0, minor: 13), + .Package(url: "https://github.com/Zewo/SQL.git", majorVersion: 0, minor: 13) ] ) diff --git a/Sources/PostgreSQL/Connection.swift b/Sources/PostgreSQL/Connection.swift index d3296de..26b9bb6 100644 --- a/Sources/PostgreSQL/Connection.swift +++ b/Sources/PostgreSQL/Connection.swift @@ -1,13 +1,14 @@ @_exported import SQL import CLibpq +import Core + +public struct ConnectionError: Error, CustomStringConvertible { + public let description: String +} public final class Connection: ConnectionProtocol { public typealias QueryRenderer = PostgreSQL.QueryRenderer - public struct Error: ErrorProtocol, CustomStringConvertible { - public let description: String - } - public struct ConnectionInfo: ConnectionInfoProtocol { public var host: String public var port: Int @@ -17,19 +18,28 @@ public final class Connection: ConnectionProtocol { public var options: String? public var tty: String? - public init(_ uri: URI) throws { + public init?(uri: URL) { + do { + try self.init(uri) + } catch { + return nil + } + } - guard let host = uri.host, port = uri.port, databaseName = uri.path?.trim(["/"]) else { - throw Error(description: "Failed to extract host, port, database name from URI") + public init(_ uri: URL) throws { + let databaseName = uri.path.trim(["/"]) + + guard let host = uri.host, let port = uri.port else { + throw ConnectionError(description: "Failed to extract host, port, database name from URI") } self.host = host self.port = port self.databaseName = databaseName - self.username = uri.userInfo?.username - self.password = uri.userInfo?.password - + self.username = uri.user + self.password = uri.password } + public init(host: String, port: Int = 5432, databaseName: String, password: String? = nil, options: String? = nil, tty: String? = nil) { self.host = host self.port = port @@ -40,7 +50,6 @@ public final class Connection: ConnectionProtocol { } } - public enum InternalStatus { case Bad case Started @@ -87,7 +96,6 @@ public final class Connection: ConnectionProtocol { break } } - } public var logger: Logger? @@ -96,7 +104,7 @@ public final class Connection: ConnectionProtocol { public let connectionInfo: ConnectionInfo - public required init(_ info: ConnectionInfo) { + public required init(info: ConnectionInfo) { self.connectionInfo = info } @@ -124,12 +132,12 @@ public final class Connection: ConnectionProtocol { } } - public var mostRecentError: Error? { - guard let errorString = String(validatingUTF8: PQerrorMessage(connection)) where !errorString.isEmpty else { + public var mostRecentError: ConnectionError? { + guard let errorString = String(validatingUTF8: PQerrorMessage(connection)), !errorString.isEmpty else { return nil } - return Error(description: errorString) + return ConnectionError(description: errorString) } public func close() { @@ -138,17 +146,18 @@ public final class Connection: ConnectionProtocol { } public func createSavePointNamed(_ name: String) throws { - try execute("SAVEPOINT \(name)") + try execute("SAVEPOINT \(name)", parameters: nil) } public func rollbackToSavePointNamed(_ name: String) throws { - try execute("ROLLBACK TO SAVEPOINT \(name)") + try execute("ROLLBACK TO SAVEPOINT \(name)", parameters: nil) } public func releaseSavePointNamed(_ name: String) throws { - try execute("RELEASE SAVEPOINT \(name)") + try execute("RELEASE SAVEPOINT \(name)", parameters: nil) } + @discardableResult public func execute(_ statement: String, parameters: [Value?]?) throws -> Result { var statement = statement.sqlStringWithEscapedPlaceholdersUsingPrefix("$") { @@ -159,13 +168,15 @@ public final class Connection: ConnectionProtocol { guard let parameters = parameters else { guard let resultPointer = PQexec(connection, statement) else { - throw mostRecentError ?? Error(description: "Empty result") + throw mostRecentError ?? ConnectionError(description: "Empty result") } return try Result(resultPointer) } - var parameterData = [[UInt8]?]() + var parameterData = [UnsafePointer?]() + var deallocators = [() -> ()]() + defer { deallocators.forEach { $0() } } for parameter in parameters { @@ -174,28 +185,41 @@ public final class Connection: ConnectionProtocol { continue } + let data: AnyCollection switch value { - case .data(let data): - parameterData.append(Array(data)) - break + case .data(let value): + data = AnyCollection(value.map { Int8($0) }) + case .string(let string): - parameterData.append(Array(string.utf8) + [0]) - break + data = AnyCollection(string.utf8CString) + } + + let pointer = UnsafeMutablePointer.allocate(capacity: Int(data.count)) + deallocators.append { + pointer.deallocate(capacity: Int(data.count)) } - } + for (index, byte) in data.enumerated() { + pointer[index] = byte + } - guard let result:OpaquePointer = PQexecParams( - self.connection, - statement, - Int32(parameters.count), - nil, - parameterData.map { UnsafePointer($0) }, - nil, - nil, - 0 - ) else { - throw mostRecentError ?? Error(description: "Empty result") + parameterData.append(pointer) + } + + let result: OpaquePointer = try parameterData.withUnsafeBufferPointer { buffer in + guard let result = PQexecParams( + self.connection, + statement, + Int32(parameters.count), + nil, + buffer.isEmpty ? nil : buffer.baseAddress, + nil, + nil, + 0 + ) else { + throw mostRecentError ?? ConnectionError(description: "Empty result") + } + return result } return try Result(result) diff --git a/Sources/PostgreSQL/Result.swift b/Sources/PostgreSQL/Result.swift index c386712..9b100f1 100644 --- a/Sources/PostgreSQL/Result.swift +++ b/Sources/PostgreSQL/Result.swift @@ -1,11 +1,12 @@ import CLibpq +import Core @_exported import SQL -public class Result: SQL.ResultProtocol { +public enum ResultError: Error { + case badStatus(Result.Status, String) +} - public enum Error: ErrorProtocol { - case BadStatus(Status, String) - } +public class Result: SQL.ResultProtocol { public enum Status: Int, ResultStatus { case EmptyQuery @@ -67,7 +68,7 @@ public class Result: SQL.ResultProtocol { self.resultPointer = resultPointer guard status.successful else { - throw Error.BadStatus(status, String(validatingUTF8: PQresultErrorMessage(resultPointer)) ?? "No error message") + throw ResultError.badStatus(status, String(validatingUTF8: PQresultErrorMessage(resultPointer)) ?? "No error message") } } @@ -81,12 +82,14 @@ public class Result: SQL.ResultProtocol { public func data(atRow rowIndex: Int, forFieldIndex fieldIndex: Int) -> Data? { - let start = PQgetvalue(resultPointer, Int32(rowIndex), Int32(fieldIndex)) let count = PQgetlength(resultPointer, Int32(rowIndex), Int32(fieldIndex)) + guard count > 0, let start = PQgetvalue(resultPointer, Int32(rowIndex), Int32(fieldIndex)) else { + return Data() + } - let buffer = UnsafeBufferPointer(start: UnsafePointer(start), count: Int(count)) - - return Data(Array(buffer)) + return start.withMemoryRebound(to: UInt8.self, capacity: Int(count)) { start in + return Data(Array(UnsafeBufferPointer(start: start, count: Int(count)))) + } } public var count: Int { @@ -123,6 +126,5 @@ public class Result: SQL.ResultProtocol { } return result - }() } diff --git a/Tests/LinuxMain.swift b/Tests/LinuxMain.swift index e229a2c..40d7c9f 100644 --- a/Tests/LinuxMain.swift +++ b/Tests/LinuxMain.swift @@ -2,4 +2,5 @@ import XCTest @testable import PostgreSQLTests XCTMain([ + testCase(PostgreSQLTests.allTests), ]) diff --git a/Tests/PostgreSQL/PostgreSQLTests.swift b/Tests/PostgreSQLTests/PostgreSQLTests.swift similarity index 59% rename from Tests/PostgreSQL/PostgreSQLTests.swift rename to Tests/PostgreSQLTests/PostgreSQLTests.swift index b821909..f89a2be 100644 --- a/Tests/PostgreSQL/PostgreSQLTests.swift +++ b/Tests/PostgreSQLTests/PostgreSQLTests.swift @@ -8,16 +8,16 @@ import XCTest @testable import PostgreSQL - +import Core public final class StandardOutputAppender: Appender { public var name: String = "Standard Output Appender" public var closed: Bool = false - public var level: Log.Level = .all - + public var levels: Logger.Level = .all + public init () {} - - public func append(_ event: LoggingEvent) { + + public func append(event: Logger.Event) { var logMessage = "\(event.message) \n" let file = event.locationInfo.file logMessage += "In File: \(file)" @@ -46,25 +46,23 @@ struct Artist { extension Artist: ModelProtocol { typealias PrimaryKey = Int - - enum Field: String { - case id = "id" - case name = "name" - case genre = "genre" + + enum Field: String, ModelField { + static let primaryKey = Field.id + static let tableName = "artists" + case id + case name + case genre } - - static let tableName: String = "artists" - static var primaryKeyField: Field = .id - - + func serialize() -> [Field: ValueConvertible?] { return [.name: name, .genre: genre] } - + init(row: T) throws { try self.init( - name: row.value(Artist.field(.name)), - genre: row.value(Artist.field(.genre)) + name: row.value(Field.name.qualifiedField), + genre: row.value(Field.genre.qualifiedField) ) } } @@ -72,7 +70,7 @@ extension Artist: ModelProtocol { final class Album { var name: String var artistId: Artist.PrimaryKey - + init(name: String, artistId: Artist.PrimaryKey) { self.name = name self.artistId = artistId @@ -81,125 +79,110 @@ final class Album { extension Album: ModelProtocol { typealias PrimaryKey = Int - - enum Field: String { + + enum Field: String, ModelField { + static let primaryKey = Field.id + static let tableName = "albums" case id = "id" case name = "name" case artistId = "artist_id" } - - static let tableName: String = "artists" - static let primaryKeyField: Field = .id - + func serialize() -> [Field: ValueConvertible?] { return [ .name: name, .artistId: artistId ] } - + convenience init(row: T) throws { try self.init( - name: row.value(Album.field(.name)), - artistId: row.value(Album.field(.artistId)) + name: row.value(Field.name.qualifiedField), + artistId: row.value(Field.artistId.qualifiedField) ) - - } + } } // MARK: - Tests public class PostgreSQLTests: XCTestCase { - - let connection = try! PostgreSQL.Connection(URI("postgres://localhost:5432/swift_test")) - + let connection = try! PostgreSQL.Connection(info: .init(URL(string: "postgres://localhost:5432/swift_test")!)) let logger = Logger(name: "SQL Logger", appenders: [StandardOutputAppender()]) - - override func setUp() { + override public func setUp() { super.setUp() - + do { try connection.open() try connection.execute("DROP TABLE IF EXISTS albums") try connection.execute("DROP TABLE IF EXISTS artists") try connection.execute("CREATE TABLE IF NOT EXISTS artists(id SERIAL PRIMARY KEY, genre VARCHAR(50), name VARCHAR(255))") try connection.execute("CREATE TABLE IF NOT EXISTS albums(id SERIAL PRIMARY KEY, name VARCHAR(255), artist_id int references artists(id))") - + try connection.execute("INSERT INTO artists (name, genre) VALUES('Josh Rouse', 'Country')") - + connection.logger = logger - - } catch { XCTFail("Connection error: \(error)") } } - + func testSimpleRawQueries() throws { try connection.execute("SELECT * FROM artists") let result = try connection.execute("SELECT * FROM artists WHERE name = %@", parameters: "Josh Rouse") XCTAssert(try result.first?.value("name") == "Josh Rouse") } - - func testBulk() { - do { - for i in 0..<300 { - var entity = Entity(model: Artist(name: "NAME \(i)", genre: "GENRE \(i)")) - try entity.save(connection: connection) - } - - measure { - do { - let result = try Entity.fetchAll(connection: self.connection) - - for artist in result { - print(artist.model.genre) - } - } - catch { - XCTFail("\(error)") + + func testBulk() throws { + for i in 0..<300 { + let entity = Entity(model: Artist(name: "NAME \(i)", genre: "GENRE \(i)")) + _ = try entity.save(connection: connection) + } + + measure { + do { + let result = try Entity.fetchAll(connection: self.connection) + + for artist in result { + print(artist.model.genre) } } - } - catch { - print("ERROR") - XCTFail("\(error)") + catch { + XCTFail("\(error)") + } } } - + func testRockArtists() throws { - - do { - let artists = try Entity.fetchAll(connection: connection) - - try Entity.fetchAll(connection: connection) - - try connection.begin() - - for var artist in artists { - artist.model.genre = "Rock & Roll" - try artist.save(connection: connection) - } - - try connection.commit() - } - catch { - print(error) - throw error + let artists = try Entity.fetchAll(connection: connection) + + _ = try Entity.fetchAll(connection: connection) + + try connection.begin() + + for var artist in artists { + artist.model.genre = "Rock & Roll" + _ = try artist.save(connection: connection) } - - + + try connection.commit() } - - - - override func tearDown() { + + override public func tearDown() { // Put teardown code here. This method is called after the invocation of each test method in the class. super.tearDown() - + connection.close() - + } +} + +extension PostgreSQLTests { + public static var allTests: [(String, (PostgreSQLTests) -> () throws -> Void)] { + return [ + ("testBulk", testBulk), + ("testSimpleRawQueries", testSimpleRawQueries), + ("testRockArtists", testRockArtists), + ] } }