Skip to content

Commit

Permalink
now subtyped as <: MMI.Model
Browse files Browse the repository at this point in the history
  • Loading branch information
pat-alt committed Oct 12, 2022
1 parent 06b6cea commit 05f27f2
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 7 deletions.
6 changes: 2 additions & 4 deletions src/ConformalModels/ConformalModels.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
module ConformalModels

using MLJ
using MLJModelInterface
import MLJModelInterface as MMI
import MLJModelInterface: predict, fit, save, restore
import MLJBase

"An abstract base type for conformal models."
abstract type ConformalModel end
abstract type ConformalModel <: MMI.Model end
export ConformalModel

const MMI = MLJModelInterface

include("conformal_models.jl")

include("regression.jl")
Expand Down
3 changes: 2 additions & 1 deletion src/ConformalModels/conformal_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ export conformal_model
Wrapper function to fit the underlying MLJ model.
"""
function MMI.fit(conf_model::ConformalModel, verbosity, X, y)
fitresult, cache, report = fit(conf_model.model, verbosity, MMI.reformat(X, y))
fitresult, cache, report = fit(conf_model.model, verbosity, MMI.reformat(conf_model.model, X, y)...)
return (fitresult, cache, report)
end
export fit

# Calibration
"""
Expand Down
3 changes: 1 addition & 2 deletions src/ConformalPrediction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ module ConformalPrediction
# conformal models
include("ConformalModels/ConformalModels.jl")
using .ConformalModels
export conformal_model
export calibrate!
export conformal_model, fit, calibrate!
export NaiveConformalRegressor
export LABELConformalClassifier

Expand Down
14 changes: 14 additions & 0 deletions test/classification.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ available_models = ConformalPrediction.ConformalModels.available_models[:classif
conf_model.fitresult = mach.fitresult
calibrate!(conf_model, selectrows(X, calibration), y[calibration])

# Use generic fit() method:
conf_model.fitresult = nothing
_mach = machine(conf_model, X, y)
fit!(_mach, rows=train)
conf_model.fitresult = _mach.fitresult
calibrate!(conf_model, selectrows(X, calibration), y[calibration])

@test !isnothing(conf_model.scores)
predict(conf_model, selectrows(X, test))
end
Expand All @@ -41,6 +48,13 @@ available_models = ConformalPrediction.ConformalModels.available_models[:classif
conf_model.fitresult = mach.fitresult
calibrate!(conf_model, selectrows(X, calibration), y[calibration])

# Use generic fit() method:
conf_model.fitresult = nothing
_mach = machine(conf_model, X, y)
fit!(_mach, rows=train)
conf_model.fitresult = _mach.fitresult
calibrate!(conf_model, selectrows(X, calibration), y[calibration])

@test !isnothing(conf_model.scores)
predict(conf_model, selectrows(X, test))
end
Expand Down
14 changes: 14 additions & 0 deletions test/regression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ available_models = ConformalPrediction.ConformalModels.available_models[:regress
conf_model.fitresult = mach.fitresult
calibrate!(conf_model, selectrows(X, calibration), y[calibration])

# Use generic fit() method:
conf_model.fitresult = nothing
_mach = machine(conf_model, X, y)
fit!(_mach, rows=train)
conf_model.fitresult = _mach.fitresult
calibrate!(conf_model, selectrows(X, calibration), y[calibration])

@test !isnothing(conf_model.scores)
predict(conf_model, selectrows(X, test))
end
Expand All @@ -40,6 +47,13 @@ available_models = ConformalPrediction.ConformalModels.available_models[:regress
# Use fitresult from machine:
conf_model.fitresult = mach.fitresult
calibrate!(conf_model, selectrows(X, calibration), y[calibration])

# Use generic fit() method:
conf_model.fitresult = nothing
_mach = machine(conf_model, X, y)
fit!(_mach, rows=train)
conf_model.fitresult = _mach.fitresult
calibrate!(conf_model, selectrows(X, calibration), y[calibration])

@test !isnothing(conf_model.scores)
predict(conf_model, selectrows(X, test))
Expand Down

1 comment on commit 05f27f2

@pat-alt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's an initial stab at #5

A few things I'm unsure about:

  1. My goal is to use an abstract supertype ConformalModel <: MMI.Model for which the compulsory MMI.fit and MMI.predict will be implemented. Subtypes then only need to implement generic methods that are relevant to how exactly conformal prediction is implemented (like here). The idea is that this way the compulsory methods don't have to be defined each time for each different conformal predictor. The motivation is that this should make it easier for contributors to add new conformal predictors. With respect to fitting, this seems to be working well: I can wrap both the NaiveConformalRegressor and the LABELConformalClassifier as a machine and call the fit! method on it. But perhaps what I have in mind is not a good idea.
  2. I have not yet added the compulsory predict method, because conformal predictions are interval-/set-valued. In other words, it isn't obvious how to produce predictions yhat in the conventional format (e.g. yhat is a vector of Distribution in the case of classification). Not sure how/if these types of predictions can be used for tuning the conformal model. I can always just define a predict function that just calls the underlying model, much like I've done with the fit method here, but then tuning does not actually address the "conformal" aspect of the predictions.

Any guidance would be much appreciated @ablaom 🙏🏽

Please sign in to comment.