Skip to content

Commit

Permalink
Merge pull request #2768 from NeOzay/cast-table-to-class
Browse files Browse the repository at this point in the history
check that the shape of the table corresponds to the class
  • Loading branch information
sumneko authored Aug 15, 2024
2 parents b71cb7a + 2c79870 commit 7c9a24f
Show file tree
Hide file tree
Showing 4 changed files with 239 additions and 5 deletions.
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Unreleased
<!-- Add all new changes here. They will be moved under a version at release -->
* `NEW` Add matching checks between the shape of tables and classes, during type checking. [#2768](https://github.com/LuaLS/lua-language-server/pull/2768
* `FIX` Error `attempt to index a nil value` when `Lua.hint.semicolon == 'All'` [#2788](https://github.com/LuaLS/lua-language-server/issues/2788)

## 3.10.3
Expand Down
93 changes: 88 additions & 5 deletions script/vm/type.lua
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ end
---@param mark table
---@param errs? typecheck.err[]
---@return boolean?
local function checkChildEnum(childName, parent , uri, mark, errs)
local function checkChildEnum(childName, parent, uri, mark, errs)
if mark[childName] then
return
end
Expand All @@ -168,7 +168,7 @@ local function checkChildEnum(childName, parent , uri, mark, errs)
end
mark[childName] = true
for _, enum in ipairs(enums) do
if not vm.isSubType(uri, vm.compileNode(enum), parent, mark ,errs) then
if not vm.isSubType(uri, vm.compileNode(enum), parent, mark, errs) then
mark[childName] = nil
return false
end
Expand Down Expand Up @@ -325,10 +325,30 @@ function vm.isSubType(uri, child, parent, mark, errs)
return true
else
local weakNil = config.get(uri, 'Lua.type.weakNilCheck')
local skipTable
for n in child:eachObject() do
if skipTable == nil and n.type == "table" and parent.type == "vm.node" then -- skip table type check if child has class
---@cast parent vm.node
for _, c in ipairs(child) do
if c.type == 'global' and c.cate == 'type' then
for _, set in ipairs(c:getSets(uri)) do
if set.type == 'doc.class' then
skipTable = true
break
end
end
end
if skipTable then
break
end
end
if not skipTable then
skipTable = false
end
end
local nodeName = vm.getNodeName(n)
if nodeName
and not (nodeName == 'nil' and weakNil)
and not (nodeName == 'nil' and weakNil) and not (skipTable and n.type == 'table')
and vm.isSubType(uri, n, parent, mark, errs) == false then
if errs then
errs[#errs+1] = 'TYPE_ERROR_UNION_DISMATCH'
Expand Down Expand Up @@ -463,6 +483,67 @@ function vm.isSubType(uri, child, parent, mark, errs)
return true
end
if childName == 'table' and not guide.isBasicType(parentName) then
local set = parent:getSets(uri)
local missedKeys = {}
local failedCheck
local myKeys
for _, def in ipairs(set) do
if not def.fields or #def.fields == 0 then
goto continue
end
if not myKeys then
myKeys = {}
for _, field in ipairs(child) do
local key = vm.getKeyName(field) or field.tindex
if key then
myKeys[key] = vm.compileNode(field)
end
end
end

for _, field in ipairs(def.fields) do
local key = vm.getKeyName(field)
if not key then
local fieldnode = vm.compileNode(field.field)[1]
if fieldnode and fieldnode.type == 'doc.type.integer' then
---@cast fieldnode parser.object
key = vm.getKeyName(fieldnode)
end
end
if not key then
goto continue
end

local ok
local nodeField = vm.compileNode(field)
if myKeys[key] then
ok = vm.isSubType(uri, myKeys[key], nodeField, mark, errs)
if ok == false then
errs[#errs+1] = 'TYPE_ERROR_PARENT_ALL_DISMATCH' -- error display can be greatly improved
errs[#errs+1] = myKeys[key]
errs[#errs+1] = nodeField
failedCheck = true
end
elseif not nodeField:isNullable() then
if type(key) == "number" then
missedKeys[#missedKeys+1] = ('`[%s]`'):format(key)
else
missedKeys[#missedKeys+1] = ('`%s`'):format(key)
end
failedCheck = true
end
end
::continue::
end
if #missedKeys > 0 then
errs[#errs+1] = 'DIAG_MISSING_FIELDS'
errs[#errs+1] = parent
errs[#errs+1] = table.concat(missedKeys, ', ')
end
if failedCheck then
return false
end

return true
end

Expand Down Expand Up @@ -570,11 +651,11 @@ function vm.getTableValue(uri, tnode, knode, inversion)
and field.value
and field.tindex == 1 then
if inversion then
if vm.isSubType(uri, 'integer', knode) then
if vm.isSubType(uri, 'integer', knode) then
result:merge(vm.compileNode(field.value))
end
else
if vm.isSubType(uri, knode, 'integer') then
if vm.isSubType(uri, knode, 'integer') then
result:merge(vm.compileNode(field.value))
end
end
Expand Down Expand Up @@ -692,6 +773,7 @@ function vm.canCastType(uri, defNode, refNode, errs)
return true
end


return false
end

Expand All @@ -713,6 +795,7 @@ local ErrorMessageMap = {
TYPE_ERROR_NUMBER_LITERAL_TO_INTEGER = {'child'},
TYPE_ERROR_NUMBER_TYPE_TO_INTEGER = {},
TYPE_ERROR_DISMATCH = {'child', 'parent'},
DIAG_MISSING_FIELDS = {"1", "2"},
}

---@param uri uri
Expand Down
66 changes: 66 additions & 0 deletions test/diagnostics/cast-local-type.lua
Original file line number Diff line number Diff line change
Expand Up @@ -332,3 +332,69 @@ local x
- 类型 `nil` 无法匹配 `'B'`
- 类型 `nil` 无法匹配 `'A'`]])
end)

TEST [[
---@class A
---@field x string
---@field y number
local a = {x = "", y = 0}
---@type A
local v
v = a
]]

TEST [[
---@class A
---@field x string
---@field y number
local a = {x = ""}
---@type A
local v
<!v!> = a
]]

TEST [[
---@class A
---@field x string
---@field y number
local a = {x = "", y = ""}
---@type A
local v
<!v!> = a
]]

TEST [[
---@class A
---@field x string
---@field y? B
---@class B
---@field x string
local a = {x = "b", y = {x = "c"}}
---@type A
local v
v = a
]]

TEST [[
---@class A
---@field x string
---@field y B
---@class B
---@field x string
local a = {x = "b", y = {}}
---@type A
local v
<!v!> = a
]]
84 changes: 84 additions & 0 deletions test/diagnostics/param-type-mismatch.lua
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,87 @@ local function f(v) end
f 'x'
f 'y'
]]

TEST [[
---@class A
---@field x string
---@field y number
local a = {x = "", y = 0}
---@param a A
function f(a) end
f(a)
]]

TEST [[
---@class A
---@field x string
---@field y number
local a = {x = ""}
---@param a A
function f(a) end
f(<!a!>)
]]

TEST [[
---@class A
---@field x string
---@field y number
local a = {x = "", y = ""}
---@param a A
function f(a) end
f(<!a!>)
]]

TEST [[
---@class A
---@field x string
---@field y? B
---@class B
---@field x string
local a = {x = "b", y = {x = "c"}}
---@param a A
function f(a) end
f(a)
]]

TEST [[
---@class A
---@field x string
---@field y B
---@class B
---@field x string
local a = {x = "b", y = {}}
---@param a A
function f(a) end
f(<!a!>)
]]

TEST [[
---@class A
---@field x string
---@type A
local a = {}
---@param a A
function f(a) end
f(a)
]]

0 comments on commit 7c9a24f

Please sign in to comment.