From 114c88c871f3b29303e62b7e82edbb5f016b1a6b Mon Sep 17 00:00:00 2001 From: NeOzay Date: Sun, 21 Jul 2024 13:34:35 +0200 Subject: [PATCH 1/6] check that the shape of the table corresponds to the class --- script/vm/type.lua | 63 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 60 insertions(+), 3 deletions(-) diff --git a/script/vm/type.lua b/script/vm/type.lua index d2a859d0d..d79a8094a 100644 --- a/script/vm/type.lua +++ b/script/vm/type.lua @@ -325,10 +325,20 @@ function vm.isSubType(uri, child, parent, mark, errs) return true else local weakNil = config.get(uri, 'Lua.type.weakNilCheck') + local iscomplextype + for i = #child, 1, -1 do + local name = vm.getNodeName(child[i]) + if name then + iscomplextype = not guide.isBasicType(name) + if iscomplextype then + break + end + end + end for n in child:eachObject() do local nodeName = vm.getNodeName(n) if nodeName - and not (nodeName == 'nil' and weakNil) + and not (nodeName == 'nil' and weakNil) and vm.isSubType(uri, n, parent, mark, errs) == false then if errs then errs[#errs+1] = 'TYPE_ERROR_UNION_DISMATCH' @@ -463,7 +473,53 @@ function vm.isSubType(uri, child, parent, mark, errs) return true end if childName == 'table' and not guide.isBasicType(parentName) then - return true + ------------------------------------------------------------------------- + local requiresKeys = {} + local t = parent + local set = {} + vm.getClassFields(uri, parent, vm.ANY, function (field, isMark) + if not set[field] then + set[field] = true + set[#set+1] = field + end + end) + print('set', set) + for i, field in ipairs(set) do + if not field.optional + and not vm.compileNode(field):isNullable() then + local key = vm.getKeyName(field) + local node = vm.compileNode(field) + if key and not requiresKeys[key] then + requiresKeys[key] = node + requiresKeys[#requiresKeys+1] = key + end + end + end + if #requiresKeys == 0 then + return + end + local refkey = {} + for _, field in ipairs(child) do + local name = vm.getKeyName(field) + local node = vm.compileNode(field) + if name then + refkey[name] = node + end + end + local ok + for _, key in ipairs(requiresKeys) do + if refkey[key] then + ok = vm.isSubType(uri, refkey[key], requiresKeys[key], mark, errs) + else + return false + end + if not ok then + return false + end + end + + ------------------------------- + return true --true end -- check class parent @@ -692,6 +748,7 @@ function vm.canCastType(uri, defNode, refNode, errs) return true end + return false end @@ -753,7 +810,7 @@ function vm.viewTypeErrorMessage(uri, errs) lparams[paramName] = vm.viewKey(value, uri) else lparams[paramName] = vm.getInfer(value):view(uri) - or vm.getInfer(value):view(uri) + or vm.getInfer(value):view(uri) end end index = index + 1 From d74fae4b4a9390b812e2de4608a662b8251f3fc6 Mon Sep 17 00:00:00 2001 From: NeOzay Date: Sun, 21 Jul 2024 21:14:59 +0200 Subject: [PATCH 2/6] Remove and fix some code --- script/vm/type.lua | 24 +++++------------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/script/vm/type.lua b/script/vm/type.lua index d79a8094a..e29515592 100644 --- a/script/vm/type.lua +++ b/script/vm/type.lua @@ -325,16 +325,6 @@ function vm.isSubType(uri, child, parent, mark, errs) return true else local weakNil = config.get(uri, 'Lua.type.weakNilCheck') - local iscomplextype - for i = #child, 1, -1 do - local name = vm.getNodeName(child[i]) - if name then - iscomplextype = not guide.isBasicType(name) - if iscomplextype then - break - end - end - end for n in child:eachObject() do local nodeName = vm.getNodeName(n) if nodeName @@ -483,20 +473,16 @@ function vm.isSubType(uri, child, parent, mark, errs) set[#set+1] = field end end) - print('set', set) + for i, field in ipairs(set) do - if not field.optional - and not vm.compileNode(field):isNullable() then local key = vm.getKeyName(field) local node = vm.compileNode(field) if key and not requiresKeys[key] then requiresKeys[key] = node - requiresKeys[#requiresKeys+1] = key - end end end if #requiresKeys == 0 then - return + return --true end local refkey = {} for _, field in ipairs(child) do @@ -507,10 +493,10 @@ function vm.isSubType(uri, child, parent, mark, errs) end end local ok - for _, key in ipairs(requiresKeys) do + for key, node in pairs(requiresKeys) do if refkey[key] then ok = vm.isSubType(uri, refkey[key], requiresKeys[key], mark, errs) - else + elseif not node:isNullable() then return false end if not ok then @@ -519,7 +505,7 @@ function vm.isSubType(uri, child, parent, mark, errs) end ------------------------------- - return true --true + return true end -- check class parent From 10b09d29557098943513b90ef3553d5f0c6431b0 Mon Sep 17 00:00:00 2001 From: NeOzay Date: Mon, 22 Jul 2024 17:39:44 +0200 Subject: [PATCH 3/6] remove unnecessary diff --- script/vm/type.lua | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/script/vm/type.lua b/script/vm/type.lua index e29515592..7e3ba0ff8 100644 --- a/script/vm/type.lua +++ b/script/vm/type.lua @@ -148,7 +148,7 @@ end ---@param mark table ---@param errs? typecheck.err[] ---@return boolean? -local function checkChildEnum(childName, parent , uri, mark, errs) +local function checkChildEnum(childName, parent, uri, mark, errs) if mark[childName] then return end @@ -168,7 +168,7 @@ local function checkChildEnum(childName, parent , uri, mark, errs) end mark[childName] = true for _, enum in ipairs(enums) do - if not vm.isSubType(uri, vm.compileNode(enum), parent, mark ,errs) then + if not vm.isSubType(uri, vm.compileNode(enum), parent, mark, errs) then mark[childName] = nil return false end @@ -328,7 +328,7 @@ function vm.isSubType(uri, child, parent, mark, errs) for n in child:eachObject() do local nodeName = vm.getNodeName(n) if nodeName - and not (nodeName == 'nil' and weakNil) + and not (nodeName == 'nil' and weakNil) and vm.isSubType(uri, n, parent, mark, errs) == false then if errs then errs[#errs+1] = 'TYPE_ERROR_UNION_DISMATCH' @@ -463,7 +463,6 @@ function vm.isSubType(uri, child, parent, mark, errs) return true end if childName == 'table' and not guide.isBasicType(parentName) then - ------------------------------------------------------------------------- local requiresKeys = {} local t = parent local set = {} @@ -475,14 +474,14 @@ function vm.isSubType(uri, child, parent, mark, errs) end) for i, field in ipairs(set) do - local key = vm.getKeyName(field) - local node = vm.compileNode(field) - if key and not requiresKeys[key] then - requiresKeys[key] = node + local key = vm.getKeyName(field) + local node = vm.compileNode(field) + if key and not requiresKeys[key] then + requiresKeys[key] = node end end - if #requiresKeys == 0 then - return --true + if not next(requiresKeys) then + return true end local refkey = {} for _, field in ipairs(child) do @@ -495,7 +494,7 @@ function vm.isSubType(uri, child, parent, mark, errs) local ok for key, node in pairs(requiresKeys) do if refkey[key] then - ok = vm.isSubType(uri, refkey[key], requiresKeys[key], mark, errs) + ok = vm.isSubType(uri, refkey[key], requiresKeys[key], mark, errs) elseif not node:isNullable() then return false end @@ -503,8 +502,7 @@ function vm.isSubType(uri, child, parent, mark, errs) return false end end - - ------------------------------- + return true end @@ -612,11 +610,11 @@ function vm.getTableValue(uri, tnode, knode, inversion) and field.value and field.tindex == 1 then if inversion then - if vm.isSubType(uri, 'integer', knode) then + if vm.isSubType(uri, 'integer', knode) then result:merge(vm.compileNode(field.value)) end else - if vm.isSubType(uri, knode, 'integer') then + if vm.isSubType(uri, knode, 'integer') then result:merge(vm.compileNode(field.value)) end end @@ -796,7 +794,7 @@ function vm.viewTypeErrorMessage(uri, errs) lparams[paramName] = vm.viewKey(value, uri) else lparams[paramName] = vm.getInfer(value):view(uri) - or vm.getInfer(value):view(uri) + or vm.getInfer(value):view(uri) end end index = index + 1 From e22ee92638966bacef958d050bda7909cbc0ab2a Mon Sep 17 00:00:00 2001 From: NeOzay Date: Sun, 28 Jul 2024 21:06:40 +0200 Subject: [PATCH 4/6] undo branche merge --- script/vm/type.lua | 110 +++++++++++++++++++++++++++++++-------------- 1 file changed, 76 insertions(+), 34 deletions(-) diff --git a/script/vm/type.lua b/script/vm/type.lua index 7e3ba0ff8..afc19984d 100644 --- a/script/vm/type.lua +++ b/script/vm/type.lua @@ -325,10 +325,30 @@ function vm.isSubType(uri, child, parent, mark, errs) return true else local weakNil = config.get(uri, 'Lua.type.weakNilCheck') + local skipTable for n in child:eachObject() do + if skipTable == nil and n.type == "table" and parent.type == "vm.node" then -- skip table type check if child has class + ---@cast parent vm.node + for _, c in ipairs(child) do + if c.type == 'global' and c.cate == 'type' then + for _, set in ipairs(c:getSets(uri)) do + if set.type == 'doc.class' then + skipTable = true + break + end + end + end + if skipTable then + break + end + end + if not skipTable then + skipTable = false + end + end local nodeName = vm.getNodeName(n) if nodeName - and not (nodeName == 'nil' and weakNil) + and not (nodeName == 'nil' and weakNil) and not (skipTable and n.type == 'table') and vm.isSubType(uri, n, parent, mark, errs) == false then if errs then errs[#errs+1] = 'TYPE_ERROR_UNION_DISMATCH' @@ -463,44 +483,65 @@ function vm.isSubType(uri, child, parent, mark, errs) return true end if childName == 'table' and not guide.isBasicType(parentName) then - local requiresKeys = {} - local t = parent - local set = {} - vm.getClassFields(uri, parent, vm.ANY, function (field, isMark) - if not set[field] then - set[field] = true - set[#set+1] = field + local set = parent:getSets(uri) + local missedKeys = {} + local failedCheck + local myKeys + for _, def in ipairs(set) do + if not def.fields or #def.fields == 0 then + goto continue + end + if not myKeys then + myKeys = {} + for _, field in ipairs(child) do + local key = vm.getKeyName(field) or field.tindex + if key then + myKeys[key] = vm.compileNode(field) + end + end end - end) - for i, field in ipairs(set) do - local key = vm.getKeyName(field) - local node = vm.compileNode(field) - if key and not requiresKeys[key] then - requiresKeys[key] = node + for _, field in ipairs(def.fields) do + local key = vm.getKeyName(field) + if not key then + local fieldnode = vm.compileNode(field.field)[1] + if fieldnode and fieldnode.type == 'doc.type.integer' then + ---@cast fieldnode parser.object + key = vm.getKeyName(fieldnode) + end + end + if not key then + goto continue + end + + local ok + local nodeField = vm.compileNode(field) + if myKeys[key] then + ok = vm.isSubType(uri, myKeys[key], nodeField, mark, errs) + if ok == false then + errs[#errs+1] = 'TYPE_ERROR_PARENT_ALL_DISMATCH' -- error display can be greatly improved + errs[#errs+1] = myKeys[key] + errs[#errs+1] = nodeField + failedCheck = true + end + elseif not nodeField:isNullable() then + if type(key) == "number" then + missedKeys[#missedKeys+1] = ('`[%s]`'):format(key) + else + missedKeys[#missedKeys+1] = ('`%s`'):format(key) + end + failedCheck = true + end end + ::continue:: end - if not next(requiresKeys) then - return true - end - local refkey = {} - for _, field in ipairs(child) do - local name = vm.getKeyName(field) - local node = vm.compileNode(field) - if name then - refkey[name] = node - end + if #missedKeys > 0 then + errs[#errs+1] = 'DIAG_MISSING_FIELDS' + errs[#errs+1] = parent + errs[#errs+1] = table.concat(missedKeys, ', ') end - local ok - for key, node in pairs(requiresKeys) do - if refkey[key] then - ok = vm.isSubType(uri, refkey[key], requiresKeys[key], mark, errs) - elseif not node:isNullable() then - return false - end - if not ok then - return false - end + if failedCheck then + return false end return true @@ -754,6 +795,7 @@ local ErrorMessageMap = { TYPE_ERROR_NUMBER_LITERAL_TO_INTEGER = {'child'}, TYPE_ERROR_NUMBER_TYPE_TO_INTEGER = {}, TYPE_ERROR_DISMATCH = {'child', 'parent'}, + DIAG_MISSING_FIELDS = {"1", "2"}, } ---@param uri uri From 6cd10388ab56fed330706f90628b44c77d778310 Mon Sep 17 00:00:00 2001 From: NeOzay Date: Sun, 28 Jul 2024 21:07:22 +0200 Subject: [PATCH 5/6] update changelog --- changelog.md | 1 + 1 file changed, 1 insertion(+) diff --git a/changelog.md b/changelog.md index a82efdabe..697b05a97 100644 --- a/changelog.md +++ b/changelog.md @@ -12,6 +12,7 @@ * `NEW` added lua regular expression support for Lua.doc.Name [#2753](https://github.com/LuaLS/lua-language-server/pull/2753) * `FIX` Bad triggering of the `inject-field` diagnostic, when the fields are declared at the creation of the object [#2746](https://github.com/LuaLS/lua-language-server/issues/2746) * `CHG` Change spacing of parameter inlay hints to match other LSPs, like `rust-analyzer` +* `NEW` Add matching checks between the shape of tables and classes, during type checking. [#2768](https://github.com/LuaLS/lua-language-server/pull/2768) ## 3.9.3 `2024-6-11` From b7c5809d6c5b08f295e60c5e53ee8d8e1074667c Mon Sep 17 00:00:00 2001 From: NeOzay Date: Sun, 28 Jul 2024 22:29:32 +0200 Subject: [PATCH 6/6] add test --- test/diagnostics/cast-local-type.lua | 66 +++++++++++++++++++ test/diagnostics/param-type-mismatch.lua | 84 ++++++++++++++++++++++++ 2 files changed, 150 insertions(+) diff --git a/test/diagnostics/cast-local-type.lua b/test/diagnostics/cast-local-type.lua index f79bf48d1..93452a922 100644 --- a/test/diagnostics/cast-local-type.lua +++ b/test/diagnostics/cast-local-type.lua @@ -332,3 +332,69 @@ local x - 类型 `nil` 无法匹配 `'B'` - 类型 `nil` 无法匹配 `'A'`]]) end) + +TEST [[ +---@class A +---@field x string +---@field y number + +local a = {x = "", y = 0} + +---@type A +local v +v = a +]] + +TEST [[ +---@class A +---@field x string +---@field y number + +local a = {x = ""} + +---@type A +local v + = a +]] + +TEST [[ +---@class A +---@field x string +---@field y number + +local a = {x = "", y = ""} + +---@type A +local v + = a +]] + +TEST [[ +---@class A +---@field x string +---@field y? B + +---@class B +---@field x string + +local a = {x = "b", y = {x = "c"}} + +---@type A +local v +v = a +]] + +TEST [[ +---@class A +---@field x string +---@field y B + +---@class B +---@field x string + +local a = {x = "b", y = {}} + +---@type A +local v + = a +]] diff --git a/test/diagnostics/param-type-mismatch.lua b/test/diagnostics/param-type-mismatch.lua index b11068db4..bb602cabf 100644 --- a/test/diagnostics/param-type-mismatch.lua +++ b/test/diagnostics/param-type-mismatch.lua @@ -264,3 +264,87 @@ local function f(v) end f 'x' f 'y' ]] + +TEST [[ +---@class A +---@field x string +---@field y number + +local a = {x = "", y = 0} + +---@param a A +function f(a) end + +f(a) +]] + +TEST [[ +---@class A +---@field x string +---@field y number + +local a = {x = ""} + +---@param a A +function f(a) end + +f() +]] + +TEST [[ +---@class A +---@field x string +---@field y number + +local a = {x = "", y = ""} + +---@param a A +function f(a) end + +f() +]] + +TEST [[ +---@class A +---@field x string +---@field y? B + +---@class B +---@field x string + +local a = {x = "b", y = {x = "c"}} + +---@param a A +function f(a) end + +f(a) +]] + +TEST [[ +---@class A +---@field x string +---@field y B + +---@class B +---@field x string + +local a = {x = "b", y = {}} + +---@param a A +function f(a) end + +f() +]] + +TEST [[ +---@class A +---@field x string + +---@type A +local a = {} + +---@param a A +function f(a) end + +f(a) +]] \ No newline at end of file