diff --git a/gramex/data.py b/gramex/data.py index c3cdbfbe3..6c8cb1293 100644 --- a/gramex/data.py +++ b/gramex/data.py @@ -741,13 +741,10 @@ def _filter_frame(data, meta, controls, args, source='select', id=[]): # Apply filters data = _filter_frame_col(data, key, col, op, vals, meta) elif source == 'update': - # Update values should only contain 1 value. 2nd onwards are ignored if key not in data.columns or len(vals) == 0: meta['ignored'].append((key, vals)) else: - cols_for_update[key] = vals[0] - if len(vals) > 1: - meta['ignored'].append((key, vals[1:])) + cols_for_update[key] = vals else: meta['ignored'].append((key, vals)) meta['count'] = len(data) @@ -868,20 +865,24 @@ def _filter_db(engine, table, meta, controls, args, source='select', id=[]): query = _filter_db_col(query, query.where, key, col, op, vals, cols[col], cols[col].type.python_type, meta) elif source == 'update': - # Update values should only contain 1 value. 2nd onwards are ignored if key not in cols or len(vals) == 0: meta['ignored'].append((key, vals)) else: - cols_for_update[key] = vals[0] - if len(vals) > 1: - meta['ignored'].append((key, vals[1:])) + cols_for_update[key] = vals else: meta['ignored'].append((key, vals)) if source == 'delete': res = engine.execute(query) return res.rowcount elif source == 'update': - query = query.values(cols_for_update) + id_name = id[0] + id_col = getattr(table.c, id_name) + cases = { + k: sa.case( + [(id_col == i, j) for i, j in zip(args[id_name], v)] + ) for k, v in cols_for_update.items() + } + query = query.values(**cases) res = engine.execute(query) return res.rowcount else: diff --git a/testlib/test_data.py b/testlib/test_data.py index b38bf49b4..fd7e15b8d 100644 --- a/testlib/test_data.py +++ b/testlib/test_data.py @@ -597,6 +597,49 @@ def test_update(self): gramex.data.update(data, args=args, id=['देश', 'city', 'product']) ase(types_original, data.dtypes) + def test_update_multiple_file(self): + # Test on a file + + update_file = os.path.join(folder, 'actors.update.csv') + shutil.copy(os.path.join(folder, '..', 'tests/actors.csv'), update_file) + self.tmpfiles.append(update_file) + + names = ['Humphrey Bogart', 'James Stewart', 'Audrey Hepburn'] + categories = ['Stars', 'Thespians', 'Heartthrobs'] + ratings = [1, 0.99, 1.11] + gramex.data.update( + update_file, + args={ + 'name': names, + 'category': categories, + 'rating': ratings + }, id=['name'] + ) + df = gramex.data.filter(update_file, args={'name': names}) + self.assertEqual(df['category'].tolist(), categories) + self.assertEqual(df['rating'].tolist(), ratings) + + def test_update_multiple_db(self): + actors = gramex.cache.open(os.path.join(folder, '../tests/actors.csv')) + temp_db = f'sqlite:///{folder}/actors.db' + self.tmpfiles.append(os.path.join(folder, 'actors.db')) + actors.to_sql('actors', sa.create_engine(temp_db), index=False) + + names = ['Humphrey Bogart', 'James Stewart', 'Audrey Hepburn'] + categories = ['Stars', 'Thespians', 'Heartthrobs'] + ratings = [1, 0.99, 1.11] + gramex.data.update( + temp_db, + args={ + 'name': names, + 'category': categories, + 'rating': ratings + }, id=['name'], table='actors' + ) + df = gramex.data.filter(temp_db, args={'name': names}, table='actors') + self.assertEqual(df['category'].tolist(), categories) + self.assertEqual(df['rating'].tolist(), ratings) + def test_delete(self): raise SkipTest('TODO: write delete test cases')