diff --git a/src/app/rosetta/lib/rosetta.ml b/src/app/rosetta/lib/rosetta.ml index 7c6bd056326..57e02059913 100644 --- a/src/app/rosetta/lib/rosetta.ml +++ b/src/app/rosetta/lib/rosetta.ml @@ -17,16 +17,30 @@ let router ~graphql_uri ~db ~logger route body = | _ -> Deferred.return (Error `Page_not_found) -let server_handler ~db ~graphql_uri ~logger ~body _sock req = +let server_handler ~pool ~graphql_uri ~logger ~body _sock req = let uri = Cohttp_async.Request.uri req in let%bind body = Cohttp_async.Body.to_string body in let route = List.tl_exn (String.split ~on:'/' (Uri.path uri)) in let%bind result = + let with_db f = + Caqti_async.Pool.use (fun db -> f ~db) pool + |> Deferred.Result.map_error ~f:(function + | `App e -> + `App e + | `Page_not_found -> + `Page_not_found + | `Connect_failed _e -> + `App (Errors.create (`Sql "Connect failed")) + | `Connect_rejected _e -> + `App (Errors.create (`Sql "Connect rejected")) + | `Post_connect _e -> + `App (Errors.create (`Sql "Post connect error")) ) + in match Yojson.Safe.from_string body with | body -> - router route body ~db ~graphql_uri ~logger + with_db (router route body ~graphql_uri ~logger) | exception Yojson.Json_error "Blank input data" -> - router route `Null ~db ~graphql_uri ~logger + with_db (router route `Null ~graphql_uri ~logger) | exception Yojson.Json_error err -> Errors.create ~context:"JSON in request malformed" (`Json_parse (Some err)) @@ -70,13 +84,13 @@ let command = fun () -> let logger = Logger.create () in Cli.logger_setup log_json log_level ; - match%bind Caqti_async.connect archive_uri with + match Caqti_async.connect_pool ~max_size:128 archive_uri with | Error e -> [%log error] ~metadata:[("error", `String (Caqti_error.show e))] - "Failed to connect to postgresql database. Error: $error" ; + "Failed to create a caqti pool to postgres. Error: $error" ; Deferred.unit - | Ok db -> + | Ok pool -> let%bind server = Cohttp_async.Server.create_expert ~max_connections:128 ~on_handler_error: @@ -88,7 +102,7 @@ let command = [ ("error", `String (Exn.to_string_mach exn)) ; ("context", `String "rest_server") ] )) (Async.Tcp.Where_to_listen.bind_to All_addresses (On_port port)) - (server_handler ~db ~graphql_uri ~logger) + (server_handler ~pool ~graphql_uri ~logger) in [%log info] ~metadata:[("port", `Int port)]