Skip to content

Commit

Permalink
[feature] normalize stft and remove complex dimension
Browse files Browse the repository at this point in the history
  • Loading branch information
mikkel committed Mar 30, 2016
1 parent 734adf4 commit 9722937
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 26 deletions.
9 changes: 6 additions & 3 deletions convert_to_istft.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sys
import tensorflow_wav
import numpy as np
import math
if(len(sys.argv)<2):
print("You have to pick a file")
no_file_Picked_Exception
Expand All @@ -22,9 +23,11 @@
nframes = wav['nframes']
print('shape', wav['data'].shape)
time = 192
print(np.min(wav['data']), np.max(wav['data']), np.mean(wav['data']), np.std(wav['data']))
#wav['data'] = np.exp(wav['data'])
data = istft(wav['data'],fs, time, hop)
print(wav)
wav['data']=data
print(np.min(data), np.max(data))
wav['data'] = data*3
#print(wav)
#wav['data'] = np.sign(wav['data'])*np.sqrt(wav['data'])
res= tensorflow_wav.save_wav(wav, file+".istft")
print(file+".istft"+" is written")
7 changes: 5 additions & 2 deletions convert_to_stft.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sys
import tensorflow_wav
import numpy as np
import math
if(len(sys.argv)<2):
print("You have to pick a file")
no_file_Picked_Exception
Expand All @@ -21,7 +22,9 @@
wav= tensorflow_wav.get_wav(file)
data = stft(wav['data'],fs, framesz, hop)
print(wav)
wav['data']=data
print(np.min(data), np.max(data))
wav['data']=data.real
print(wav['data'])
#wav['data'] = np.sign(wav['data'])*np.power(wav['data'], 2)
print(np.min(wav['data']), np.max(wav['data']))
res= tensorflow_wav.save_stft(wav, file+".stft")
print(file+".stft"+" is written")
20 changes: 9 additions & 11 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
BITRATE=4096
class DCGAN(object):
def __init__(self, sess, wav_size=WAV_SIZE, is_crop=True,
batch_size=64, sample_size = 2, wav_shape=[WAV_SIZE, WAV_HEIGHT, 2],
batch_size=64, sample_size = 2, wav_shape=[WAV_SIZE, WAV_HEIGHT, 1],
y_dim=None, z_dim=64, gf_dim=64, df_dim=64,
gfc_dim=1024, dfc_dim=1024, c_dim=2, dataset_name='default',
gfc_dim=1024, dfc_dim=1024, c_dim=1, dataset_name='default',
checkpoint_dir='checkpoint'):
"""
Expand Down Expand Up @@ -74,7 +74,7 @@ def build_model(self):
if self.y_dim:
self.y= tf.placeholder(tf.float32, [None, self.y_dim], name='y')

self.wavs = tf.placeholder(tf.complex64, [self.batch_size, BITRATE],
self.wavs = tf.placeholder(tf.float32, [self.batch_size, BITRATE],
name='real_wavs')

self.z = tf.placeholder(tf.float32, [None, self.z_dim],
Expand All @@ -91,7 +91,7 @@ def build_model(self):
self.sampler = self.sampler(self.z)
self.sampler = tf.reshape(self.sampler,[-1])
#self.sampler = tensorflow_wav.decode(self.sampler)
encoded_G = tensorflow_wav.compose(self.G)#tensorflow_wav.encode(self.G)
encoded_G = self.G#tensorflow_wav.encode(self.G)
self.D_ = self.discriminator(encoded_G, reuse=True)


Expand Down Expand Up @@ -202,14 +202,12 @@ def get_wav_content(files):
for repeat in range(errd_range):
#print("discrim ", errd_range)
# Update D network
print("Running discriminator with min/max", batch_wavs.min(), batch_wavs.max())
_= self.sess.run([d_optim],
feed_dict={ self.wavs: batch_wavs, self.z: batch_z })
#self.writer.add_summary(summary_str, counter)

# Run g_optim twice to make sure that d_loss does not go to zero (different from paper)
#if(errG > 8):
# errg_range = 2
#else:
errg_range=1
for repeat in range(errg_range):
#print("generating ", errg_range)
Expand All @@ -227,8 +225,8 @@ def get_wav_content(files):
% (epoch, idx, batch_idxs,
time.time() - start_time, errD_fake, errD_real, errG))

SAVE_COUNT=100
SAMPLE_COUNT=20
SAVE_COUNT=500
SAMPLE_COUNT=100

print("Batch ", counter)
if np.mod(counter, SAVE_COUNT) == SAVE_COUNT-3:
Expand Down Expand Up @@ -309,7 +307,7 @@ def generator(self, z, y=None):
print('h3',h3.get_shape())

h4= deconv2d(h3,
[self.batch_size, WAV_SIZE, WAV_HEIGHT, 2], name='g_h4', with_w=False, no_bias=False)
[self.batch_size, WAV_SIZE, WAV_HEIGHT, 1], name='g_h4', with_w=False, no_bias=False)

print('h4',h4.get_shape())
tanh = tf.nn.tanh(h4)
Expand Down Expand Up @@ -338,7 +336,7 @@ def sampler(self, z, y=None):
h3 = tf.nn.relu(self.g_bn3(h3, train=False))
print('h3', h3.get_shape())

h4 = deconv2d(h3, [self.batch_size, 64, 64, 2], name='g_h4', no_bias=False)
h4 = deconv2d(h3, [self.batch_size, 64, 64, 1], name='g_h4', no_bias=False)
print('h4', h4.get_shape())

#tanh = tf.nn.tanh(h4)
Expand Down
11 changes: 3 additions & 8 deletions tensorflow_wav.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def save_wav(in_wav, path):
wav.writeframes(processed)

def save_stft(in_wav, path):
in_wav['data'] = in_wav['data']/(1e7+1e7j)
f = open(path, "wb")
try:
pickle.dump(in_wav, f, pickle.HIGHEST_PROTOCOL)
Expand All @@ -52,16 +51,12 @@ def save_stft(in_wav, path):
def get_stft(filename):
f = open(filename, "rb")
data = pickle.load(f)
data['data'] = data['data']*(1e7+1e7j)
f.close()
return data


def decompose(input, rank=3):
real, imag = tf.split(rank, 2, input)
complex = tf.complex(real, imag)
return complex
def compose(input, rank=3):
return input
real = tf.real(input)
imag = tf.imag(input)
return tf.concat(rank, [real, imag])
Expand All @@ -74,6 +69,6 @@ def encode(input,bitrate=4096):
return output

def scale_up(input):
output = tf.nn.tanh(input)
return decompose(input)#*2e4
output = tf.nn.tanh(input)*(65535)
return output

12 changes: 10 additions & 2 deletions test_stft_cpu.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
import scipy, pylab
import numpy as np
import tensorflow_wav
from math import sqrt

def fft(x):
n = x.shape[0]
return scipy.fft(x)*1/sqrt(n)
def stft(x, fs, framesz, hop):
#print("STFT got", x, fs, framesz, hop)
framesamp = int(framesz*fs)
hopsamp = int(hop*fs)
w = scipy.hanning(framesamp)
def do_fft(w,x,i,framesamp):
#print("Running FFT for ", i, framesamp)
return scipy.fft(w*x[i:i+framesamp])
return fft(w*x[i:i+framesamp])
X = scipy.array([do_fft(w,x,i,framesamp)
for i in range(0, len(x)-framesamp, hopsamp)])
#print("X SHAPE IS", len(X), len(X[0]))
return X

def ifft(x):
n = x.shape[0]
return scipy.ifft(x)*1/sqrt(n)
def istft(X, fs, T, hop):
x = scipy.zeros(T*fs)
framesamp = X.shape[1]
Expand All @@ -23,7 +31,7 @@ def istft(X, fs, T, hop):
if(n>=X.shape[0]):
break
#print("setting i to i+framesamp", n, i, framesamp, i+framesamp, len(X), len(x))
x[i:i+framesamp] += scipy.real(scipy.ifft(X[n]))
x[i:i+framesamp] += scipy.real(ifft(X[n]))
#print(x.shape)
return x

Expand Down

0 comments on commit 9722937

Please sign in to comment.