#Spatial-Temporal Multi-Head Graph Attention Networks for Traffic Forecasting
This is a Pytorch implementation of Spatial-Temporal Multi-Head Graph Attention Networks for Traffic Forecasting, which combines the graph attention convolution (GAT) and the dilated convolution structure with gate mechanisms.
- see
requirements.txt
- Download METR-LA and PEMS-BAY data from Google Drive or Baidu Yun links provided by DCRNN.
metr-la.h5
andpems-bay.h5
should be put into thedata/
folder.
# Create data directoriesdfrt yui
mkdir -p data/{METR-LA,PEMS-BAY}
# METR-LA
python data_preparation.py --output_dir=data/METR-LA --traffic_df_filename=data/metr-la.h5
# PEMS-BAY
python data_preparation.py --output_dir=data/PEMS-BAY --traffic_df_filename=data/pems-bay.h5
# change 'base_path' and 'best_model_path' to choose which dataset you want
python run_demo.py --base_path=./pre_train_model/BAY_dataset --best_model_path=stgat_1.45.pkl
# run pre trained models (stgcn, gwnet models about BAY datasets will be added)
python run_demo_baselines.py
├── baselines
│ ├── experiment_base
│ ├── __init__.py
│ ├── gwnet.py
│ ├── rnn.py
│ ├── run_demo_baselines.py
│ ├── stgcn.py
│ └── train_base.csv
├── data
│ ├── METR-LA
│ ├── PEMS-BAY
│ ├── sensor_graph
│ ├── metr-la.h5
│ └── pems-bay.h5
├── experiment
├── pre_train_model
│ ├── BAY_dataset
│ └── LA_dataset
│
├── data_preparation.py
├── draw.py
├── model_stgat.py
├── README.md
├── requirements.txt
├── run_demo.py
├── train.py
└── util.py