diff --git a/.azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt b/.azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt
index d9e062b0240..34c016f9792 100644
--- a/.azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt
+++ b/.azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt
@@ -196,6 +196,7 @@ CropResize
CropToBoundingBox
CrossEntropyLoss
Curran
+CustomDataset
CustomObj
CvAClvFfyA
DBMDZ
@@ -1671,6 +1672,7 @@ gpus
grafftti
graphDef
graphdef
+graphsage
grappler
grey
groupnorm
@@ -2198,6 +2200,7 @@ postprocess
postprocessed
postprocessing
powersave
+ppi
pplm
ppn
pragma
@@ -2440,6 +2443,7 @@ ssd
sshleifer
sst
stackoverflow
+stanford
startswith
startup
stderr
diff --git a/examples/.config/model_params_tensorflow.json b/examples/.config/model_params_tensorflow.json
index dee09cf20d1..9b5fa9e4949 100644
--- a/examples/.config/model_params_tensorflow.json
+++ b/examples/.config/model_params_tensorflow.json
@@ -1807,6 +1807,13 @@
"input_model": "/tf_dataset/tensorflow/vit/HF-ViT-Base16-Img224-frozen.pb",
"main_script": "main.py",
"batch_size": 32
+ },
+ "GraphSage": {
+ "model_src_dir": "graph_networks/graphsage/quantization/ptq",
+ "dataset_location": "/tf_dataset/dataset/ppi",
+ "input_model": "/tf_dataset/tensorflow/graphsage/graphsage_frozen_model.pb",
+ "main_script": "main.py",
+ "batch_size": 1000
}
}
}
diff --git a/examples/README.md b/examples/README.md
index 8d3a06dfc0a..c7e6e43829d 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -279,6 +279,12 @@ Intel® Neural Compressor validated examples with multiple compression technique
Post-Training Static Quantization |
pb |
+
+ GraphSage |
+ Graph Networks |
+ Post-Training Static Quantization |
+ pb |
+
diff --git a/examples/tensorflow/graph_networks/graphsage/quantization/ptq/README.md b/examples/tensorflow/graph_networks/graphsage/quantization/ptq/README.md
new file mode 100644
index 00000000000..775790a08b5
--- /dev/null
+++ b/examples/tensorflow/graph_networks/graphsage/quantization/ptq/README.md
@@ -0,0 +1,126 @@
+Step-by-Step
+============
+
+This document is used to list steps of reproducing TensorFlow Object Detection models tuning results. This example can run on Intel CPUs and GPUs.
+
+# Prerequisite
+
+
+## 1. Environment
+Recommend python 3.6 or higher version.
+
+### Install Intel® Neural Compressor
+```shell
+pip install neural-compressor
+```
+
+### Install Intel Tensorflow
+```shell
+pip install intel-tensorflow
+```
+> Note: Validated TensorFlow [Version](/docs/source/installation_guide.md#validated-software-environment).
+
+### Installation Dependency packages
+```shell
+cd examples\tensorflow\graph_networks\graphsage\quantization\ptq
+pip install -r requirements.txt
+```
+
+### Install Intel Extension for Tensorflow
+
+#### Quantizing the model on Intel GPU(Mandatory to install ITEX)
+Intel Extension for Tensorflow is mandatory to be installed for quantizing the model on Intel GPUs.
+
+```shell
+pip install --upgrade intel-extension-for-tensorflow[gpu]
+```
+For any more details, please follow the procedure in [install-gpu-drivers](https://github.com/intel/intel-extension-for-tensorflow/blob/main/docs/install/install_for_gpu.md#install-gpu-drivers)
+
+#### Quantizing the model on Intel CPU(Optional to install ITEX)
+Intel Extension for Tensorflow for Intel CPUs is experimental currently. It's not mandatory for quantizing the model on Intel CPUs.
+
+```shell
+pip install --upgrade intel-extension-for-tensorflow[cpu]
+```
+
+> **Note**:
+> The version compatibility of stock Tensorflow and ITEX can be checked [here](https://github.com/intel/intel-extension-for-tensorflow#compatibility-table). Please make sure you have installed compatible Tensorflow and ITEX.
+
+## 2. Prepare Model
+Download Frozen graph:
+```shell
+wget https://storage.googleapis.com/intel-optimized-tensorflow/models/2_12_0/graphsage_frozen_model.pb
+```
+
+## 3. Prepare Dataset
+
+```shell
+wget https://snap.stanford.edu/graphsage/ppi.zip
+unzip ppi.zip
+```
+
+# Run
+
+## Quantization Config
+
+The Quantization Config class has default parameters setting for running on Intel CPUs. If running this example on Intel GPUs, the 'backend' parameter should be set to 'itex' and the 'device' parameter should be set to 'gpu'.
+
+```
+config = PostTrainingQuantConfig(
+ device="gpu",
+ backend="itex",
+ ...
+ )
+```
+
+## 1. Quantization
+
+ ```shell
+ # The cmd of running faster_rcnn_resnet50
+ bash run_quant.sh --input_model=./graphsage_frozen_model.pb --output_model=./nc_graphsage_int8_model.pb --dataset_location=./ppi
+ ```
+
+## 2. Benchmark
+ ```shell
+ bash run_benchmark.sh --input_model=./nc_graphsage_int8_model.pb --dataset_location=./ppi --mode=performance
+ ```
+
+Details of enabling Intel® Neural Compressor on graphsage for Tensorflow.
+=========================
+
+This is a tutorial of how to enable graphsage model with Intel® Neural Compressor.
+## User Code Analysis
+User specifies fp32 *model*, calibration dataset *calib_dataloader* and a custom *eval_func* which encapsulates the evaluation dataset and metric by itself.
+
+For graphsage, we applied the latter one because our philosophy is to enable the model with minimal changes. Hence we need to make two changes on the original code. The first one is to implement the q_dataloader and make necessary changes to *eval_func*.
+
+### Code update
+
+After prepare step is done, we just need update main.py like below.
+```python
+ if args.tune:
+ from neural_compressor import quantization
+ from neural_compressor.data import DataLoader
+ from neural_compressor.config import PostTrainingQuantConfig
+ dataset = CustomDataset()
+ calib_dataloader=DataLoader(framework='tensorflow', dataset=dataset, \
+ batch_size=1, collate_fn = collate_function)
+ conf = PostTrainingQuantConfig()
+ q_model = quantization.fit(args.input_graph, conf=conf, \
+ calib_dataloader=calib_dataloader, eval_func=evaluate)
+ q_model.save(args.output_graph)
+
+ if args.benchmark:
+ if args.mode == 'performance':
+ from neural_compressor.benchmark import fit
+ from neural_compressor.config import BenchmarkConfig
+ conf = BenchmarkConfig()
+ fit(args.input_graph, conf, b_func=evaluate)
+ elif args.mode == 'accuracy':
+ acc_result = evaluate(args.input_graph)
+ print("Batch size = %d" % args.batch_size)
+ print("Accuracy: %.5f" % acc_result)
+
+```
+
+The quantization.fit() function will return a best quantized model during timeout constrain.
diff --git a/examples/tensorflow/graph_networks/graphsage/quantization/ptq/dataloader.py b/examples/tensorflow/graph_networks/graphsage/quantization/ptq/dataloader.py
new file mode 100644
index 00000000000..3e237dd972e
--- /dev/null
+++ b/examples/tensorflow/graph_networks/graphsage/quantization/ptq/dataloader.py
@@ -0,0 +1,80 @@
+#
+# -*- coding: utf-8 -*-
+#
+# Copyright (c) 2023 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import numpy as np
+import random
+import json
+import sys
+import os
+
+import networkx as nx
+from networkx.readwrite import json_graph
+
+
+def load_data(prefix, normalize=True, load_walks=False):
+ G_data = json.load(open(prefix + "-G.json"))
+ G = json_graph.node_link_graph(G_data)
+ if isinstance(list(G.nodes())[0], int):
+ conversion = lambda n : int(n)
+ else:
+ conversion = lambda n : n
+
+ if os.path.exists(prefix + "-feats.npy"):
+ feats = np.load(prefix + "-feats.npy")
+ else:
+ print("No features present.. Only identity features will be used.")
+ feats = None
+ id_map = json.load(open(prefix + "-id_map.json"))
+ id_map = {conversion(k):int(v) for k,v in id_map.items()}
+ walks = []
+ class_map = json.load(open(prefix + "-class_map.json"))
+ if isinstance(list(class_map.values())[0], list):
+ lab_conversion = lambda n : n
+ else:
+ lab_conversion = lambda n : int(n)
+
+ class_map = {conversion(k):lab_conversion(v) for k,v in class_map.items()}
+
+ ## Remove all nodes that do not have val/test annotations
+ ## (necessary because of networkx weirdness with the Reddit data)
+ broken_count = 0
+ for node in G.nodes():
+ if not 'val' in G.nodes[node] or not 'test' in G.nodes[node]:
+ G.remove_node(node)
+ broken_count += 1
+ print("Removed {:d} nodes that lacked proper annotations due to networkx versioning issues".format(broken_count))
+
+ ## Make sure the graph has edge train_removed annotations
+ ## (some datasets might already have this..)
+ print("Loaded data.. now preprocessing..")
+ for edge in G.edges():
+ if (G.nodes[edge[0]]['val'] or G.nodes[edge[1]]['val'] or
+ G.nodes[edge[0]]['test'] or G.nodes[edge[1]]['test']):
+ G[edge[0]][edge[1]]['train_removed'] = True
+ else:
+ G[edge[0]][edge[1]]['train_removed'] = False
+
+ if normalize and not feats is None:
+ from sklearn.preprocessing import StandardScaler
+ train_ids = np.array([id_map[n] for n in G.nodes() if not G.nodes[n]['val'] and not G.nodes[n]['test']])
+ train_feats = feats[train_ids]
+ scaler = StandardScaler()
+ scaler.fit(train_feats)
+ feats = scaler.transform(feats)
+
+ return G, feats, id_map, walks, class_map
diff --git a/examples/tensorflow/graph_networks/graphsage/quantization/ptq/main.py b/examples/tensorflow/graph_networks/graphsage/quantization/ptq/main.py
new file mode 100644
index 00000000000..75a054c565f
--- /dev/null
+++ b/examples/tensorflow/graph_networks/graphsage/quantization/ptq/main.py
@@ -0,0 +1,194 @@
+#
+# -*- coding: utf-8 -*-
+#
+# Copyright (c) 2023 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import time
+import utils
+import dataloader
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.platform import tf_logging
+from tensorflow.core.protobuf import rewriter_config_pb2
+
+from argparse import ArgumentParser
+
+np.random.seed(123)
+
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+
+arg_parser = ArgumentParser(description='Parse args')
+arg_parser.add_argument('-g', "--input-graph",
+ help='Specify the input graph for the transform tool',
+ dest='input_graph')
+arg_parser.add_argument("--output-graph",
+ help='Specify tune result model save dir',
+ dest='output_graph')
+arg_parser.add_argument('--benchmark', dest='benchmark', action='store_true', help='run benchmark')
+arg_parser.add_argument('--mode', dest='mode', default='performance', help='benchmark mode')
+arg_parser.add_argument('--tune', dest='tune', action='store_true', help='use neural_compressor to tune.')
+arg_parser.add_argument('--dataset_location', dest='dataset_location',
+ help='location of calibration dataset and evaluate dataset')
+arg_parser.add_argument('-e', "--num-inter-threads",
+ help='The number of inter-thread.',
+ dest='num_inter_threads', type=int, default=0)
+
+arg_parser.add_argument('-a', "--num-intra-threads",
+ help='The number of intra-thread.',
+ dest='num_intra_threads', type=int, default=0)
+arg_parser.add_argument('--batch_size', type=int, default=1000, dest='batch_size', help='batch_size of benchmark')
+arg_parser.add_argument('--iters', type=int, default=100, dest='iters', help='interations')
+args = arg_parser.parse_args()
+
+def prepare_Dataset():
+ data_location = args.dataset_location
+ pretrained_model = args.input_graph
+ data = dataloader.load_data(prefix=data_location+'/ppi')
+ G = data[0]
+ features = data[1]
+ id_map = data[2]
+ class_map = data[4]
+ if isinstance(list(class_map.values())[0], list):
+ num_classes = len(list(class_map.values())[0])
+ else:
+ num_classes = len(set(class_map.values()))
+
+ context_pairs = data[3]
+ placeholders = utils.construct_placeholders(num_classes)
+ minibatch = utils.NodeMinibatchIterator(G,
+ id_map,
+ placeholders,
+ class_map,
+ num_classes,
+ batch_size=args.batch_size,
+ context_pairs = context_pairs)
+ return minibatch
+
+class CustomDataset(object):
+ def __init__(self):
+ self.batch1 = []
+ self.batch_labels = []
+ minibatch = prepare_Dataset()
+ self.parse_minibatch(minibatch)
+
+ def parse_minibatch(self, minibatch):
+ iter_num = 0
+ finished = False
+ while not finished:
+ feed_dict_val, batch_labels, finished, _ = minibatch.incremental_node_val_feed_dict(args.batch_size, iter_num, test=True)
+ self.batch1.append(feed_dict_val['batch1:0'])
+ self.batch_labels.append(batch_labels)
+ iter_num += 1
+
+ def __getitem__(self, index):
+ return (self.batch1[index], len(self.batch1[index])), self.batch_labels[index]
+
+ def __len__(self):
+ return len(self.batch1)
+
+def evaluate(model):
+ """Custom evaluate function to estimate the accuracy of the model.
+
+ Args:
+ model (tf.Graph_def): The input model graph
+
+ Returns:
+ accuracy (float): evaluation result, the larger is better.
+ """
+ from neural_compressor.model import Model
+ model = Model(model)
+ output_tensor = model.output_tensor if len(model.output_tensor)>1 else \
+ model.output_tensor[0]
+ iteration = -1
+ minibatch = prepare_Dataset()
+ if args.benchmark and args.mode == 'performance':
+ iteration = args.iters
+
+ #output_tensor = model.sess.graph.get_tensor_by_name('Sigmoid:0')
+ def eval_func(size, output_tensor, minibatch, test):
+ t_test = time.time()
+ val_losses = []
+ val_preds = []
+ labels = []
+ iter_num = 0
+ finished = False
+ total_time = 0
+ while not finished:
+ feed_dict_val, batch_labels, finished, _ = minibatch.incremental_node_val_feed_dict(size, iter_num, test=True)
+ tf_logging.warn('\n---> Start iteration {0}'.format(str(iter_num)))
+ start_time = time.time()
+ node_outs_val = model.sess.run([output_tensor],feed_dict=feed_dict_val)
+ time_consume = time.time() - start_time
+ val_preds.append(node_outs_val[0])
+ labels.append(batch_labels)
+ iter_num += 1
+ total_time += time_consume
+ if iteration and iter_num >= iteration:
+ break
+ tf_logging.warn('\n---> Stop iteration {0}'.format(str(iter_num)))
+ val_preds = np.vstack(val_preds)
+ labels = np.vstack(labels)
+ f1_scores = utils.calc_f1(labels, val_preds)
+ time_average = total_time / iter_num
+ return f1_scores, (time.time() - t_test)/iter_num, time_average
+
+ test_f1_micro, duration, time_average = eval_func(args.batch_size, output_tensor, minibatch, test=True)
+ if args.benchmark and args.mode == 'performance':
+ latency = time_average / args.batch_size
+ print("Batch size = {}".format(args.batch_size))
+ print("Latency: {:.3f} ms".format(latency * 1000))
+ print("Throughput: {:.3f} images/sec".format(1. / latency))
+ return test_f1_micro
+
+def collate_function(batch):
+ return (batch[0][0][0], batch[0][0][1]), batch[0][1]
+
+class eval_graphsage_optimized_graph:
+ """Evaluate image classifier with optimized TensorFlow graph."""
+
+ def run(self):
+ """This is neural_compressor function include tuning, export and benchmark option."""
+ from neural_compressor import set_random_seed
+ set_random_seed(9527)
+
+ if args.tune:
+ from neural_compressor import quantization
+ from neural_compressor.data import DataLoader
+ from neural_compressor.config import PostTrainingQuantConfig
+ dataset = CustomDataset()
+ calib_dataloader=DataLoader(framework='tensorflow', dataset=dataset, \
+ batch_size=1, collate_fn = collate_function)
+ conf = PostTrainingQuantConfig()
+ q_model = quantization.fit(args.input_graph, conf=conf, \
+ calib_dataloader=calib_dataloader, eval_func=evaluate)
+ q_model.save(args.output_graph)
+
+ if args.benchmark:
+ if args.mode == 'performance':
+ from neural_compressor.benchmark import fit
+ from neural_compressor.config import BenchmarkConfig
+ conf = BenchmarkConfig()
+ fit(args.input_graph, conf, b_func=evaluate)
+ elif args.mode == 'accuracy':
+ acc_result = evaluate(args.input_graph)
+ print("Batch size = %d" % args.batch_size)
+ print("Accuracy: %.5f" % acc_result)
+
+if __name__ == "__main__":
+ evaluate_opt_graph = eval_graphsage_optimized_graph()
+ evaluate_opt_graph.run()
diff --git a/examples/tensorflow/graph_networks/graphsage/quantization/ptq/requirements.txt b/examples/tensorflow/graph_networks/graphsage/quantization/ptq/requirements.txt
new file mode 100644
index 00000000000..a6c2afe448c
--- /dev/null
+++ b/examples/tensorflow/graph_networks/graphsage/quantization/ptq/requirements.txt
@@ -0,0 +1,2 @@
+networkx
+scikit-learn
\ No newline at end of file
diff --git a/examples/tensorflow/graph_networks/graphsage/quantization/ptq/run_benchmark.sh b/examples/tensorflow/graph_networks/graphsage/quantization/ptq/run_benchmark.sh
new file mode 100644
index 00000000000..89c7cc19b6e
--- /dev/null
+++ b/examples/tensorflow/graph_networks/graphsage/quantization/ptq/run_benchmark.sh
@@ -0,0 +1,51 @@
+#!/bin/bash
+set -x
+
+function main {
+
+ init_params "$@"
+ run_benchmark
+
+}
+
+# init params
+function init_params {
+ batch_size=1000
+ iters=100
+ for var in "$@"
+ do
+ case $var in
+ --input_model=*)
+ input_model=$(echo $var |cut -f2 -d=)
+ ;;
+ --mode=*)
+ mode=$(echo $var |cut -f2 -d=)
+ ;;
+ --dataset_location=*)
+ dataset_location=$(echo "$var" |cut -f2 -d=)
+ ;;
+ --batch_size=*)
+ batch_size=$(echo $var |cut -f2 -d=)
+ ;;
+ --iters=*)
+ iters=$(echo $var |cut -f2 -d=)
+ ;;
+ esac
+ done
+
+}
+
+
+# run_tuning
+function run_benchmark {
+
+ python main.py \
+ --input-graph ${input_model} \
+ --mode ${mode} \
+ --dataset_location "${dataset_location}" \
+ --batch_size ${batch_size} \
+ --iters ${iters} \
+ --benchmark
+}
+
+main "$@"
diff --git a/examples/tensorflow/graph_networks/graphsage/quantization/ptq/run_quant.sh b/examples/tensorflow/graph_networks/graphsage/quantization/ptq/run_quant.sh
new file mode 100644
index 00000000000..f7046cc3df7
--- /dev/null
+++ b/examples/tensorflow/graph_networks/graphsage/quantization/ptq/run_quant.sh
@@ -0,0 +1,41 @@
+#!/bin/bash
+set -x
+
+function main {
+
+ init_params "$@"
+
+ run_tuning
+
+}
+
+# init params
+function init_params {
+
+ for var in "$@"
+ do
+ case $var in
+ --input_model=*)
+ input_model=$(echo "$var" |cut -f2 -d=)
+ ;;
+ --output_model=*)
+ output_model=$(echo "$var" |cut -f2 -d=)
+ ;;
+ --dataset_location=*)
+ dataset_location=$(echo "$var" |cut -f2 -d=)
+ ;;
+ esac
+ done
+
+}
+
+# run_tuning
+function run_tuning {
+ python main.py \
+ --input-graph "${input_model}" \
+ --output-graph "${output_model}" \
+ --dataset_location "${dataset_location}" \
+ --tune
+}
+
+main "$@"
diff --git a/examples/tensorflow/graph_networks/graphsage/quantization/ptq/utils.py b/examples/tensorflow/graph_networks/graphsage/quantization/ptq/utils.py
new file mode 100644
index 00000000000..7ff4311b80d
--- /dev/null
+++ b/examples/tensorflow/graph_networks/graphsage/quantization/ptq/utils.py
@@ -0,0 +1,112 @@
+#!/usr/bin/env bash
+#
+# Copyright (c) 2023 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import numpy as np
+import random
+import json
+import sys
+import os
+import json
+import networkx as nx
+from networkx.readwrite import json_graph
+import tensorflow as tf
+from sklearn import metrics
+
+def calc_f1(y_true, y_pred):
+ y_pred[y_pred > 0.5] = 1
+ y_pred[y_pred <= 0.5] = 0
+ return metrics.f1_score(y_true, y_pred, average="micro")
+
+def construct_placeholders(num_classes):
+ # Define placeholders
+ tf.compat.v1.disable_eager_execution()
+ placeholders = {
+ 'labels' : tf.compat.v1.placeholder(tf.float32, shape=(None, num_classes), name='labels'),
+ 'batch' : tf.compat.v1.placeholder(tf.int32, shape=(None), name='batch1'),
+ 'batch_size' : tf.compat.v1.placeholder(tf.int32, name='batch_size'),
+ }
+ return placeholders
+
+
+class NodeMinibatchIterator(object):
+
+ """
+ This minibatch iterator iterates over nodes for supervised learning.
+
+ G -- networkx graph
+ id2idx -- dict mapping node ids to integer values indexing feature tensor
+ placeholders -- standard tensorflow placeholders object for feeding
+ label_map -- map from node ids to class values (integer or list)
+ num_classes -- number of output classes
+ batch_size -- size of the minibatches
+ max_degree -- maximum size of the downsampled adjacency lists
+ """
+ # (G,
+ # id_map,
+ # placeholders,
+ # class_map,
+ # num_classes,
+ # batch_size=FLAGS.batch_size,
+ # max_degree=FLAGS.max_degree,
+ # context_pairs = context_pairs)
+ def __init__(self, G, id2idx,
+ placeholders, label_map, num_classes,
+ batch_size=100, max_degree=25,
+ **kwargs):
+
+ self.G = G
+ self.nodes = G.nodes()
+ self.id2idx = id2idx
+ self.placeholders = placeholders
+ self.batch_size = batch_size
+ self.max_degree = max_degree
+ self.batch_num = 0
+ self.label_map = label_map
+ self.num_classes = num_classes
+ self.test_nodes = [n for n in self.G.nodes() if self.G.nodes[n]['test']]
+
+ def _make_label_vec(self, node):
+ label = self.label_map[node]
+ if isinstance(label, list):
+ label_vec = np.array(label)
+ else:
+ label_vec = np.zeros((self.num_classes))
+ class_ind = self.label_map[node]
+ label_vec[class_ind] = 1
+ return label_vec
+ def batch_feed_dict(self, batch_nodes, val=False):
+ batch1id = batch_nodes
+ batch1 = [self.id2idx[n] for n in batch1id]
+
+ labels = np.vstack([self._make_label_vec(node) for node in batch1id])
+ feed_dict = dict()
+ feed_dict.update({'batch1:0': batch1})
+ feed_dict.update({'batch_size:0' : len(batch1)})
+ return feed_dict, labels
+
+
+ def incremental_node_val_feed_dict(self, size, iter_num, test=False):
+ if test:
+ val_nodes = self.test_nodes
+ else:
+ val_nodes = self.val_nodes
+ val_node_subset = val_nodes[iter_num*size:min((iter_num+1)*size,
+ len(val_nodes))]
+
+ # add a dummy neighbor
+ ret_val = self.batch_feed_dict(val_node_subset)
+ return ret_val[0], ret_val[1], (iter_num+1)*size >= len(val_nodes), val_node_subset
diff --git a/neural_compressor/adaptor/tf_utils/graph_converter.py b/neural_compressor/adaptor/tf_utils/graph_converter.py
index ec1390ea4eb..88260cd816e 100644
--- a/neural_compressor/adaptor/tf_utils/graph_converter.py
+++ b/neural_compressor/adaptor/tf_utils/graph_converter.py
@@ -244,7 +244,12 @@ def _inference(self, model):
# we should check and pair them
def check_shape(tensor, data):
# scalar or 1 dim default True
- if tensor.shape is None or len(tensor.shape.dims) == 1 or not hasattr(data, "shape"):
+ if (
+ tensor.shape is None
+ or tensor.shape == tf.TensorShape(None)
+ or len(tensor.shape.dims) == 1
+ or not hasattr(data, "shape")
+ ):
return True
tensor_shape = tuple(tensor.shape)
data_shape = tuple(data.shape)