Skip to content

Commit

Permalink
Simplify nonlinfit by doing float in the beginning
Browse files Browse the repository at this point in the history
  • Loading branch information
singularitti committed May 12, 2020
1 parent 48da7c0 commit 8105c15
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/Fitting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ If `eos`, `volumes` and `ydata` are all unitless, `volumes` must have the same u
function nonlinfit(f, volumes, ydata; kwargs...)
eos, property = fieldvalues(f)
T = constructorof(typeof(eos)) # Get the `UnionAll` type
params, volumes, ydata = _preprocess(_Data(eos), _Data(volumes), _Data(ydata))
params, volumes, ydata = _preprocess(_Data(float(eos)), _Data(float(volumes)), _Data(float(ydata)))
model = (x, p) -> map(T(p...)(property), x)
fit = curve_fit(model, volumes, ydata, params; kwargs...)
return _postprocess(T(fit.param...), _Data(eos))
Expand All @@ -75,17 +75,15 @@ struct _Data{S,T}
end
_Data(data::T) where {T} = _Data{eltype(data),T}(data)

_preprocess(eos::_Data{<:Real}, xdata::_Data{<:Real}, ydata::_Data{<:Real}) =
float.(fieldvalues(eos.data)), float(xdata.data), float(ydata.data)
_preprocess(eos, xdata, ydata) = (eos, xdata, ydata)
function _preprocess(
eos::_Data{<:AbstractQuantity},
xdata::_Data{<:AbstractQuantity},
ydata::_Data{<:AbstractQuantity},
)
values = fieldvalues(eos.data)
original_units = unit.(values) # Keep a record of `eos`'s units
f = x -> map(float _ustrip, x) # Convert to preferred units and strip the unit
return map(f, (values, xdata.data, ydata.data))
return map(_ustrip, (values, xdata.data, ydata.data)) # Convert to preferred units and strip the unit
end # function _preprocess

_postprocess(eos, trial_eos::_Data{<:Real}) = eos
Expand All @@ -108,4 +106,6 @@ _upreferred(::typeof(dimension(u"Pa"))) = u"eV/angstrom^3"
_upreferred(::typeof(dimension(u"1/Pa"))) = u"angstrom^3/eV"
_upreferred(::typeof(dimension(u"1/Pa^2"))) = u"angstrom^6/eV^2"

Base.float(eos::EquationOfState) = constructorof(eos)(float.(fieldvalues(eos)))

end # module Fitting

0 comments on commit 8105c15

Please sign in to comment.