Skip to content

Commit

Permalink
ENH: Update multiple rows in files and DBs with data.update
Browse files Browse the repository at this point in the history
Closes #312
  • Loading branch information
jaidevd committed Sep 28, 2021
1 parent 146b436 commit c6778d8
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 9 deletions.
19 changes: 10 additions & 9 deletions gramex/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,13 +741,10 @@ def _filter_frame(data, meta, controls, args, source='select', id=[]):
# Apply filters
data = _filter_frame_col(data, key, col, op, vals, meta)
elif source == 'update':
# Update values should only contain 1 value. 2nd onwards are ignored
if key not in data.columns or len(vals) == 0:
meta['ignored'].append((key, vals))
else:
cols_for_update[key] = vals[0]
if len(vals) > 1:
meta['ignored'].append((key, vals[1:]))
cols_for_update[key] = vals
else:
meta['ignored'].append((key, vals))
meta['count'] = len(data)
Expand Down Expand Up @@ -868,20 +865,24 @@ def _filter_db(engine, table, meta, controls, args, source='select', id=[]):
query = _filter_db_col(query, query.where, key, col, op, vals,
cols[col], cols[col].type.python_type, meta)
elif source == 'update':
# Update values should only contain 1 value. 2nd onwards are ignored
if key not in cols or len(vals) == 0:
meta['ignored'].append((key, vals))
else:
cols_for_update[key] = vals[0]
if len(vals) > 1:
meta['ignored'].append((key, vals[1:]))
cols_for_update[key] = vals
else:
meta['ignored'].append((key, vals))
if source == 'delete':
res = engine.execute(query)
return res.rowcount
elif source == 'update':
query = query.values(cols_for_update)
id_name = id[0]
id_col = getattr(table.c, id_name)
cases = {
k: sa.case(
[(id_col == i, j) for i, j in zip(args[id_name], v)]
) for k, v in cols_for_update.items()
}
query = query.values(**cases)
res = engine.execute(query)
return res.rowcount
else:
Expand Down
43 changes: 43 additions & 0 deletions testlib/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,49 @@ def test_update(self):
gramex.data.update(data, args=args, id=['देश', 'city', 'product'])
ase(types_original, data.dtypes)

def test_update_multiple_file(self):
# Test on a file

update_file = os.path.join(folder, 'actors.update.csv')
shutil.copy(os.path.join(folder, '..', 'tests/actors.csv'), update_file)
self.tmpfiles.append(update_file)

names = ['Humphrey Bogart', 'James Stewart', 'Audrey Hepburn']
categories = ['Stars', 'Thespians', 'Heartthrobs']
ratings = [1, 0.99, 1.11]
gramex.data.update(
update_file,
args={
'name': names,
'category': categories,
'rating': ratings
}, id=['name']
)
df = gramex.data.filter(update_file, args={'name': names})
self.assertEqual(df['category'].tolist(), categories)
self.assertEqual(df['rating'].tolist(), ratings)

def test_update_multiple_db(self):
actors = gramex.cache.open(os.path.join(folder, '../tests/actors.csv'))
temp_db = f'sqlite:///{folder}/actors.db'
self.tmpfiles.append(os.path.join(folder, 'actors.db'))
actors.to_sql('actors', sa.create_engine(temp_db), index=False)

names = ['Humphrey Bogart', 'James Stewart', 'Audrey Hepburn']
categories = ['Stars', 'Thespians', 'Heartthrobs']
ratings = [1, 0.99, 1.11]
gramex.data.update(
temp_db,
args={
'name': names,
'category': categories,
'rating': ratings
}, id=['name'], table='actors'
)
df = gramex.data.filter(temp_db, args={'name': names}, table='actors')
self.assertEqual(df['category'].tolist(), categories)
self.assertEqual(df['rating'].tolist(), ratings)

def test_delete(self):
raise SkipTest('TODO: write delete test cases')

Expand Down

0 comments on commit c6778d8

Please sign in to comment.