Skip to content

Commit

Permalink
-
Browse files Browse the repository at this point in the history
  • Loading branch information
Sanaxen committed Nov 16, 2022
1 parent 5a5e171 commit 7aef019
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 7 deletions.
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@ Through the C# UI, you can easily perform time series forecasting with the
At the core of this application, tft is ![](./images/554068490.png) https://github.com/mlverse.
It relies heavily on https://github.com/mlverse/tft, a wonderful library implemented in R developed at

<img src="./images/image03.gif" width=70%>

<img src="./images/image03.gif" width=80%>
---
### Feature importance Plot
The permutation feature importance algorithm based on Fisher, Rudin, and Dominici (2018)
<img src="./images/image01.png" width=80%>
# Requirements

[webview2](https://developer.microsoft.com/ja-jp/microsoft-edge/webview2/)
Expand Down
Binary file added images/image01.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
29 changes: 26 additions & 3 deletions script/tft_util.r
Original file line number Diff line number Diff line change
Expand Up @@ -879,7 +879,7 @@ permutationFeatureImportance<- function(fitted, test, validation=F, base_name=""

FI = NULL
FI_s = NULL
sampling_n = 10
sampling_n = 1

for ( k in 1:sampling_n )
{
Expand Down Expand Up @@ -958,7 +958,30 @@ permutationFeatureImportance<- function(fitted, test, validation=F, base_name=""
g2 <- ggplot(data = x, aes(x = date, y = key , fill = importance)) +
geom_tile()+
scale_fill_gradient2(low = "springgreen4", mid = "yellow", high = "red", midpoint = 0.5)
ggsave(file = paste(base_name,"_feature_importance_time.png", sep=""), plot = g2, dpi = 100, width = 6.4, height = 4.8*length(name)/10)
ggsave(file = paste(base_name,"_feature_importance_time1.png", sep=""), plot = g2, dpi = 100, width = 6.4, height = 4.8*length(name)/10)

FI_s$date <- NULL
FI_s <- (FI_s - min(FI_s))/(max(FI_s) - min(FI_s))
FI_s$date <- test_tmp$date
x<-horizontally_to_vertically(FI_s, ids_cols=c('date'), key=name)
x$importance <- x$target
x$target <- NULL

g3 <- x %>%
ggplot(aes(x = date, y = importance, color=key))+
geom_line()+
scale_x_datetime(breaks = date_breaks(unit), labels = date_format("%Y-%m-%d %H")) +
theme(axis.text.x = element_text(angle = 45, hjust = 1))

g4 <- ggplot(x, aes(x = date, y = importance, fill = key))
g4 <- g4 + geom_bar(stat = "identity", position = "fill")
g4 <- g4 + scale_y_continuous(labels = percent)
plot(g4)

g5 <- ggplot(x, aes(x = date, y = importance, fill = key))
g5 <- g5 + geom_bar(stat = "identity")
plot(g5)
ggsave(file = paste(base_name,"_feature_importance_time.png", sep=""), plot = g4, dpi = 100, width = 6.4, height = 4.8*length(name)/10)

if ( FALSE )
{
Expand All @@ -975,7 +998,7 @@ permutationFeatureImportance<- function(fitted, test, validation=F, base_name=""
heatmap(as.matrix(FI_s2),Colv = NA, Rowv=NA, scale='col',col=c(rgb(seq(0.9,0.2,-0.001),0, seq(0.0,0.3,0.001))))
heatmap(as.matrix(FI_s2),Colv = NA, Rowv=NA, scale='col',col=c(rgb(seq(0.9,0.2,-0.001),0, seq(0.0,0.2,0.001))))
}
return( list(n, g1, g2))
return( list(n, g1, g2, g3, g4, g5))
}


Expand Down
1 change: 1 addition & 0 deletions tft/Form1.Designer.cs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 8 additions & 2 deletions tft/Form1.cs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ void Plot()
try
{
pictureBox6.Image = CreateImage("tft_predict_measure_" + base_name + ".png");
pictureBox6.Refresh();
}
catch { }
}
Expand All @@ -181,6 +182,7 @@ void Plot()
try
{
pictureBox7.Image = CreateImage( base_name + "_feature_importance.png");
pictureBox7.Refresh();
}
catch { }
}
Expand All @@ -189,6 +191,7 @@ void Plot()
try
{
pictureBox8.Image = CreateImage(base_name + "_feature_importance_time.png");
pictureBox8.Refresh();
}
catch { }
}
Expand Down Expand Up @@ -354,6 +357,7 @@ void Proc_Exited(object sender, EventArgs e)
{
pictureBox6.Image = CreateImage("tft_predict_measure_" + base_name + ".png");
// pictureBox6.SizeMode = PictureBoxSizeMode.StretchImage;
pictureBox6.Refresh();
}
catch { }
}
Expand All @@ -363,6 +367,7 @@ void Proc_Exited(object sender, EventArgs e)
try
{
pictureBox7.Image = CreateImage(base_name + "_feature_importance.png");
pictureBox7.Refresh();
}
catch { }
}
Expand All @@ -372,6 +377,7 @@ void Proc_Exited(object sender, EventArgs e)
try
{
pictureBox8.Image = CreateImage(base_name + "_feature_importance_time.png");
pictureBox8.Refresh();
}
catch { }
}
Expand Down Expand Up @@ -1619,10 +1625,10 @@ string tft_test()
cmd += "\r\n";
cmd += "\r\n";
cmd += "fi <- permutationFeatureImportance(fitted, test, validation=validation, base_name ='" + base_name + "')\r\n";
cmd += "fi_plot <- gridExtra::grid.arrange(fi[[2]], fi[[3]], ncol = 1)\r\n";
cmd += "fi_plot <- gridExtra::grid.arrange(fi[[2]], fi[[5]], ncol = 1)\r\n";
cmd += "ggsave(file = \"tft_" + base_name + "_fi.png\", plot = fi_plot,dpi=100, width= 1.5*6.4,height=0.09*4.8" + "*fi[[1]], limitsize = FALSE)\r\n";
cmd += "fi_plot1 <- ggplotly(fi[[2]])\r\n";
cmd += "fi_plot2 <- ggplotly(fi[[3]])\r\n";
cmd += "fi_plot2 <- ggplotly(fi[[5]])\r\n";
cmd += "fi_plotly <- subplot(fi_plot1, fi_plot2, nrows = 2)\r\n";
cmd += "print(fi_plotly)\r\n";
cmd += "htmlwidgets::saveWidget(as_widget(fi_plotly), \"tft_" + base_name + "_fi.html\", selfcontained = F)\r\n";
Expand Down

0 comments on commit 7aef019

Please sign in to comment.