Skip to content

Commit

Permalink
Fancy Indexing Getter
Browse files Browse the repository at this point in the history
  • Loading branch information
jjjkkkjjj committed Jul 19, 2022
1 parent a7eb139 commit 19fa6dd
Show file tree
Hide file tree
Showing 7 changed files with 238 additions and 61 deletions.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,18 @@ But this is beta version. so, any bug may be ocurred.

Please report me by issue!

**TODO**

- [x] Arithmetic Operation
- [x] Angle, Conjugate and Absolute
- [x] Math (partial: `sin,cos,tan,exp,log`)
- [x] Basic Subscription Getter
- [ ] Basic Subscription Setter
- [x] Boolean Indexing Getter
- [ ] Boolean Indexing Setter
- [x] Fancy Indexing Getter
- [ ] Fancy Indexing Setter

```swift
let real = Matft.arange(start: 0, to: 16, by: 1).reshape([2,2,4])
let imag = Matft.arange(start: 0, to: -16, by: -1).reshape([2,2,4])
Expand Down
3 changes: 3 additions & 0 deletions Sources/Matft/core/protocol/mftypeProtocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ extension Bool: MfBinary, StoredFloat {

public protocol MfStorable: MfTypable, FloatingPoint{
associatedtype vDSPType: vDSP_ComplexTypable
associatedtype blasType: blas_ComplexTypable

static func num(_ number: Int) -> Self
static func from<T: MfTypable>(_ value: T) -> Self
Expand All @@ -205,6 +206,7 @@ public protocol MfStorable: MfTypable, FloatingPoint{

extension Float: MfStorable{
public typealias vDSPType = DSPSplitComplex
public typealias blasType = DSPComplex

public static func num(_ number: Int) -> Float {
return Float(number)
Expand Down Expand Up @@ -251,6 +253,7 @@ extension Float: MfStorable{
}
extension Double: MfStorable{
public typealias vDSPType = DSPDoubleSplitComplex
public typealias blasType = DSPDoubleComplex

public static func num(_ number: Int) -> Double {
return Double(number)
Expand Down
8 changes: 4 additions & 4 deletions Sources/Matft/core/util/pointer/withptr.swift
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ extension MfArray{

return ret
}
internal func withUnsafeMutableblasPointer<T: blas_ComplexTypable, R>(datatype: T.Type, vDSP_func: vDSP_convertz_func<T.vDSPType, T>, _ body: (UnsafeMutablePointer<T>) throws -> R) rethrows -> R{
internal func withUnsafeMutableblasPointer<T: blas_ComplexTypable, R>(datatype: T.Type, vDSP_func: vDSP_convert_func<T.vDSPType, T>, _ body: (UnsafeMutablePointer<T>) throws -> R) rethrows -> R{

let ret = try self.withUnsafeMutablevDSPPointer(datatype: T.vDSPType.self){ [unowned self](ptr) -> R in
var arr = Array(repeating: T(real: T.T.zero, imag: T.T.zero), count: self.storedSize)
wrap_vDSP_convertz(arr.count, ptr, 1, &arr, 1, vDSP_func)
wrap_vDSP_convert(arr.count, ptr, 1, &arr, 1, vDSP_func)
return try body(&arr)
}

Expand Down Expand Up @@ -131,11 +131,11 @@ extension MfData{

return ret
}
internal func withUnsafeMutableblasPointer<T: blas_ComplexTypable, R>(datatype: T.Type, vDSP_func: vDSP_convertz_func<T.vDSPType, T>, _ body: (UnsafeMutablePointer<T>) throws -> R) rethrows -> R{
internal func withUnsafeMutableblasPointer<T: blas_ComplexTypable, R>(datatype: T.Type, vDSP_func: vDSP_convert_func<T.vDSPType, T>, _ body: (UnsafeMutablePointer<T>) throws -> R) rethrows -> R{

let ret = try self.withUnsafeMutablevDSPPointer(datatype: T.vDSPType.self){ [unowned self](ptr) -> R in
var arr = Array(repeating: T(real: T.T.zero, imag: T.T.zero), count: self.storedSize)
wrap_vDSP_convertz(arr.count, ptr, 1, &arr, 1, vDSP_func)
wrap_vDSP_convert(arr.count, ptr, 1, &arr, 1, vDSP_func)
return try body(&arr)
}

Expand Down
9 changes: 9 additions & 0 deletions Sources/Matft/function/method/subscript.swift
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,9 @@ extension MfArray: MfSubscriptable{
}

private func _set_mfarray(indices: inout [Any], newValue: MfArray){
unsupport_complex(self)
unsupport_complex(newValue)

for index in indices{
if let _ = index as? SubscriptOps{
fatalError("SubscriptOps must not be passed to setter")
Expand Down Expand Up @@ -289,6 +292,7 @@ extension MfArray: MfSubscriptable{


private func _get_mfarray(indices: MfArray) -> MfArray{
unsupport_complex(indices)

switch indices.mftype {
case .Bool:
Expand Down Expand Up @@ -318,6 +322,8 @@ extension MfArray: MfSubscriptable{


private func _fancygetall_mfarray(indices: inout [MfArray]) -> MfArray{
let _ = indices.map{ unsupport_complex($0) }

switch self.storedType {
case .Float:
return fancygetall_by_cblas(self, &indices, cblas_scopy)
Expand All @@ -328,6 +334,9 @@ extension MfArray: MfSubscriptable{
}

private func _fancysetall_mfarray(indices: inout [MfArray], assignedMfarray: MfArray) -> Void{
unsupport_complex(self)
unsupport_complex(assignedMfarray)

switch self.storedType {
case .Float:
fancysetall_by_cblas(self, &indices, assignedMfarray, cblas_scopy)
Expand Down
117 changes: 88 additions & 29 deletions Sources/Matft/library/cblas.swift
Original file line number Diff line number Diff line change
Expand Up @@ -379,27 +379,58 @@ internal func fancyndget_by_cblas<T: MfStorable>(_ mfarray: MfArray, _ indices:

let workSize = shape2size(&workShape)

let newdata = MfData(size: retSize, mftype: mfarray.mftype)
if mfarray.isReal{
let newdata = MfData(size: retSize, mftype: mfarray.mftype)

newdata.withUnsafeMutableStartPointer(datatype: T.self){
dstptrT in
var dstptrT = dstptrT
let _ = mfarray.withUnsafeMutableStartPointer(datatype: T.self){
[unowned mfarray](srcptr) in

let offsets = (indices.data as! [Int]).map{ get_positive_index($0, axissize: mfarray.shape[0], axis: 0) * mfarray.strides[0] }
for offset in offsets{
wrap_cblas_copy(workSize, srcptr + offset, 1, dstptrT, 1, cblas_func)
dstptrT += workSize
newdata.withUnsafeMutableStartPointer(datatype: T.self){
dstptrT in
var dstptrT = dstptrT
let _ = mfarray.withUnsafeMutableStartPointer(datatype: T.self){
[unowned mfarray](srcptr) in

let offsets = (indices.data as! [Int]).map{ get_positive_index($0, axissize: mfarray.shape[0], axis: 0) * mfarray.strides[0] }
for offset in offsets{
wrap_cblas_copy(workSize, srcptr + offset, 1, dstptrT, 1, cblas_func)
dstptrT += workSize
}
}
}

let newstructure = MfStructure(shape: retShape, mforder: .Row)

return MfArray(mfdata: newdata, mfstructure: newstructure)
}
else{
let newdata = MfData(size: retSize, mftype: mfarray.mftype, complex: true)

newdata.withUnsafeMutablevDSPPointer(datatype: T.vDSPType.self){
dstptrT in
var dstptrTr = dstptrT.pointee.realp as! UnsafeMutablePointer<T>
var dstptrTi = dstptrT.pointee.imagp as! UnsafeMutablePointer<T>
let _ = mfarray.withUnsafeMutablevDSPPointer(datatype: T.vDSPType.self){
[unowned mfarray](srcptr) in
let srcptrr = srcptr.pointee.realp as! UnsafeMutablePointer<T>
let srcptri = srcptr.pointee.imagp as! UnsafeMutablePointer<T>

let offsets = (indices.data as! [Int]).map{ get_positive_index($0, axissize: mfarray.shape[0], axis: 0) * mfarray.strides[0] }
for offset in offsets{
wrap_cblas_copy(workSize, srcptrr + offset, 1, dstptrTr, 1, cblas_func)
dstptrTr += workSize

wrap_cblas_copy(workSize, srcptri + offset, 1, dstptrTi, 1, cblas_func)
dstptrTi += workSize
}
}
}

let newstructure = MfStructure(shape: retShape, mforder: .Row)

return MfArray(mfdata: newdata, mfstructure: newstructure)
}

let newstructure = MfStructure(shape: retShape, mforder: .Row)

return MfArray(mfdata: newdata, mfstructure: newstructure)
}



/// Getter function for the fancy indexing on a given Interger indices.
/// - Parameters:
/// - mfarray: An inpu mfarray. Must be more than 2d
Expand Down Expand Up @@ -433,23 +464,51 @@ internal func fancygetall_by_cblas<T: MfStorable>(_ mfarray: MfArray, _ indices:
array([0, 1, 2])
*/

let newdata = MfData(size: retSize, mftype: mfarray.mftype)
newdata.withUnsafeMutableStartPointer(datatype: T.self){
dstptrT in
var dstptrT = dstptrT
let _ = mfarray.withUnsafeMutableStartPointer(datatype: T.self){
srcptr in

for offset in offsets{
wrap_cblas_copy(workSize, srcptr + offset, 1, dstptrT, 1, cblas_func)
dstptrT += workSize
if mfarray.isReal{
let newdata = MfData(size: retSize, mftype: mfarray.mftype)
newdata.withUnsafeMutableStartPointer(datatype: T.self){
dstptrT in
var dstptrT = dstptrT
let _ = mfarray.withUnsafeMutableStartPointer(datatype: T.self){
srcptr in

for offset in offsets{
wrap_cblas_copy(workSize, srcptr + offset, 1, dstptrT, 1, cblas_func)
dstptrT += workSize
}
}
}

let newstructure = MfStructure(shape: retShape, mforder: .Row)

return MfArray(mfdata: newdata, mfstructure: newstructure)
}
else{
let newdata = MfData(size: retSize, mftype: mfarray.mftype, complex: true)

newdata.withUnsafeMutablevDSPPointer(datatype: T.vDSPType.self){
dstptrT in
var dstptrTr = dstptrT.pointee.realp as! UnsafeMutablePointer<T>
var dstptrTi = dstptrT.pointee.imagp as! UnsafeMutablePointer<T>
let _ = mfarray.withUnsafeMutablevDSPPointer(datatype: T.vDSPType.self){
srcptr in
let srcptrr = srcptr.pointee.realp as! UnsafeMutablePointer<T>
let srcptri = srcptr.pointee.imagp as! UnsafeMutablePointer<T>

for offset in offsets{
wrap_cblas_copy(workSize, srcptrr + offset, 1, dstptrTr, 1, cblas_func)
dstptrTr += workSize

wrap_cblas_copy(workSize, srcptri + offset, 1, dstptrTi, 1, cblas_func)
dstptrTi += workSize
}
}
}

let newstructure = MfStructure(shape: retShape, mforder: .Row)

return MfArray(mfdata: newdata, mfstructure: newstructure)
}

let newstructure = MfStructure(shape: retShape, mforder: .Row)

return MfArray(mfdata: newdata, mfstructure: newstructure)
}

/// Setter function for the fancy indexing on a given Interger indices.
Expand Down
109 changes: 81 additions & 28 deletions Sources/Matft/library/vDSP.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1122,28 +1122,59 @@ internal func boolget_by_vDSP<T: MfStorable>(_ mfarray: MfArray, _ indices: MfAr
var retShape = [true_num] + lastShape
let retSize = shape2size(&retShape)

let newdata = MfData(size: retSize, mftype: mfarray.mftype)
newdata.withUnsafeMutableStartPointer(datatype: T.self){
dstptrT in
indicesT.withUnsafeMutableStartPointer(datatype: T.self){
//[unowned indicesT](indptr) in
indptr in
// note that indices and mfarray is row contiguous
mfarray.withUnsafeMutableStartPointer(datatype: T.self){
srcptr in

for vDSPPrams in OptOffsetParamsSequence(shape: indicesT.shape, bigger_strides: indicesT.strides, smaller_strides: mfarray.strides){
wrap_vDSP_cmprs(vDSPPrams.blocksize, srcptr + vDSPPrams.s_offset, vDSPPrams.s_stride, indptr + vDSPPrams.b_offset, vDSPPrams.b_stride, dstptrT + vDSPPrams.b_offset, vDSPPrams.b_stride, vDSP_func)
if mfarray.isReal{
let newdata = MfData(size: retSize, mftype: mfarray.mftype)
newdata.withUnsafeMutableStartPointer(datatype: T.self){
dstptrT in
indicesT.withUnsafeMutableStartPointer(datatype: T.self){
//[unowned indicesT](indptr) in
indptr in
// note that indices and mfarray is row contiguous
mfarray.withUnsafeMutableStartPointer(datatype: T.self){
srcptr in

for vDSPPrams in OptOffsetParamsSequence(shape: indicesT.shape, bigger_strides: indicesT.strides, smaller_strides: mfarray.strides){
wrap_vDSP_cmprs(vDSPPrams.blocksize, srcptr + vDSPPrams.s_offset, vDSPPrams.s_stride, indptr + vDSPPrams.b_offset, vDSPPrams.b_stride, dstptrT + vDSPPrams.b_offset, vDSPPrams.b_stride, vDSP_func)
}
//vDSP_func(srcptr.baseAddress!, vDSP_Stride(1), indptr.baseAddress!, vDSP_Stride(1), dstptrT, vDSP_Stride(1), vDSP_Length(indicesT.size))
}
//vDSP_func(srcptr.baseAddress!, vDSP_Stride(1), indptr.baseAddress!, vDSP_Stride(1), dstptrT, vDSP_Stride(1), vDSP_Length(indicesT.size))
}
}


let newstructure = MfStructure(shape: retShape, mforder: .Row)

return MfArray(mfdata: newdata, mfstructure: newstructure)
}
else{
let newdata = MfData(size: retSize, mftype: mfarray.mftype, complex: true)
newdata.withUnsafeMutablevDSPPointer(datatype: T.vDSPType.self){
dstptrT in
indicesT.withUnsafeMutableStartPointer(datatype: T.self){
//[unowned indicesT](indptr) in
indptr in
// note that indices and mfarray is row contiguous
mfarray.withUnsafeMutablevDSPPointer(datatype: T.vDSPType.self){
srcptr in
let srcptrr = srcptr.pointee.realp as! UnsafeMutablePointer<T>
let srcptri = srcptr.pointee.imagp as! UnsafeMutablePointer<T>
let dstptrTr = dstptrT.pointee.realp as! UnsafeMutablePointer<T>
let dstptrTi = dstptrT.pointee.imagp as! UnsafeMutablePointer<T>

for vDSPPrams in OptOffsetParamsSequence(shape: indicesT.shape, bigger_strides: indicesT.strides, smaller_strides: mfarray.strides){
wrap_vDSP_cmprs(vDSPPrams.blocksize, srcptrr + vDSPPrams.s_offset, vDSPPrams.s_stride, indptr + vDSPPrams.b_offset, vDSPPrams.b_stride, dstptrTr + vDSPPrams.b_offset, vDSPPrams.b_stride, vDSP_func)
wrap_vDSP_cmprs(vDSPPrams.blocksize, srcptri + vDSPPrams.s_offset, vDSPPrams.s_stride, indptr + vDSPPrams.b_offset, vDSPPrams.b_stride, dstptrTi + vDSPPrams.b_offset, vDSPPrams.b_stride, vDSP_func)
}
//vDSP_func(srcptr.baseAddress!, vDSP_Stride(1), indptr.baseAddress!, vDSP_Stride(1), dstptrT, vDSP_Stride(1), vDSP_Length(indicesT.size))
}
}
}


let newstructure = MfStructure(shape: retShape, mforder: .Row)

return MfArray(mfdata: newdata, mfstructure: newstructure)
}


let newstructure = MfStructure(shape: retShape, mforder: .Row)

return MfArray(mfdata: newdata, mfstructure: newstructure)
}


Expand Down Expand Up @@ -1184,18 +1215,40 @@ internal func fancy1dgetcol_by_vDSP<T: MfStorable>(_ mfarray: MfArray, _ indices
print(c)
//[0.0, 0.0, 3.0]
*/
let newdata = MfData(size: indices.size, mftype: mfarray.mftype)
newdata.withUnsafeMutableStartPointer(datatype: T.self){
dstptrT in
let _ = mfarray.withUnsafeMutableStartPointer(datatype: T.self){
srcptr in
var offsets = (indices.data as! [Int]).map{ UInt(get_positive_index($0, axissize: mfarray.size, axis: 0) * mfarray.strides[0] + 1) }
wrap_vDSP_gathr(indices.size, srcptr, &offsets, 1, dstptrT, 1, vDSP_func)
if mfarray.isReal{
let newdata = MfData(size: indices.size, mftype: mfarray.mftype)
newdata.withUnsafeMutableStartPointer(datatype: T.self){
dstptrT in
let _ = mfarray.withUnsafeMutableStartPointer(datatype: T.self){
srcptr in
var offsets = (indices.data as! [Int]).map{ UInt(get_positive_index($0, axissize: mfarray.size, axis: 0) * mfarray.strides[0] + 1) }
wrap_vDSP_gathr(indices.size, srcptr, &offsets, 1, dstptrT, 1, vDSP_func)
}
}

let newstructure = MfStructure(shape: indices.shape, strides: indices.strides)
return MfArray(mfdata: newdata, mfstructure: newstructure)
}
else{
let newdata = MfData(size: indices.size, mftype: mfarray.mftype, complex: true)
newdata.withUnsafeMutablevDSPPointer(datatype: T.vDSPType.self){
dstptrT in
let _ = mfarray.withUnsafeMutablevDSPPointer(datatype: T.vDSPType.self){
srcptr in
let srcptrr = srcptr.pointee.realp as! UnsafeMutablePointer<T>
let srcptri = srcptr.pointee.imagp as! UnsafeMutablePointer<T>
let dstptrTr = dstptrT.pointee.realp as! UnsafeMutablePointer<T>
let dstptrTi = dstptrT.pointee.imagp as! UnsafeMutablePointer<T>

var offsets = (indices.data as! [Int]).map{ UInt(get_positive_index($0, axissize: mfarray.size, axis: 0) * mfarray.strides[0] + 1) }
wrap_vDSP_gathr(indices.size, srcptrr, &offsets, 1, dstptrTr, 1, vDSP_func)
wrap_vDSP_gathr(indices.size, srcptri, &offsets, 1, dstptrTi, 1, vDSP_func)
}
}

let newstructure = MfStructure(shape: indices.shape, strides: indices.strides)
return MfArray(mfdata: newdata, mfstructure: newstructure)
}

let newstructure = MfStructure(shape: indices.shape, strides: indices.strides)
return MfArray(mfdata: newdata, mfstructure: newstructure)
}

/*
Expand Down
Loading

0 comments on commit 19fa6dd

Please sign in to comment.