-
-
Notifications
You must be signed in to change notification settings - Fork 26
/
read.jl
73 lines (62 loc) · 2.75 KB
/
read.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""
array(p::TensorProto)
Return `p` as an `Array` of the correct type. Second argument can be used to change type of the returned array
"""
function array(p::TensorProto, wrap=Array)
# Copy pasted from jl
# Can probably be cleaned up a bit
# TODO: Add missing datatypes...
if p.data_type === TensorProto_DataType.INT64
if hasproperty(p, :int64_data) && !isempty(p.int64_data)
return reshape(reinterpret(Int64, p.int64_data), reverse(p.dims)...) |> wrap
end
return reshape(reinterpret(Int64, p.raw_data), reverse(p.dims)...) |> wrap
end
if p.data_type === TensorProto_DataType.INT32
if hasproperty(p, :int32_data) && !isempty(p.int32_data)
return reshape(p.int32_data , reverse(p.dims)...) |> wrap
end
return reshape(reinterpret(Int32, p.raw_data), reverse(p.dims)...) |> wrap
end
if p.data_type === TensorProto_DataType.INT8
return reshape(reinterpret(Int8, p.raw_data), reverse(p.dims)...) |> wrap
end
if p.data_type === TensorProto_DataType.DOUBLE
if hasproperty(p, :double_data) && !isempty(p.double_data)
return reshape(p.double_data , reverse(p.dims)...) |> wrap
end
return reshape(reinterpret(Float64, p.raw_data), reverse(p.dims)...) |> wrap
end
if p.data_type === TensorProto_DataType.FLOAT
if hasproperty(p,:float_data) && !isempty(p.float_data)
return reshape(reinterpret(Float32, p.float_data), reverse(p.dims)...) |> wrap
end
return reshape(reinterpret(Float32, p.raw_data), reverse(p.dims)...) |> wrap
end
if p.data_type === TensorProto_DataType.FLOAT16
return reshape(reinterpret(Float16, p.raw_data), reverse(p.dims)...) |> wrap
end
end
Base.size(vip::ValueInfoProto) = size(vip._type)
Base.size(tp::TypeProto) = size(tp.tensor_type)
Base.size(tp::TensorProto) = tp.dims
Base.size(tp_t::TypeProto_Tensor) = hasproperty(tp_t, :shape) ? size(tp_t.shape) : missing
Base.size(tsp::TensorShapeProto) = size.(Tuple(reverse(tsp.dim)))
Base.size(tsp_d::TensorShapeProto_Dimension) = hasproperty(tsp_d, :dim_value) ? tsp_d.dim_value : missing
"""
attribute(p::AttributeProto)
Return attribute in `p` as a name => value pair.
"""
function attribute(p::AttributeProto)
# Copy paste from ONNX.jl
if (p._type != 0)
field = [:f, :i, :s, :t, :g, :floats, :ints, :strings, :tensors, :graphs][p._type]
if field === :s
return Symbol(p.name) => String(getproperty(p, field))
elseif field === :strings
return Symbol(p.name) => String.(getproperty(p, field))
end
return Symbol(p.name) => getproperty(p, field)
end
end
Base.Dict(pa::AbstractVector{AttributeProto}) = Dict(attribute(p) for p in pa)