diff --git a/app/modules/net.c b/app/modules/net.c index 8181b947b2..3748bee18c 100644 --- a/app/modules/net.c +++ b/app/modules/net.c @@ -13,6 +13,7 @@ #include "lwip/ip_addr.h" #include "espconn.h" #include "lwip/dns.h" +#include "lwip/udp.h" #define TCP ESPCONN_TCP #define UDP ESPCONN_UDP @@ -1702,6 +1703,334 @@ static int expose_array(lua_State* L, char *array, unsigned short len) { } #endif +typedef struct { + struct udp_pcb *pcb; + int cb_recv_ref; + int self_ref; +} net_udp_ud; + +static void net_udp_recv( net_udp_ud *ud, struct udp_pcb *pcb, struct pbuf *p, ip_addr_t *addr, u16_t port ) { + if ( !ud || !pcb || + ud->pcb != pcb || + ud->self_ref == LUA_NOREF || + ud->cb_recv_ref == LUA_NOREF) + { + pbuf_free(p); + return; + } + + lua_State *L = lua_getstate(); + lua_rawgeti(L, LUA_REGISTRYINDEX, ud->cb_recv_ref); + lua_rawgeti(L, LUA_REGISTRYINDEX, ud->self_ref); + lua_pushinteger(L, port); + char iptmp[20]; + size_t ipl = c_sprintf(iptmp, IPSTR, IP2STR(&addr->addr)); + lua_pushlstring(L, iptmp, ipl); + lua_pushlstring(L, p->payload, p->len); + lua_call(L, 4, 0); + + pbuf_free(p); +} + +static int net_createUDPSocket( lua_State* L ) { + net_udp_ud *ud = (net_udp_ud*) lua_newuserdata(L, sizeof(net_udp_ud)); + + ud->pcb = udp_new(); + udp_recv(ud->pcb, (udp_recv_fn)net_udp_recv, ud); + ud->cb_recv_ref = LUA_NOREF; + ud->self_ref = LUA_NOREF; + + luaL_getmetatable(L, "net.udp"); + lua_setmetatable(L, -2); + + return 1; +} + +static int luaL_lwip_checkerr(err_t res, lua_State *L) { + switch (res) { + case ERR_OK: return 0; \ + case ERR_MEM: return luaL_error(L, "out of memory"); + case ERR_BUF: return luaL_error(L, "buffer error"); + case ERR_TIMEOUT: return luaL_error(L, "timeout"); + case ERR_RTE: return luaL_error(L, "routing error"); + case ERR_INPROGRESS: return luaL_error(L, "operation in progress"); + case ERR_VAL: return luaL_error(L, "illegal value"); + case ERR_WOULDBLOCK: return luaL_error(L, "operation would block"); + case ERR_ABRT: return luaL_error(L, "connection aborted"); + case ERR_RST: return luaL_error(L, "connection reset"); + case ERR_CLSD: return luaL_error(L, "connection closed"); + case ERR_CONN: return luaL_error(L, "not connected"); + case ERR_ARG: return luaL_error(L, "illegal argument"); + case ERR_USE: return luaL_error(L, "address in use"); + case ERR_IF: return luaL_error(L, "low-level netif error"); + case ERR_ISCONN: return luaL_error(L, "already connected"); + default: return luaL_error(L, "invalid LWIP state"); + } +} + +static int net_udp_getaddr( lua_State* L ); + +static int net_udp_bind( lua_State* L ) { + net_udp_ud *ud; + ip_addr_t addr; + u16_t port; + + ud = (net_udp_ud *)luaL_checkudata(L, 1, "net.udp"); + luaL_argcheck(L, ud, 1, "UDP socket expected"); + if(ud==NULL){ + NODE_DBG("userdata is nil.\n"); + return 0; + } + + port = luaL_checkinteger(L, 2); + if (lua_isstring(L, 3)) { + size_t sl; + const char* domain = luaL_checklstring(L, 3, &sl); + addr.addr = ipaddr_addr(domain); + if (addr.addr == 0xFFFFFFFF) { + return luaL_error(L, "Invalid IP address"); + } + } else { + addr.addr = 0; + } + + luaL_unref(L, LUA_REGISTRYINDEX, ud->self_ref); + ud->self_ref = LUA_NOREF; + udp_disconnect(ud->pcb); + + int err; + if ((err = luaL_lwip_checkerr(udp_bind(ud->pcb, &addr, port), L)) != 0) { + return err; + } + + lua_pushvalue(L, 1); // copy to the top of stack + ud->self_ref = luaL_ref(L, LUA_REGISTRYINDEX); + + return net_udp_getaddr(L); +} + +static int net_udp_connect( lua_State* L ) { + net_udp_ud *ud; + ip_addr_t addr; + u16_t port; + size_t sl; + + ud = (net_udp_ud *)luaL_checkudata(L, 1, "net.udp"); + luaL_argcheck(L, ud, 1, "UDP socket expected"); + if(ud==NULL){ + NODE_DBG("userdata is nil.\n"); + return 0; + } + + port = luaL_checkinteger(L, 2); + const char* domain = luaL_checklstring(L, 3, &sl); + + if (port == 0) { + return luaL_error(L, "Invalid port"); + } + addr.addr = ipaddr_addr(domain); + if (domain == 0 || addr.addr == 0 || addr.addr == 0xFFFFFFFF) { + return luaL_error(L, "Invalid IP address"); + } + + int err; + if ((err = luaL_lwip_checkerr(udp_connect(ud->pcb, &addr, port), L)) != 0) { + return err; + } + + if (ud->self_ref == LUA_NOREF) { + lua_pushvalue(L, 1); // copy to the top of stack + ud->self_ref = luaL_ref(L, LUA_REGISTRYINDEX); + } + + return net_udp_getaddr(L); +} + +static int net_udp_send( lua_State* L ) { + net_udp_ud *ud = (net_udp_ud *)luaL_checkudata(L, 1, "net.udp"); + luaL_argcheck(L, ud, 1, "UDP socket expected"); + if(ud==NULL){ + NODE_DBG("userdata is nil.\n"); + return 0; + } + + size_t sl; + const char* payload = luaL_checklstring(L, 2, &sl); + if (payload == 0) { + return luaL_error(L, "invalid payload"); + } + + if (ud->pcb->remote_port == 0 || ud->pcb->remote_ip.addr == 0) { + return luaL_lwip_checkerr(ERR_CONN, L); + } + + struct pbuf *buf = pbuf_alloc(PBUF_TRANSPORT, sl, PBUF_RAM); + os_memmove(buf->payload, payload, sl); + int err; + if ((err = luaL_lwip_checkerr(udp_send(ud->pcb, buf), L)) != 0) { + pbuf_free(buf); + return err; + } + pbuf_free(buf); + + return 0; +} + +static int net_udp_sendto( lua_State* L ) { + net_udp_ud *ud = (net_udp_ud *)luaL_checkudata(L, 1, "net.udp"); + luaL_argcheck(L, ud, 1, "UDP socket expected"); + if(ud==NULL){ + NODE_DBG("userdata is nil.\n"); + return 0; + } + + size_t sl, il; + const char* payload = luaL_checklstring(L, 2, &sl); + u16_t port = luaL_checkinteger(L, 3); + const char* domain = luaL_checklstring(L, 4, &il); + ip_addr_t addr; + + if (payload == 0) { + return luaL_error(L, "invalid payload"); + } + if (port == 0) { + return luaL_error(L, "Invalid port"); + } + addr.addr = ipaddr_addr(domain); + if (domain == 0 || addr.addr == 0 || addr.addr == 0xFFFFFFFF) { + return luaL_error(L, "Invalid IP address"); + } + + struct pbuf *buf = pbuf_alloc(PBUF_TRANSPORT, sl, PBUF_RAM); + os_memmove(buf->payload, payload, sl); + int err; + if ((err = luaL_lwip_checkerr(udp_sendto(ud->pcb, buf, &addr, port), L)) != 0) { + pbuf_free(buf); + return err; + } + pbuf_free(buf); + + return 0; +} + +static int net_udp_getaddr( lua_State* L ) { + net_udp_ud *ud = (net_udp_ud *)luaL_checkudata(L, 1, "net.udp"); + luaL_argcheck(L, ud, 1, "UDP socket expected"); + if(ud==NULL){ + NODE_DBG("userdata is nil.\n"); + return 0; + } + + if (ud->self_ref == LUA_NOREF) { + lua_pushnil(L); + lua_pushnil(L); + return 2; + } + + char ipaddr[20]; + c_sprintf(ipaddr, IPSTR, IP2STR(&ud->pcb->local_ip.addr)); + + lua_pushinteger(L, ud->pcb->local_port); + lua_pushstring(L, ipaddr); + + return 2; +} + +static int net_udp_getpeer( lua_State* L ) { + net_udp_ud *ud = (net_udp_ud *)luaL_checkudata(L, 1, "net.udp"); + luaL_argcheck(L, ud, 1, "UDP socket expected"); + if(ud==NULL){ + NODE_DBG("userdata is nil.\n"); + return 0; + } + + if (ud->self_ref == LUA_NOREF) { + lua_pushnil(L); + lua_pushnil(L); + return 2; + } + + if (ud->pcb->remote_port == 0 || ud->pcb->remote_ip.addr == 0) { + lua_pushnil(L); + lua_pushnil(L); + return 2; + } + + char ipaddr[20]; + c_sprintf(ipaddr, IPSTR, IP2STR(&ud->pcb->remote_ip.addr)); + + lua_pushinteger(L, ud->pcb->remote_port); + lua_pushstring(L, ipaddr); + + return 2; +} + +static int net_udp_close( lua_State* L ) { + net_udp_ud *ud = (net_udp_ud *)luaL_checkudata(L, 1, "net.udp"); + luaL_argcheck(L, ud, 1, "UDP socket expected"); + if(ud==NULL){ + NODE_DBG("userdata is nil.\n"); + return 0; + } + + udp_disconnect(ud->pcb); + + lua_gc(L, LUA_GCSTOP, 0); + luaL_unref(L, LUA_REGISTRYINDEX, ud->self_ref); + ud->self_ref = LUA_NOREF; + lua_gc(L, LUA_GCRESTART, 0); + + return 0; +} + +static int net_udp_on( lua_State* L ) { + net_udp_ud *ud = (net_udp_ud *)luaL_checkudata(L, 1, "net.udp"); + luaL_argcheck(L, ud, 1, "UDP socket expected"); + if(ud==NULL){ + NODE_DBG("userdata is nil.\n"); + return 0; + } + + size_t sl; + const char* event = luaL_checklstring(L, 2, &sl); + + luaL_checkanyfunction(L, 3); + lua_pushvalue(L, 3); // copy argument (func) to the top of stack + + if(c_strcmp(event, "receive") == 0){ + luaL_unref(L, LUA_REGISTRYINDEX, ud->cb_recv_ref); + ud->cb_recv_ref = luaL_ref(L, LUA_REGISTRYINDEX); + } else { + return luaL_error(L, "invalid method"); + } + + return 0; +} + +static int net_udp_delete( lua_State* L ) { + net_udp_ud *ud = (net_udp_ud *)luaL_checkudata(L, 1, "net.udp"); + luaL_argcheck(L, ud, 1, "UDP socket expected"); + if(ud==NULL){ + NODE_DBG("userdata is nil.\n"); + return 0; + } + + if (ud->pcb) { + udp_disconnect(ud->pcb); + udp_remove(ud->pcb); + ud->pcb = 0; + } + luaL_unref(L, LUA_REGISTRYINDEX, ud->cb_recv_ref); + ud->cb_recv_ref = LUA_NOREF; + + lua_gc(L, LUA_GCSTOP, 0); + luaL_unref(L, LUA_REGISTRYINDEX, ud->self_ref); + ud->self_ref = LUA_NOREF; + lua_gc(L, LUA_GCRESTART, 0); + + return 0; +} + // Module function map static const LUA_REG_TYPE net_server_map[] = { { LSTRKEY( "listen" ), LFUNCVAL( net_server_listen ) }, @@ -1751,9 +2080,24 @@ static const LUA_REG_TYPE net_dns_map[] = { { LNILKEY, LNILVAL } }; +static const LUA_REG_TYPE net_udp_map[] = { + { LSTRKEY( "bind" ), LFUNCVAL( net_udp_bind ) }, + { LSTRKEY( "on" ), LFUNCVAL( net_udp_on ) }, + { LSTRKEY( "connect" ), LFUNCVAL( net_udp_connect ) }, + { LSTRKEY( "sendto" ), LFUNCVAL( net_udp_sendto ) }, + { LSTRKEY( "send" ), LFUNCVAL( net_udp_send ) }, + { LSTRKEY( "getaddr" ), LFUNCVAL( net_udp_getaddr ) }, + { LSTRKEY( "getpeer" ), LFUNCVAL( net_udp_getpeer ) }, + { LSTRKEY( "close" ), LFUNCVAL( net_udp_close ) }, + { LSTRKEY( "__gc" ), LFUNCVAL( net_udp_delete ) }, + { LSTRKEY( "__index" ), LROVAL( net_udp_map ) }, + { LNILKEY, LNILVAL } +}; + static const LUA_REG_TYPE net_map[] = { { LSTRKEY( "createServer" ), LFUNCVAL( net_createServer ) }, { LSTRKEY( "createConnection" ), LFUNCVAL( net_createConnection ) }, + { LSTRKEY( "createUDPSocket" ), LFUNCVAL( net_createUDPSocket ) }, { LSTRKEY( "multicastJoin"), LFUNCVAL( net_multicastJoin ) }, { LSTRKEY( "multicastLeave"), LFUNCVAL( net_multicastLeave ) }, { LSTRKEY( "dns" ), LROVAL( net_dns_map ) }, @@ -1778,6 +2122,7 @@ int luaopen_net( lua_State *L ) { #if 0 luaL_rometatable(L, "net.array", (void *)net_array_map); // create metatable for net.array #endif + luaL_rometatable(L, "net.udp", (void *)net_udp_map); // create metatable for net.udp return 0; }