Skip to content

Commit

Permalink
fix minor bugs
Browse files Browse the repository at this point in the history
cagrikymk committed Feb 16, 2024
1 parent 87b37ec commit 859c382
Showing 3 changed files with 5 additions and 3 deletions.
1 change: 1 addition & 0 deletions jaxreaxff/driver_v2.py
Original file line number Diff line number Diff line change
@@ -167,6 +167,7 @@ def main():

batch_size = args.batch_size
data = align_and_batch_structures(data, max_sizes, batch_size=batch_size, dtype=TYPE)
data = [move_dataclass(d, jnp) for d in data]
total_size = len(data)
train_size = int(total_size * 0.8)
train_data = data[:train_size]
2 changes: 1 addition & 1 deletion jaxreaxff/optimizer.py
Original file line number Diff line number Diff line change
@@ -549,7 +549,7 @@ def random_parameter_search(bounds, sample_count,
else:
selected_params = onp.random.uniform(low=bounds[:,0],high=bounds[:,1])
selected_params = jnp.array(selected_params, dtype=dtype)
return min_params, min_loss
return selected_params



5 changes: 3 additions & 2 deletions jaxreaxff/structure.py
Original file line number Diff line number Diff line change
@@ -155,6 +155,7 @@ def align_and_batch_structures(structures, max_sizes, batch_size, dtype=onp.floa
full_size = len(structures)
batches = []
for bs in range(0,full_size-batch_size,batch_size):
batches.append(align_structures(structures[bs:bs+batch_size],
max_sizes, dtype))
batch = align_structures(structures[bs:bs+batch_size],
max_sizes, dtype)
batches.append(batch)
return batches

0 comments on commit 859c382

Please sign in to comment.