diff --git a/lua/plenary/async_lib/api.lua b/lua/plenary/async_lib/api.lua new file mode 100644 index 000000000..619a03681 --- /dev/null +++ b/lua/plenary/async_lib/api.lua @@ -0,0 +1,13 @@ +local a = require('plenary.async_lib.async') +local async, await = a.async, a.await + +return setmetatable({}, {__index = function(t, k) + return async(function(...) + -- if we are in a fast event await the scheduler + if vim.in_fast_event() then + await(a.scheduler()) + end + + vim.api[k](...) + end) +end}) diff --git a/lua/plenary/async_lib/async.lua b/lua/plenary/async_lib/async.lua new file mode 100644 index 000000000..0b47de719 --- /dev/null +++ b/lua/plenary/async_lib/async.lua @@ -0,0 +1,201 @@ +local co = coroutine +local errors = require('plenary.errors') +local traceback_error = errors.traceback_error + +local M = {} + +---@class Future +---Something that will give a value when run + +---Executes a future with a callback when it is done +---@param future Future: the future to execute +---@param callback function: the callback to call when done +local execute = function(future, callback) + assert(type(future) == "function", "type error :: expected func") + local thread = co.create(future) + + local step + step = function(...) + local res = {co.resume(thread, ...)} + local stat = res[1] + local ret = {select(2, unpack(res))} + + if not stat then + error(string.format("The coroutine failed with this message: %s", ret[1])) + end + + if co.status(thread) == "dead" then + (callback or function() end)(unpack(ret)) + else + assert(#ret == 1, "expected a single return value") + local returned_future = ret[1] + assert(type(returned_future) == "function", "type error :: expected func") + returned_future(step) + end + end + + step() +end + +---Creates an async function with a callback style function. +---@param func function: A callback style function to be converted. The last argument must be the callback. +---@param argc number: The number of arguments of func. Must be included. +---@return function: Returns an async function +M.wrap = function(func, argc) + if type(func) ~= "function" then + traceback_error("type error :: expected func, got " .. type(func)) + end + + if type(argc) ~= "number" and argc ~= "vararg" then + traceback_error("expected argc to be a number or string literal 'vararg'") + end + + return function(...) + local params = {...} + + local function future(step) + if step then + if type(argc) == "number" then + params[argc] = step + else + table.insert(params, step) -- change once not optional + end + return func(unpack(params)) + else + return co.yield(future) + end + end + return future + end +end + +---Return a new future that when run will run all futures concurrently. +---@param futures table: the futures that you want to join +---@return Future: returns a future +M.join = M.wrap(function(futures, step) + local len = #futures + local results = {} + local done = 0 + + if len == 0 then + return step(results) + end + + for i, future in ipairs(futures) do + assert(type(future) == "function", "type error :: future must be function") + + local callback = function(...) + results[i] = {...} + done = done + 1 + if done == len then + step(results) + end + end + + future(callback) + end +end, 2) + +---Returns a future that when run will select the first future that finishes +---@param futures table: The future that you want to select +---@return Future +M.select = M.wrap(function(futures, step) + local selected = false + + for _, future in ipairs(futures) do + assert(type(future) == "function", "type error :: future must be function") + + local callback = function(...) + if not selected then + selected = true + step(...) + end + end + + future(callback) + end +end, 2) + +---Use this to either run a future concurrently and then do something else +---or use it to run a future with a callback in a non async context +---@param future Future +---@param callback function +M.run = function(future, callback) + future(callback or function() end) +end + +---Same as run but runs multiple futures +---@param futures table +---@param callback function +M.run_all = function(futures, callback) + M.run(M.join(futures), callback) +end + +---Await a future, yielding the current function +---@param future Future +---@return any: returns the result of the future when it is done +M.await = function(future) + assert(type(future) == "function", "type error :: expected function to await") + return future(nil) +end + +---Same as await but can await multiple futures. +---If the futures have libuv leaf futures they will be run concurrently +---@param futures table +---@return table: returns a table of results that each future returned. Note that if the future returns multiple values they will be packed into a table. +M.await_all = function(futures) + assert(type(futures) == "table", "type error :: expected table") + return M.await(M.join(futures)) +end + +---suspend a coroutine +M.suspend = co.yield + +---create a async scope +M.scope = function(func) + M.run(M.future(func)) +end + +--- Future a :: a -> (a -> ()) +--- turns this signature +--- ... -> Future a +--- into this signature +--- ... -> () +M.void = function(async_func) + return function(...) + async_func(...)(function() end) + end +end + +---creates an async function +---@param func function +---@return function: returns an async function +M.async = function(func) + if type(func) ~= "function" then + traceback_error("type error :: expected func, got " .. type(func)) + end + + return function(...) + local args = {...} + local function future(step) + if step == nil then + return func(unpack(args)) + else + execute(future, step) + end + end + return future + end +end + +---creates a future +---@param func function +---@return Future +M.future = function(func) + return M.async(func)() +end + +---An async function that when awaited will await the scheduler to be able to call the api. +M.scheduler = M.wrap(vim.schedule, 1) + +return M diff --git a/lua/plenary/async_lib/init.lua b/lua/plenary/async_lib/init.lua new file mode 100644 index 000000000..3a8f277c3 --- /dev/null +++ b/lua/plenary/async_lib/init.lua @@ -0,0 +1,36 @@ +local exports = require('plenary.async_lib.async') +exports.uv = require('plenary.async_lib.uv_async') +exports.util = require('plenary.async_lib.util') +exports.lsp = require('plenary.async_lib.lsp') +exports.api = require('plenary.async_lib.api') +exports.tests = require('plenary.async_lib.tests') + +exports.tests.add_globals = function() + a = exports + async = exports.async + await = exports.await + await_all = exports.await_all + + -- must prefix with a or stack overflow, plenary.test harness already added it + a.describe = exports.tests.describe + -- must prefix with a or stack overflow + a.it = exports.tests.it +end + +exports.tests.add_to_env = function() + local env = getfenv(2) + + env.a = exports + env.async = exports.async + env.await = exports.await + env.await_all = exports.await_all + + -- must prefix with a or stack overflow, plenary.test harness already added it + env.a.describe = exports.tests.describe + -- must prefix with a or stack overflow + env.a.it = exports.tests.it + + setfenv(2, env) +end + +return exports diff --git a/lua/plenary/async_lib/lsp.lua b/lua/plenary/async_lib/lsp.lua new file mode 100644 index 000000000..6f96354e3 --- /dev/null +++ b/lua/plenary/async_lib/lsp.lua @@ -0,0 +1,8 @@ +local a = require('plenary.async_lib.async') + +local M = {} + +---Same as vim.lsp.buf_request but works with async await +M.buf_request = a.wrap(vim.lsp.buf_request, 4) + +return M diff --git a/lua/plenary/async_lib/structs.lua b/lua/plenary/async_lib/structs.lua new file mode 100644 index 000000000..a5c2f97f4 --- /dev/null +++ b/lua/plenary/async_lib/structs.lua @@ -0,0 +1,112 @@ +local M = {} + +Deque = {} +Deque.__index = Deque + +---@class Deque +---A double ended queue +--- +---@return Deque +function Deque.new() + -- the indexes are created with an offset so that the indices are consequtive + -- otherwise, when both pushleft and pushright are used, the indices will have a 1 length hole in the middle + return setmetatable({first = 0, last = -1}, Deque) +end + +---push to the left of the deque +---@param value any +function Deque:pushleft(value) + local first = self.first - 1 + self.first = first + self[first] = value +end + +---push to the right of the deque +---@param value any +function Deque:pushright(value) + local last = self.last + 1 + self.last = last + self[last] = value +end + +---pop from the left of the deque +---@return any +function Deque:popleft() + local first = self.first + if first > self.last then return nil end + local value = self[first] + self[first] = nil -- to allow garbage collection + self.first = first + 1 + return value +end + +---pops from the right of the deque +---@return any +function Deque:popright() + local last = self.last + if self.first > last then return nil end + local value = self[last] + self[last] = nil -- to allow garbage collection + self.last = last - 1 + return value +end + +---checks if the deque is empty +---@return boolean +function Deque:is_empty() + return self:len() == 0 +end + +---returns the number of elements of the deque +---@return number +function Deque:len() + return self.last - self.first + 1 +end + +---returns and iterator of the indices and values starting from the left +---@return function +function Deque:ipairs_left() + local i = self.first + + return function() + local res = self[i] + local idx = i + + if res then + i = i + 1 + + return idx, res + end + end +end + +---returns and iterator of the indices and values starting from the right +---@return function +function Deque:ipairs_right() + local i = self.last + + return function() + local res = self[i] + local idx = i + + if res then + i = i - 1 -- advance the iterator before we return + + return idx, res + end + end +end + +---removes all values from the deque +---@return nil +function Deque:clear() + for i, _ in self:ipairs_left() do + self[i] = nil + end + self.first = 0 + self.last = -1 +end + +M.Deque = Deque + +return M diff --git a/lua/plenary/async_lib/tests.lua b/lua/plenary/async_lib/tests.lua new file mode 100644 index 000000000..61b6282f0 --- /dev/null +++ b/lua/plenary/async_lib/tests.lua @@ -0,0 +1,14 @@ +local a = require('plenary.async_lib.async') +local util = require('plenary.async_lib.util') + +local M = {} + +M.describe = function(s, func) + describe(s, util.will_block(a.future(func))) +end + +M.it = function(s, func) + it(s, util.will_block(a.future(func))) +end + +return M diff --git a/lua/plenary/async_lib/util.lua b/lua/plenary/async_lib/util.lua new file mode 100644 index 000000000..917dfef87 --- /dev/null +++ b/lua/plenary/async_lib/util.lua @@ -0,0 +1,333 @@ +local a = require('plenary.async_lib.async') +local await = a.await +local async = a.async +local co = coroutine +local Deque = require('plenary.async_lib.structs').Deque +local uv = vim.loop + +local M = {} + +---Sleep for milliseconds +---@param ms number +M.sleep = a.wrap(function(ms, callback) + local timer = uv.new_timer() + uv.timer_start(timer, ms, 0, function() + uv.timer_stop(timer) + uv.close(timer) + callback() + end) +end, 2) + +---Takes a future and a millisecond as the timeout. +---If the time is reached and the future hasn't completed yet, it will short circuit the future +---NOTE: the future will still be running in libuv, we are just not waiting for it to complete +---thats why you should call this on a leaf future only to avoid unexpected results +---@param future Future +---@param ms number +M.timeout = a.wrap(function(future, ms, callback) + -- make sure that the callback isn't called twice, or else the coroutine can be dead + local done = false + + local timeout_callback = function(...) + if not done then + done = true + callback(false, ...) -- false because it has run normally + end + end + + vim.defer_fn(function() + if not done then + done = true + callback(true) -- true because it has timed out + end + end, ms) + + a.run(future, timeout_callback) +end, 3) + +---create an async function timer +---@param ms number +M.timer = function(ms) + return async(function() + await(M.sleep(ms)) + end) +end + +---id function that can be awaited +---@param nil ... +---@return ... +M.id = async(function(...) + return ... +end) + +---Running this function will yield now and do nothing else +M.yield_now = async(function() + await(M.id()) +end) + +local Condvar = {} +Condvar.__index = Condvar + +---@class Condvar +---@return Condvar +function Condvar.new() + return setmetatable({handles = {}}, Condvar) +end + +---`blocks` the thread until a notification is received +Condvar.wait = a.wrap(function(self, callback) + -- not calling the callback will block the coroutine + table.insert(self.handles, callback) +end, 2) + +---notify everyone that is waiting on this Condvar +function Condvar:notify_all() + if #self.handles == 0 then return end + + for _, callback in ipairs(self.handles) do + callback() + end + self.handles = {} -- reset all handles as they have been used up +end + +---notify randomly one person that is waiting on this Condvar +function Condvar:notify_one() + if #self.handles == 0 then return end + + local idx = math.random(#self.handles) + self.handles[idx]() + table.remove(self.handles, idx) +end + +M.Condvar = Condvar + +local Semaphore = {} +Semaphore.__index = Semaphore + +---@class Semaphore +---@param initial_permits number: the number of permits that it can give out +---@return Semaphore +function Semaphore.new(initial_permits) + vim.validate { + initial_permits = { + initial_permits, + function(n) return n > 0 end, + 'number greater than 0' + } + } + + return setmetatable({permits = initial_permits, handles = {}}, Semaphore) +end + +---async function, blocks until a permit can be acquired +---example: +---local semaphore = Semaphore.new(1024) +---local permit = await(semaphore:acquire()) +---permit:forget() +---when a permit can be acquired returns it +---call permit:forget() to forget the permit +Semaphore.acquire = a.wrap(function(self, callback) + self.permits = self.permits - 1 + + if self.permits <= 0 then + table.insert(self.handles, callback) + return + end + + local permit = {} + + permit.forget = function(self_permit) + self.permits = self.permits + 1 + + if self.permits > 0 and #self.handles > 0 then + local callback = table.remove(self.handles) + callback(self_permit) + self.permits = self.permits - 1 + end + end + + callback(permit) +end, 2) + +M.Semaphore = Semaphore + +M.channel = {} + +---Creates a oneshot channel +---returns a sender and receiver function +---the sender is not async while the receiver is +---@return function, function +M.channel.oneshot = function() + local val = nil + local saved_callback = nil + local sent = false + local received = false + + --- sender is not async + --- sends a value + local sender = function(...) + if sent then + error("Oneshot channel can only send once") + end + + sent = true + + local args = {...} + + if saved_callback then + saved_callback(unpack(val or args)) + else + val = args + end + end + + --- receiver is async + --- blocks until a value is received + local receiver = a.wrap(function(callback) + if received then + error('Oneshot channel can only send one value!') + end + + if val then + received = true + callback(unpack(val)) + else + saved_callback = callback + end + end, 1) + + return sender, receiver +end + +---A counter channel. +---Basically a channel that you want to use only to notify and not to send any actual values. +---@return function: sender +---@return function: receiver +M.channel.counter = function() + local counter = 0 + local condvar = Condvar.new() + + local Sender = {} + + function Sender:send() + counter = counter + 1 + condvar:notify_all() + end + + local Receiver = {} + + Receiver.recv = async(function() + if counter == 0 then + await(condvar:wait()) + end + counter = counter - 1 + end) + + Receiver.last = async(function() + if counter == 0 then + await(condvar:wait()) + end + counter = 0 + end) + + return Sender, Receiver +end + +---A multiple producer single consumer channel +---@return table +---@return table +M.channel.mpsc = function() + local deque = Deque.new() + local condvar = Condvar.new() + + local Sender = {} + + function Sender.send(...) + deque:pushleft({...}) + condvar:notify_all() + end + + local Receiver = {} + + Receiver.recv = async(function() + if deque:is_empty() then + await(condvar:wait()) + end + return unpack(deque:popright()) + end) + + Receiver.last = async(function() + if deque:is_empty() then + await(condvar:wait()) + end + local val = deque:popright() + deque:clear() + return unpack(val) + end) + + return Sender, Receiver +end + +local pcall_wrap = function(func) + return function(...) + return pcall(func, ...) + end +end + +---Makes a future protected. It is like pcall but for futures. +---Only works for non-leaf futures +M.protected_non_leaf = async(function(future) + return await(pcall_wrap(future)) +end) + +---Makes a future protected. It is like pcall but for futures. +---@param future Future +---@return Future +M.protected = async(function(future) + local tx, rx = M.channel.oneshot() + + stat, ret = pcall(future, tx) + + if stat == true then + return stat, await(rx()) + else + return stat, ret + end +end) + +---This will COMPLETELY block neovim +---please just use a.run unless you have a very special usecase +---for example, in plenary test_harness you must use this +---@param future Future +---@param timeout number: Stop blocking if the timeout was surpassed. Default 2000. +M.block_on = function(future, timeout) + future = M.protected(future) + + local stat, ret + a.run(future, function(_stat, ...) + stat = _stat + ret = {...} + end) + + local function check() + if stat == false then + error("Blocking on future failed " .. unpack(ret)) + end + return stat == true + end + + if not vim.wait(timeout or 2000, check, 20, false) then + error("Blocking on future timed out or was interrupted") + end + + return unpack(ret) +end + +---Returns a new future that WILL BLOCK +---@param future Future +---@return Future +M.will_block = async(function(future) + return M.block_on(future) +end) + +return M diff --git a/lua/plenary/async_lib/uv_async.lua b/lua/plenary/async_lib/uv_async.lua new file mode 100644 index 000000000..368223c83 --- /dev/null +++ b/lua/plenary/async_lib/uv_async.lua @@ -0,0 +1,82 @@ +local a = require('plenary.async_lib.async') +local uv = vim.loop + +local M = {} + +local function add(name, argc) + local success, ret = pcall(a.wrap, uv[name], argc) + + if not success then + error("Failed to add function with name " .. name) + end + + M[name] = ret +end + +add('close', 4) -- close a handle + +-- filesystem operations +add('fs_open', 4) +add('fs_read', 4) +add('fs_close', 2) +add('fs_unlink', 2) +add('fs_write', 4) +add('fs_mkdir', 3) +add('fs_mkdtemp', 2) +-- 'fs_mkstemp', +add('fs_rmdir', 2) +add('fs_scandir', 2) +add('fs_stat', 2) +add('fs_fstat', 2) +add('fs_lstat', 2) +add('fs_rename', 3) +add('fs_fsync', 2) +add('fs_fdatasync', 2) +add('fs_ftruncate', 3) +add('fs_sendfile', 5) +add('fs_access', 3) +add('fs_chmod', 3) +add('fs_fchmod', 3) +add('fs_utime', 4) +add('fs_futime', 4) +-- 'fs_lutime', +add('fs_link', 3) +add('fs_symlink', 4) +add('fs_readlink', 2) +add('fs_realpath', 2) +add('fs_chown', 4) +add('fs_fchown', 4) +-- 'fs_lchown', +add('fs_copyfile', 4) +-- add('fs_opendir', 3) -- TODO: fix this one +add('fs_readdir', 2) +add('fs_closedir', 2) +-- 'fs_statfs', + +-- stream +add('shutdown', 2) +add('listen', 3) +-- add('read_start', 2) -- do not do this one, the callback is made multiple times +add('write', 3) +add('write2', 4) +add('shutdown', 2) + +-- tcp +add('tcp_connect', 4) +-- 'tcp_close_reset', + +-- pipe +add('pipe_connect', 3) + +-- udp +add('udp_send', 5) +add('udp_recv_start', 2) + +-- fs event (wip make into async await event) +-- fs poll event (wip make into async await event) + +-- dns +add('getaddrinfo', 4) +add('getnameinfo', 2) + +return M diff --git a/lua/plenary/errors.lua b/lua/plenary/errors.lua new file mode 100644 index 000000000..fef139174 --- /dev/null +++ b/lua/plenary/errors.lua @@ -0,0 +1,15 @@ +local M = {} + +M.traceback_error = function(s, level) + local traceback = debug.traceback() + traceback = traceback .. '\n' .. s + error(traceback, (level or 1) + 1) +end + +M.info_error = function(s, func_info, level) + local info = debug.getinfo(func_info) + info = info .. '\n' .. s + error(info, (level or 1) + 1) +end + +return M diff --git a/scratch/async.lua b/scratch/async.lua new file mode 100644 index 000000000..38c78819f --- /dev/null +++ b/scratch/async.lua @@ -0,0 +1,5 @@ +local a = require('plenary.async_lib') +local async = a.async + +async(nil) +-- a.wrap(function() end, nil) diff --git a/tests/plenary/async_lib/channel_spec.lua b/tests/plenary/async_lib/channel_spec.lua new file mode 100644 index 000000000..f7da47981 --- /dev/null +++ b/tests/plenary/async_lib/channel_spec.lua @@ -0,0 +1,116 @@ +require('plenary.async_lib').tests.add_to_env() +local channel = a.util.channel +local eq = assert.are.same +local protected = a.util.protected + +a.describe('channel', function() + a.describe('oneshot', function() + a.it('should work when rx is used first', function() + local tx, rx = channel.oneshot() + + a.run(a.future(function() + local got = await(rx()) + eq("sent value", got) + end)) + + tx("sent value") + end) + + a.it('should work when tx is used first', function() + local tx, rx = channel.oneshot() + + tx("sent value") + + a.run(a.future(function() + local got = await(rx()) + eq("sent value", got) + end)) + end) + + a.it('should work with multiple returns', function() + local tx, rx = channel.oneshot() + + a.run(a.future(function() + local got, got2 = await(rx()) + eq("sent value", got) + eq("another sent value", got2) + end)) + + tx("sent value", "another sent value") + end) + + a.it('should work when sending a nil value', function () + local tx, rx = channel.oneshot() + + tx(nil) + + local res = await(rx()) + eq(res, nil) + + local stat, ret = await(protected(rx())) + eq(stat, false) + local stat, ret = await(protected(rx())) + eq(stat, false) + end) + + a.it('should block sending mulitple times', function() + local tx, rx = channel.oneshot() + + tx() + local stat = pcall(tx) + eq(stat, false) + end) + + a.it('should block receiving multiple times', function () + local tx, rx = channel.oneshot() + tx() + await(rx()) + local stat = await(protected(rx())) + eq(stat, false) + end) + end) + + a.describe('counter', function() + a.it('should work', function() + local tx, rx = channel.counter() + + tx.send() + tx.send() + tx.send() + + local counter = 0 + + local recv_stuff = async(function() + for i = 1, 3 do + await(rx.recv()) + counter = counter + 1 + end + end) + + a.run(recv_stuff()) + + eq(counter, 3) + end) + + a.it('should work when getting last', function() + local tx, rx = channel.counter() + + tx.send() + tx.send() + tx.send() + + local counter = 0 + + local recv_stuff = async(function() + for i = 1, 3 do + await(rx.last()) + counter = counter + 1 + end + end) + + a.run(recv_stuff()) + + eq(counter, 1) + end) + end) +end) diff --git a/tests/plenary/async_lib/condvar_spec.lua b/tests/plenary/async_lib/condvar_spec.lua new file mode 100644 index 000000000..1793c293e --- /dev/null +++ b/tests/plenary/async_lib/condvar_spec.lua @@ -0,0 +1,127 @@ +require('plenary.async_lib').tests.add_to_env() +local Condvar = a.util.Condvar +local eq = assert.are.same + +a.describe('condvar', function() + a.it('should allow blocking', function() + local var = false + + local condvar = Condvar.new() + + local blocking = async(function() + await(condvar:wait()) + var = true + end) + + a.run(blocking()) + + eq(var, false) + + condvar:notify_one() + + eq(var, true) + end) + + a.it('should be able to notify one when running', function() + local counter = 0 + + local condvar = Condvar.new() + + local first = async(function() + await(condvar:wait()) + counter = counter + 1 + end) + + local second = async(function() + await(condvar:wait()) + counter = counter + 1 + end) + + local third = async(function() + await(condvar:wait()) + counter = counter + 1 + end) + + a.run_all { first(), second(), third() } + + eq(0, counter) + + condvar:notify_one() + + eq(1, counter) + + condvar:notify_one() + + eq(counter, 2) + + condvar:notify_one() + + eq(counter, 3) + end) + + a.it('should allow notify_one to work when using await_all', function() + local counter = 0 + + local condvar = Condvar.new() + + local first = async(function() + await(condvar:wait()) + counter = counter + 1 + end) + + local second = async(function() + await(condvar:wait()) + counter = counter + 1 + end) + + local third = async(function() + await(condvar:wait()) + counter = counter + 1 + end) + + a.run_all { first(), second(), third() } + + eq(0, counter) + + condvar:notify_one() + + eq(1, counter) + + condvar:notify_one() + + eq(counter, 2) + + condvar:notify_one() + + eq(counter, 3) + end) + + a.it('should notify_all', function() + local counter = 0 + + local condvar = Condvar.new() + + local first = async(function() + await(condvar:wait()) + counter = counter + 1 + end) + + local second = async(function() + await(condvar:wait()) + counter = counter + 1 + end) + + local third = async(function() + await(condvar:wait()) + counter = counter + 1 + end) + + a.run_all { first(), second(), third() } + + eq(0, counter) + + condvar:notify_all() + + eq(3, counter) + end) +end) diff --git a/tests/plenary/async_lib/deque_spec.lua b/tests/plenary/async_lib/deque_spec.lua new file mode 100644 index 000000000..ae14f8332 --- /dev/null +++ b/tests/plenary/async_lib/deque_spec.lua @@ -0,0 +1,91 @@ +local Deque = require('plenary.async_lib.structs').Deque +local eq = assert.are.same + +-- just a helper to create the test deque +local function new_deque() + local deque = Deque.new() + eq(deque:len(), 0) + + deque:pushleft(1) + eq(deque:len(), 1) + + deque:pushleft(2) + eq(deque:len(), 2) + + deque:pushright(3) + eq(deque:len(), 3) + + deque:pushright(4) + eq(deque:len(), 4) + + deque:pushright(5) + eq(deque:len(), 5) + + return deque +end + +describe('deque', function() + it('should allow pushing and popping and finding len', function() + new_deque() + end) + + it('should be able to iterate from left', function() + local deque = new_deque() + + local iter = deque:ipairs_left() + + local i, v = iter() + eq(i, -2) + eq(v, 2) + + i, v = iter() + eq(i, -1) + eq(v, 1) + + i, v = iter() + eq(i, 0) + eq(v, 3) + + i, v = iter() + eq(i, 1) + eq(v, 4) + + i, v = iter() + eq(i, 2) + eq(v, 5) + end) + + it('should be able to iterate from right', function() + local deque = new_deque() + + local iter = deque:ipairs_right() + + local i, v = iter() + eq(i, 2) + eq(v, 5) + + i, v = iter() + eq(i, 1) + eq(v, 4) + + i, v = iter() + eq(i, 0) + eq(v, 3) + + i, v = iter() + eq(i, -1) + eq(v, 1) + + i, v = iter() + eq(i, -2) + eq(v, 2) + end) + + it('should allow clearing', function() + local deque = new_deque() + + deque:clear() + + assert(deque:is_empty()) + end) +end) diff --git a/tests/plenary/async_lib/read_file_bench.lua b/tests/plenary/async_lib/read_file_bench.lua new file mode 100644 index 000000000..c4b0d37bc --- /dev/null +++ b/tests/plenary/async_lib/read_file_bench.lua @@ -0,0 +1,97 @@ +-- do not run this for now, the asset was removed +local a = require('plenary.async_lib') +local async = a.async +local await = a.await +local await_all = a.await_all +local uv = vim.loop + +local plenary_init = vim.api.nvim_get_runtime_file('lua/plenary/init.lua', false)[1] +local plenary_dir = vim.fn.fnamemodify(plenary_init, ":h:h:h") +local assets_dir = plenary_dir .. '/' .. 'tests/plenary/async_lib/assets/' + +local read_file = async(function(path) + local err, fd = await(a.uv.fs_open(path, "r", 438)) + assert(not err, err) + + print('fd', fd) + + local err, stat = await(a.uv.fs_fstat(fd)) + assert(not err, err) + + local err, data = await(a.uv.fs_read(fd, stat.size, 0)) + assert(not err, err) + + local err = await(a.uv.fs_close(fd)) + assert(not err, err) + + return data +end) + +local read_file_other = function(path, callback) + uv.fs_open(path, "r", 438, function(err, fd) + assert(not err, err) + uv.fs_fstat(fd, function(err, stat) + assert(not err, err) + uv.fs_read(fd, stat.size, 0, function(err, data) + assert(not err, err) + uv.fs_close(fd, function(err) + assert(not err, err) + return callback(data) + end) + end) + end) + end) +end + +local first_bench = async(function() + + local futures = {} + + for i = 1, 300 do futures[i] = read_file(assets_dir .. 'syn.json') end + + local start = os.clock() + + await_all(futures) + + print("Elapsed time: ", os.clock() - start) +end) + +local second_bench = function() + local results = {} + + local start = os.clock() + + for i = 1, 300 do + read_file_other(assets_dir .. 'syn.json', function(data) + results[i] = data + if #results == 300 then + print("Elapsed time: ", os.clock() - start) + end + end) + end +end + +local call_api = async(function() + local futures = {} + + local read_and_api = async(function() + local res = await(read_file(assets_dir .. 'syn.json')) + print('first in fast event', vim.in_fast_event()) + await(a.api.nvim_notify("Hello", 1, {})) + print('second in fast event', vim.in_fast_event()) + print('third in fast event', vim.in_fast_event()) + return res + end) + + for i = 1, 300 do + futures[i] = read_and_api() + end + + local start = os.clock() + + local res = await_all(futures) + + print("Elapsed time: ", os.clock() - start) +end) + +a.run(call_api()) diff --git a/tests/plenary/async_lib/semaphore_spec.lua b/tests/plenary/async_lib/semaphore_spec.lua new file mode 100644 index 000000000..f2908fbb5 --- /dev/null +++ b/tests/plenary/async_lib/semaphore_spec.lua @@ -0,0 +1,17 @@ +require('plenary.async_lib').tests.add_to_env() +local Semaphore = a.util.Semaphore + +local eq = assert.are.same + +a.describe('semaphore', function() + a.it('should validate arguments', function() + local status = pcall(Semaphore.new, -1) + eq(status, false) + + local status = pcall(Semaphore.new) + eq(status, false) + end) + + a.it('should count properly', function() + end) +end) diff --git a/tests/plenary/async_lib/util_spec.lua b/tests/plenary/async_lib/util_spec.lua new file mode 100644 index 000000000..ccbc1b657 --- /dev/null +++ b/tests/plenary/async_lib/util_spec.lua @@ -0,0 +1,104 @@ +require('plenary.async_lib').tests.add_to_env() +local block_on = a.util.block_on +local eq = assert.are.same +local id = a.util.id + +a.describe('async await util', function() + a.describe('block_on', function() + a.it('should block_on', function() + local fn = async(function() + await(a.util.sleep(100)) + return 'hello' + end) + + local res = block_on(fn()) + eq(res, 'hello') + end) + + a.it('should work even when failing', function () + local nonleaf = async(function() + eq(true, false) + end) + + local stat = pcall(block_on, nonleaf()) + eq(stat, false) + end) + end) + + a.describe('protect', function() + a.it('should be able to protect a non-leaf future', function() + local nonleaf = async(function() + error("This should error") + return 'return' + end) + + local stat, ret = await(a.util.protected_non_leaf(nonleaf())) + eq(false, stat) + assert(ret:match("This should error")) + end) + + a.it('should be able to protect a non-leaf future that doesnt fail', function() + local nonleaf = async(function() + return 'didnt fail' + end) + + local stat, ret = await(a.util.protected_non_leaf(nonleaf())) + eq(stat, true) + eq(ret, 'didnt fail') + end) + + a.it('should be able to protect a leaf future', function() + local leaf = a.wrap(function(callback) + error("This should error") + callback() + end, 1) + + local stat, ret = await(a.util.protected(leaf())) + eq(stat, false) + assert(ret:match("This should error") ~= nil) + end) + + a.it('should be able to protect a leaf future that doesnt fail', function() + local fn = a.wrap(function(callback) + callback('didnt fail') + end, 1) + + local stat, ret = await(a.util.protected(fn())) + eq(stat, true) + eq(ret, 'didnt fail') + end) + end) + + a.describe('timeout', function() + a.it('should block one and work', function() + local timed_out = await(a.util.timeout(a.util.sleep(1000), 500)) + + print('timed out 2:', timed_out) + + assert(timed_out == true) + end) + + a.it('should work when timeout is longer', function () + local timed_out = await(a.util.timeout(a.util.sleep(100), 1000)) + + print('timed out 1:', timed_out) + + assert(timed_out == false) + end) + end) + + a.it('id should work', function() + eq(await(id(1, 2, 3, 4, 5)), 1, 2, 3, 4, 5) + end) + + a.it('yield_now should work', function () + local yield_now = a.util.yield_now + yield_now() + yield_now() + yield_now() + yield_now() + yield_now() + yield_now() + yield_now() + end) +end)