Skip to content

Commit

Permalink
Merge pull request #19 from invenia/wct/tweak-predict-docs
Browse files Browse the repository at this point in the history
Clarify `predict` docs
  • Loading branch information
willtebbutt authored Aug 11, 2020
2 parents 9fa0caa + f3a6c79 commit 1c87cf7
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Models"
uuid = "e6388cff-ecff-480c-9b53-83211bf7812a"
authors = ["Invenia Technical Computing Corporation"]
version = "0.2.1"
version = "0.2.2"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand Down
11 changes: 9 additions & 2 deletions docs/src/design.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,25 @@ One does not have to carry both a [`Model`](@ref) type, and a varying collection

```julia
model = StatsBase.fit(
template,
template::Template,
outputs::AbstractMatrix, # always Features x Observations
inputs::AbstractMatrix, # always Variates x Observations
weights=uweights(Float32, size(outputs, 2))
)::Model
```

```julia
# estimate_type(model) == PointEsimate
outputs = StatsBase.predict(
model,
model::Model,
inputs::AbstractMatrix # always Features x Observations
)::AbstractMatrix # always Variates x Observations

# estimate_type(model) == DistributionEstimate
outputs = StatsBase.predict(
model::Model,
inputs::AbstractMatrix # always Features x Observations
)::AbstractVector{<:Distribution} # length Observations
```

[`fit`](@ref) takes in a [`Template`](@ref) and some *data* and returns a [`Model`](@ref) that has been fit to the data.
Expand Down
12 changes: 8 additions & 4 deletions src/Models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Defined as well are the traits:
abstract type Model end

"""
fit(::Template, output, input, [weights]) -> Model
fit(::Template, output::AbstractMatrix, input::AbstractMatrix, [weights]) -> Model
Fit the [`Template`](@ref) to the `output` and `input` data and return a trained
[`Model`](@ref).
Expand All @@ -37,11 +37,15 @@ Convention is that `weights` defaults to `StatsBase.uweights(Float32, size(outpu
function fit end

"""
predict(::Model, input)
predict(model::Model, inputs::AbstractMatrix)
Predict targets for the provided `input` and [`Model`](@ref).
Predict targets for the provided the collection of `inputs` and [`Model`](@ref).
Returns a predictive distribution or point estimates depending on the [`Model`](@ref).
If the `estimate_type(model)` is [`PointEstimate`](@ref) then this function should return
another `AbstractMatrix` in which each column contains the prediction for a single input.
If the `estimate_type(model)` is [`DistributionEstimate`](@ref) then this function should
return a `AbstractVector{<:Distribution}`.
"""
function predict end

Expand Down

2 comments on commit 1c87cf7

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/19330

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.2 -m "<description of version>" 1c87cf7f90af796fc04cf0687e28635fd095782e
git push origin v0.2.2

Please sign in to comment.