diff --git a/.DS_Store b/.DS_Store index 7015617..8f2dd23 100644 Binary files a/.DS_Store and b/.DS_Store differ diff --git a/.gitignore b/.gitignore index d285b0d..27d3d48 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ weights/ +__pycache__/ +src/__pycache__/ diff --git a/VPRTempo-Saliency.py b/VPRTempo-Saliency.py deleted file mode 100644 index 72acbd9..0000000 --- a/VPRTempo-Saliency.py +++ /dev/null @@ -1,765 +0,0 @@ -#MIT License - -#Copyright (c) 2023 Adam Hines, Peter G Stratton, Michael Milford, Tobias Fischer - -#Permission is hereby granted, free of charge, to any person obtaining a copy -#of this software and associated documentation files (the "Software"), to deal -#in the Software without restriction, including without limitation the rights -#to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -#copies of the Software, and to permit persons to whom the Software is -#furnished to do so, subject to the following conditions: - -#The above copyright notice and this permission notice shall be included in all -#copies or substantial portions of the Software. - -#THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -#IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -#FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -#AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -#LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -#OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -#SOFTWARE. - -''' -Imports -''' -import cv2 -import pickle -import os -import torch -import gc -import math -import timeit -import sys -sys.path.append('./src') -sys.path.append('./weights') -sys.path.append('./settings') -sys.path.append('./output') - -import numpy as np -import blitnet_open as blitnet -import blitnet_ensemble as ensemble -import validation as validate -import matplotlib.pyplot as plt -import TranSalNet_Dense as sal - -from os import path -from alive_progress import alive_bar -from metrics import createPR -from timeit import default_timer as timer -from data_process import preprocess_img, postprocess_img -from torchvision import transforms, utils, models - - -''' -Spiking network model class -''' -class snn_model(): - def __init__(self): - super().__init__() - - ''' - USER SETTINGS - ''' - self.trainingPath = '/home/adam/data/hpc/' # training datapath - self.testPath = '/home/adam/data/testing_data/' # testing datapath - self.number_training_images = 100 # alter number of training images - self.number_ensembles = 30 # number of ensemble networks - self.ensemble_max = self.number_training_images*self.number_ensembles # maximum number of images per ensemble_net - self.location_repeat = 2 # Number of training locations that are the same - self.locations = ["spring","fall"] # which datasets are used in the training - self.test_location = "summer" - - # Image and patch normalization settings - self.imWidth = 28 # image width for patch norm - self.imHeight = 28 # image height for patch norm - self.num_patches = 7 # number of patches - self.intensity = 255 # divide pixel values to get spikes in range [0,1] - - # Network and training settings - self.input_layer = (self.imWidth*self.imHeight) # number of input layer neurons - self.feature_layer = int(self.input_layer*3) # number of feature layer neurons - self.output_layer = self.number_training_images # number of output layer neurons (match training images) - self.train_img = self.output_layer # number of training images - self.epoch = 4 # number of training iterations - self.test_t = self.output_layer # number of testing time points - self.cuda = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # saliency calculating on cpu or gpu - self.T = int(self.number_training_images*self.epoch) # number of training steps - self.annl_pow = 2 # learning rate anneal power - self.filter = 8 # filter images every 8 seconds (equivelant to 8 images) - - # Select training images from list - with open('./nordland_imageNames.txt') as file: - self.imageNames = [line.rstrip() for line in file] - - # Filter the loading images based on self.filter - self.filteredNames = [] - for n in range(0,len(self.imageNames),self.filter): - self.filteredNames.append(self.imageNames[n]) - del self.filteredNames[self.ensemble_max*self.number_ensembles:len(self.filteredNames)] - - # Get the full training and testing data paths - self.fullTrainPaths = [] - for n in self.locations: - self.fullTrainPaths.append(self.trainingPath+n+'/') - - # Hyperparamters - self.theta_max = 0.25 # maximum threshold value - self.n_init = 0.01 # initial learning rate value - self.n_itp = 0.25 # initial intrinsic threshold plasticity rate[[0.9999]] - self.f_rate = [0.0012,0.2]# firing rate range - self.p_exc = 0.025 # probability of excitatory connection - self.p_inh = 0.125 # probability of inhibitory connection - self.c= 0.125 # constant input - - # Test settings - self.test_true = False # leave default to False - self.validation = True # test this network against other methods, default True - - # Print network details - print('////////////') - print('VPRTempo - Temporally Encoded Visual Place Recognition v0.1') - print('Queensland University of Technology, Centre for Robotics') - print('\\\\\\\\\\\\\\\\\\\\\\\\') - print('Theta: '+str(self.theta_max)) - print('Initial learning rate: '+str(self.n_init)) - print('ITP Learning: '+str(self.n_itp)) - print('Firing rate: '+str(self.f_rate[0]) +' to '+ str(self.f_rate[1])) - print('Excitatory p: '+str(self.p_exc)) - print('Inhibitory p: '+str(self.p_inh)) - print('Constant input '+str(self.c)) - print('CUDA available: '+str(torch.cuda.is_available())) - if torch.cuda.is_available() == True: - current_device = torch.cuda.current_device() - print('Current device is: '+str(torch.cuda.get_device_name(current_device))) - else: - print('Current device is: CPU') - - # Network weights name - self.training_out = './weights/'+str(self.input_layer)+'i'+\ - str(self.feature_layer)+\ - 'f'+str(self.output_layer)+\ - 'o'+str(self.epoch)+'/' - - # Start the TranSalNet model - self.model = sal.TranSalNet() - self.model.load_state_dict(torch.load(r'TranSalNet_Dense.pth')) - self.model = self.model.to(self.cuda) - self.model.eval() - - # Get the 2D patches or the patch normalization - def get_patches2D(self): - - if self.patch_size[0] % 2 == 0: - nrows = self.image_pad.shape[0] - self.patch_size[0] + 2 - ncols = self.image_pad.shape[1] - self.patch_size[1] + 2 - else: - nrows = self.image_pad.shape[0] - self.patch_size[0] + 1 - ncols = self.image_pad.shape[1] - self.patch_size[1] + 1 - self.patches = np.lib.stride_tricks.as_strided(self.image_pad , self.patch_size + (nrows, ncols), - self.image_pad.strides + self.image_pad.strides).reshape(self.patch_size[0]*self.patch_size[1],-1) - - # Run patch normalization on imported RGB images - def patch_normalise_pad(self): - - self.patch_size = (self.num_patches, self.num_patches) - patch_half_size = [int((p-1)/2) for p in self.patch_size ] - - self.image_pad = np.pad(np.float64(self.img), patch_half_size, 'constant', - constant_values=np.nan) - - nrows = self.img.shape[0] - ncols = self.img.shape[1] - self.get_patches2D() - mus = np.nanmean(self.patches, 0) - stds = np.nanstd(self.patches, 0) - - with np.errstate(divide='ignore', invalid='ignore'): - self.im_norm = (self.img - mus.reshape(nrows, ncols)) / stds.reshape(nrows, ncols) - - self.im_norm[np.isnan(self.im_norm)] = 0.0 - self.im_norm[self.im_norm < -1.0] = -1.0 - self.im_norm[self.im_norm > 1.0] = 1.0 - - # Process the loaded images - resize, normalize color, & patch normalize - def processImage(self): - - self.sal = preprocess_img(self.fullpath) - self.sal = np.array(self.sal)/255. - self.sal = np.expand_dims(np.transpose(self.sal,(2,0,1)),axis=0) - - # convert numpy array to torch - self.sal = torch.from_numpy(self.sal) - self.sal = self.sal.type(torch.FloatTensor).to(self.cuda) - - # run model and make predictions - pred_saliency = self.model(self.sal) - toPIL = transforms.ToPILImage() - pic = toPIL(pred_saliency.squeeze()) - self.sal = postprocess_img(pic,self.img) - - # filter saliency pixels < certain value - #self.sal = self.sal-((np.mean(self.sal)/np.std(self.sal))+np.median(self.sal)) - self.sal[self.sal<10] = 0 - - # reshape 2D images into 1D vector - self.img = cv2.cvtColor(self.img,cv2.COLOR_RGB2GRAY) - imgX = len(self.img) - imgY = len(self.img[1]) - self.img = np.reshape(self.img,(imgX*imgY,)) - self.sal = np.reshape(self.sal,(imgX*imgY,)) - - # find where saliency map = 0 and remove equivelant pixels from image - index = np.where(self.sal==0) - self.img[tuple(index)] = 0 - self.img = np.reshape(self.img,(imgX,imgY)) - - - # gamma correct images - mid = 0.5 - mean = np.mean(self.img) - gamma = math.log(mid*255)/math.log(mean) - self.img = np.power(self.img,gamma).clip(0,255).astype(np.uint8) - - # resize image to 28x28 and patch normalize - self.img = cv2.resize(self.img,(self.imWidth, self.imHeight)) - self.patch_normalise_pad() - self.img = np.uint8(255.0 * (1 + self.im_norm) / 2.0) - - - # Image loader function - runs all image import functions - def loadImages(self): - - # Create dictionary of images - self.imgs = {'training':[],'testing':[]} - self.ids = {'training':[],'testing':[]} - - if self.test_true: - self.fullTrainPaths = [self.testPath+self.test_location+'/'] - - for paths in self.fullTrainPaths: - self.dataPath = paths - if self.test_location in self.dataPath: - dictEntry = 'testing' - else: - dictEntry = 'training' - for m in self.loadNames: - self.fullpath = self.dataPath+m - # read and convert image from BGR to RGB - self.img = cv2.imread(self.fullpath)[:,:,::-1] - # convert image - self.processImage() - self.imgs[dictEntry].append(self.img) - self.ids[dictEntry].append(m) - - - def setSpikeRates(self, images, ids): - data = {'x': np.array(images), 'y': np.array(ids), - 'rows': self.imWidth, 'cols': self.imHeight} - num_testing_imgs = data['x'].shape[0] - self.num_examples = num_testing_imgs - n_input = self.imWidth * self.imHeight - - # Set the spike rates based on the number of example training images - self.spike_rates = [] - self.init_rates = [] - for jdx, j in enumerate(range(int(self.num_examples))): - self.init_rates.append((data['x'][j%num_testing_imgs,:,:].reshape((n_input))/self.intensity)) - for n in range(int(self.epoch)): - self.spike_rates.extend(self.init_rates) - if not self.test_true: - self.spike_rates.extend(self.spike_rates) - - def checkTrainTest(self): - # Check if pre-trained network exists, prompt if retrain or run - if path.isdir(self.training_out) and len(os.listdir(self.training_out)) == self.number_ensembles or os.path.isfile(self.training_out + 'ensemble_net1.pkl'): - retrain = input("A network with these parameters exists, re-train ensembles? (y/n):\n") - else: - retrain = 'y' - return retrain - - ''' - Run the training network - ''' - def train(self): - - # check if training required - check = self.checkTrainTest() - - if check == 'y': - - if os.path.isfile(self.training_out + 'ensemble_net.pkl'): - os.remove(self.training_out+'ensemble_net.pkl') - - - def train_start(): - - self.copy_max = self.ensemble_max - self.ensemble_netNum = 1 - breakUp = int(self.ensemble_max/self.number_training_images) - - for ensemble_num, ndx in enumerate(range(self.number_ensembles)): - start_range = self.number_training_images * ndx - end_range = self.number_training_images * (ndx+1) - - self.loadNames = self.filteredNames[start_range:end_range] - self.loadImages() - self.setSpikeRates(self.imgs['training'],self.ids['training']) - # create new network - net = blitnet.newNet() - iLayer = blitnet.addLayer(net,[self.input_layer,1],0.0,0.0,0.0,0.0,0.0, - False) - fLayer = blitnet.addLayer(net,[self.feature_layer,1],[0,self.theta_max], - [self.f_rate[0],self.f_rate[1]],self.n_itp, - [0,self.c],0,False) - # sequentially set the feature firing rates - fstep = (self.f_rate[1]-self.f_rate[0])/self.feature_layer - for i in range(self.feature_layer): - net['fire_rate'][fLayer][i] = self.f_rate[0]+fstep*(i+1) - - # create the excitatory and inhibitory connections - - # Excitatory weights - idx = blitnet.addWeights(net,iLayer,fLayer,[0,1],self.p_exc,self.n_init, False) - ex_weight = [] - ex_weight.append(idx) - - # Inhibitory weights - inh_weight = [] - idx = blitnet.addWeights(net,iLayer,fLayer,[-1,0],self.p_inh,self.n_init,False) - inh_weight.append(idx) - - # Set the spikes times for the input images - spikeTimes = [] - for n, ndx in enumerate(self.spike_rates): - nz_indicies = np.nonzero(ndx) - tempspikes = ndx[nz_indicies[0]] - tempspikes[tempspikes>=1] = 0.999 - spiketime = (n+1)+tempspikes - spike_neuron = (np.column_stack((spiketime,nz_indicies[0]))).tolist() - spikeTimes.extend(spike_neuron) - - spikeTimes = torch.from_numpy(np.array(spikeTimes)) - # set input spikes - blitnet.setSpikeTimes(net,0,spikeTimes) - - # Train the input to feature layer - # Train the feature layer - for t in range(int(self.T/10)): - blitnet.runSim(net,10) - # anneal learning rates - if np.mod(t,10)==0: - pt = pow(float(self.T-t)/self.T,self.annl_pow) - net['eta_ip'][fLayer] = self.n_itp*pt - net['eta_stdp'][ex_weight[-1]] = self.n_init*pt - net['eta_stdp'][inh_weight[-1]] = -1*self.n_init*pt - - # Turn off learning between input and feature layer - net['eta_ip'][fLayer] = 0.0 - if self.p_exc > 0.0: net['eta_stdp'][ex_weight[-1]] = 0.0 - if self.p_inh > 0.0: net['eta_stdp'][inh_weight[-1]] = 0.0 - - # Create and train the output layer with the feature layer - oLayer = blitnet.addLayer(net,[self.output_layer,1],0.0,0.0,0.0,0.0,0.0,False) - - # Add the excitatory connections - idx = blitnet.addWeights(net,fLayer,oLayer,[0.0,1.0],1.0,self.n_init,False) - ex_weight.append(idx) - - # Add the inhibitory connections - idx = blitnet.addWeights(net,fLayer,oLayer,[-1.0,0.0],1.0,-self.n_init,False) - inh_weight.append(idx) - - # Output spikes for spike forcing (final layer) - out_spks = np.zeros([(self.output_layer),2]) - append_spks = np.zeros([(self.output_layer),2]) - - for n in range(self.output_layer): - out_spks[n] = [(n)+1.5,n] - append_spks[n] = [(n)+1.5,n] - - if self.location_repeat != 0: - base_spks = np.copy(out_spks) - for n in range(1,self.location_repeat): - base_spks[:,0] = base_spks[:,0] + self.test_t - out_spks = np.append(out_spks,base_spks,axis=0) - - for n in range(self.epoch): - out_spks[:,0] += self.output_layer - - append_spks= torch.from_numpy(np.concatenate((append_spks,out_spks),axis=0)) - - # Set the output spikes (spike forcing) - append_spks[:,0] += self.T - blitnet.setSpikeTimes(net,oLayer,append_spks) - - # Train the feature to output layer - for t in range(self.T): - blitnet.runSim(net,1) - # Anneal learning rates - if np.mod(t,10)==0: - pt = pow(float(self.T-t)/(self.T),self.annl_pow) - net['eta_ip'][oLayer] = self.n_itp*pt - net['eta_stdp'][ex_weight[-1]] = self.n_init*pt - net['eta_stdp'][inh_weight[-1]] = -1*self.n_init*pt - - # Turn off learning - net['eta_ip'][oLayer] = 0.0 - net['eta_stdp'][ex_weight[-1]] = 0.0 - net['eta_stdp'][inh_weight[-1]] = 0.0 - - # Clear the network output spikes - blitnet.setSpikeTimes(net,oLayer,[]) - - # Reset network details - net['set_spks'][0] = [] - #net['rec_spks'] = [True,True,True] - net['sspk_idx'] = [0,0,0] - net['step_num'] = 0 - net['spikes'] = [[],[],[]] - - # check if output dir exsist, create if not - if not path.isdir(self.training_out): - os.mkdir(self.training_out) - - # Output the trained network - outputPkl = self.training_out + str(ensemble_num) + '.pkl' - with open(outputPkl, 'wb') as f: - pickle.dump(net, f) - - breakflag = False - # when ensemble training is done, pickle entire ensemble into a dictionary - if self.ensemble_max == (ensemble_num+1)*self.number_training_images or ensemble_num == range(self.number_ensembles)[-1]: - for n in range((self.ensemble_netNum-1)*breakUp,(self.ensemble_netNum*breakUp)): - pickleName = self.training_out + str(n) + '.pkl' - # Select training images from list - net = [] - if os.path.isfile(pickleName): - with open(pickleName, 'rb') as f: - net = pickle.load(f) - if n == range((self.ensemble_netNum-1)*breakUp,(self.ensemble_netNum*breakUp))[0]: - ensemble_net = net - for m in range(oLayer+2): - ensemble_net['W'][m] = torch.unsqueeze(ensemble_net['W'][m],-1) - if m <= (len(ensemble_net['thr'])-1): - ensemble_net['thr'][m] = torch.unsqueeze(ensemble_net['thr'][m],-1) - ensemble_net['const_inp'][m] = torch.unsqueeze(ensemble_net['const_inp'][m],-1) - ensemble_net['fire_rate'][m] = torch.unsqueeze(ensemble_net['fire_rate'][m],-1) - else: - for m in range(oLayer+2): - ensemble_net['W'][m] = torch.concat((ensemble_net['W'][m],torch.unsqueeze(net['W'][m],-1)),-1) - if m <= (len(ensemble_net['thr'])-1): - ensemble_net['thr'][m] = torch.concat((ensemble_net['thr'][m],torch.unsqueeze(net['thr'][m],-1)),-1) - ensemble_net['const_inp'][m] = torch.concat((ensemble_net['const_inp'][m],torch.unsqueeze(net['const_inp'][m],-1)),-1) - ensemble_net['fire_rate'][m] = torch.concat((ensemble_net['fire_rate'][m],torch.unsqueeze(net['fire_rate'][m],-1)),-1) - - # delete the individual pickled net - os.remove(pickleName) - # pickle the ensemble network - outputPkl = self.training_out + 'ensemble_net'+str(self.ensemble_netNum)+'.pkl' - with open(outputPkl, 'wb') as f: - pickle.dump(ensemble_net, f) - - self.ensemble_max += self.copy_max - self.ensemble_netNum += 1 - - yield - - print('Training the ensembles') - with alive_bar(self.number_ensembles) as sbar: - for i in train_start(): - sbar() - - ''' - Run the testing network - ''' - - def networktester(self): - ''' - Network tester functions - ''' - # set the input spikes - def set_spikes(): - spikeTimes = [] - - for n, ndx in enumerate(self.spike_rates): - nz_indicies = np.nonzero(ndx) - tempspikes = ndx[nz_indicies[0]] - tempspikes[tempspikes>=1] = 0.999 - spiketime = (n+1)+tempspikes - spike_neuron = (np.column_stack((spiketime,nz_indicies[0]))).tolist() - spikeTimes.extend(spike_neuron) - - spikeTimes = torch.from_numpy(np.array(spikeTimes)) - - for m in netDict: - print('Setting spikes') - # set input spikes - blitnet.setSpikeTimes(netDict[str(m)],0,spikeTimes) - - netDict[str(m)]['set_spks'][0] = torch.unsqueeze(netDict[str(m)]['set_spks'][0],-1) - tempspikes = torch.clone(netDict[str(m)]['set_spks'][0]) - for n in range(int(self.ensemble_max/self.number_training_images)-1): - netDict[str(m)]['set_spks'][0] = torch.concat((netDict[str(m)]['set_spks'][0],tempspikes),-1) - - # calculate and plot distance matrices - def plotit(netx,name): - reshape_mat = np.reshape(netx,(self.test_t,int(self.train_img/self.location_repeat))) - # plot the matrix - fig = plt.figure() - plt.matshow(reshape_mat,fig, cmap=plt.cm.gist_yarg) - plt.colorbar(label="Spike amplitude") - fig.suptitle("Similarity "+name,fontsize = 12) - plt.xlabel("Query",fontsize = 12) - plt.ylabel("Database",fontsize = 12) - plt.show() - - # calculate PR curves - - - # network validation using alternative place matching algorithms and P@R calculation - def network_validator(): - # reload training images for the comparisons - self.test_true= False # reset the test true flag - self.test_imgs = self.ims.copy() - self.dataPath = '/home/adam/data/hpc/' - self.fullTrainPaths = [] - for n in self.locations: - self.fullTrainPaths.append(self.trainingPath+n) - self.loadImages() - # run sum of absolute differences caluclation - validate.SAD(self) - - ''' - Setup & running network tester - ''' - - # Alter network running parameters for testing - self.epoch = 1 # Only run the network once - self.location_repeat = 1 # One location repeat for testing - self.test_true = True # Flag for multiple data functions - - # unpickle the network - print('Unpickling the ensemble network') - self.ensemble_netNum = len([entry for entry in os.listdir(self.training_out) if os.path.isfile(os.path.join(self.training_out, entry))]) - netDict = {} - for n in range(self.ensemble_netNum): - with open(self.training_out+'ensemble_net'+str(n+1)+'.pkl', 'rb') as f: - netDict[str(n)] = pickle.load(f) - - # Load the network training images and set the input spikes - print('Loading dataset images') - self.loadNames = self.filteredNames - self.loadImages() - - self.setSpikeRates(self.imgs['testing'],self.ids['testing']) - set_spikes() - - #_net['rec_spks'] = [True,True,True]'= - numcorrect = 0 - - try: - self.ensemble_max = self.copy_max - except AttributeError: - self.ensemble_max = self.ensemble_max - - for n in netDict: - netDict[str(n)]['n_ensemble'] = int(self.ensemble_max/self.number_training_images) - - # Combine all the networks and set each input image sequentially - start = timeit.default_timer() - - avAccurate = np.zeros(self.number_ensembles) - ensembleNum = self.number_training_images-1 - ensembleInd = 0 - for t in range(self.test_t*self.number_ensembles): - tonump = np.array([]) - for g in netDict: - _net = netDict[str(g)] - ensemble.runSim(_net,1) - # output the index of highest amplitude spike - tonump = np.append(tonump,np.reshape(_net['x'][-1].detach().cpu().numpy(),(self.test_t*int(self.ensemble_max/self.number_training_images),1),order='F')) - - nidx = np.argmax(tonump) - if nidx == t: - numcorrect += 1 - print('\033[32m'+"!Match! for image - "+self.ids['testing'][t]+': '+str(t)+' - '+str(nidx)) - - avAccurate[ensembleInd] += 1 - else: - print('\033[31m'+":( fail for image - "+self.ids['testing'][t]+': '+str(t)+' - '+str(nidx)) - - if t == ensembleNum: - ensembleNum += self.number_training_images - ensembleInd += 1 - #if nidx == t: - # numcorrect += 1 - print('\033[0m'+"It took this long ",timeit.default_timer() - start) - print("Number of correct places "+str((numcorrect/(self.test_t*self.number_ensembles))*100)+"%") - avPerc = 100*(avAccurate/self.number_training_images) - ensemblename = [] - for n in range(self.number_ensembles): - ensemblename.append(str(n+1)) - - fig = plt.figure() - plt.bar(ensemblename,avPerc) - plt.xlabel("Ensemble number") - plt.ylabel("P@100R (%)") - plt.title("Performance of individual ensemble networks") - plt.xticks(rotation=45) - plt.tick_params(axis='x', which='major',labelsize=7) - plt.tight_layout() - plt.show() - - # plot the similarity matrices for each location repetition - append_mat = [] - for n in self.mat_dict: - if int(n) != 0: - append_mat = append_mat + self.mat_dict[str(n)] - else: - append_mat = np.copy(self.mat_dict[str(n)]) - plot_name = "training images" - #plotit(append_mat,plot_name) - #plotit(self.net_x,plot_name) - - # pickle the ground truth matrix - reshape_mat = np.reshape(append_mat,(self.test_t,int(self.train_img/self.location_repeat))) - boolval = reshape_mat > 0 - GTsoft = boolval.astype(int) - GT = np.zeros((self.test_t,self.test_t), dtype=int) - - for n in range(len(GT)): - GT[n,n] = 1 - plot_name = "Similarity absolute ground truth" - #fig = plt.figure() - #plt.matshow(GT,fig, cmap=plt.cm.gist_yarg) - #plt.colorbar(label="Spike amplitude") - #fig.suptitle(plot_name,fontsize = 12) - #plt.xlabel("Query",fontsize = 12) - #plt.ylabel("Database",fontsize = 12) - #plt.show() - #with open('./output/GT.pkl', 'wb') as f: - # pickle.dump(GT, f) - - self.VPRTempo_correct = 100*self.numcorrect/self.test_t - #print(self.VPRTempo_correct,'% correct') - - # Clear the network output spikes - blitnet.setSpikeTimes(net,2,[]) - - # Reset network details - net['set_spks'][0] = [] - net['rec_spks'] = [True,True,True] - net['sspk_idx'] = [0,0,0] - net['step_num'] = 0 - net['spikes'] = [[],[],[]] - - # Load the testing images - #self.test_true = True # Set image path to the testing images - self.test_true = True - self.loadImages() - - # Set input layer spikes as the testing images - set_spikes() - - # run the test netowrk - test_network() - - # store and print out number of correctly identified places - self.VPRTempo_correct = 100*self.numcorrect/self.test_t - - # plot the similarity matrices for each location repetition - append_mat = [] - for n in self.mat_dict: - if int(n) != 0: - append_mat = append_mat + self.mat_dict[str(n)] - else: - append_mat = np.copy(self.mat_dict[str(n)]) - plot_name = "VPRTempo" - #plotit(append_mat,plot_name) - #plotit(self.net_x,plot_name) - # pickle the ground truth matrix - S_in = np.reshape(append_mat,(self.test_t,int(self.train_img/self.location_repeat))) - with open('./output/S_in.pkl', 'wb') as f: - pickle.dump(S_in, f) - - # calculate the precision of the system - self.precision = self.tp/(self.tp+self.fp) - self.recall = self.tp/self.test_t - #P, R = createPR(S_in, GT, GT) - # plot PR curve - #fig = plt.figure() - #plt.plot(R,P) - #fig.suptitle("VPRTempo Precision Recall curve",fontsize = 12) - #plt.xlabel("Recall",fontsize = 12) - #plt.ylabel("Precision",fontsize = 12) - #plt.show() - - # plot spikes if they were recorded - #if net['rec_spks'][0] == True: - # blitnet.plotSpikes(net,0) - - # clear the CUDA cache - torch.cuda.empty_cache() - gc.collect() - - # if validation is set to True, run comparison methods - if self.validation: - network_validator() - - sadcorrect = self.sad_correct - return self.numcorrect, append_mat, sadcorrect - -''' -Run the network -''' -if __name__ == "__main__": - model = snn_model() # Instantiate model - model.train() - model.networktester() - # unpickle the network - print('Unpickling the ensemble network') - with open(outfold+'ensemble_net.pkl', 'rb') as f: - ensemble_net = pickle.load(f) - - # unpickle the image names - print('Unpickling the image training names') - if not bool(trained_imgs): - with open(outfold+'img_ids.pkl', 'rb') as f: - trained_imgs = pickle.load(f) - - # run supervised testing - def supervisedTest(): - global totCorrect, network_mat, sadCorrect - totCorrect = 0 - sadCorrect = 0 - network_mat = np.zeros(3300*3300) - for t in ensemble_range: - filter_names = trained_imgs[str(t)] - correct = [] - correct, mat, sad = model.networktester(ensemble_net[str(t)],filter_names) - totCorrect = totCorrect + correct - sadCorrect = sadCorrect + sad - #tempmat = np.zeros(625*132) - #start = (len(mat)*t) - #end = (len(mat)*(t+1)) - #tempmat[start:end] = mat - #network_mat[82500*t:82500*(t+1)] = tempmat - yield - - print('Testing ensemble network') - with alive_bar(n_ensembles) as sbar: - for i in supervisedTest(): - sbar() - - #reshape_mat = np.reshape(network_mat,(3300,3300)) - # plot_name = "Similarity VPRTempo" - #fig = plt.figure() - #plt.matshow(reshape_mat,fig, cmap=plt.cm.gist_yarg) - #plt.colorbar(label="Pixel intensity") - #fig.suptitle(plot_name,fontsize = 12) - #plt.xlabel("Query",fontsize = 12) - #plt.ylabel("Database",fontsize = 12) - #plt.show() - print('P@100R '+str(100*(totCorrect/3300))) - print('P@100R for SAD '+str(100*(sadCorrect/3300))) \ No newline at end of file diff --git a/VPRTempo.py b/VPRTempo.py index fc36e34..c9f7b8a 100644 --- a/VPRTempo.py +++ b/VPRTempo.py @@ -27,6 +27,7 @@ import pickle import os import torch +import blitnet import gc import math import timeit @@ -38,8 +39,6 @@ sys.path.append('./output') import numpy as np -import BlitnetDense as blitnet -import blitnet_ensemble as ensemble import validation as validate import matplotlib.pyplot as plt @@ -58,8 +57,8 @@ def __init__(self): ''' USER SETTINGS ''' - self.trainingPath = '/home/adam/data/hpc/' # training datapath - self.testPath = '/home/adam/data/testing_data/' # testing datapath + self.trainingPath = '/Users/adam/data/train/' # training datapath + self.testPath = '/Users/adam/data/test/' # testing datapath self.number_training_images =1000 # alter number of training images self.number_testing_images = 100# alter number of testing images self.number_modules = 40 # number of module networks @@ -120,7 +119,7 @@ def __init__(self): # Print network details print('////////////') - print('VPRTempo - Temporally Encoded Visual Place Recognition v0.1') + print('VPRTempo - Temporally Encoded Visual Place Recognition v1.0.0-alpha') print('Queensland University of Technology, Centre for Robotics') print('\\\\\\\\\\\\\\\\\\\\\\\\') print('Theta: '+str(self.theta_max)) @@ -327,7 +326,7 @@ def train_start(): net['fire_rate'][fLayer][x][:,i] = self.f_rate[0]+fstep*(i+1) # create the excitatory and inhibitory connections - idx = blitnet.addWeights(net,iLayer,fLayer,[-1,0,1],[self.p_exc,self.p_inh],self.n_init, False) + idx = blitnet.addWeights(net,iLayer,fLayer,[-1,0,1],[self.p_exc,self.p_inh],self.n_init) weight = [] weight.append(idx-1) weight.append(idx) @@ -372,7 +371,7 @@ def train_start(): oLayer = blitnet.addLayer(net,[self.number_modules,1,self.output_layer],0.0,0.0,0.0,0.0,0.0,False) # Add excitatory and inhibitory connections - idx = blitnet.addWeights(net,fLayer,oLayer,[-1.0,0.0,1.0],[1.0,1.0],self.n_init,False) + idx = blitnet.addWeights(net,fLayer,oLayer,[-1.0,0.0,1.0],[1.0,1.0],self.n_init) weight.append(idx) # Output spikes for spike forcing (final layer) diff --git a/__pycache__/VPRTempo.cpython-37.pyc b/__pycache__/VPRTempo.cpython-37.pyc deleted file mode 100644 index 2de2883..0000000 Binary files a/__pycache__/VPRTempo.cpython-37.pyc and /dev/null differ diff --git a/output/.DS_Store b/output/.DS_Store new file mode 100644 index 0000000..5008ddf Binary files /dev/null and b/output/.DS_Store differ diff --git a/src/.DS_Store b/src/.DS_Store new file mode 100644 index 0000000..d121930 Binary files /dev/null and b/src/.DS_Store differ diff --git a/src/BlitnetSparse.py b/src/BlitnetSparse.py deleted file mode 100644 index b7e57cd..0000000 --- a/src/BlitnetSparse.py +++ /dev/null @@ -1,522 +0,0 @@ -#MIT License - -#Copyright (c) 2023 Adam Hines, Peter Stratton, Michael Milford, Tobias Fischer - -#Permission is hereby granted, free of charge, to any person obtaining a copy -#of this software and associated documentation files (the "Software"), to deal -#in the Software without restriction, including without limitation the rights -#to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -#copies of the Software, and to permit persons to whom the Software is -#furnished to do so, subject to the following conditions: - -#The above copyright notice and this permission notice shall be included in all -#copies or substantial portions of the Software. - -#THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -#IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -#FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -#AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -#LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -#OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -#SOFTWARE. - -''' -Imports -''' -import numpy as np -import pdb -import torch - -import matplotlib.pyplot as plt - - -################################## -# Return a new empty BITnet instance - -def newNet(): - - np.random.seed() # new random seed - - # ** NEURON FIELDS ** - # x = activations - # x_input = total inputs - # x_prev = previous activations - # x_calc = calculated activations - # x_fastinp = total inputs including fast inhib - # dim = dimensions - # thr = thresholds for each neuron - # fire_rate = target firing rate for each neuron - # have_rate = have a target firing rate - # mean_rate = running avg firing rate for each neuron - # eta_ip = IP (threshold) learning rate - # const_inp = constant input to each neuron - # nois = noise st.dev. - # set_spks = pre-defined spike times (if any) - # sspk_idx = current index into set_spks - # spikes = spike events - # rec_spks = record spikes? - # - # ** CONNECTION FIELDS ** - # W = weights (-ve for inhib synapses) - # I = synaptic currents - # is_inhib = inhib weights flag - # fast_inhib = fast inhib weights flag - # W_lyr = pre and post layer numbers - # eta_stdp = STDP learning rate (-ve for inhib synapses) - # - # ** SIMULATION FIELDS ** - # step_num = current step - - #pdb.set_trace() - net = dict(x=[],x_input=[],x_prev=[],x_calc=[],x_fastinp=[],dim=[],thr=[], - fire_rate=[],have_rate=[],mean_rate=[],eta_ip=[],const_inp=[],nois=[], - set_spks=[],sspk_idx=[],spikes=[],rec_spks=[], - W=[],I=[],is_inhib=[],fast_inhib=[],W_lyr=[],eta_stdp=[], - step_num=0, num_modules = []) - - return net - -################################## -# Add a neuron layer (ie a neuron population) -# net: BITnet instance -# dim: layer dimensions [x,y,...] -# thr_range: initial threshold range -# fire_rate: target firing rate (0=no target) -# ip_rate: intrinsic threshold plasticity (IP) rate (0=no IP) -# const_inp: constant input to each neuron (0=none) -# nois: noise variance (0=no noise) -# rec_spks: record spikes? - -def addLayer(net,dims,thr_range,fire_rate,ip_rate,const_inp,nois,rec_spks): - - # Check constraints etc - if np.isscalar(thr_range): thr_range = [thr_range,thr_range] - if np.isscalar(fire_rate): fire_rate = [fire_rate,fire_rate] - if np.isscalar(const_inp): const_inp = [const_inp,const_inp] - - net['dim'].append(np.array(dims,int)) - net['x'].append(torch.from_numpy(np.zeros(int(np.prod(dims))))) - net['x_prev'].append(torch.from_numpy(np.zeros(int(np.prod(dims))))) - net['x_calc'].append(torch.from_numpy(np.zeros(int(np.prod(dims))))) - net['x_input'].append(torch.from_numpy(np.zeros(int(np.prod(dims))))) - net['x_fastinp'].append(torch.from_numpy(np.zeros(int(np.prod(dims))))) - net['mean_rate'].append(torch.from_numpy(np.zeros(int(np.prod(dims))))) - net['eta_ip'].append(ip_rate) - net['thr'].append(torch.from_numpy(np.random.uniform(thr_range[0],thr_range[1], - int(np.prod(dims))))) - net['fire_rate'].append(torch.from_numpy(np.random.uniform(fire_rate[0],fire_rate[1], - int(np.prod(dims))))) - net['have_rate'].append(any(net['fire_rate'][-1]>0.0)) - - net['const_inp'].append(torch.from_numpy(np.random.uniform(const_inp[0],const_inp[1], - int(np.prod(dims))))) - - net['nois'].append(nois) - net['set_spks'].append([]) - net['sspk_idx'].append(0) - net['spikes'].append(torch.empty([],dtype=torch.float64)) - net['rec_spks'].append(rec_spks) - - return len(net['x'])-1 - -################################## -# Add a set of random connections between layers -# net: BITnet instance -# layer_pre: presynaptic layer -# layer_post: postsynaptic layer -# W_range: weight range [lo,hi] -# p: initial connection probability -# stdp_rate: STDP rate (0=no STDP) -# fast_inhib: is this fast inhibition (ie inhib applied at same timestep) - -def addWeights(net,layer_pre,layer_post,W_range,p,stdp_rate,fast_inhib): - - # Check constraints etc - if np.isscalar(W_range): W_range = [W_range,W_range] - - nrow =net['x'][layer_pre].size(dim=0) - ncol = net['x'][layer_post].size(dim=0) - - Wmn = (W_range[0]+W_range[1])/2.0 - Wsd = (W_range[1]-W_range[0])/6.0 - net['W'].append(torch.from_numpy(np.random.normal(Wmn,Wsd,[nrow,ncol]))) - if Wmn > 0.0: - net['W'][-1][net['W'][-1]<0.0] = 0.0 - else: - net['W'][-1][net['W'][-1]>0.0] = 0.0 - setzero = np.random.rand(nrow,ncol) > p - if layer_pre==layer_post: # no self connections allowed - setzero = np.logical_or(setzero,np.identity(nrow)) - net['W'][-1][setzero] = 0.0 - net['W_lyr'].append([layer_pre,layer_post]) - net['I'].append(np.zeros(ncol)) - net['eta_stdp'].append(stdp_rate) - net['is_inhib'].append(W_range[0]<0.0 and W_range[1]<=0.0) - net['fast_inhib'].append(fast_inhib and net['is_inhib'][-1]) - p_nz = p - if p_nz==0.0: p_nz = 1.0 - - # Normalise the weights (except fast inhib weights) - if not net['fast_inhib'][-1]: - nrm = np.linalg.norm(net['W'][-1],ord=1,axis=0) - nrm[nrm==0.0] = 1.0 - net['W'][-1] = net['W'][-1]/nrm - - net['W'][-1] = net['W'][-1].to_sparse() - net['W'][-1] = net['W'][-1].coalesce() - - return len(net['W'])-1 - -################################## -# Set defined spike times for a neuron layer (ie a neuron population) -# net: BITnet instance -# layer: layer number -# times: 2-column matrix (col 1 = step num (ordered); col 2 = neuron num to spike) -# NOTE for spike forcing an output layer ensure that: eta_ip=0 and target fire_rate=0 -# FOLLOWING training ensure that: forced spikes array is removed, ie: setSpikeTimes(n,l,[]) - -def setSpikeTimes(net,layer,times): - if isinstance(times,list): - net['set_spks'][layer] = times.copy() - else: - net['set_spks'][layer] = times.detach().clone() - net['sspk_idx'][layer] = 0 - -################################## -# Normalise all the firing rates -# net: BITnet instance - -def norm_rates(net): - - for i,rate in enumerate(net['fire_rate']): - if rate.any() and net['eta_ip'][i] > 0.0: - net['thr'][i] = net['thr'][i] + net['eta_ip'][i]*(net['x'][i]-rate) - #xxx net['thr'][i] = net['thr'][i] + net['eta_ip'][i]*(net['x'][i]-rate) - net['thr'][i][net['thr'][i]<0.0] = 0.0 #xxx - -################################## -# Normalise inhib weights to balance input currents -# net: BITnet instance - -def norm_inhib(net): - - #return #xxx no norm_inhib - for i,W in enumerate(net['W']): - if net['eta_stdp'][i] < 0: # and not net['fast_inhib'][i]: - #if net['is_inhib'][i]: # and not net['fast_inhib'][i]: - lyr = net['W_lyr'][i][1] - #wadj = np.multiply(W,np.sign(net['x_input'][lyr]))*-net['eta_stdp'][i]*10 - #wadj = np.multiply(W,net['x_input'][lyr]-net['fire_rate'][lyr])*-net['eta_stdp'][i]*100 - try: - W = W.to_dense() - wadj = np.multiply(W,net['x_input'][lyr])*-net['eta_stdp'][i]*50 #0.5 #100 - temp_W = net['W'][i].to_dense() - temp_W += wadj - temp_W[W>0.0] = -0.000001 - net['W'][i] = temp_W.to_sparse() - except RuntimeWarning: - print("norm_inhib err") - pdb.set_trace() - -################################## -# Propagate spikes thru the network -# net: SORN instance - -def calc_spikes(net): - # Start with the noise and constant input in the neurons of each layer - for i,nois in enumerate(net['nois']): - if nois > 0: - net['x_input'][i] = np.random.normal(0.0,nois,int(np.prod(net['dim'][i]))) - else: - net['x_input'][i] = torch.full_like(net['x_input'][i],0.0) - net['x_input'][i] += net['const_inp'][i].detach().clone() - # Find the threshold crossings (overwritten later if needed) - net['x'][i] = torch.clamp((net['x_input'][i]-net['thr'][i]),0.0,0.9) - # Loop thru layers to insert any predefined spikes - for i in range(len(net['set_spks'])): - if len(net['set_spks'][i]): - net['x'][i] = torch.full_like(net['x'][i],0.0) - sidx = net['sspk_idx'][i] - if sidx < len(net['set_spks'][i]): stim = net['set_spks'][i][sidx,0] - while sidx < len(net['set_spks'][i]) and int(stim) <= net['step_num']: - net['x'][i][int(net['set_spks'][i][sidx,1])] = torch.fmod(stim,1) - sidx += 1 - if sidx < len(net['set_spks'][i]): - stim = net['set_spks'][i][sidx,0] - #else: - # net['set_spks'][i] = [] - net['sspk_idx'][i] = sidx - net['x'][i] = net['x'][i].to_sparse() - # Loop thru weight matrices, propagating spikes through. - # The idea is to process all weight matrices going into a layer (ie the nett input to that layer) - # then calculate that layer's spikes (threshold crossings), then move to the next group of weight - # matrices for the next layer. A group is defined as a contiguous set of weight matrices all ending - # on the same layer. This scheme is designed to propagate spikes rapidly up a feedfoward - # hierarachy. It won't work for layers with recurrent connections even if they are in the same - # weight group, since the spikes won't be recurrently p[numnrocessed until the next timestep, so fast - # inhibition is still needed for that. For feedback connections (ie the same layer being in - # different weight groups) this code will do a double timestep for those layers (not desirable). - #ipdb.set_trace() - for i,W in enumerate(net['W']): - if not net['fast_inhib'][i]: - layers = net['W_lyr'][i] - - # Synaptic currents last for 1 timestep - if layers[0]!=layers[1]: - net['I'][i] = torch.sparse.mm(net['x'][layers[0]],W) - else: - net['I'][i] = torch.matmul(net['x_prev'][layers[0]],W) - - net['x_input'][layers[1]] += net['I'][i] - - # Do spikes if this is the last weight matrix or if the next one has a different post layer - # or the next one is fast inhib,### UNLESS this is a recurrent layer - do_spikes = (i==len(net['W'])-1) - if not do_spikes: - do_spikes = not(layers[1]==net['W_lyr'][i+1][1]) or net['fast_inhib'][i+1] - #if do_spikes: - # do_spikes = layers[0]!=layers[1] - if do_spikes: - j = layers[1] - - # Find threshold crossings - if layers[0]!=layers[1]: - net['x_prev'][j] = net['x'][j][:] - if not len(net['set_spks'][j]): - # No predefined spikes for this layer - net['x'][j] = np.clip(net['x_input'][j]-net['thr'][j],a_min=0.0,a_max=0.9) - else: - # Predefined spikes exist for this layer, remember the calculated ones - net['x_calc'][j] = np.clip(net['x_input'][j]-net['thr'][j],a_min=0.0,a_max=0.9) - if layers[0]==layers[1]: - net['x_prev'][j] = net['x'][j][:] - # If the next weight matrix is fast inhib for this layer, process it now - if i < len(net['W'])-1: - if net['fast_inhib'][i+1] and layers[1]==net['W_lyr'][i+1][1]: - flyrs = net['W_lyr'][i+1] - net['x_fastinp'][flyrs[1]] = net['x_input'][flyrs[1]].copy() - if flyrs[0]==flyrs[1]: - postsyn_spks = np.tile(net['x'][flyrs[0]],[len(net['x'][flyrs[0]]),1]) - presyn_spks = np.transpose(postsyn_spks) - presyn_spks[presyn_spks < postsyn_spks] = 0.0 - net['x_fastinp'][flyrs[1]] += np.sum((presyn_spks)*net['W'][i+1],0) - else: - net['x_fastinp'][flyrs[1]] += np.matmul(net['x'][flyrs[0]],net['W'][i+1]) - if not len(net['set_spks'][j]): - # No predefined spikes for this layer - net['x'][flyrs[1]] = np.clip(net['x_fastinp'][flyrs[1]]-net['thr'][flyrs[1]], - a_min=0.0,a_max=0.9) - else: - # Predefined spikes exist for this layer, remember the calculated ones - net['x_calc'][flyrs[1]] = np.clip(net['x_fastinp'][flyrs[1]]-net['thr'][flyrs[1]], - a_min=0.0,a_max=0.9) - - # Finally, update mean firing rates and record all spikes if needed - for i,eta in enumerate(net['eta_ip']): - - if eta > 0.0: - net['mean_rate'][i] = net['mean_rate'][i]*(1.0-eta) +\ - (net['x'][i]>0.0)*eta - if net['rec_spks'][i]: - outspk = (net['x'][i]).detach().cpu().numpy() - if i == 2: - outspk[outspk<0.05] = 0 - n_idx = np.nonzero(outspk) - net['spikes'][i].extend([net['step_num']+net['x'][i][n].detach().cpu().numpy(),n] - for n in n_idx) - -################################## -# Calculate STDP -# net: BITnet instance - -def calc_stdp(net): - - # Loop thru weight matrices that have non-zero learning rate - for i,W in enumerate(net['W']): - if net['eta_stdp'][i] != 0: - - # Remember layer numbers and weight matrix shape - layers = net['W_lyr'][i] - shape = W.size() - - # - # Spike Forcing has special rules to make calculated and forced spikes match - # - if len(net['set_spks'][layers[1]]): - - # Diff between forced and calculated spikes - xdiff = net['x'][layers[1]] - net['x_calc'][layers[1]] - # Modulate learning rate by firing rate (low firing rate = high learning rate) - #if net['have_rate'][layers[1]]: - # xdiff /= net['fire_rate'][layers[1]] - - # Threshold rules - lower it if calced spike is smaller (and vice versa) - net['thr'][layers[1]] -= np.sign(xdiff)*np.abs(net['eta_stdp'][i])/10 - net['thr'][layers[1]][net['thr'][layers[1]]<0.0] = 0.0 # don't go -ve - - # A little bit of threshold decay - #net['thr'][layers[1]] *= (1-net['eta_stdp'][i]/100) - - # Pre and Post spikes tiled across and down for all synapses - if net['have_rate'][layers[0]]: - # Modulate learning rate by firing rate (low firing rate = high learning rate) - mpre = net['x'][layers[0]]/net['fire_rate'][layers[0]] - else: - mpre = net['x'][layers[0]] - pre = torch.from_numpy(np.tile(np.reshape(mpre, [shape[0],1]),[1,shape[1]])) - post = torch.from_numpy(np.tile(np.reshape(xdiff,[1,shape[1]]),[shape[0],1])) - - # Excitatory connections - if net['eta_stdp'][i] > 0: - W = W.to_dense() - havconn = W>0 - inc_stdp = pre*post*havconn - inc_stdp = inc_stdp.to_sparse() - # Inhibitory connections - else: - W = W.to_dense() - havconn = W<0 - inc_stdp = -pre*post*havconn - inc_stdp = inc_stdp.to_sparse() - - # Apply the weight changes - net['W'][i].values()[inc_stdp.indices()] += inc_stdp.values()*net['eta_stdp'][i] - - # - # Normal STDP - # - elif not net['fast_inhib'][i]: - - pre = torch.from_numpy(np.tile(np.reshape(net['x'][layers[0]],[shape[0],1]),[1,shape[1]])) - if net['have_rate'][layers[1]]: - # Modulate learning rate by firing rate (low firing rate = high learning rate) - mpost = net['x'][layers[1]] #/net['fire_rate'][layers[1]] - else: - mpost = net['x'][layers[1]] - post = torch.from_numpy(np.tile(np.reshape(mpost,[1,shape[1]]),[shape[0],1])) - - # Excitatory synapses - if net['eta_stdp'][i] > 0: - W = W.to_dense() - havconn = W>0 - inc_stdp = (0.5-post)*(pre>0)*(post>0)*havconn - inc_stdp = inc_stdp.to_sparse().coalesce() - # Inhibitory synapses - elif not net['fast_inhib'][i]: # and False: - W = W.to_dense() - havconn = W<0 - inc_stdp = (0.5-post)*(pre>0)*(post>0)*havconn - inc_stdp = inc_stdp.to_sparse().coalesce() - # Apply the weight changes - - net['W'][i].values()[inc_stdp.indices()] += inc_stdp.values()*net['eta_stdp'][i] - - # - # Fast inhibitory synapses, xxx update for firing rate modulation of eta_stdp? - # - else: - - # Store weight changes - inc_stdp = np.zeros(shape) - dec_stdp = np.zeros(shape) - - # Loop thru firing pre neurons - for pre in np.where(net['x'][layers[0]])[0]: - # Loop thru ALL post neurons - for post in range(len(net['x'][layers[1]])): - if net['W'][i][pre,post]!=0: - if net['x'][layers[1]][post] > 0.0: - if net['x'][layers[0]][pre] >\ - net['x'][layers[1]][post]: - # Synapse gets stronger if pre fires before post - inc_stdp[pre,post] = 0.5 #0.1 #/\ - #net['mean_rate'][layers[1]][post] - #net['mean_rate'][layers[0]][pre] - else: - # Synapse gets weaker if pre fires after post - dec_stdp[pre,post] = 0.5 *\ - (1.0-net['mean_rate'][layers[1]][post]) - else: - # Also gets weaker if pre fires and not post - dec_stdp[pre,post] = 0.5*\ - net['mean_rate'][layers[1]][post] - - # Apply the weight changes - net['W'][i] += inc_stdp*net['eta_stdp'][i] - net['W'][i] -= dec_stdp*net['eta_stdp'][i] - - # - # Finish - # - - # Try weight decay? - #net['W'][i] = net['W'][i]-net['eta_stdp'][i]/10 # * (1-net['eta_stdp'][i]) - if net['eta_stdp'][i] > 0.0: - # Excitation - pruning and synaptogenesis (structural plasticity) - net['W'][i].values()[ net['W'][i].values()<0.0] = 0.000001 #xxx - net['W'][i].values()[ net['W'][i].values()>10.0] = 10.0 #xxx - if np.random.rand() < 0.0: #0.1: #xxx TEMPORARILY OFF - synap = (np.random.rand(2)*shape).astype(int) - if net['W'][i][synap[0]][synap[1]] == 0: - net['W'][i][synap[0]][synap[1]] = 0.001 - else: - # Inhibition - must not go +ve - net['W'][i].values()[ net['W'][i].values()>0.0] = -0.000001 #xxx - net['W'][i].values()[ net['W'][i].values()<-10.0] = -10.0 #xxx - - # Finally clear out any predefined spikes that are used up (so calculated network spikes can take over) - for i in range(len(net['set_spks'])): - if len(net['set_spks'][i]): - if len(net['set_spks'][i]) <= net['sspk_idx'][i]: - net['set_spks'][i] = [] - -################################## -# Run the simulation -# net: BITnet instance -# n_steps: number of steps - -def runSim(net,n_steps): - - # Loop - for step in range(n_steps): - - # Inc step count - net['step_num'] += 1 - - # Propagate spikes from pre to post neurons - calc_spikes(net) - - # Calculate STDP weight changes - calc_stdp(net) - - # Normalise firing rates and inhibitory balance - norm_rates(net) - norm_inhib(net) - -################################## -# Plot recorded spikes in current subplot -# net: BITnet instance - -def subplotSpikes(net,cutoff): - - n_tot = 0 - for i,sp in enumerate(net['spikes']): - x=[]; y=[] - for n in sp: - x.extend(list(n[0])) - y.extend(list(n[1]+n_tot)) - - plt.plot(x,y,'.',ms=1) - n_tot += np.size(net['x'][i].detach().cpu().numpy()) - -################################## -# Plot recorded spikes in new figure -# net: BITnet instance - -def plotSpikes(net,cutoff): - - plt.figure() - subplotSpikes(net,cutoff) - plt.show(block=False) - -################################## diff --git a/src/TranSalNet_Dense.py b/src/TranSalNet_Dense.py deleted file mode 100644 index 8904460..0000000 --- a/src/TranSalNet_Dense.py +++ /dev/null @@ -1,171 +0,0 @@ -import os -import torch -import numpy as np -import pandas as pd -from torch.utils.data import Dataset, DataLoader -from skimage import io, transform -from PIL import Image -import torch.nn as nn -from torchvision import transforms, utils, models -import torch.nn.functional as F -import densenet as densenet - -from TransformerEncoder import Encoder - - - -cfg1 = { -"hidden_size" : 768, -"mlp_dim" : 768*4, -"num_heads" : 12, -"num_layers" : 2, -"attention_dropout_rate" : 0, -"dropout_rate" : 0.0, -} - -cfg2 = { -"hidden_size" : 768, -"mlp_dim" : 768*4, -"num_heads" : 12, -"num_layers" : 2, -"attention_dropout_rate" : 0, -"dropout_rate" : 0.0, -} - -cfg3 = { -"hidden_size" : 512, -"mlp_dim" : 512*4, -"num_heads" : 8, -"num_layers" : 2, -"attention_dropout_rate" : 0, -"dropout_rate" : 0.0, -} - - -class TranSalNet(nn.Module): - - def __init__(self): - super(TranSalNet, self).__init__() - self.encoder = _Encoder() - self.decoder = _Decoder() - - def forward(self, x): - x = self.encoder(x) - x = self.decoder(x) - return x - - -class _Encoder(nn.Module): - def __init__(self): - super(_Encoder, self).__init__() - base_model = densenet.densenet161(pretrained=True) - base_layers = list(base_model.children())[0][:-1] - self.encoder = nn.ModuleList(base_layers).eval() - - def forward(self, x): - outputs = [] - for ii,layer in enumerate(self.encoder): - x = layer(x) - if ii in {6, 8, 10}: - outputs.append(x) - return outputs - - -class _Decoder(nn.Module): - - def __init__(self): - super(_Decoder, self).__init__() - self.conv1 = nn.Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) - self.conv2 = nn.Conv2d(768, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) - self.conv3 = nn.Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) - self.conv4 = nn.Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) - self.conv5 = nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) - self.conv6 = nn.Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) - self.conv7 = nn.Conv2d(32, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) - - self.batchnorm1 = nn.BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) - self.batchnorm2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) - self.batchnorm3 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) - self.batchnorm4 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) - self.batchnorm5 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) - self.batchnorm6 = nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) - - self.TransEncoder1 = TransEncoder(in_channels=2208, spatial_size=9*12, cfg=cfg1) - self.TransEncoder2 = TransEncoder(in_channels=2112, spatial_size=18*24, cfg=cfg2) - self.TransEncoder3 = TransEncoder(in_channels=768, spatial_size=36*48, cfg=cfg3) - - self.add = torch.add - self.relu = nn.ReLU(True) - self.upsample = nn.Upsample(scale_factor=2, mode='nearest') - self.sigmoid = nn.Sigmoid() - - def forward(self, x): - x3, x4, x5 = x - - x5 = self.TransEncoder1(x5) - x5 = self.conv1(x5) - x5 = self.batchnorm1(x5) - x5 = self.relu(x5) - x5 = self.upsample(x5) - - x4_a = self.TransEncoder2(x4) - x4 = x5 * x4_a - x4 = self.relu(x4) - x4 = self.conv2(x4) - x4 = self.batchnorm2(x4) - x4 = self.relu(x4) - x4 = self.upsample(x4) - - x3_a = self.TransEncoder3(x3) - x3 = x4 * x3_a - x3 = self.relu(x3) - x3 = self.conv3(x3) - x3 = self.batchnorm3(x3) - x3 = self.relu(x3) - x3 = self.upsample(x3) - - x2 = self.conv4(x3) - x2 = self.batchnorm4(x2) - x2 = self.relu(x2) - x2 = self.upsample(x2) - x2 = self.conv5(x2) - x2 = self.batchnorm5(x2) - x2 = self.relu(x2) - - x1 = self.upsample(x2) - x1 = self.conv6(x1) - x1 = self.batchnorm6(x1) - x1 = self.relu(x1) - x1 = self.conv7(x1) - x = self.sigmoid(x1) - - return x - - -class TransEncoder(nn.Module): - - def __init__(self, in_channels, spatial_size, cfg): - super(TransEncoder, self).__init__() - - self.patch_embeddings = nn.Conv2d(in_channels=in_channels, - out_channels=cfg['hidden_size'], - kernel_size=1, - stride=1) - self.position_embeddings = nn.Parameter(torch.zeros(1, spatial_size, cfg['hidden_size'])) - - self.transformer_encoder = Encoder(cfg) - - def forward(self, x): - a, b = x.shape[2], x.shape[3] - x = self.patch_embeddings(x) - x = x.flatten(2) - x = x.transpose(-1, -2) - - embeddings = x + self.position_embeddings - x = self.transformer_encoder(embeddings) - B, n_patch, hidden = x.shape - x = x.permute(0, 2, 1) - x = x.contiguous().view(B, hidden, a, b) - - return x - diff --git a/src/TransformerEncoder.py b/src/TransformerEncoder.py deleted file mode 100644 index f2e8e07..0000000 --- a/src/TransformerEncoder.py +++ /dev/null @@ -1,137 +0,0 @@ -# coding=utf-8 - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import copy -import logging -import math - -from os.path import join as pjoin - -import torch -import torch.nn as nn -import numpy as np - -from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm -from torch.nn.modules.utils import _pair -from scipy import ndimage - - -ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu} - - -class Attention(nn.Module): - def __init__(self, config): - super(Attention, self).__init__() - self.num_attention_heads = config["num_heads"] # 12 - self.attention_head_size = int(config['hidden_size'] / self.num_attention_heads) # 42 - self.all_head_size = self.num_attention_heads * self.attention_head_size # 12*42=504 - - self.query = Linear(config['hidden_size'], self.all_head_size) # (512, 504) - self.key = Linear(config['hidden_size'], self.all_head_size) - self.value = Linear(config['hidden_size'], self.all_head_size) - - # self.out = Linear(config['hidden_size'], config['hidden_size']) - self.out = Linear(self.all_head_size, config['hidden_size']) - self.attn_dropout = Dropout(config["attention_dropout_rate"]) - self.proj_dropout = Dropout(config["attention_dropout_rate"]) - - self.softmax = Softmax(dim=-1) - - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - - def forward(self, hidden_states): - - mixed_query_layer = self.query(hidden_states) - mixed_key_layer = self.key(hidden_states) - mixed_value_layer = self.value(hidden_states) - - query_layer = self.transpose_for_scores(mixed_query_layer) - key_layer = self.transpose_for_scores(mixed_key_layer) - value_layer = self.transpose_for_scores(mixed_value_layer) - - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - attention_probs = self.softmax(attention_scores) - attention_probs = self.attn_dropout(attention_probs) - - context_layer = torch.matmul(attention_probs, value_layer) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) - attention_output = self.out(context_layer) - attention_output = self.proj_dropout(attention_output) - return attention_output - - -class Mlp(nn.Module): - def __init__(self, config): - super(Mlp, self).__init__() - self.fc1 = Linear(config['hidden_size'], config["mlp_dim"]) - self.fc2 = Linear(config["mlp_dim"], config['hidden_size']) - self.act_fn = ACT2FN["gelu"] - self.dropout = Dropout(config["dropout_rate"]) - self._init_weights() - - def _init_weights(self): - nn.init.xavier_uniform_(self.fc1.weight) - nn.init.xavier_uniform_(self.fc2.weight) - nn.init.normal_(self.fc1.bias, std=1e-6) - nn.init.normal_(self.fc2.bias, std=1e-6) - - def forward(self, x): - x = self.fc1(x) - x = self.act_fn(x) - x = self.dropout(x) - x = self.fc2(x) - x = self.dropout(x) - return x - - -class Block(nn.Module): - def __init__(self, config): - super(Block, self).__init__() - self.flag = config['num_heads'] - self.hidden_size = config['hidden_size'] - self.ffn_norm = LayerNorm(config['hidden_size'], eps=1e-6) - self.ffn = Mlp(config) - self.attn = Attention(config) - self.attention_norm = LayerNorm(config['hidden_size'], eps=1e-6) - - def forward(self, x): - h = x - - x = self.attention_norm(x) - x = self.attn(x) - x = x + h - - h = x - x = self.ffn_norm(x) - x = self.ffn(x) - x = x + h - return x - - -class Encoder(nn.Module): - def __init__(self, config): - super(Encoder, self).__init__() - - self.layer = nn.ModuleList() - self.encoder_norm = LayerNorm(config['hidden_size'], eps=1e-6) - for _ in range(config["num_layers"]): - layer = Block(config) - self.layer.append(copy.deepcopy(layer)) - - def forward(self, hidden_states): - for layer_block in self.layer: - hidden_states = layer_block(hidden_states) - encoded = self.encoder_norm(hidden_states) - - return encoded - - diff --git a/src/__pycache__/BlitnetDense.cpython-311.pyc b/src/__pycache__/BlitnetDense.cpython-311.pyc deleted file mode 100644 index 50bcffc..0000000 Binary files a/src/__pycache__/BlitnetDense.cpython-311.pyc and /dev/null differ diff --git a/src/__pycache__/BlitnetSparse.cpython-311.pyc b/src/__pycache__/BlitnetSparse.cpython-311.pyc deleted file mode 100644 index 6696afe..0000000 Binary files a/src/__pycache__/BlitnetSparse.cpython-311.pyc and /dev/null differ diff --git a/src/__pycache__/TranSalNet_Dense.cpython-311.pyc b/src/__pycache__/TranSalNet_Dense.cpython-311.pyc deleted file mode 100644 index 0c3aa3e..0000000 Binary files a/src/__pycache__/TranSalNet_Dense.cpython-311.pyc and /dev/null differ diff --git a/src/__pycache__/TranSalNet_Dense.cpython-37.pyc b/src/__pycache__/TranSalNet_Dense.cpython-37.pyc deleted file mode 100644 index d4d7c76..0000000 Binary files a/src/__pycache__/TranSalNet_Dense.cpython-37.pyc and /dev/null differ diff --git a/src/__pycache__/TranSalNet_Dense.cpython-38.pyc b/src/__pycache__/TranSalNet_Dense.cpython-38.pyc deleted file mode 100644 index 7541134..0000000 Binary files a/src/__pycache__/TranSalNet_Dense.cpython-38.pyc and /dev/null differ diff --git a/src/__pycache__/TransformerEncoder.cpython-311.pyc b/src/__pycache__/TransformerEncoder.cpython-311.pyc deleted file mode 100644 index 300d75b..0000000 Binary files a/src/__pycache__/TransformerEncoder.cpython-311.pyc and /dev/null differ diff --git a/src/__pycache__/TransformerEncoder.cpython-37.pyc b/src/__pycache__/TransformerEncoder.cpython-37.pyc deleted file mode 100644 index 454691a..0000000 Binary files a/src/__pycache__/TransformerEncoder.cpython-37.pyc and /dev/null differ diff --git a/src/__pycache__/TransformerEncoder.cpython-38.pyc b/src/__pycache__/TransformerEncoder.cpython-38.pyc deleted file mode 100644 index 9f081eb..0000000 Binary files a/src/__pycache__/TransformerEncoder.cpython-38.pyc and /dev/null differ diff --git a/src/__pycache__/blitnet_ensemble.cpython-311.pyc b/src/__pycache__/blitnet_ensemble.cpython-311.pyc deleted file mode 100644 index 8eef6eb..0000000 Binary files a/src/__pycache__/blitnet_ensemble.cpython-311.pyc and /dev/null differ diff --git a/src/__pycache__/blitnet_ensemble.cpython-39.pyc b/src/__pycache__/blitnet_ensemble.cpython-39.pyc deleted file mode 100644 index 1f03745..0000000 Binary files a/src/__pycache__/blitnet_ensemble.cpython-39.pyc and /dev/null differ diff --git a/src/__pycache__/blitnet_open.cpython-311.pyc b/src/__pycache__/blitnet_open.cpython-311.pyc deleted file mode 100644 index 1d52831..0000000 Binary files a/src/__pycache__/blitnet_open.cpython-311.pyc and /dev/null differ diff --git a/src/__pycache__/blitnet_open.cpython-37.pyc b/src/__pycache__/blitnet_open.cpython-37.pyc deleted file mode 100644 index 98be91d..0000000 Binary files a/src/__pycache__/blitnet_open.cpython-37.pyc and /dev/null differ diff --git a/src/__pycache__/blitnet_open.cpython-38.pyc b/src/__pycache__/blitnet_open.cpython-38.pyc deleted file mode 100644 index 7909952..0000000 Binary files a/src/__pycache__/blitnet_open.cpython-38.pyc and /dev/null differ diff --git a/src/__pycache__/blitnet_open.cpython-39.pyc b/src/__pycache__/blitnet_open.cpython-39.pyc deleted file mode 100644 index c87c435..0000000 Binary files a/src/__pycache__/blitnet_open.cpython-39.pyc and /dev/null differ diff --git a/src/__pycache__/blitnet_open_dense.cpython-311.pyc b/src/__pycache__/blitnet_open_dense.cpython-311.pyc deleted file mode 100644 index fd2928c..0000000 Binary files a/src/__pycache__/blitnet_open_dense.cpython-311.pyc and /dev/null differ diff --git a/src/__pycache__/data_process.cpython-311.pyc b/src/__pycache__/data_process.cpython-311.pyc deleted file mode 100644 index 8676c18..0000000 Binary files a/src/__pycache__/data_process.cpython-311.pyc and /dev/null differ diff --git a/src/__pycache__/data_process.cpython-37.pyc b/src/__pycache__/data_process.cpython-37.pyc deleted file mode 100644 index 15c1d7d..0000000 Binary files a/src/__pycache__/data_process.cpython-37.pyc and /dev/null differ diff --git a/src/__pycache__/data_process.cpython-38.pyc b/src/__pycache__/data_process.cpython-38.pyc deleted file mode 100644 index 6d8576b..0000000 Binary files a/src/__pycache__/data_process.cpython-38.pyc and /dev/null differ diff --git a/src/__pycache__/densenet.cpython-311.pyc b/src/__pycache__/densenet.cpython-311.pyc deleted file mode 100644 index 6016186..0000000 Binary files a/src/__pycache__/densenet.cpython-311.pyc and /dev/null differ diff --git a/src/__pycache__/densenet.cpython-37.pyc b/src/__pycache__/densenet.cpython-37.pyc deleted file mode 100644 index a81ace5..0000000 Binary files a/src/__pycache__/densenet.cpython-37.pyc and /dev/null differ diff --git a/src/__pycache__/densenet.cpython-38.pyc b/src/__pycache__/densenet.cpython-38.pyc deleted file mode 100644 index 668d4b0..0000000 Binary files a/src/__pycache__/densenet.cpython-38.pyc and /dev/null differ diff --git a/src/__pycache__/metrics.cpython-311.pyc b/src/__pycache__/metrics.cpython-311.pyc deleted file mode 100644 index 668619c..0000000 Binary files a/src/__pycache__/metrics.cpython-311.pyc and /dev/null differ diff --git a/src/__pycache__/metrics.cpython-39.pyc b/src/__pycache__/metrics.cpython-39.pyc deleted file mode 100644 index e16adc1..0000000 Binary files a/src/__pycache__/metrics.cpython-39.pyc and /dev/null differ diff --git a/src/__pycache__/utils.cpython-311.pyc b/src/__pycache__/utils.cpython-311.pyc deleted file mode 100644 index ef17046..0000000 Binary files a/src/__pycache__/utils.cpython-311.pyc and /dev/null differ diff --git a/src/__pycache__/utils.cpython-37.pyc b/src/__pycache__/utils.cpython-37.pyc deleted file mode 100644 index 30a49fd..0000000 Binary files a/src/__pycache__/utils.cpython-37.pyc and /dev/null differ diff --git a/src/__pycache__/validation.cpython-311.pyc b/src/__pycache__/validation.cpython-311.pyc deleted file mode 100644 index 052e9af..0000000 Binary files a/src/__pycache__/validation.cpython-311.pyc and /dev/null differ diff --git a/src/__pycache__/validation.cpython-37.pyc b/src/__pycache__/validation.cpython-37.pyc deleted file mode 100644 index 1c04cc1..0000000 Binary files a/src/__pycache__/validation.cpython-37.pyc and /dev/null differ diff --git a/src/__pycache__/validation.cpython-39.pyc b/src/__pycache__/validation.cpython-39.pyc deleted file mode 100644 index c3ddf1e..0000000 Binary files a/src/__pycache__/validation.cpython-39.pyc and /dev/null differ diff --git a/src/BlitnetDense.py b/src/blitnet.py similarity index 95% rename from src/BlitnetDense.py rename to src/blitnet.py index 9349a71..351ef3a 100644 --- a/src/BlitnetDense.py +++ b/src/blitnet.py @@ -26,7 +26,6 @@ import numpy as np import pdb import torch -import timeit import matplotlib.pyplot as plt @@ -61,7 +60,6 @@ def newNet(modules, dims): # W = weights (-ve for inhib synapses) # I = synaptic currents # is_inhib = inhib weights flag - # fast_inhib = fast inhib weights flag # W_lyr = pre and post layer numbers # eta_stdp = STDP learning rate (-ve for inhib synapses) # @@ -72,7 +70,7 @@ def newNet(modules, dims): net = dict(x=[],x_input=[],x_prev=[],x_calc=[],x_fastinp=[],dim=[],thr=[], fire_rate=[],have_rate=[],mean_rate=[],eta_ip=[],const_inp=[],nois=[], set_spks=[],sspk_idx=[],spikes=[],rec_spks=[], - W=[],I=[],is_inhib=[],fast_inhib=[],W_lyr=[],eta_stdp=[], + W=[],I=[],is_inhib=[],W_lyr=[],eta_stdp=[], step_num=0, num_modules = modules, spike_dims = dims) return net @@ -131,9 +129,8 @@ def addLayer(net,dims,thr_range,fire_rate,ip_rate,const_inp,nois,rec_spks): # W_range: weight range [lo,hi] # p: initial connection probability # stdp_rate: STDP rate (0=no STDP) -# fast_inhib: is this fast inhibition (ie inhib applied at same timestep) -def addWeights(net,layer_pre,layer_post,W_range,p,stdp_rate,fast_inhib): +def addWeights(net,layer_pre,layer_post,W_range,p,stdp_rate): # get torch device device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -182,7 +179,6 @@ def addWeights(net,layer_pre,layer_post,W_range,p,stdp_rate,fast_inhib): net['eta_stdp'].append(stdp_rate) net['eta_stdp'].append(-stdp_rate) net['is_inhib'].append(W_range[0]<0.0 and W_range[1]<=0.0) - net['fast_inhib'].append(fast_inhib and net['is_inhib'][-1]) else: net['I'][Iindex] = torch.concat((net['I'][Iindex], torch.unsqueeze(torch.zeros(nrow, device=device),0)),0) @@ -194,13 +190,12 @@ def addWeights(net,layer_pre,layer_post,W_range,p,stdp_rate,fast_inhib): net['W'][inhIndex][n,:,:][setzeroInh] = 0.0 # inhibitory connections # Normalise the weights (except fast inhib weights) - if not net['fast_inhib'][-1]: - nrmExc = torch.linalg.norm(net['W'][excIndex][len(net['W'][excIndex])-1],ord=1,axis=0) - nrmInh = torch.linalg.norm(net['W'][inhIndex][len(net['W'][inhIndex])-1],ord=1,axis=0) - nrmExc[nrmExc==0.0] = 1.0 - nrmInh[nrmInh==0.0] = 1.0 - net['W'][excIndex][n] = net['W'][excIndex][n,:,:]/nrmExc - net['W'][inhIndex][n] = net['W'][inhIndex][n,:,:]/nrmInh + nrmExc = torch.linalg.norm(net['W'][excIndex][len(net['W'][excIndex])-1],ord=1,axis=0) + nrmInh = torch.linalg.norm(net['W'][inhIndex][len(net['W'][inhIndex])-1],ord=1,axis=0) + nrmExc[nrmExc==0.0] = 1.0 + nrmInh[nrmInh==0.0] = 1.0 + net['W'][excIndex][n] = net['W'][excIndex][n,:,:]/nrmExc + net['W'][inhIndex][n] = net['W'][inhIndex][n,:,:]/nrmInh return len(net['W'])-1 diff --git a/src/blitnet_ensemble.py b/src/blitnet_ensemble.py deleted file mode 100644 index 95c0614..0000000 --- a/src/blitnet_ensemble.py +++ /dev/null @@ -1,349 +0,0 @@ -#MIT License - -#Copyright (c) 2023 Adam Hines, Peter Stratton, Michael Milford, Tobias Fischer - -#Permission is hereby granted, free of charge, to any person obtaining a copy -#of this software and associated documentation files (the "Software"), to deal -#in the Software without restriction, including without limitation the rights -#to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -#copies of the Software, and to permit persons to whom the Software is -#furnished to do so, subject to the following conditions: - -#The above copyright notice and this permission notice shall be included in all -#copies or substantial portions of the Software. - -#THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -#IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -#FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -#AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -#LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -#OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -#SOFTWARE. - -''' -Imports -''' -import numpy as np -import pdb -import torch - - -################################## -# Propagate spikes thru the network -# net: SORN instance - -def calc_spikes(net): - - # Start with the noise and constant input in the neurons of each layer - for i,nois in enumerate(net['nois']): - if nois > 0: - net['x_input'][i] = np.random.normal(0.0,nois,int(np.prod(net['dim'][i]))) - else: - if net['x_input'][i].dim() < 2: - net['x_input'][i] = torch.unsqueeze(torch.full_like(net['x_input'][i],0.0),-1) - elif net['x_input'][i].dim() > 2: - net['x_input'][i] = torch.squeeze(torch.full_like(net['x_input'][i],0.0),-1) - else: - net['x_input'][i] = torch.full_like(net['x_input'][i],0.0) - - if net['x_input'][i].size(dim=1) > net['n_ensemble']: - temptens = torch.hsplit(net['x_input'][i],2) - net['x_input'][i] = temptens[0] - elif net['x_input'][i].size(dim=1) < net['n_ensemble']: - tempinput = torch.clone(net['x_input'][i]) - for n in range(len(net['const_inp'][0][-1])-1): - net['x_input'][i] = torch.concat((net['x_input'][i],tempinput),-1) - - net['x_input'][i] += net['const_inp'][i].detach().clone() - # Find the threshold crossings (overwritten later if needed) - net['x'][i] = torch.clamp((net['x_input'][i]-net['thr'][i]),0.0,0.9) - - # Loop thru layers to insert any predefined spikes - for i in range(len(net['set_spks'])): - if len(net['set_spks'][i]): - net['x'][i] = torch.full_like(net['x'][i],0.0) - sidx = net['sspk_idx'][i] - if sidx < len(net['set_spks'][i]): stim = net['set_spks'][i][sidx,0,:] - while sidx < len(net['set_spks'][i]) and int(stim[0]) <= net['step_num']: - net['x'][i][int(net['set_spks'][i][sidx,1][0])] = torch.fmod(stim,1) - sidx += 1 - if sidx < len(net['set_spks'][i]): - stim = net['set_spks'][i][sidx,0] - #else: - # net['set_spks'][i] = [] - net['sspk_idx'][i] = sidx - pause=1 - - # Loop thru weight matrices, propagating spikes through. - # The idea is to process all weight matrices going into a layer (ie the nett input to that layer) - # then calculate that layer's spikes (threshold crossings), then move to the next group of weight - # matrices for the next layer. A group is defined as a contiguous set of weight matrices all ending - # on the same layer. This scheme is designed to propagate spikes rapidly up a feedfoward - # hierarachy. It won't work for layers with recurrent connections even if they are in the same - # weight group, since the spikes won't be recurrently p[numnrocessed until the next timestep, so fast - # inhibition is still needed for that. For feedback connections (ie the same layer being in - # different weight groups) this code will do a double timestep for those layers (not desirable). - #ipdb.set_trace() - def batch_mm(matrix, vector_batch): - batch_size = vector_batch.shape[0] - # Stack the vector batch into columns. (b, n, 1) -> (n, b) - vectors = vector_batch.transpose(0, 1).reshape(-1, batch_size) - - # A matrix-matrix product is a batched matrix-vector product of the columns. - # And then reverse the reshaping. (m, b) -> (b, m, 1) - return matrix.mm(vectors).transpose(1, 0).reshape(batch_size, -1, 1) - - for i,W in enumerate(net['W']): - if not net['fast_inhib'][i]: - layers = net['W_lyr'][i] - - # Synaptic currents last for 1 timestep - if layers[0]!=layers[1]: - net['I'][i] = torch.matmul(net['x'][layers[0]],W) - #net['I'][i] = torch.einsum('bi,bji->ji',net['x'][layers[0]],W) - else: - net['I'][i] = torch.einsum('bi,bji->ji',net['x_prev'][layers[0]],W) - - net['x_input'][layers[1]] += net['I'][i] - a = torch.squeeze(net['x_input'][layers[1]],-1).numpy() - # Do spikes if this is the last weight matrix or if the next one has a different post layer - # or the next one is fast inhib,### UNLESS this is a recurrent layer - do_spikes = (i==len(net['W'])-1) - if not do_spikes: - do_spikes = not(layers[1]==net['W_lyr'][i+1][1]) or net['fast_inhib'][i+1] - #if do_spikes: - # do_spikes = layers[0]!=layers[1] - if do_spikes: - j = layers[1] - - # Find threshold crossings - if layers[0]!=layers[1]: - net['x_prev'][j] = net['x'][j][:] - if not len(net['set_spks'][j]): - # No predefined spikes for this layer - net['x'][j] = np.clip(net['x_input'][j]-net['thr'][j],a_min=0.0,a_max=0.9) - else: - # Predefined spikes exist for this layer, remember the calculated ones - net['x_calc'][j] = np.clip(net['x_input'][j]-net['thr'][j],a_min=0.0,a_max=0.9) - if layers[0]==layers[1]: - net['x_prev'][j] = net['x'][j][:] - # If the next weight matrix is fast inhib for this layer, process it now - if i < len(net['W'])-1: - if net['fast_inhib'][i+1] and layers[1]==net['W_lyr'][i+1][1]: - flyrs = net['W_lyr'][i+1] - net['x_fastinp'][flyrs[1]] = net['x_input'][flyrs[1]].copy() - if flyrs[0]==flyrs[1]: - postsyn_spks = np.tile(net['x'][flyrs[0]],[len(net['x'][flyrs[0]]),1]) - presyn_spks = np.transpose(postsyn_spks) - presyn_spks[presyn_spks < postsyn_spks] = 0.0 - net['x_fastinp'][flyrs[1]] += np.sum((presyn_spks)*net['W'][i+1],0) - else: - net['x_fastinp'][flyrs[1]] += np.matmul(net['x'][flyrs[0]],net['W'][i+1]) - if not len(net['set_spks'][j]): - # No predefined spikes for this layer - net['x'][flyrs[1]] = np.clip(net['x_fastinp'][flyrs[1]]-net['thr'][flyrs[1]], - a_min=0.0,a_max=0.9) - else: - # Predefined spikes exist for this layer, remember the calculated ones - net['x_calc'][flyrs[1]] = np.clip(net['x_fastinp'][flyrs[1]]-net['thr'][flyrs[1]], - a_min=0.0,a_max=0.9) - - # Finally, update mean firing rates and record all spikes if needed - for i,eta in enumerate(net['eta_ip']): - - if eta > 0.0: - net['mean_rate'][i] = net['mean_rate'][i]*(1.0-eta) +\ - (net['x'][i]>0.0)*eta - if net['rec_spks'][i]: - outspk = (net['x'][i]).detach().cpu().numpy() - if i == 2: - outspk[outspk<0.05] = 0 - n_idx = np.nonzero(outspk) - net['spikes'][i].extend([net['step_num']+net['x'][i][n].detach().cpu().numpy(),n] - for n in n_idx) - -def calc_stdp(net): - - # Loop thru weight matrices that have non-zero learning rate - for i,W in enumerate(net['W']): - if net['eta_stdp'][i] != 0: - - # Remember layer numbers and weight matrix shape - layers = net['W_lyr'][i] - shape = W.size() - - # - # Spike Forcing has special rules to make calculated and forced spikes match - # - if len(net['set_spks'][layers[1]]): - - # Diff between forced and calculated spikes - xdiff = net['x'][layers[1]] - net['x_calc'][layers[1]] - # Modulate learning rate by firing rate (low firing rate = high learning rate) - #if net['have_rate'][layers[1]]: - # xdiff /= net['fire_rate'][layers[1]] - - # Threshold rules - lower it if calced spike is smaller (and vice versa) - net['thr'][layers[1]] -= np.sign(xdiff)*np.abs(net['eta_stdp'][i])/10 - net['thr'][layers[1]][net['thr'][layers[1]]<0.0] = 0.0 # don't go -ve - - # A little bit of threshold decay - #net['thr'][layers[1]] *= (1-net['eta_stdp'][i]/100) - - # Pre and Post spikes tiled across and down for all synapses - if net['have_rate'][layers[0]]: - # Modulate learning rate by firing rate (low firing rate = high learning rate) - mpre = net['x'][layers[0]]/net['fire_rate'][layers[0]] - else: - mpre = net['x'][layers[0]] - pre = torch.from_numpy(np.tile(np.reshape(mpre, [shape[0],1]),[1,shape[1]])) - post = torch.from_numpy(np.tile(np.reshape(xdiff,[1,shape[1]]),[shape[0],1])) - - # Excitatory connections - if net['eta_stdp'][i] > 0: - havconn = W>0 - inc_stdp = pre*post*havconn - # Inhibitory connections - else: - havconn = W<0 - inc_stdp = -pre*post*havconn - - # Apply the weight changes - net['W'][i] += inc_stdp*net['eta_stdp'][i] - - # - # Normal STDP - # - elif not net['fast_inhib'][i]: - - pre = torch.from_numpy(np.tile(np.reshape(net['x'][layers[0]],[shape[0],1]),[1,shape[1]])) - if net['have_rate'][layers[1]]: - # Modulate learning rate by firing rate (low firing rate = high learning rate) - mpost = net['x'][layers[1]] #/net['fire_rate'][layers[1]] - else: - mpost = net['x'][layers[1]] - post = torch.from_numpy(np.tile(np.reshape(mpost,[1,shape[1]]),[shape[0],1])) - - # Excitatory synapses - if net['eta_stdp'][i] > 0: - havconn = W>0 - inc_stdp = (0.5-post)*(pre>0)*(post>0)*havconn - # Inhibitory synapses - elif not net['fast_inhib'][i]: # and False: - havconn = W<0 - inc_stdp = (0.5-post)*(pre>0)*(post>0)*havconn - - # Apply the weight changes - net['W'][i] += inc_stdp*net['eta_stdp'][i] - - # - # Fast inhibitory synapses, xxx update for firing rate modulation of eta_stdp? - # - else: - - # Store weight changes - inc_stdp = np.zeros(shape) - dec_stdp = np.zeros(shape) - - # Loop thru firing pre neurons - for pre in np.where(net['x'][layers[0]])[0]: - # Loop thru ALL post neurons - for post in range(len(net['x'][layers[1]])): - if net['W'][i][pre,post]!=0: - if net['x'][layers[1]][post] > 0.0: - if net['x'][layers[0]][pre] >\ - net['x'][layers[1]][post]: - # Synapse gets stronger if pre fires before post - inc_stdp[pre,post] = 0.5 #0.1 #/\ - #net['mean_rate'][layers[1]][post] - #net['mean_rate'][layers[0]][pre] - else: - # Synapse gets weaker if pre fires after post - dec_stdp[pre,post] = 0.5 *\ - (1.0-net['mean_rate'][layers[1]][post]) - else: - # Also gets weaker if pre fires and not post - dec_stdp[pre,post] = 0.5*\ - net['mean_rate'][layers[1]][post] - - # Apply the weight changes - net['W'][i] += inc_stdp*net['eta_stdp'][i] - net['W'][i] -= dec_stdp*net['eta_stdp'][i] - - # - # Finish - # - - # Try weight decay? - #net['W'][i] = net['W'][i]-net['eta_stdp'][i]/10 # * (1-net['eta_stdp'][i]) - if net['eta_stdp'][i] > 0.0: - # Excitation - pruning and synaptogenesis (structural plasticity) - net['W'][i][W<0.0] = 0.000001 #xxx - net['W'][i][W>10.0] = 10.0 #xxx - if np.random.rand() < 0.0: #0.1: #xxx TEMPORARILY OFF - synap = (np.random.rand(2)*shape).astype(int) - if net['W'][i][synap[0]][synap[1]] == 0: - net['W'][i][synap[0]][synap[1]] = 0.001 - else: - # Inhibition - must not go +ve - net['W'][i][W>0.0] = -0.000001 #xxx - net['W'][i][W<-10.0] = -10.0 #xxx - - # Finally clear out any predefined spikes that are used up (so calculated network spikes can take over) - for i in range(len(net['set_spks'])): - if len(net['set_spks'][i]): - if len(net['set_spks'][i]) <= net['sspk_idx'][i]: - net['set_spks'][i] = [] - -def norm_rates(net): - - for i,rate in enumerate(net['fire_rate']): - if rate.any() and net['eta_ip'][i] > 0.0: - net['thr'][i] = net['thr'][i] + net['eta_ip'][i]*(net['x'][i]-rate) - #xxx net['thr'][i] = net['thr'][i] + net['eta_ip'][i]*(net['x'][i]-rate) - net['thr'][i][net['thr'][i]<0.0] = 0.0 #xxx - -################################## -# Normalise inhib weights to balance input currents -# net: BITnet instance - -def norm_inhib(net): - - #return #xxx no norm_inhib - for i,W in enumerate(net['W']): - if net['eta_stdp'][i] < 0: # and not net['fast_inhib'][i]: - #if net['is_inhib'][i]: # and not net['fast_inhib'][i]: - lyr = net['W_lyr'][i][1] - #wadj = np.multiply(W,np.sign(net['x_input'][lyr]))*-net['eta_stdp'][i]*10 - #wadj = np.multiply(W,net['x_input'][lyr]-net['fire_rate'][lyr])*-net['eta_stdp'][i]*100 - try: - wadj = np.multiply(W,net['x_input'][lyr])*-net['eta_stdp'][i]*50 #0.5 #100 - net['W'][i] += wadj - net['W'][i][W>0.0] = -0.000001 - except RuntimeWarning: - print("norm_inhib err") - pdb.set_trace() - -################################## -# Run the simulation -# net: BITnet instance -# n_steps: number of steps - -def runSim(net,n_steps): - - # Loop - for step in range(n_steps): - - # Inc step count - net['step_num'] += 1 - - # Propagate spikes from pre to post neurons - calc_spikes(net) - - - # Calculate STDP weight changes - calc_stdp(net) - - # Normalise firing rates and inhibitory balance - norm_rates(net) - norm_inhib(net) \ No newline at end of file diff --git a/src/blitnet_open.py b/src/blitnet_open.py deleted file mode 100644 index 5b95a0b..0000000 --- a/src/blitnet_open.py +++ /dev/null @@ -1,509 +0,0 @@ -#MIT License - -#Copyright (c) 2023 Adam Hines, Peter Stratton, Michael Milford, Tobias Fischer - -#Permission is hereby granted, free of charge, to any person obtaining a copy -#of this software and associated documentation files (the "Software"), to deal -#in the Software without restriction, including without limitation the rights -#to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -#copies of the Software, and to permit persons to whom the Software is -#furnished to do so, subject to the following conditions: - -#The above copyright notice and this permission notice shall be included in all -#copies or substantial portions of the Software. - -#THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -#IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -#FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -#AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -#LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -#OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -#SOFTWARE. - -''' -Imports -''' -import numpy as np -import pdb -import torch - -import matplotlib.pyplot as plt - - -################################## -# Return a new empty BITnet instance - -def newNet(): - - np.random.seed() # new random seed - - # ** NEURON FIELDS ** - # x = activations - # x_input = total inputs - # x_prev = previous activations - # x_calc = calculated activations - # x_fastinp = total inputs including fast inhib - # dim = dimensions - # thr = thresholds for each neuron - # fire_rate = target firing rate for each neuron - # have_rate = have a target firing rate - # mean_rate = running avg firing rate for each neuron - # eta_ip = IP (threshold) learning rate - # const_inp = constant input to each neuron - # nois = noise st.dev. - # set_spks = pre-defined spike times (if any) - # sspk_idx = current index into set_spks - # spikes = spike events - # rec_spks = record spikes? - # - # ** CONNECTION FIELDS ** - # W = weights (-ve for inhib synapses) - # I = synaptic currents - # is_inhib = inhib weights flag - # fast_inhib = fast inhib weights flag - # W_lyr = pre and post layer numbers - # eta_stdp = STDP learning rate (-ve for inhib synapses) - # - # ** SIMULATION FIELDS ** - # step_num = current step - - #pdb.set_trace() - net = dict(x=[],x_input=[],x_prev=[],x_calc=[],x_fastinp=[],dim=[],thr=[], - fire_rate=[],have_rate=[],mean_rate=[],eta_ip=[],const_inp=[],nois=[], - set_spks=[],sspk_idx=[],spikes=[],rec_spks=[], - W=[],I=[],is_inhib=[],fast_inhib=[],W_lyr=[],eta_stdp=[], - step_num=0) - - return net - -################################## -# Add a neuron layer (ie a neuron population) -# net: BITnet instance -# dim: layer dimensions [x,y,...] -# thr_range: initial threshold range -# fire_rate: target firing rate (0=no target) -# ip_rate: intrinsic threshold plasticity (IP) rate (0=no IP) -# const_inp: constant input to each neuron (0=none) -# nois: noise variance (0=no noise) -# rec_spks: record spikes? - -def addLayer(net,dims,thr_range,fire_rate,ip_rate,const_inp,nois,rec_spks): - - # Check constraints etc - if np.isscalar(thr_range): thr_range = [thr_range,thr_range] - if np.isscalar(fire_rate): fire_rate = [fire_rate,fire_rate] - if np.isscalar(const_inp): const_inp = [const_inp,const_inp] - - net['dim'].append(np.array(dims,int)) - net['x'].append(torch.from_numpy(np.zeros(int(np.prod(dims))))) - net['x_prev'].append(torch.from_numpy(np.zeros(int(np.prod(dims))))) - net['x_calc'].append(torch.from_numpy(np.zeros(int(np.prod(dims))))) - net['x_input'].append(torch.from_numpy(np.zeros(int(np.prod(dims))))) - net['x_fastinp'].append(torch.from_numpy(np.zeros(int(np.prod(dims))))) - net['mean_rate'].append(torch.from_numpy(np.zeros(int(np.prod(dims))))) - net['eta_ip'].append(ip_rate) - net['thr'].append(torch.from_numpy(np.random.uniform(thr_range[0],thr_range[1], - int(np.prod(dims))))) - net['fire_rate'].append(torch.from_numpy(np.random.uniform(fire_rate[0],fire_rate[1], - int(np.prod(dims))))) - net['have_rate'].append(any(net['fire_rate'][-1]>0.0)) - - net['const_inp'].append(torch.from_numpy(np.random.uniform(const_inp[0],const_inp[1], - int(np.prod(dims))))) - - net['nois'].append(nois) - net['set_spks'].append([]) - net['sspk_idx'].append(0) - net['spikes'].append(torch.empty([],dtype=torch.float64)) - net['rec_spks'].append(rec_spks) - - return len(net['x'])-1 - -################################## -# Add a set of random connections between layers -# net: BITnet instance -# layer_pre: presynaptic layer -# layer_post: postsynaptic layer -# W_range: weight range [lo,hi] -# p: initial connection probability -# stdp_rate: STDP rate (0=no STDP) -# fast_inhib: is this fast inhibition (ie inhib applied at same timestep) - -def addWeights(net,layer_pre,layer_post,W_range,p,stdp_rate,fast_inhib): - - # Check constraints etc - if np.isscalar(W_range): W_range = [W_range,W_range] - - nrow =net['x'][layer_pre].size(dim=0) - ncol = net['x'][layer_post].size(dim=0) - - Wmn = (W_range[0]+W_range[1])/2.0 - Wsd = (W_range[1]-W_range[0])/6.0 - net['W'].append(torch.from_numpy(np.random.normal(Wmn,Wsd,[nrow,ncol]))) - if Wmn > 0.0: - net['W'][-1][net['W'][-1]<0.0] = 0.0 - else: - net['W'][-1][net['W'][-1]>0.0] = 0.0 - setzero = np.random.rand(nrow,ncol) > p - if layer_pre==layer_post: # no self connections allowed - setzero = np.logical_or(setzero,np.identity(nrow)) - net['W'][-1][setzero] = 0.0 - net['W_lyr'].append([layer_pre,layer_post]) - net['I'].append(np.zeros(ncol)) - net['eta_stdp'].append(stdp_rate) - net['is_inhib'].append(W_range[0]<0.0 and W_range[1]<=0.0) - net['fast_inhib'].append(fast_inhib and net['is_inhib'][-1]) - p_nz = p - if p_nz==0.0: p_nz = 1.0 - - # Normalise the weights (except fast inhib weights) - if not net['fast_inhib'][-1]: - nrm = np.linalg.norm(net['W'][-1],ord=1,axis=0) - nrm[nrm==0.0] = 1.0 - net['W'][-1] = net['W'][-1]/nrm - - return len(net['W'])-1 - -################################## -# Set defined spike times for a neuron layer (ie a neuron population) -# net: BITnet instance -# layer: layer number -# times: 2-column matrix (col 1 = step num (ordered); col 2 = neuron num to spike) -# NOTE for spike forcing an output layer ensure that: eta_ip=0 and target fire_rate=0 -# FOLLOWING training ensure that: forced spikes array is removed, ie: setSpikeTimes(n,l,[]) - -def setSpikeTimes(net,layer,times): - if isinstance(times,list): - net['set_spks'][layer] = times.copy() - else: - net['set_spks'][layer] = times.detach().clone() - net['sspk_idx'][layer] = 0 - -################################## -# Normalise all the firing rates -# net: BITnet instance - -def norm_rates(net): - - for i,rate in enumerate(net['fire_rate']): - if rate.any() and net['eta_ip'][i] > 0.0: - net['thr'][i] = net['thr'][i] + net['eta_ip'][i]*(net['x'][i]-rate) - #xxx net['thr'][i] = net['thr'][i] + net['eta_ip'][i]*(net['x'][i]-rate) - net['thr'][i][net['thr'][i]<0.0] = 0.0 #xxx - -################################## -# Normalise inhib weights to balance input currents -# net: BITnet instance - -def norm_inhib(net): - - #return #xxx no norm_inhib - for i,W in enumerate(net['W']): - if net['eta_stdp'][i] < 0: # and not net['fast_inhib'][i]: - #if net['is_inhib'][i]: # and not net['fast_inhib'][i]: - lyr = net['W_lyr'][i][1] - #wadj = np.multiply(W,np.sign(net['x_input'][lyr]))*-net['eta_stdp'][i]*10 - #wadj = np.multiply(W,net['x_input'][lyr]-net['fire_rate'][lyr])*-net['eta_stdp'][i]*100 - try: - wadj = np.multiply(W,net['x_input'][lyr])*-net['eta_stdp'][i]*50 #0.5 #100 - net['W'][i] += wadj - net['W'][i][W>0.0] = -0.000001 - except RuntimeWarning: - print("norm_inhib err") - pdb.set_trace() - -################################## -# Propagate spikes thru the network -# net: SORN instance - -def calc_spikes(net): - # Start with the noise and constant input in the neurons of each layer - for i,nois in enumerate(net['nois']): - if nois > 0: - net['x_input'][i] = np.random.normal(0.0,nois,int(np.prod(net['dim'][i]))) - else: - net['x_input'][i] = torch.full_like(net['x_input'][i],0.0) - net['x_input'][i] += net['const_inp'][i].detach().clone() - # Find the threshold crossings (overwritten later if needed) - net['x'][i] = torch.clamp((net['x_input'][i]-net['thr'][i]),0.0,0.9) - # Loop thru layers to insert any predefined spikes - for i in range(len(net['set_spks'])): - if len(net['set_spks'][i]): - net['x'][i] = torch.full_like(net['x'][i],0.0) - sidx = net['sspk_idx'][i] - if sidx < len(net['set_spks'][i]): stim = net['set_spks'][i][sidx,0] - while sidx < len(net['set_spks'][i]) and int(stim) <= net['step_num']: - net['x'][i][int(net['set_spks'][i][sidx,1])] = torch.fmod(stim,1) - sidx += 1 - if sidx < len(net['set_spks'][i]): - stim = net['set_spks'][i][sidx,0] - #else: - # net['set_spks'][i] = [] - net['sspk_idx'][i] = sidx - - # Loop thru weight matrices, propagating spikes through. - # The idea is to process all weight matrices going into a layer (ie the nett input to that layer) - # then calculate that layer's spikes (threshold crossings), then move to the next group of weight - # matrices for the next layer. A group is defined as a contiguous set of weight matrices all ending - # on the same layer. This scheme is designed to propagate spikes rapidly up a feedfoward - # hierarachy. It won't work for layers with recurrent connections even if they are in the same - # weight group, since the spikes won't be recurrently p[numnrocessed until the next timestep, so fast - # inhibition is still needed for that. For feedback connections (ie the same layer being in - # different weight groups) this code will do a double timestep for those layers (not desirable). - #ipdb.set_trace() - for i,W in enumerate(net['W']): - if not net['fast_inhib'][i]: - layers = net['W_lyr'][i] - - # Synaptic currents last for 1 timestep - if layers[0]!=layers[1]: - net['I'][i] = torch.matmul(net['x'][layers[0]],W) - else: - net['I'][i] = torch.matmul(net['x_prev'][layers[0]],W) - - net['x_input'][layers[1]] += net['I'][i] - - # Do spikes if this is the last weight matrix or if the next one has a different post layer - # or the next one is fast inhib,### UNLESS this is a recurrent layer - do_spikes = (i==len(net['W'])-1) - if not do_spikes: - do_spikes = not(layers[1]==net['W_lyr'][i+1][1]) or net['fast_inhib'][i+1] - #if do_spikes: - # do_spikes = layers[0]!=layers[1] - if do_spikes: - j = layers[1] - - # Find threshold crossings - if layers[0]!=layers[1]: - net['x_prev'][j] = net['x'][j][:] - if not len(net['set_spks'][j]): - # No predefined spikes for this layer - net['x'][j] = np.clip(net['x_input'][j]-net['thr'][j],a_min=0.0,a_max=0.9) - else: - # Predefined spikes exist for this layer, remember the calculated ones - net['x_calc'][j] = np.clip(net['x_input'][j]-net['thr'][j],a_min=0.0,a_max=0.9) - if layers[0]==layers[1]: - net['x_prev'][j] = net['x'][j][:] - - # If the next weight matrix is fast inhib for this layer, process it now - if i < len(net['W'])-1: - if net['fast_inhib'][i+1] and layers[1]==net['W_lyr'][i+1][1]: - flyrs = net['W_lyr'][i+1] - net['x_fastinp'][flyrs[1]] = net['x_input'][flyrs[1]].copy() - if flyrs[0]==flyrs[1]: - postsyn_spks = np.tile(net['x'][flyrs[0]],[len(net['x'][flyrs[0]]),1]) - presyn_spks = np.transpose(postsyn_spks) - presyn_spks[presyn_spks < postsyn_spks] = 0.0 - net['x_fastinp'][flyrs[1]] += np.sum((presyn_spks)*net['W'][i+1],0) - else: - net['x_fastinp'][flyrs[1]] += np.matmul(net['x'][flyrs[0]],net['W'][i+1]) - if not len(net['set_spks'][j]): - # No predefined spikes for this layer - net['x'][flyrs[1]] = np.clip(net['x_fastinp'][flyrs[1]]-net['thr'][flyrs[1]], - a_min=0.0,a_max=0.9) - else: - # Predefined spikes exist for this layer, remember the calculated ones - net['x_calc'][flyrs[1]] = np.clip(net['x_fastinp'][flyrs[1]]-net['thr'][flyrs[1]], - a_min=0.0,a_max=0.9) - - # Finally, update mean firing rates and record all spikes if needed - for i,eta in enumerate(net['eta_ip']): - - if eta > 0.0: - net['mean_rate'][i] = net['mean_rate'][i]*(1.0-eta) +\ - (net['x'][i]>0.0)*eta - if net['rec_spks'][i]: - outspk = (net['x'][i]).detach().cpu().numpy() - if i == 2: - outspk[outspk<0.05] = 0 - n_idx = np.nonzero(outspk) - net['spikes'][i].extend([net['step_num']+net['x'][i][n].detach().cpu().numpy(),n] - for n in n_idx) - -################################## -# Calculate STDP -# net: BITnet instance - -def calc_stdp(net): - - # Loop thru weight matrices that have non-zero learning rate - for i,W in enumerate(net['W']): - if net['eta_stdp'][i] != 0: - - # Remember layer numbers and weight matrix shape - layers = net['W_lyr'][i] - shape = W.size() - - # - # Spike Forcing has special rules to make calculated and forced spikes match - # - if len(net['set_spks'][layers[1]]): - - # Diff between forced and calculated spikes - xdiff = net['x'][layers[1]] - net['x_calc'][layers[1]] - # Modulate learning rate by firing rate (low firing rate = high learning rate) - #if net['have_rate'][layers[1]]: - # xdiff /= net['fire_rate'][layers[1]] - - # Threshold rules - lower it if calced spike is smaller (and vice versa) - net['thr'][layers[1]] -= np.sign(xdiff)*np.abs(net['eta_stdp'][i])/10 - net['thr'][layers[1]][net['thr'][layers[1]]<0.0] = 0.0 # don't go -ve - - # A little bit of threshold decay - #net['thr'][layers[1]] *= (1-net['eta_stdp'][i]/100) - - # Pre and Post spikes tiled across and down for all synapses - if net['have_rate'][layers[0]]: - # Modulate learning rate by firing rate (low firing rate = high learning rate) - mpre = net['x'][layers[0]]/net['fire_rate'][layers[0]] - else: - mpre = net['x'][layers[0]] - pre = torch.from_numpy(np.tile(np.reshape(mpre, [shape[0],1]),[1,shape[1]])) - post = torch.from_numpy(np.tile(np.reshape(xdiff,[1,shape[1]]),[shape[0],1])) - - # Excitatory connections - if net['eta_stdp'][i] > 0: - havconn = W>0 - inc_stdp = pre*post*havconn - # Inhibitory connections - else: - havconn = W<0 - inc_stdp = -pre*post*havconn - - # Apply the weight changes - net['W'][i] += inc_stdp*net['eta_stdp'][i] - - # - # Normal STDP - # - elif not net['fast_inhib'][i]: - - pre = torch.from_numpy(np.tile(np.reshape(net['x'][layers[0]],[shape[0],1]),[1,shape[1]])) - if net['have_rate'][layers[1]]: - # Modulate learning rate by firing rate (low firing rate = high learning rate) - mpost = net['x'][layers[1]] #/net['fire_rate'][layers[1]] - else: - mpost = net['x'][layers[1]] - post = torch.from_numpy(np.tile(np.reshape(mpost,[1,shape[1]]),[shape[0],1])) - - # Excitatory synapses - if net['eta_stdp'][i] > 0: - havconn = W>0 - inc_stdp = (0.5-post)*(pre>0)*(post>0)*havconn - # Inhibitory synapses - elif not net['fast_inhib'][i]: # and False: - havconn = W<0 - inc_stdp = (0.5-post)*(pre>0)*(post>0)*havconn - - # Apply the weight changes - net['W'][i] += inc_stdp*net['eta_stdp'][i] - - # - # Fast inhibitory synapses, xxx update for firing rate modulation of eta_stdp? - # - else: - - # Store weight changes - inc_stdp = np.zeros(shape) - dec_stdp = np.zeros(shape) - - # Loop thru firing pre neurons - for pre in np.where(net['x'][layers[0]])[0]: - # Loop thru ALL post neurons - for post in range(len(net['x'][layers[1]])): - if net['W'][i][pre,post]!=0: - if net['x'][layers[1]][post] > 0.0: - if net['x'][layers[0]][pre] >\ - net['x'][layers[1]][post]: - # Synapse gets stronger if pre fires before post - inc_stdp[pre,post] = 0.5 #0.1 #/\ - #net['mean_rate'][layers[1]][post] - #net['mean_rate'][layers[0]][pre] - else: - # Synapse gets weaker if pre fires after post - dec_stdp[pre,post] = 0.5 *\ - (1.0-net['mean_rate'][layers[1]][post]) - else: - # Also gets weaker if pre fires and not post - dec_stdp[pre,post] = 0.5*\ - net['mean_rate'][layers[1]][post] - - # Apply the weight changes - net['W'][i] += inc_stdp*net['eta_stdp'][i] - net['W'][i] -= dec_stdp*net['eta_stdp'][i] - - # - # Finish - # - - # Try weight decay? - #net['W'][i] = net['W'][i]-net['eta_stdp'][i]/10 # * (1-net['eta_stdp'][i]) - if net['eta_stdp'][i] > 0.0: - # Excitation - pruning and synaptogenesis (structural plasticity) - net['W'][i][W<0.0] = 0.000001 #xxx - net['W'][i][W>10.0] = 10.0 #xxx - if np.random.rand() < 0.0: #0.1: #xxx TEMPORARILY OFF - synap = (np.random.rand(2)*shape).astype(int) - if net['W'][i][synap[0]][synap[1]] == 0: - net['W'][i][synap[0]][synap[1]] = 0.001 - else: - # Inhibition - must not go +ve - net['W'][i][W>0.0] = -0.000001 #xxx - net['W'][i][W<-10.0] = -10.0 #xxx - - # Finally clear out any predefined spikes that are used up (so calculated network spikes can take over) - for i in range(len(net['set_spks'])): - if len(net['set_spks'][i]): - if len(net['set_spks'][i]) <= net['sspk_idx'][i]: - net['set_spks'][i] = [] - -################################## -# Run the simulation -# net: BITnet instance -# n_steps: number of steps - -def runSim(net,n_steps): - - # Loop - for step in range(n_steps): - - # Inc step count - net['step_num'] += 1 - - # Propagate spikes from pre to post neurons - calc_spikes(net) - - # Calculate STDP weight changes - calc_stdp(net) - - # Normalise firing rates and inhibitory balance - norm_rates(net) - norm_inhib(net) - -################################## -# Plot recorded spikes in current subplot -# net: BITnet instance - -def subplotSpikes(net,cutoff): - - n_tot = 0 - for i,sp in enumerate(net['spikes']): - x=[]; y=[] - for n in sp: - x.extend(list(n[0])) - y.extend(list(n[1]+n_tot)) - - plt.plot(x,y,'.',ms=1) - n_tot += np.size(net['x'][i].detach().cpu().numpy()) - -################################## -# Plot recorded spikes in new figure -# net: BITnet instance - -def plotSpikes(net,cutoff): - - plt.figure() - subplotSpikes(net,cutoff) - plt.show(block=False) - -################################## diff --git a/src/data_process.py b/src/data_process.py deleted file mode 100644 index 57a378d..0000000 --- a/src/data_process.py +++ /dev/null @@ -1,117 +0,0 @@ -import cv2 -from PIL import Image -import numpy as np -import os -import torch -from torch.utils.data import Dataset, DataLoader - - -def preprocess_img(img_dir, channels=3): - - if channels == 1: - img = cv2.imread(img_dir, 0) - elif channels == 3: - img = cv2.imread(img_dir) - - channels = img.shape[2] - shape_r = 288 - shape_c = 384 - img_padded = np.ones((shape_r, shape_c, channels), dtype=np.uint8) - if channels == 1: - img_padded = np.zeros((shape_r, shape_c), dtype=np.uint8) - original_shape = img.shape - rows_rate = original_shape[0] / shape_r - cols_rate = original_shape[1] / shape_c - if rows_rate > cols_rate: - new_cols = (original_shape[1] * shape_r) // original_shape[0] - img = cv2.resize(img, (new_cols, shape_r)) - if new_cols > shape_c: - new_cols = shape_c - img_padded[:, - ((img_padded.shape[1] - new_cols) // 2):((img_padded.shape[1] - new_cols) // 2 + new_cols)] = img - else: - new_rows = (original_shape[0] * shape_c) // original_shape[1] - img = cv2.resize(img, (shape_c, new_rows)) - - if new_rows > shape_r: - new_rows = shape_r - img_padded[((img_padded.shape[0] - new_rows) // 2):((img_padded.shape[0] - new_rows) // 2 + new_rows), - :] = img - - return img_padded - - -def postprocess_img(pred,img): - pred = np.array(pred) - - shape_r = img.shape[0] - shape_c = img.shape[1] - predictions_shape = pred.shape - - rows_rate = shape_r / predictions_shape[0] - cols_rate = shape_c / predictions_shape[1] - - if rows_rate > cols_rate: - new_cols = (predictions_shape[1] * shape_r) // predictions_shape[0] - pred = cv2.resize(pred, (new_cols, shape_r)) - img = pred[:, ((pred.shape[1] - shape_c) // 2):((pred.shape[1] - shape_c) // 2 + shape_c)] - else: - new_rows = (predictions_shape[0] * shape_c) // predictions_shape[1] - pred = cv2.resize(pred, (shape_c, new_rows)) - img = pred[((pred.shape[0] - shape_r) // 2):((pred.shape[0] - shape_r) // 2 + shape_r), :] - - return img - - -class MyDataset(Dataset): - """Load dataset.""" - - def __init__(self, ids, stimuli_dir, saliency_dir, fixation_dir, transform=None): - """ - Args: - csv_file (string): Path to the csv file with annotations. - root_dir (string): Directory with all the images. - transform (callable, optional): Optional transform to be applied - on a sample. - """ - self.ids = ids - self.stimuli_dir = stimuli_dir - self.saliency_dir = saliency_dir - self.fixation_dir = fixation_dir - self.transform = transform - - def __len__(self): - return len(self.ids) - - def __getitem__(self, idx): - if torch.is_tensor(idx): - idx = idx.tolist() - - im_path = self.stimuli_dir + self.ids.iloc[idx, 0] - image = Image.open(im_path).convert('RGB') - img = np.array(image) / 255. - img = np.transpose(img, (2, 0, 1)) - img = torch.from_numpy(img) - # if self.transform: - # img = self.transform(image) - - smap_path = self.saliency_dir + self.ids.iloc[idx, 1] - saliency = Image.open(smap_path) - - smap = np.expand_dims(np.array(saliency) / 255., axis=0) - smap = torch.from_numpy(smap) - - fmap_path = self.fixation_dir + self.ids.iloc[idx, 2] - fixation = Image.open(fmap_path) - - fmap = np.expand_dims(np.array(fixation) / 255., axis=0) - fmap = torch.from_numpy(fmap) - - sample = {'image': img, 'saliency': smap, 'fixation': fmap} - - return sample - - - - - diff --git a/src/densenet.py b/src/densenet.py deleted file mode 100644 index fb9cde6..0000000 --- a/src/densenet.py +++ /dev/null @@ -1,287 +0,0 @@ -import re -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint as cp -from collections import OrderedDict -# from .utils import load_state_dict_from_url -from torch import Tensor -from torch.jit.annotations import List - - -__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] - -model_urls = { - 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', - 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', - 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', - 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', -} - - -class _DenseLayer(nn.Module): - def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=False): - super(_DenseLayer, self).__init__() - self.add_module('norm1', nn.BatchNorm2d(num_input_features)), - self.add_module('relu1', nn.ReLU(inplace=True)), - self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * - growth_rate, kernel_size=1, stride=1, - bias=False)), - self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), - self.add_module('relu2', nn.ReLU(inplace=True)), - self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, - kernel_size=3, stride=1, padding=1, - bias=False)), - self.drop_rate = float(drop_rate) - self.memory_efficient = memory_efficient - - def bn_function(self, inputs): - # type: (List[Tensor]) -> Tensor - concated_features = torch.cat(inputs, 1) - bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features))) # noqa: T484 - return bottleneck_output - - # todo: rewrite when torchscript supports any - def any_requires_grad(self, input): - # type: (List[Tensor]) -> bool - for tensor in input: - if tensor.requires_grad: - return True - return False - - @torch.jit.unused # noqa: T484 - def call_checkpoint_bottleneck(self, input): - # type: (List[Tensor]) -> Tensor - def closure(*inputs): - return self.bn_function(inputs) - - return cp.checkpoint(closure, *input) - - @torch.jit._overload_method # noqa: F811 - def forward(self, input): - # type: (List[Tensor]) -> (Tensor) - pass - - @torch.jit._overload_method # noqa: F811 - def forward(self, input): - # type: (Tensor) -> (Tensor) - pass - - # torchscript does not yet support *args, so we overload method - # allowing it to take either a List[Tensor] or single Tensor - def forward(self, input): # noqa: F811 - if isinstance(input, Tensor): - prev_features = [input] - else: - prev_features = input - - if self.memory_efficient and self.any_requires_grad(prev_features): - if torch.jit.is_scripting(): - raise Exception("Memory Efficient not supported in JIT") - - bottleneck_output = self.call_checkpoint_bottleneck(prev_features) - else: - bottleneck_output = self.bn_function(prev_features) - - new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) - if self.drop_rate > 0: - new_features = F.dropout(new_features, p=self.drop_rate, - training=self.training) - return new_features - - -class _DenseBlock(nn.ModuleDict): - _version = 2 - - def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=False): - super(_DenseBlock, self).__init__() - for i in range(num_layers): - layer = _DenseLayer( - num_input_features + i * growth_rate, - growth_rate=growth_rate, - bn_size=bn_size, - drop_rate=drop_rate, - memory_efficient=memory_efficient, - ) - self.add_module('denselayer%d' % (i + 1), layer) - - def forward(self, init_features): - features = [init_features] - for name, layer in self.items(): - new_features = layer(features) - features.append(new_features) - return torch.cat(features, 1) - - -class _Transition(nn.Sequential): - def __init__(self, num_input_features, num_output_features): - super(_Transition, self).__init__() - self.add_module('norm', nn.BatchNorm2d(num_input_features)) - self.add_module('relu', nn.ReLU(inplace=True)) - self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, - kernel_size=1, stride=1, bias=False)) - self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) - - -class DenseNet(nn.Module): - r"""Densenet-BC model class, based on - `"Densely Connected Convolutional Networks" `_ - - Args: - growth_rate (int) - how many filters to add each layer (`k` in paper) - block_config (list of 4 ints) - how many layers in each pooling block - num_init_features (int) - the number of filters to learn in the first convolution layer - bn_size (int) - multiplicative factor for number of bottle neck layers - (i.e. bn_size * k features in the bottleneck layer) - drop_rate (float) - dropout rate after each dense layer - num_classes (int) - number of classification classes - memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, - but slower. Default: *False*. See `"paper" `_ - """ - - def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), - num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000, memory_efficient=False): - - super(DenseNet, self).__init__() - - # First convolution - self.features = nn.Sequential(OrderedDict([ - ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, - padding=3, bias=False)), - ('norm0', nn.BatchNorm2d(num_init_features)), - ('relu0', nn.ReLU(inplace=True)), - ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), - ])) - - # Each denseblock - num_features = num_init_features - for i, num_layers in enumerate(block_config): - block = _DenseBlock( - num_layers=num_layers, - num_input_features=num_features, - bn_size=bn_size, - growth_rate=growth_rate, - drop_rate=drop_rate, - memory_efficient=memory_efficient - ) - self.features.add_module('denseblock%d' % (i + 1), block) - num_features = num_features + num_layers * growth_rate - if i != len(block_config) - 1: - trans = _Transition(num_input_features=num_features, - num_output_features=num_features // 2) - self.features.add_module('transition%d' % (i + 1), trans) - num_features = num_features // 2 - - # Final batch norm - self.features.add_module('norm5', nn.BatchNorm2d(num_features)) - - # Linear layer - self.classifier = nn.Linear(num_features, num_classes) - - # Official init from torch repo. - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight) - elif isinstance(m, nn.BatchNorm2d): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.Linear): - nn.init.constant_(m.bias, 0) - - def forward(self, x): - features = self.features(x) - out = F.relu(features, inplace=True) - out = F.adaptive_avg_pool2d(out, (1, 1)) - out = torch.flatten(out, 1) - out = self.classifier(out) - return out - - -def _load_state_dict(model, model_url, progress, flag): - # '.'s are no longer allowed in module names, but previous _DenseLayer - # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. - # They are also in the checkpoints in model_urls. This pattern is used - # to find such keys. - pattern = re.compile( - r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') - if flag == "densenet161": - state_dict = torch.load(r'densenet161-8d451a50.pth') - else: - state_dict = load_state_dict_from_url(model_url, progress=progress) - for key in list(state_dict.keys()): - res = pattern.match(key) - if res: - new_key = res.group(1) + res.group(2) - state_dict[new_key] = state_dict[key] - del state_dict[key] - model.load_state_dict(state_dict) - - -def _densenet(arch, growth_rate, block_config, num_init_features, pretrained, progress, - **kwargs): - model = DenseNet(growth_rate, block_config, num_init_features, **kwargs) - if pretrained: - if arch == 'densenet161': - _load_state_dict(model, model_urls[arch], progress, 'densenet161') - else: - _load_state_dict(model, model_urls[arch], progress, 0) - return model - - -def densenet121(pretrained=False, progress=True, **kwargs): - r"""Densenet-121 model from - `"Densely Connected Convolutional Networks" `_ - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, - but slower. Default: *False*. See `"paper" `_ - """ - return _densenet('densenet121', 32, (6, 12, 24, 16), 64, pretrained, progress, - **kwargs) - - - -def densenet161(pretrained=False, progress=True, **kwargs): - r"""Densenet-161 model from - `"Densely Connected Convolutional Networks" `_ - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, - but slower. Default: *False*. See `"paper" `_ - """ - return _densenet('densenet161', 48, (6, 12, 36, 24), 96, pretrained, progress, - **kwargs) - - - -def densenet169(pretrained=False, progress=True, **kwargs): - r"""Densenet-169 model from - `"Densely Connected Convolutional Networks" `_ - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, - but slower. Default: *False*. See `"paper" `_ - """ - return _densenet('densenet169', 32, (6, 12, 32, 32), 64, pretrained, progress, - **kwargs) - - - -def densenet201(pretrained=False, progress=True, **kwargs): - r"""Densenet-201 model from - `"Densely Connected Convolutional Networks" `_ - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, - but slower. Default: *False*. See `"paper" `_ - """ - return _densenet('densenet201', 32, (6, 12, 48, 32), 64, pretrained, progress, - **kwargs) \ No newline at end of file diff --git a/src/utils.py b/src/utils.py deleted file mode 100644 index 01e2e21..0000000 --- a/src/utils.py +++ /dev/null @@ -1,172 +0,0 @@ -###################################################### -# -# Useful utility functions -# - -import numpy as np -import statsmodels.api as sm -import matplotlib.pyplot as plt - -##################################################### -# Function to calculate normalised dot product -# - -def dotprod(vec1,vec2): - - # Normalise - n1 = np.linalg.norm(vec1) - n2 = np.linalg.norm(vec2) - - # Calc - if n1 > 0.0 and n2 > 0.0: - dot = np.dot(vec1/n1,vec2/n2) - else: - dot = 0.0 - - return dot - -##################################################### -# Function to build and test a linear classifier -# - -def testmodel(X,Y,X_tst,Y_tst): - - # Construct decoder and get predictions - Xc = sm.add_constant(X,has_constant='add') - Xc_tst = sm.add_constant(X_tst,has_constant='add') - model = sm.OLS(Y,Xc) - results = model.fit() - Yp = results.predict(Xc_tst) - - # How many correct classifications - sum = 0 - for i in range(len(Y_tst)): - sum += Y_tst[i,np.argmax(Yp[i,:])] - - return sum/len(Y_tst) - -##################################################### -# Function to build and test a linear predictor (not a strict classifier) -# - -def testmodel_arb(X,Y,X_tst,Y_tst): - - # Construct decoder and get predictions - Xc = sm.add_constant(X,has_constant='add') - Xc_tst = sm.add_constant(X_tst,has_constant='add') - model = sm.OLS(Y,Xc) - results = model.fit() - Yp = results.predict(Xc_tst) - - # How many correct predictions - sum = 0 - for i in range(len(Y_tst)): - sum += dotprod(Y_tst[i,:],Yp[i,:]) - - return sum/len(Y_tst) - -##################################################### -# Function to build and test a linear binary classifier -# - -def testmodel_bin(X,y,X_tst,y_tst): - - # Construct decoder and get predictions - Xc = sm.add_constant(X,has_constant='add') - Xc_tst = sm.add_constant(X_tst,has_constant='add') - model = sm.OLS(y,Xc) - results = model.fit() - yp = np.round(results.predict(Xc_tst)) - - # How many correct classifications - sum = np.sum(y_tst==yp) - - return sum/len(y_tst) - -##################################################### -# Function to create fit line -# - -def fitlin(x,y,nb,pctl): - minx = min(x); maxx = max(x); stp = (maxx-minx)/nb; minx += stp/2 - fx = []; fy = [] - for i in range(nb): - fx.append(minx) - #fy.append(np.median(y[np.logical_and(x>=minx-stp/2,x<=minx+stp/2)])) - xrange = np.logical_and(x>=minx-stp/2,x<=minx+stp/2) - if np.any(xrange): - if pctl: - fy.append(np.percentile(y[xrange],[10,90])) - else: - fy.append(np.median(y[xrange])) - else: - if pctl: - fy.append(np.array([0,0])) - else: - fy.append(0) - minx += stp - return(np.array(fx),np.array(fy)) - -##################################################### -# Function to perform PCA (from askpython.com/python/examples/principal-component-analysis) -# - -def PCA(X,num_components): - - #Step-1 - X_meaned = X - np.mean(X,axis=0) - #Step-2 - cov_mat = np.cov(X_meaned,rowvar=False) - #Step-3 - eigen_values,eigen_vectors = np.linalg.eigh(cov_mat) - #Step-4 - sorted_index = np.argsort(eigen_values)[::-1] - sorted_eigenvalue = eigen_values[sorted_index] - sorted_eigenvectors = eigen_vectors[:,sorted_index] - #Step-5 - eigenvector_subset = sorted_eigenvectors[:,0:num_components] - #Step-6 - X_reduced = np.dot(eigenvector_subset.transpose(),X_meaned.transpose() ).transpose() - # Done - return X_reduced - -##################################################### -# Function to build and test a logistic classifier -# - -def testmodel_logit(X,y,X_tst,y_tst): - - try: - # Build model and fit data - #X = sm.add_constant(X,has_constant='add') - #X_tst = sm.add_constant(X_tst,has_constant='add') - log_reg = sm.Logit(y,X).fit() - # Get predictions - yp = np.round(log_reg.predict(X_tst)) - - # How many correct classifications - sum = np.sum(y_tst==yp) - except: - sum = 0 - - return sum/len(y_tst) - -##################################################### -# Function to build and test a LM classifier -# - -def testmodel_LM(X,y,X_tst,y_tst): - - # Build model and fit data - X = sm.add_constant(X,has_constant='add') - X_tst = sm.add_constant(X_tst,has_constant='add') - results = sm.GLM(y,X).fit() - # Get predictions - yp = np.round(results.predict(X_tst)) - - # How many correct classifications - sum = np.sum(y_tst==yp) - - return sum/len(y_tst) - -#####################################################