Skip to content

Commit

Permalink
Fixed NewReader to work with auxs
Browse files Browse the repository at this point in the history
  • Loading branch information
richardjgowers committed Dec 1, 2016
1 parent 5535e3b commit c9503b5
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions package/MDAnalysis/coordinates/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1642,6 +1642,7 @@ def __init__(self, filename, convert_units=None, **kwargs):
convert_units = flags['convert_lengths']
self.convert_units = convert_units

self._auxs = {}
ts_kwargs = {}
for att in ('dt', 'time_offset'):
try:
Expand All @@ -1657,30 +1658,29 @@ def _update_last_fh_position(self, pos):
"""Update the last known position of the file handle"""
self._last_fh_pos = pos

def rewind(self):
self._reopen()
self.next()

def _full_iter(self):
self._reopen()
with util.openany(self.filename, 'r') as self._file:
while True:
try:
yield self._read_next_timestep()
ts = self._read_next_timestep()
for auxname in self.aux_list:
ts = self._auxs[auxname].update_ts(ts)
yield ts
except (EOFError, IOError):
self.rewind()
raise StopIteration

def _sliced_iter(self, frames):
def _sliced_iter(self, start, stop, step):
with util.openany(self.filename, 'r') as self._file:
for f in frames:
yield self._read_frame(f)
for f in range(start, stop, step):
yield self._read_frame_with_aux(f)
self.rewind()
raise StopIteration

def _goto_frame(self, i):
with util.openany(self.filename, 'r') as self._file:
ts = self._read_frame(i)
ts = self._read_frame_with_aux(i)
return ts

def __iter__(self):
Expand All @@ -1700,7 +1700,10 @@ def apply_limits(frame):
elif isinstance(item, (list, np.ndarray)):
return self._sliced_iter(item)
elif isinstance(item, slice): # TODO Fix me!
return self._sliced_iter(item)
start, stop, step = self.check_slice_indices(
item.start, item.stop, item.step)

return self._sliced_iter(start, stop, step)

def __next__(self):
with util.openany(self.filename, 'r') as self._file:
Expand All @@ -1712,6 +1715,8 @@ def __next__(self):
raise StopIteration
else:
self._last_fh_pos = self._file.tell()
for auxname in self.aux_list:
ts = self._auxs[auxname].update_ts(ts)
return ts

next = __next__
Expand Down

0 comments on commit c9503b5

Please sign in to comment.