Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][SetParameters] Generalize to work for any type, rename to TypeParameterAccessors #1331

Closed
wants to merge 149 commits into from
Closed
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
149 commits
Select commit Hold shift + click to select a range
1724369
Start working on changing the name of `SetParameters`
kmp5VT Feb 5, 2024
0513d41
Merge commit '172436901f883cd939caf78b81ad433bcc466140' into kmp5/ref…
kmp5VT Feb 5, 2024
4ac2a8f
Merge branch 'main' into kmp5/refactor/rename_setparameters
kmp5VT Feb 5, 2024
f78c142
Move unused files to deprecated to better
kmp5VT Feb 5, 2024
80fecae
Start moving get_parameter/parameter
kmp5VT Feb 5, 2024
7bda99a
Remove `set_parameter` functions from other directories
kmp5VT Feb 5, 2024
9cbbcbf
Make parameter function
kmp5VT Feb 5, 2024
93ec2f4
move set_parameter functions to `set_parameters.jl`
kmp5VT Feb 5, 2024
070166d
format
kmp5VT Feb 5, 2024
37c8eec
Merge branch 'main' into kmp5/refactor/rename_setparameters
kmp5VT Feb 5, 2024
27f147a
add comment to work on `default_parameter` [no ci]
kmp5VT Feb 5, 2024
96446b7
Merge branch 'kmp5/refactor/rename_setparameters' of github.com:kmp5V…
kmp5VT Feb 5, 2024
01bf0f0
Rename `SetParameters` to
kmp5VT Feb 5, 2024
8f9bb82
format
kmp5VT Feb 5, 2024
e55611f
`TypeParameterAccessor` -> `TypeParameterAccessors`
kmp5VT Feb 5, 2024
611e350
Update parameters file remove get_parameters
kmp5VT Feb 6, 2024
fe78a5a
Update tests for get parameters
kmp5VT Feb 6, 2024
023770f
Update documentation
kmp5VT Feb 7, 2024
e1c7e04
Fix typo
kmp5VT Feb 7, 2024
403dd72
Move Position based function to Position.jl
kmp5VT Feb 7, 2024
5e6c405
Create Int conversion function
kmp5VT Feb 7, 2024
b15169d
Use new to _datatype function for type stability
kmp5VT Feb 8, 2024
c7a4cad
Update get_parameters tests
kmp5VT Feb 8, 2024
a8b66ff
Move `parameters.jl` before `position.jl`
kmp5VT Feb 8, 2024
354a07b
Move functions that depend on `Position` to `position.jl`
kmp5VT Feb 8, 2024
26e0bc6
Update set_parameters test
kmp5VT Feb 8, 2024
ec69ce4
Move wrapped array definitions to
kmp5VT Feb 8, 2024
4936d3f
Ammending to previous push
kmp5VT Feb 8, 2024
8141575
Remove parenttype_position def is unecessary here
kmp5VT Feb 8, 2024
65b273a
remove uncessary set_ndims call
kmp5VT Feb 8, 2024
415a39d
Export is_wrapped_array and parenttype
kmp5VT Feb 8, 2024
3800125
Fix some issues and dont use `Position`
kmp5VT Feb 8, 2024
0987f8f
Another issue with type stability
kmp5VT Feb 8, 2024
e7dc057
Start defining parenttype position places
kmp5VT Feb 8, 2024
09dda15
Remove exports from `Unwrap`
kmp5VT Feb 8, 2024
bb12557
Remove the `parenttype.jl` file because of compiling requirements
kmp5VT Feb 8, 2024
49fd429
Some updates to specify_parameters (still working here)
kmp5VT Feb 8, 2024
2c3e62f
Updates to base functions to make more type stable
kmp5VT Feb 8, 2024
3842143
grab the name `TypeParameterAccessors` to be able to set `parenttype_…
kmp5VT Feb 8, 2024
6408feb
Update wrapper type to not have stackoverflow
kmp5VT Feb 8, 2024
0edd59e
Update `set_ndim` and `set_eltype`
kmp5VT Feb 8, 2024
4c6e29d
format
kmp5VT Feb 8, 2024
abfb994
Merge branch 'main' into kmp5/refactor/rename_setparameters
kmp5VT Feb 9, 2024
6bfb87d
Some convention updates in the `TypeParameterAccessor` module
kmp5VT Feb 9, 2024
512ad24
Try and remove positoin when possible
kmp5VT Feb 9, 2024
7348682
Remove unecessary code
kmp5VT Feb 9, 2024
e43f3ae
formatting
kmp5VT Feb 9, 2024
3da8963
formatting of `wrappedarray`
kmp5VT Feb 9, 2024
1e112cd
Remove unecessary tests
kmp5VT Feb 9, 2024
1d1b141
formatting/cleanup
kmp5VT Feb 9, 2024
8994541
Start updating the `Specify_parameters`
kmp5VT Feb 9, 2024
cf58b1a
Update `set_parameters` to be more type stable
kmp5VT Feb 12, 2024
eae0a4b
Tests no longer broken
kmp5VT Feb 12, 2024
ad96b1b
Update specify_parameters
kmp5VT Feb 12, 2024
5037009
format
kmp5VT Feb 12, 2024
4d50a3c
format
kmp5VT Feb 12, 2024
97f28fd
Remove unecessary code and old `specify_parameters`
kmp5VT Feb 12, 2024
ed83acf
Update to the position based on function system
kmp5VT Feb 12, 2024
9ce2045
Add interface
kmp5VT Feb 12, 2024
e69657f
Add interface.jl
kmp5VT Feb 12, 2024
02cc6ed
Define setstoragemode for GPU. And update default parameter
kmp5VT Feb 13, 2024
ffa4876
Add a `UndefinedPosition` set_parameter which does nothing
kmp5VT Feb 13, 2024
015f7b5
Rename undefposition -> undefinedposition
kmp5VT Feb 13, 2024
d1b3cce
Make axestype and alloctype parameters
kmp5VT Feb 13, 2024
5f52eec
Spelling
kmp5VT Feb 13, 2024
dc45b4e
Move undefinedPosition set_parameter out of set_parameter.jl. And spe…
kmp5VT Feb 13, 2024
061aeba
format
kmp5VT Feb 13, 2024
f6569cc
Move these generated functions to the correct place
kmp5VT Feb 13, 2024
bfbee6b
No more parenttype_position function
kmp5VT Feb 13, 2024
6b6b850
Update to position(type, function) but there are some tricky circular…
kmp5VT Feb 13, 2024
0bd183e
missing changes from last push in parenttype
kmp5VT Feb 13, 2024
3fee6a9
format
kmp5VT Feb 13, 2024
b5105b1
Merge branch 'kmp5/temp_branch' into kmp5/refactor/rename_setparameters
kmp5VT Feb 13, 2024
11751e0
Update from parenttype position to position
kmp5VT Feb 13, 2024
6603615
More parenttype_position to position
kmp5VT Feb 13, 2024
5e58e0c
Remove typeparameteraccessor file (to_unionall and unspecifyparameters)
kmp5VT Feb 13, 2024
83dadb4
Add some tests to `set_eltype`
kmp5VT Feb 13, 2024
2a8cd90
Fix default parameters for Dense
kmp5VT Feb 13, 2024
a8b4b38
Add back unspecify_parameters to get wrapper name
kmp5VT Feb 13, 2024
726d9d1
Use new system to set alloctype
kmp5VT Feb 13, 2024
e415155
Update UnallocatedArrays tests
kmp5VT Feb 13, 2024
541f349
Update default_alloctype for UnallocatedArrays
kmp5VT Feb 13, 2024
6d8be2c
Comment out default parameter test for now
kmp5VT Feb 13, 2024
23cb8e0
format
kmp5VT Feb 13, 2024
967953a
Some missing updates to the removed file
kmp5VT Feb 13, 2024
d032707
Remove NDTensors. and just say using TypeParameterAccessors
kmp5VT Feb 14, 2024
7db2ca5
Update SimpleTraits using calls
kmp5VT Feb 14, 2024
cab209f
Move before functions
kmp5VT Feb 14, 2024
4f19e7a
Fix broken test
kmp5VT Feb 14, 2024
c06d717
Remove unecessary @generated
kmp5VT Feb 14, 2024
2bff003
Update specify_parameters to be more type stable and no longer need `…
kmp5VT Feb 14, 2024
66cf178
format
kmp5VT Feb 14, 2024
a5064f4
`unwrap_type` -> `unwrap_array_type`
kmp5VT Feb 14, 2024
b29e169
correct the spelling of set_parameter
kmp5VT Feb 14, 2024
69a7561
Add specify_parameter with function
kmp5VT Feb 14, 2024
91d0b1f
Add some comments and remove `specify_parameters`
kmp5VT Feb 14, 2024
7d2e314
format
kmp5VT Feb 14, 2024
07461a8
Fix typo
kmp5VT Feb 14, 2024
1b0a0bd
Working on type stable `DefaultParameter` system
kmp5VT Feb 14, 2024
e4c56c4
Add test to specify_parameter with function in call
kmp5VT Feb 14, 2024
2a5c444
Add specify_defaults test
kmp5VT Feb 14, 2024
5062135
format
kmp5VT Feb 14, 2024
8276521
Move the list of default parameters to `interface` Still workshopping…
kmp5VT Feb 14, 2024
ea4afda
Add set_parameters, where you can put arbitrary parameters and functions
kmp5VT Feb 14, 2024
16e5c17
Add some context and fix the broken case where `Int` isn't wrapped in…
kmp5VT Feb 14, 2024
25d09d5
Fix typo
kmp5VT Feb 14, 2024
09a9dff
Remove default_parameter test and add back unspecifying parameters test
kmp5VT Feb 14, 2024
adbf686
Update specify_parameter system to work better
kmp5VT Feb 14, 2024
efb2269
format
kmp5VT Feb 14, 2024
5ba9b0c
Make a `Function` based parameter function
kmp5VT Feb 16, 2024
bb5d3d0
Fix specify_defaults
kmp5VT Feb 16, 2024
7f8ee72
use function in parameter getting function
kmp5VT Feb 16, 2024
9c21426
Use alloctype over 4
kmp5VT Feb 16, 2024
07e95b9
Rename file and add defaults
kmp5VT Feb 16, 2024
23c81f7
Move undefined position above parameters for compiling
kmp5VT Feb 16, 2024
8b6dbeb
append to previous push
kmp5VT Feb 16, 2024
d8a791f
Responses to Matts comments
kmp5VT Feb 16, 2024
5f1c948
int -> position
kmp5VT Feb 16, 2024
6864abd
Typos
kmp5VT Feb 16, 2024
719a881
Add mytype for testing
kmp5VT Feb 16, 2024
0ba372e
update tests
kmp5VT Feb 16, 2024
9398a90
Revert to default_parameters for now
kmp5VT Feb 16, 2024
92330e7
format
kmp5VT Feb 16, 2024
dc3a9dd
Merge branch 'main' into kmp5/refactor/rename_setparameters
kmp5VT Feb 16, 2024
87bd412
Start making `parameter_function`
kmp5VT Feb 18, 2024
f345715
Add parameter_function for GPUArrays
kmp5VT Feb 18, 2024
94bd0d2
Start fixing things
kmp5VT Feb 18, 2024
d1fb870
Working on other sections of the code
kmp5VT Feb 18, 2024
6b83902
parameter_function -> parameter_name
kmp5VT Feb 19, 2024
3e68db2
Update mytype for testing
kmp5VT Feb 19, 2024
00e954b
Another update to mytype
kmp5VT Feb 19, 2024
9084b24
format
kmp5VT Feb 19, 2024
73f9527
Use `compat: assume_effects`
kmp5VT Feb 19, 2024
f1ae64f
update CUDA cu function to use all three buffers
kmp5VT Feb 19, 2024
ebde22d
Remove extra set_datatype function
kmp5VT Feb 19, 2024
d0a39e0
Migrate `datatype` to `parenttype` funciton so `TensorStorage` is `Wr…
kmp5VT Feb 19, 2024
887b8bc
format
kmp5VT Feb 19, 2024
fc5f639
Updates for metal
kmp5VT Feb 19, 2024
7465f07
Update similartype for wrapped types properly
kmp5VT Feb 19, 2024
36c4e75
format
kmp5VT Feb 19, 2024
4757858
Update to use buffer
kmp5VT Feb 20, 2024
c171f7f
Fix typo
kmp5VT Feb 20, 2024
b2eef5c
Some updates
kmp5VT Feb 20, 2024
c81e9a0
Remove parenttype definition
kmp5VT Feb 20, 2024
104a157
Update ndims
kmp5VT Feb 20, 2024
528142b
format
kmp5VT Feb 20, 2024
75e0969
Get `TypeParameterAccessors` working for NDTensors. There arestill ma…
kmp5VT Feb 21, 2024
7b315fc
Small fixes
kmp5VT Feb 21, 2024
1ce0406
Changes to the typeparameter system and get working for 1.6
kmp5VT Feb 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions NDTensors/ext/NDTensorsCUDAExt/NDTensorsCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module NDTensorsCUDAExt

using NDTensors
using NDTensors.SetParameters
using NDTensors.TypeParameterAccessor
using NDTensors.Unwrap
using Adapt
using Functors
Expand All @@ -13,7 +13,6 @@ using CUDA.CUSOLVER
include("imports.jl")
include("default_kwargs.jl")
include("copyto.jl")
include("set_types.jl")
include("iscu.jl")
include("adapt.jl")
include("indexing.jl")
Expand Down
3 changes: 2 additions & 1 deletion NDTensors/ext/NDTensorsCUDAExt/imports.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import NDTensors: cu, similartype
import NDTensors:
ContractionProperties, _contract!, GemmBackend, auto_select_backend, _gemm!, iscu
import NDTensors.SetParameters: nparameters, get_parameter, set_parameter, default_parameter
import NDTensors.TypeParameterAccessor:
nparameters, get_parameter, set_parameter, default_parameter
68 changes: 34 additions & 34 deletions NDTensors/ext/NDTensorsCUDAExt/set_types.jl
Original file line number Diff line number Diff line change
@@ -1,40 +1,40 @@
# `SetParameters.jl` overloads.
get_parameter(::Type{<:CuArray{P1}}, ::Position{1}) where {P1} = P1
get_parameter(::Type{<:CuArray{<:Any,P2}}, ::Position{2}) where {P2} = P2
get_parameter(::Type{<:CuArray{<:Any,<:Any,P3}}, ::Position{3}) where {P3} = P3
# # `TypeParameterAccessor.jl` overloads.
# get_parameter(::Type{<:CuArray{P1}}, ::Position{1}) where {P1} = P1
# get_parameter(::Type{<:CuArray{<:Any,P2}}, ::Position{2}) where {P2} = P2
# get_parameter(::Type{<:CuArray{<:Any,<:Any,P3}}, ::Position{3}) where {P3} = P3

# Set parameter 1
set_parameter(::Type{<:CuArray}, ::Position{1}, P1) = CuArray{P1}
set_parameter(::Type{<:CuArray{<:Any,P2}}, ::Position{1}, P1) where {P2} = CuArray{P1,P2}
function set_parameter(::Type{<:CuArray{<:Any,<:Any,P3}}, ::Position{1}, P1) where {P3}
return CuArray{P1,<:Any,P3}
end
function set_parameter(::Type{<:CuArray{<:Any,P2,P3}}, ::Position{1}, P1) where {P2,P3}
return CuArray{P1,P2,P3}
end
# # Set parameter 1
# set_parameter(::Type{<:CuArray}, ::Position{1}, P1) = CuArray{P1}
# set_parameter(::Type{<:CuArray{<:Any,P2}}, ::Position{1}, P1) where {P2} = CuArray{P1,P2}
# function set_parameter(::Type{<:CuArray{<:Any,<:Any,P3}}, ::Position{1}, P1) where {P3}
# return CuArray{P1,<:Any,P3}
# end
# function set_parameter(::Type{<:CuArray{<:Any,P2,P3}}, ::Position{1}, P1) where {P2,P3}
# return CuArray{P1,P2,P3}
# end

# Set parameter 2
set_parameter(::Type{<:CuArray}, ::Position{2}, P2) = CuArray{<:Any,P2}
set_parameter(::Type{<:CuArray{P1}}, ::Position{2}, P2) where {P1} = CuArray{P1,P2}
function set_parameter(::Type{<:CuArray{<:Any,<:Any,P3}}, ::Position{2}, P2) where {P3}
return CuArray{<:Any,P2,P3}
end
function set_parameter(::Type{<:CuArray{P1,<:Any,P3}}, ::Position{2}, P2) where {P1,P3}
return CuArray{P1,P2,P3}
end
# # Set parameter 2
# set_parameter(::Type{<:CuArray}, ::Position{2}, P2) = CuArray{<:Any,P2}
# set_parameter(::Type{<:CuArray{P1}}, ::Position{2}, P2) where {P1} = CuArray{P1,P2}
# function set_parameter(::Type{<:CuArray{<:Any,<:Any,P3}}, ::Position{2}, P2) where {P3}
# return CuArray{<:Any,P2,P3}
# end
# function set_parameter(::Type{<:CuArray{P1,<:Any,P3}}, ::Position{2}, P2) where {P1,P3}
# return CuArray{P1,P2,P3}
# end

# Set parameter 3
set_parameter(::Type{<:CuArray}, ::Position{3}, P3) = CuArray{<:Any,<:Any,P3}
set_parameter(::Type{<:CuArray{P1}}, ::Position{3}, P3) where {P1} = CuArray{P1,<:Any,P3}
function set_parameter(::Type{<:CuArray{<:Any,P2}}, ::Position{3}, P3) where {P2}
return CuArray{<:Any,P2,P3}
end
set_parameter(::Type{<:CuArray{P1,P2}}, ::Position{3}, P3) where {P1,P2} = CuArray{P1,P2,P3}
# # Set parameter 3
# set_parameter(::Type{<:CuArray}, ::Position{3}, P3) = CuArray{<:Any,<:Any,P3}
# set_parameter(::Type{<:CuArray{P1}}, ::Position{3}, P3) where {P1} = CuArray{P1,<:Any,P3}
# function set_parameter(::Type{<:CuArray{<:Any,P2}}, ::Position{3}, P3) where {P2}
# return CuArray{<:Any,P2,P3}
# end
# set_parameter(::Type{<:CuArray{P1,P2}}, ::Position{3}, P3) where {P1,P2} = CuArray{P1,P2,P3}

default_parameter(::Type{<:CuArray}, ::Position{1}) = Float64
default_parameter(::Type{<:CuArray}, ::Position{2}) = 1
default_parameter(::Type{<:CuArray}, ::Position{3}) = Mem.DeviceBuffer
# default_parameter(::Type{<:CuArray}, ::Position{1}) = Float64
# default_parameter(::Type{<:CuArray}, ::Position{2}) = 1
# default_parameter(::Type{<:CuArray}, ::Position{3}) = Mem.DeviceBuffer

nparameters(::Type{<:CuArray}) = Val(3)
# nparameters(::Type{<:CuArray}) = Val(3)

SetParameters.unspecify_parameters(::Type{<:CuArray}) = CuArray
# TypeParameterAccessor.unspecify_parameters(::Type{<:CuArray}) = CuArray
4 changes: 2 additions & 2 deletions NDTensors/ext/NDTensorsMetalExt/NDTensorsMetalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ using Adapt
using Functors
using LinearAlgebra: LinearAlgebra, Adjoint, Transpose, mul!, qr, eigen, svd
using NDTensors
using NDTensors.SetParameters
using NDTensors.TypeParameterAccessor
using NDTensors.Unwrap: qr_positive, ql_positive, ql

using Metal

include("imports.jl")
include("adapt.jl")
include("set_types.jl")
# include("set_types.jl")
include("indexing.jl")
include("linearalgebra.jl")
include("copyto.jl")
Expand Down
2 changes: 1 addition & 1 deletion NDTensors/ext/NDTensorsMetalExt/adapt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ end

# More general than the version in Metal.jl
function Adapt.adapt_storage(arraytype::Type{<:MtlArray}, xs::AbstractArray)
arraytype_specified = specify_parameters(arraytype, get_parameters(xs))
arraytype_specified = specify_parameters(arraytype, parameters(xs))
return isbitstype(typeof(xs)) ? xs : convert(arraytype_specified, xs)
end
3 changes: 2 additions & 1 deletion NDTensors/ext/NDTensorsMetalExt/imports.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import NDTensors: mtl
import NDTensors.SetParameters: nparameters, get_parameter, set_parameter, default_parameter
import NDTensors.TypeParameterAccessor:
nparameters, get_parameter, set_parameter, default_parameter

using NDTensors.Unwrap: Exposed, unwrap_type, unexpose, expose
using Metal: DefaultStorageMode
Expand Down
80 changes: 40 additions & 40 deletions NDTensors/ext/NDTensorsMetalExt/set_types.jl
Original file line number Diff line number Diff line change
@@ -1,47 +1,47 @@
# `SetParameters.jl` overloads.
get_parameter(::Type{<:MtlArray{P1}}, ::Position{1}) where {P1} = P1
get_parameter(::Type{<:MtlArray{<:Any,P2}}, ::Position{2}) where {P2} = P2
get_parameter(::Type{<:MtlArray{<:Any,<:Any,P3}}, ::Position{3}) where {P3} = P3
# # `TypeParameterAccessor.jl` overloads.
# get_parameter(::Type{<:MtlArray{P1}}, ::Position{1}) where {P1} = P1
# get_parameter(::Type{<:MtlArray{<:Any,P2}}, ::Position{2}) where {P2} = P2
# get_parameter(::Type{<:MtlArray{<:Any,<:Any,P3}}, ::Position{3}) where {P3} = P3

# Set parameter 1
set_parameter(::Type{<:MtlArray}, ::Position{1}, P1) = MtlArray{P1}
set_parameter(::Type{<:MtlArray{<:Any,P2}}, ::Position{1}, P1) where {P2} = MtlArray{P1,P2}
function set_parameter(::Type{<:MtlArray{<:Any,<:Any,P3}}, ::Position{1}, P1) where {P3}
return MtlArray{P1,<:Any,P3}
end
function set_parameter(::Type{<:MtlArray{<:Any,P2,P3}}, ::Position{1}, P1) where {P2,P3}
return MtlArray{P1,P2,P3}
end
# # Set parameter 1
# set_parameter(::Type{<:MtlArray}, ::Position{1}, P1) = MtlArray{P1}
# set_parameter(::Type{<:MtlArray{<:Any,P2}}, ::Position{1}, P1) where {P2} = MtlArray{P1,P2}
# function set_parameter(::Type{<:MtlArray{<:Any,<:Any,P3}}, ::Position{1}, P1) where {P3}
# return MtlArray{P1,<:Any,P3}
# end
# function set_parameter(::Type{<:MtlArray{<:Any,P2,P3}}, ::Position{1}, P1) where {P2,P3}
# return MtlArray{P1,P2,P3}
# end

# Set parameter 2
set_parameter(::Type{<:MtlArray}, ::Position{2}, P2) = MtlArray{<:Any,P2}
set_parameter(::Type{<:MtlArray{P1}}, ::Position{2}, P2) where {P1} = MtlArray{P1,P2}
function set_parameter(::Type{<:MtlArray{<:Any,<:Any,P3}}, ::Position{2}, P2) where {P3}
return MtlArray{<:Any,P2,P3}
end
function set_parameter(::Type{<:MtlArray{P1,<:Any,P3}}, ::Position{2}, P2) where {P1,P3}
return MtlArray{P1,P2,P3}
end
# # Set parameter 2
# set_parameter(::Type{<:MtlArray}, ::Position{2}, P2) = MtlArray{<:Any,P2}
# set_parameter(::Type{<:MtlArray{P1}}, ::Position{2}, P2) where {P1} = MtlArray{P1,P2}
# function set_parameter(::Type{<:MtlArray{<:Any,<:Any,P3}}, ::Position{2}, P2) where {P3}
# return MtlArray{<:Any,P2,P3}
# end
# function set_parameter(::Type{<:MtlArray{P1,<:Any,P3}}, ::Position{2}, P2) where {P1,P3}
# return MtlArray{P1,P2,P3}
# end

# Set parameter 3
set_parameter(::Type{<:MtlArray}, ::Position{3}, P3) = MtlArray{<:Any,<:Any,P3}
set_parameter(::Type{<:MtlArray{P1}}, ::Position{3}, P3) where {P1} = MtlArray{P1,<:Any,P3}
function set_parameter(::Type{<:MtlArray{<:Any,P2}}, ::Position{3}, P3) where {P2}
return MtlArray{<:Any,P2,P3}
end
function set_parameter(::Type{<:MtlArray{P1,P2}}, ::Position{3}, P3) where {P1,P2}
return MtlArray{P1,P2,P3}
end
# # Set parameter 3
# set_parameter(::Type{<:MtlArray}, ::Position{3}, P3) = MtlArray{<:Any,<:Any,P3}
# set_parameter(::Type{<:MtlArray{P1}}, ::Position{3}, P3) where {P1} = MtlArray{P1,<:Any,P3}
# function set_parameter(::Type{<:MtlArray{<:Any,P2}}, ::Position{3}, P3) where {P2}
# return MtlArray{<:Any,P2,P3}
# end
# function set_parameter(::Type{<:MtlArray{P1,P2}}, ::Position{3}, P3) where {P1,P2}
# return MtlArray{P1,P2,P3}
# end

default_parameter(::Type{<:MtlArray}, ::Position{1}) = Float32
default_parameter(::Type{<:MtlArray}, ::Position{2}) = 1
default_parameter(::Type{<:MtlArray}, ::Position{3}) = Metal.DefaultStorageMode
# default_parameter(::Type{<:MtlArray}, ::Position{1}) = Float32
# default_parameter(::Type{<:MtlArray}, ::Position{2}) = 1
# default_parameter(::Type{<:MtlArray}, ::Position{3}) = Metal.DefaultStorageMode

nparameters(::Type{<:MtlArray}) = Val(3)
# nparameters(::Type{<:MtlArray}) = Val(3)

# Metal-specific type parameter setting
function set_storagemode(arraytype::Type{<:MtlArray}, storagemode)
return set_parameter(arraytype, Position(3), storagemode)
end
# # Metal-specific type parameter setting
# function set_storagemode(arraytype::Type{<:MtlArray}, storagemode)
# return set_parameter(arraytype, Position(3), storagemode)
# end

SetParameters.unspecify_parameters(::Type{<:MtlArray}) = MtlArray
# TypeParameterAccessor.unspecify_parameters(::Type{<:MtlArray}) = MtlArray
2 changes: 1 addition & 1 deletion NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ for lib in [
:BaseExtensions,
:UnspecifiedTypes,
:Unwrap,
:SetParameters,
:TypeParameterAccessor,
:BroadcastMapConversion,
:RankFactorization,
:Sectors,
Expand Down
2 changes: 1 addition & 1 deletion NDTensors/src/abstractarray/fill.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using .SetParameters: DefaultParameters, specify_parameters
using .TypeParameterAccessor: DefaultParameters, specify_parameters
using .Unwrap: unwrap_type

function generic_randn(
Expand Down
4 changes: 2 additions & 2 deletions NDTensors/src/abstractarray/set_types.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using .SetParameters: set_ndims
using .TypeParameterAccessor: set_ndims
"""
# Do we still want to define things like this?
TODO: Use `Accessors.jl` notation:
Expand All @@ -11,7 +11,7 @@ TODO: Use `Accessors.jl` notation:
# TODO: Delete this when we change to using a
# `FillArray` instead. This is a stand-in
# to make things work with the current design.
function SetParameters.set_ndims(numbertype::Type{<:Number}, ndims)
function TypeParameterAccessor.set_ndims(numbertype::Type{<:Number}, ndims)
return numbertype
end

Expand Down
6 changes: 3 additions & 3 deletions NDTensors/src/array/set_types.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using .SetParameters: Position, get_parameter, set_parameters, set_eltype
using .TypeParameterAccessor: Position

SetParameters.eltype_position(::Type{<:SubArray}) = Position(1)
SetParameters.parenttype_position(::Type{<:SubArray}) = Position(3)
TypeParameterAccessor.eltype_position(::Type{<:SubArray}) = Position(1)
TypeParameterAccessor.parenttype_position(::Type{<:SubArray}) = Position(3)

# TODO: Figure out how to define this properly.
# function set_ndims(arraytype::Type{<:SubArray}, ndims)
Expand Down
2 changes: 1 addition & 1 deletion NDTensors/src/blocksparse/blocksparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ function set_datatype(storagetype::Type{<:BlockSparse}, datatype::Type{<:Abstrac
return BlockSparse{eltype(datatype),datatype,ndims(storagetype)}
end

function SetParameters.set_ndims(storagetype::Type{<:BlockSparse}, ndims::Int)
function TypeParameterAccessor.set_ndims(storagetype::Type{<:BlockSparse}, ndims::Int)
return BlockSparse{eltype(storagetype),datatype(storagetype),ndims}
end

Expand Down
27 changes: 6 additions & 21 deletions NDTensors/src/dense/set_types.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using .SetParameters:
SetParameters, Position, get_parameters, specify_parameters, unspecify_parameters
using .TypeParameterAccessor:
TypeParameterAccessor, Position, parameters, specify_parameters, unspecify_parameters

function set_datatype(storagetype::Type{<:Dense}, datatype::Type{<:AbstractVector})
return Dense{eltype(datatype),datatype}
Expand All @@ -11,23 +11,8 @@ function set_datatype(storagetype::Type{<:Dense}, datatype::Type{<:AbstractArray
)
end

SetParameters.unspecify_parameters(::Type{<:Dense}) = Dense
# TypeParameterAccessor.unspecify_parameters(::Type{<:Dense}) = Dense

SetParameters.parenttype_position(::Type{<:Dense}) = Position(2)
SetParameters.nparameters(::Type{<:Dense}) = Val(2)
SetParameters.get_parameter(::Type{<:Dense{P1}}, ::Position{1}) where {P1} = P1
SetParameters.get_parameter(::Type{<:Dense{<:Any,P2}}, ::Position{2}) where {P2} = P2
SetParameters.default_parameter(::Type{<:Dense}, ::Position{1}) = Float64
SetParameters.default_parameter(::Type{<:Dense}, ::Position{2}) = Vector

SetParameters.set_parameter(::Type{<:Dense}, ::Position{1}, P1) = Dense{P1}
function SetParameters.set_parameter(
::Type{<:Dense{<:Any,P2}}, ::Position{1}, P1
) where {P2}
return Dense{P1,P2}
end

SetParameters.set_parameter(::Type{<:Dense}, ::Position{2}, P2) = Dense{<:Any,P2}
function SetParameters.set_parameter(::Type{<:Dense{P1}}, ::Position{2}, P2) where {P1}
return Dense{P1,P2}
end
TypeParameterAccessor.parenttype_position(::Type{<:Dense}) = Position(2)
TypeParameterAccessor.default_parameter(::Type{<:Dense}, ::Position{1}) = Float64
TypeParameterAccessor.default_parameter(::Type{<:Dense}, ::Position{2}) = Vector
4 changes: 2 additions & 2 deletions NDTensors/src/diag/set_types.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
function SetParameters.set_eltype(storagetype::Type{<:UniformDiag}, eltype::Type)
function TypeParameterAccessor.set_eltype(storagetype::Type{<:UniformDiag}, eltype::Type)
return Diag{eltype,eltype}
end

function SetParameters.set_eltype(
function TypeParameterAccessor.set_eltype(
storagetype::Type{<:NonuniformDiag}, eltype::Type{<:AbstractArray}
)
return Diag{eltype,similartype(storagetype, eltype)}
Expand Down
2 changes: 1 addition & 1 deletion NDTensors/src/lib/AlgorithmSelection/src/algorithm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Algorithm(s; kwargs...) = Algorithm{Symbol(s)}(NamedTuple(kwargs))

Algorithm(alg::Algorithm) = alg

# TODO: Use `SetParameters`.
# TODO: Use `TypeParameterAccessor`.
algorithm_string(::Algorithm{Alg}) where {Alg} = string(Alg)

function Base.show(io::IO, alg::Algorithm)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ end

Base.axes(block_arr::BlockSparseArray) = block_arr.axes
BlockArrays.blocks(a::BlockSparseArray) = a.blocks
# TODO: Use `SetParameters`.
# TODO: Use `TypeParameterAccessor`.
blocktype(a::BlockSparseArray{<:Any,<:Any,A}) where {A} = A

# TODO: Use `SetParameters`.
# TODO: Use `TypeParameterAccessor`.
set_ndims(::Type{<:Array{T}}, n) where {T} = Array{T,n}

# TODO: Move to `AbstractArray` file.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ Base.axes(a::BlockSparseArray) = a.axes
# BlockArrays `AbstractBlockArray` interface
BlockArrays.blocks(a::BlockSparseArray) = a.blocks

# TODO: Use `SetParameters`.
# TODO: Use `TypeParameterAccessor`.
blockstype(::Type{<:BlockSparseArray{<:Any,<:Any,<:Any,B}}) where {B} = B

# Base interface
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dimnames(a::AbstractNamedDimsArray) = error("Not implemented")
Base.parent(::AbstractNamedDimsArray) = error("Not implemented")

# TODO: Use `Unwrap`.
# TODO: Use `SetParameters`.
# TODO: Use `TypeParameterAccessor`.
parenttype(::AbstractNamedDimsArray{<:Any,<:Any,Parent}) where {Parent} = Parent

# Set the names of an unnamed AbstractArray
Expand Down
Loading
Loading