forked from tensorflow/privacy
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add InstaHide Attack paper to research folder
- Loading branch information
Showing
8 changed files
with
635 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
Implementation of our reconstruction attack on InstaHide. | ||
|
||
An Attack on InstaHide: Is Private Learning Possible with Instance Encoding? | ||
Nicholas Carlini, Samuel Deng, Sanjam Garg, Somesh Jha, Saeed Mahloujifar, Mohammad Mahmoody, Shuang Song, Abhradeep Thakurta, Florian Tramer | ||
https://arxiv.org/abs/2011.05315 | ||
|
||
|
||
## Overview | ||
|
||
InstaHide is a recent privacy-preserving machine learning framework. | ||
It takes a (sensitive) dataset and generates encoded images that are privacy-preserving. | ||
Our attack breaks InstaHide and shows it does not offer meaningful privacy. | ||
Given the encoded dataset, we can recover a near-identical copy of the original images. | ||
|
||
This repository implements the attack described in our paper. It consists of a number of | ||
steps that shoul be run sequentially. It assumes access to pre-trained neural network | ||
classifiers that should be downloaded following the steps below. | ||
|
||
|
||
### Requirements | ||
|
||
* Python, version ≥ 3.5 | ||
* jax | ||
* jaxlib | ||
* objax (https://github.com/google/objax) | ||
* PIL | ||
* sklearn | ||
|
||
|
||
### Running the attack | ||
|
||
To reproduce our results and run the attack, each of the files should be run in turn. | ||
|
||
0. Download the necessary dependency files: | ||
- (encryption.npy)[https://www.dropbox.com/sh/8zdsr1sjftia4of/AAA-60TOjGKtGEZrRmbawwqGa?dl=0] and (labels.npy)[https://www.dropbox.com/sh/8zdsr1sjftia4of/AAA-60TOjGKtGEZrRmbawwqGa?dl=0] from the (InstaHide Challenge)[https://github.com/Hazelsuko07/InstaHide_Challenge] | ||
- The (saved models)[https://drive.google.com/file/d/1YfKzGRfnnzKfUKpLjIRXRto8iD4FdwGw/view?usp=sharing] used to run the attack | ||
- Set up all the requirements as above | ||
|
||
1. Run `step_1_create_graph.py`. Produce the similarity graph to pair together encoded images that share an original image. | ||
|
||
2. Run `step_2_color_graph.py`. Color the graph to find 50 dense cliques. | ||
|
||
3. Run `step_3_second_graph.py`. Create a new bipartite similarity graph. | ||
|
||
4. Run `step_4_final_graph.py`. Solve the matching problem to assign encoded images to original images. | ||
|
||
5. Run `step_5_reconstruct.py`. Reconstruct the original images. | ||
|
||
6. Run `step_6_adjust_color.py`. Adjust the color curves to match. | ||
|
||
7. Run `step_7_visualize.py`. Show the final resulting images. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# Copyright 2020 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
""" | ||
Create the similarity graph given the encoded images by running the similarity | ||
neural network over all pairs of images. | ||
""" | ||
|
||
import objax | ||
import numpy as np | ||
import jax.numpy as jn | ||
import functools | ||
import os | ||
import random | ||
|
||
from objax.zoo import wide_resnet | ||
|
||
def setup(): | ||
global model | ||
class DoesUseSame(objax.Module): | ||
def __init__(self): | ||
fn = functools.partial(wide_resnet.WideResNet, depth=28, width=6) | ||
self.model = fn(6,2) | ||
|
||
model_vars = self.model.vars() | ||
self.ema = objax.optimizer.ExponentialMovingAverage(model_vars, momentum=0.999, debias=True) | ||
|
||
|
||
def predict_op(x,y): | ||
# The model takes the two images and checks if they correspond | ||
# to the same original image. | ||
xx = jn.concatenate([jn.abs(x), | ||
jn.abs(y)], | ||
axis=1) | ||
return self.model(xx, training=False) | ||
|
||
self.predict = objax.Jit(self.ema.replace_vars(predict_op), model_vars + self.ema.vars()) | ||
self.predict_fast = objax.Parallel(self.ema.replace_vars(predict_op), model_vars + self.ema.vars()) | ||
|
||
model = DoesUseSame() | ||
checkpoint = objax.io.Checkpoint("models/step1/", keep_ckpts=5, makedir=True) | ||
start_epoch, last_ckpt = checkpoint.restore(model.vars()) | ||
|
||
|
||
def doall(): | ||
global graph | ||
n = np.load("data/encryption.npy") | ||
n = np.transpose(n, (0,3,1,2)) | ||
|
||
# Compute the similarity between each encoded image and all others | ||
# This is n^2 work but should run fairly quickly, especially given | ||
# more than one GPU. Otherwise about an hour or so. | ||
graph = [] | ||
with model.vars().replicate(): | ||
for i in range(5000): | ||
print(i) | ||
v = model.predict_fast(np.tile(n[i:i+1], (5000,1,1,1)), n) | ||
graph.append(np.array(v[:,0]-v[:,1])) | ||
graph = np.array(graph) | ||
np.save("data/graph.npy", graph) | ||
|
||
|
||
if __name__ == "__main__": | ||
setup() | ||
doall() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# Copyright 2020 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
import multiprocessing as mp | ||
import pickle | ||
import random | ||
import collections | ||
import numpy as np | ||
|
||
def score(subset): | ||
sub = graph[subset] | ||
sub = sub[:,subset] | ||
return np.sum(sub) | ||
|
||
def run(v, return_scores=False): | ||
if isinstance(v, int): | ||
v = [v] | ||
scores = [] | ||
for _ in range(100): | ||
keep = graph[v,:] | ||
next_value = np.sum(keep,axis=0) | ||
to_add = next_value.argsort() | ||
to_add = [x for x in to_add if x not in v] | ||
if _ < 1: | ||
v.append(to_add[random.randint(0,10)]) | ||
else: | ||
v.append(to_add[0]) | ||
if return_scores: | ||
scores.append(score(v)/len(keep)) | ||
if return_scores: | ||
return v, scores | ||
else: | ||
return v | ||
|
||
def make_many_clusters(): | ||
# Compute clusters of 100 examples that probably correspond to some original image | ||
p = mp.Pool(mp.cpu_count()) | ||
s = p.map(run, range(2000)) | ||
return s | ||
|
||
|
||
def downselect_clusters(s): | ||
# Right now we have a lot of clusters, but they probably overlap. Let's remove that. | ||
# We want to find disjoint clusters, so we'll greedily add them until we have | ||
# 100 distjoint clusters. | ||
|
||
ss = [set(x) for x in s] | ||
|
||
keep = [] | ||
keep_set = [] | ||
for iteration in range(2): | ||
for this_set in s: | ||
# MAGIC NUMBERS...! | ||
# We want clusters of size 50 because it works | ||
# Except on iteration 2 where we'll settle for 25 if we haven't | ||
# found clusters with 50 neighbors that work. | ||
cur = set(this_set[:50 - 25*iteration]) | ||
intersections = np.array([len(cur & x) for x in ss]) | ||
good = np.sum(intersections==50)>2 | ||
# Good means that this cluster isn't a fluke and some other cluster | ||
# is like this one. | ||
if good or iteration == 1: | ||
print("N") | ||
# And also make sure we haven't found this cluster (or one like it). | ||
already_found = np.array([len(cur & x) for x in keep_set]) | ||
if np.all(already_found<len(cur)/2): | ||
print("And is new") | ||
keep.append(this_set) | ||
keep_set.append(set(this_set)) | ||
if len(keep) == 100: | ||
break | ||
print("Found", len(keep)) | ||
if len(keep) == 100: | ||
break | ||
|
||
# Keep should now have 100 items. | ||
# If it doesn't go and change the 2000 in make_many_clusters to a bigger number. | ||
return keep | ||
|
||
if __name__ == "__main__": | ||
graph = np.load("data/graph.npy") | ||
np.save("data/many_clusters",make_many_clusters()) | ||
np.save("data/100_clusters", downselect_clusters(np.load("data/many_clusters.npy"))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# Copyright 2020 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
""" | ||
Create the improved graph mapping each encoded image to an original image. | ||
""" | ||
|
||
import objax | ||
import numpy as np | ||
import jax.numpy as jn | ||
import functools | ||
import os | ||
import random | ||
|
||
from objax.zoo import wide_resnet | ||
|
||
|
||
def setup(): | ||
global model | ||
class DoesUseSame(objax.Module): | ||
def __init__(self): | ||
fn = functools.partial(wide_resnet.WideResNet, depth=28, width=6) | ||
self.model = fn(3*4,2) | ||
|
||
model_vars = self.model.vars() | ||
self.ema = objax.optimizer.ExponentialMovingAverage(model_vars, momentum=0.999, debias=True) | ||
|
||
|
||
def predict_op(x,y): | ||
# The model takes SEVERAL images and checks if they all correspond | ||
# to the same original image. | ||
# Guaranteed that the first N-1 all do, the test is if the last does. | ||
xx = jn.concatenate([jn.abs(x), | ||
jn.abs(y)], | ||
axis=1) | ||
return self.model(xx, training=False) | ||
|
||
self.predict = objax.Jit(self.ema.replace_vars(predict_op), model_vars + self.ema.vars()) | ||
|
||
model = DoesUseSame() | ||
checkpoint = objax.io.Checkpoint("models/step2/", keep_ckpts=5, makedir=True) | ||
start_epoch, last_ckpt = checkpoint.restore(model.vars()) | ||
|
||
def step2(): | ||
global v, n, u, nextgraph | ||
|
||
# Start out by loading the encoded images | ||
n = np.load("data/encryption.npy") | ||
n = np.transpose(n, (0,3,1,2)) | ||
|
||
# Then load the graph with 100 cluster-centers. | ||
keep = np.array(np.load("data/100_clusters.npy", allow_pickle=True)) | ||
graph = np.load("data/graph.npy") | ||
|
||
|
||
# Now we're going to record the distance to each of the cluster centers | ||
# from every encoded image, so that we can do the matching. | ||
|
||
# To do that, though, first we need to choose the cluster centers. | ||
# Start out by choosing the best cluster centers. | ||
|
||
distances = [] | ||
|
||
for x in keep: | ||
this_set = x[:50] | ||
use_elts = graph[this_set] | ||
distances.append(np.sum(use_elts,axis=0)) | ||
distances = np.array(distances) | ||
|
||
ds = np.argsort(distances, axis=1) | ||
|
||
# Now we record the "prototypes" of each cluster center. | ||
# We just need three, more might help a little bit but not much. | ||
# (And then do that ten times, so we can average out noise | ||
# with respect to which cluster centers we picked.) | ||
|
||
prototypes = [] | ||
for _ in range(10): | ||
ps = [] | ||
# choose 3 random samples from each set | ||
for i in range(3): | ||
ps.append(n[ds[:,random.randint(0,20)]]) | ||
prototypes.append(np.concatenate(ps,1)) | ||
prototypes = np.concatenate(prototypes,0) | ||
|
||
# Finally compute the distances from each node to each cluster center. | ||
nextgraph = [] | ||
for i in range(5000): | ||
out = model.predict(prototypes, np.tile(n[i:i+1], (1000,1,1,1))) | ||
out = out.reshape((10, 100, 2)) | ||
|
||
v = np.sum(out,axis=0) | ||
v = v[:,0] - v[:,1] | ||
v = np.array(v) | ||
nextgraph.append(v) | ||
|
||
np.save("data/nextgraph.npy", nextgraph) | ||
|
||
|
||
if __name__ == "__main__": | ||
setup() | ||
step2() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Copyright 2020 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
import multiprocessing as mp | ||
import pickle | ||
import random | ||
import numpy as np | ||
|
||
|
||
labels = np.load("data/label.npy") | ||
nextgraph = np.load("data/nextgraph.npy") | ||
|
||
assigned = [[] for _ in range(5000)] | ||
lambdas = [[] for _ in range(5000)] | ||
for i in range(100): | ||
order = (np.argsort(nextgraph[:,i])) | ||
correct = (labels[order[:20]]>0).sum(axis=0).argmax() | ||
|
||
# Let's create the final graph | ||
# Instead of doing a full bipartite matching, let's just greedily | ||
# choose the closest 80 candidates for each encoded image to pair | ||
# together can call it a day. | ||
# This is within a percent or two of doing that, and much easier. | ||
|
||
# Also record the lambdas based on which image it coresponds to, | ||
# but if they share a label then just guess it's an even 50/50 split. | ||
|
||
|
||
for x in order[:80]: | ||
if labels[x][correct] > 0 and len(assigned[x]) < 2: | ||
assigned[x].append(i) | ||
if np.sum(labels[x]>0) == 1: | ||
# the same label was mixed in twice. punt. | ||
lambdas[x].append(labels[x][correct]/2) | ||
else: | ||
lambdas[x].append(labels[x][correct]) | ||
|
||
np.save("data/predicted_pairings_80.npy", assigned) | ||
np.save("data/predicted_lambdas_80.npy", lambdas) |
Oops, something went wrong.