Skip to content

Commit

Permalink
Refactor GRANT/REVOKE in Hive
Browse files Browse the repository at this point in the history
Leverage newly introduced method for recursive role grants traversal

Extracted-From: prestodb/presto#10904
  • Loading branch information
Andrii Rosa authored and sopel39 committed Jan 29, 2019
1 parent 3dca463 commit cb2da7b
Show file tree
Hide file tree
Showing 19 changed files with 246 additions and 608 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import io.prestosql.plugin.hive.metastore.SortingColumn;
import io.prestosql.plugin.hive.metastore.StorageFormat;
import io.prestosql.plugin.hive.metastore.Table;
import io.prestosql.plugin.hive.metastore.thrift.ThriftMetastoreUtil;
import io.prestosql.plugin.hive.statistics.HiveStatisticsProvider;
import io.prestosql.spi.PrestoException;
import io.prestosql.spi.StandardErrorCode;
Expand Down Expand Up @@ -182,6 +183,7 @@
import static io.prestosql.plugin.hive.metastore.MetastoreUtil.verifyOnline;
import static io.prestosql.plugin.hive.metastore.StorageFormat.VIEW_STORAGE_FORMAT;
import static io.prestosql.plugin.hive.metastore.StorageFormat.fromHiveStorageFormat;
import static io.prestosql.plugin.hive.metastore.thrift.ThriftMetastoreUtil.listApplicableTablePrivileges;
import static io.prestosql.plugin.hive.util.ConfigurationUtils.toJobConf;
import static io.prestosql.plugin.hive.util.Statistics.ReduceOperator.ADD;
import static io.prestosql.plugin.hive.util.Statistics.createComputedStatisticsToPartitionMap;
Expand Down Expand Up @@ -1796,7 +1798,7 @@ public void revokeRoles(ConnectorSession session, Set<String> roles, Set<PrestoP
@Override
public Set<RoleGrant> listApplicableRoles(ConnectorSession session, PrestoPrincipal principal)
{
return metastore.listApplicableRoles(principal);
return ThriftMetastoreUtil.listApplicableRoles(principal, metastore::listRoleGrants);
}

@Override
Expand Down Expand Up @@ -1830,10 +1832,11 @@ public List<GrantInfo> listTablePrivileges(ConnectorSession session, SchemaTable
{
ImmutableList.Builder<GrantInfo> grantInfos = ImmutableList.builder();
for (SchemaTableName tableName : listTables(session, schemaTablePrefix)) {
Set<PrivilegeInfo> privileges = metastore.getTablePrivileges(session.getUser(), tableName.getSchemaName(), tableName.getTableName()).stream()
.map(HivePrivilegeInfo::toPrivilegeInfo)
.flatMap(Set::stream)
.collect(toImmutableSet());
Set<PrivilegeInfo> privileges =
listApplicableTablePrivileges(metastore, tableName.getSchemaName(), tableName.getTableName(), new PrestoPrincipal(USER, session.getUser())).stream()
.map(HivePrivilegeInfo::toPrivilegeInfo)
.flatMap(Set::stream)
.collect(toImmutableSet());

grantInfos.add(
new GrantInfo(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
import static io.prestosql.plugin.hive.metastore.HivePartitionName.hivePartitionName;
import static io.prestosql.plugin.hive.metastore.HiveTableName.hiveTableName;
import static io.prestosql.plugin.hive.metastore.PartitionFilter.partitionFilter;
import static io.prestosql.spi.security.PrincipalType.USER;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.MILLISECONDS;

Expand All @@ -83,8 +84,7 @@ public class CachingHiveMetastore
private final LoadingCache<HivePartitionName, Optional<Partition>> partitionCache;
private final LoadingCache<PartitionFilter, Optional<List<String>>> partitionFilterCache;
private final LoadingCache<HiveTableName, Optional<List<String>>> partitionNamesCache;
private final LoadingCache<String, Set<String>> userRolesCache;
private final LoadingCache<UserTableKey, Set<HivePrivilegeInfo>> userTablePrivileges;
private final LoadingCache<UserTableKey, Set<HivePrivilegeInfo>> tablePrivilegesCache;
private final LoadingCache<String, Set<String>> rolesCache;
private final LoadingCache<PrestoPrincipal, Set<RoleGrant>> roleGrantsCache;

Expand Down Expand Up @@ -187,11 +187,8 @@ public Map<HivePartitionName, Optional<Partition>> loadAll(Iterable<? extends Hi
}
}, executor));

userRolesCache = newCacheBuilder(expiresAfterWriteMillis, refreshMills, maximumSize)
.build(asyncReloading(CacheLoader.from(user -> loadRoles(user)), executor));

userTablePrivileges = newCacheBuilder(expiresAfterWriteMillis, refreshMills, maximumSize)
.build(asyncReloading(CacheLoader.from(key -> loadTablePrivileges(key.getUser(), key.getDatabase(), key.getTable())), executor));
tablePrivilegesCache = newCacheBuilder(expiresAfterWriteMillis, refreshMills, maximumSize)
.build(asyncReloading(CacheLoader.from(key -> loadTablePrivileges(key.getDatabase(), key.getTable(), key.getPrincipal())), executor));

rolesCache = newCacheBuilder(expiresAfterWriteMillis, refreshMills, maximumSize)
.build(asyncReloading(CacheLoader.from(() -> loadRoles()), executor));
Expand All @@ -211,10 +208,9 @@ public void flushCache()
tableCache.invalidateAll();
partitionCache.invalidateAll();
partitionFilterCache.invalidateAll();
userTablePrivileges.invalidateAll();
tablePrivilegesCache.invalidateAll();
tableStatisticsCache.invalidateAll();
partitionStatisticsCache.invalidateAll();
userRolesCache.invalidateAll();
rolesCache.invalidateAll();
}

Expand Down Expand Up @@ -504,9 +500,9 @@ protected void invalidateTable(String databaseName, String tableName)
tableCache.invalidate(hiveTableName(databaseName, tableName));
tableNamesCache.invalidate(databaseName);
viewNamesCache.invalidate(databaseName);
userTablePrivileges.asMap().keySet().stream()
tablePrivilegesCache.asMap().keySet().stream()
.filter(userTableKey -> userTableKey.matches(databaseName, tableName))
.forEach(userTablePrivileges::invalidate);
.forEach(tablePrivilegesCache::invalidate);
tableStatisticsCache.invalidate(hiveTableName(databaseName, tableName));
invalidatePartitionCache(databaseName, tableName);
}
Expand Down Expand Up @@ -631,7 +627,6 @@ public void createRole(String role, String grantor)
}
finally {
rolesCache.invalidateAll();
userRolesCache.invalidate(grantor);
}
}

Expand All @@ -643,7 +638,6 @@ public void dropRole(String role)
}
finally {
rolesCache.invalidateAll();
userRolesCache.invalidateAll();
roleGrantsCache.invalidateAll();
}
}
Expand Down Expand Up @@ -707,42 +701,14 @@ private void invalidatePartitionCache(String databaseName, String tableName)
.forEach(partitionStatisticsCache::invalidate);
}

@Override
public Set<String> getRoles(String user)
{
return get(userRolesCache, user);
}

private Set<String> loadRoles(String user)
{
return delegate.getRoles(user);
}

@Override
public Set<HivePrivilegeInfo> getDatabasePrivileges(String user, String databaseName)
{
return delegate.getDatabasePrivileges(user, databaseName);
}

@Override
public Set<HivePrivilegeInfo> getTablePrivileges(String user, String databaseName, String tableName)
{
return get(userTablePrivileges, new UserTableKey(user, tableName, databaseName));
}

private Set<HivePrivilegeInfo> loadTablePrivileges(String user, String databaseName, String tableName)
{
return delegate.getTablePrivileges(user, databaseName, tableName);
}

@Override
public void grantTablePrivileges(String databaseName, String tableName, String grantee, Set<HivePrivilegeInfo> privileges)
{
try {
delegate.grantTablePrivileges(databaseName, tableName, grantee, privileges);
}
finally {
userTablePrivileges.invalidate(new UserTableKey(grantee, tableName, databaseName));
tablePrivilegesCache.invalidate(new UserTableKey(new PrestoPrincipal(USER, grantee), databaseName, tableName));
}
}

Expand All @@ -753,10 +719,21 @@ public void revokeTablePrivileges(String databaseName, String tableName, String
delegate.revokeTablePrivileges(databaseName, tableName, grantee, privileges);
}
finally {
userTablePrivileges.invalidate(new UserTableKey(grantee, tableName, databaseName));
tablePrivilegesCache.invalidate(new UserTableKey(new PrestoPrincipal(USER, grantee), databaseName, tableName));
}
}

@Override
public Set<HivePrivilegeInfo> listTablePrivileges(String databaseName, String tableName, PrestoPrincipal principal)
{
return get(tablePrivilegesCache, new UserTableKey(principal, databaseName, tableName));
}

public Set<HivePrivilegeInfo> loadTablePrivileges(String databaseName, String tableName, PrestoPrincipal principal)
{
return delegate.listTablePrivileges(databaseName, tableName, principal);
}

private static CacheBuilder<Object, Object> newCacheBuilder(OptionalLong expiresAfterWriteMillis, OptionalLong refreshMillis, long maximumSize)
{
CacheBuilder<Object, Object> cacheBuilder = CacheBuilder.newBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,9 @@ public interface ExtendedHiveMetastore

Set<RoleGrant> listRoleGrants(PrestoPrincipal principal);

Set<String> getRoles(String user);

Set<HivePrivilegeInfo> getDatabasePrivileges(String user, String databaseName);

Set<HivePrivilegeInfo> getTablePrivileges(String user, String databaseName, String tableName);

void grantTablePrivileges(String databaseName, String tableName, String grantee, Set<HivePrivilegeInfo> privileges);

void revokeTablePrivileges(String databaseName, String tableName, String grantee, Set<HivePrivilegeInfo> privileges);

Set<HivePrivilegeInfo> listTablePrivileges(String databaseName, String tableName, PrestoPrincipal principal);
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,7 @@ public class RecordingHiveMetastore
private final Cache<HiveTableName, Optional<List<String>>> partitionNamesCache;
private final Cache<PartitionFilter, Optional<List<String>>> partitionNamesByPartsCache;
private final Cache<Set<HivePartitionName>, Map<String, Optional<Partition>>> partitionsByNamesCache;
private final Cache<String, Set<String>> rolesCache;
private final Cache<UserDatabaseKey, Set<HivePrivilegeInfo>> databasePrivilegesCache;
private final Cache<UserTableKey, Set<HivePrivilegeInfo>> tablePrivilegesCache;
private final Cache<UserTableKey, Set<HivePrivilegeInfo>> listTablePrivilegesCache;
private final Cache<String, Set<String>> listRolesCache;
private final Cache<PrestoPrincipal, Set<RoleGrant>> listRoleGrantsCache;

Expand All @@ -100,9 +98,7 @@ public RecordingHiveMetastore(@ForRecordingHiveMetastore ExtendedHiveMetastore d
partitionNamesCache = createCache(hiveClientConfig);
partitionNamesByPartsCache = createCache(hiveClientConfig);
partitionsByNamesCache = createCache(hiveClientConfig);
rolesCache = createCache(hiveClientConfig);
databasePrivilegesCache = createCache(hiveClientConfig);
tablePrivilegesCache = createCache(hiveClientConfig);
listTablePrivilegesCache = createCache(hiveClientConfig);
listRolesCache = createCache(hiveClientConfig);
listRoleGrantsCache = createCache(hiveClientConfig);

Expand All @@ -129,9 +125,7 @@ void loadRecording()
partitionNamesCache.putAll(toMap(recording.getPartitionNames()));
partitionNamesByPartsCache.putAll(toMap(recording.getPartitionNamesByParts()));
partitionsByNamesCache.putAll(toMap(recording.getPartitionsByNames()));
rolesCache.putAll(toMap(recording.getRoles()));
databasePrivilegesCache.putAll(toMap(recording.getDatabasePrivileges()));
tablePrivilegesCache.putAll(toMap(recording.getTablePrivileges()));
listTablePrivilegesCache.putAll(toMap(recording.getListTablePrivileges()));
listRolesCache.putAll(toMap(recording.getListRoles()));
listRoleGrantsCache.putAll(toMap(recording.getListRoleGrants()));
}
Expand Down Expand Up @@ -169,9 +163,7 @@ public void writeRecording()
toPairs(partitionNamesCache),
toPairs(partitionNamesByPartsCache),
toPairs(partitionsByNamesCache),
toPairs(rolesCache),
toPairs(databasePrivilegesCache),
toPairs(tablePrivilegesCache),
toPairs(listTablePrivilegesCache),
toPairs(listRolesCache),
toPairs(listRoleGrantsCache));
new ObjectMapperProvider().get()
Expand Down Expand Up @@ -394,27 +386,12 @@ public void alterPartition(String databaseName, String tableName, PartitionWithS
}

@Override
public Set<String> getRoles(String user)
{
return loadValue(rolesCache, user, () -> delegate.getRoles(user));
}

@Override
public Set<HivePrivilegeInfo> getDatabasePrivileges(String user, String databaseName)
{
return loadValue(
databasePrivilegesCache,
new UserDatabaseKey(user, databaseName),
() -> delegate.getDatabasePrivileges(user, databaseName));
}

@Override
public Set<HivePrivilegeInfo> getTablePrivileges(String user, String databaseName, String tableName)
public Set<HivePrivilegeInfo> listTablePrivileges(String databaseName, String tableName, PrestoPrincipal principal)
{
return loadValue(
tablePrivilegesCache,
new UserTableKey(user, databaseName, tableName),
() -> delegate.getTablePrivileges(user, databaseName, tableName));
listTablePrivilegesCache,
new UserTableKey(principal, databaseName, tableName),
() -> delegate.listTablePrivileges(databaseName, tableName, principal));
}

@Override
Expand Down Expand Up @@ -518,9 +495,7 @@ public static class Recording
private final List<Pair<HiveTableName, Optional<List<String>>>> partitionNames;
private final List<Pair<PartitionFilter, Optional<List<String>>>> partitionNamesByParts;
private final List<Pair<Set<HivePartitionName>, Map<String, Optional<Partition>>>> partitionsByNames;
private final List<Pair<String, Set<String>>> roles;
private final List<Pair<UserDatabaseKey, Set<HivePrivilegeInfo>>> databasePrivileges;
private final List<Pair<UserTableKey, Set<HivePrivilegeInfo>>> tablePrivileges;
private final List<Pair<UserTableKey, Set<HivePrivilegeInfo>>> listTablePrivileges;
private final List<Pair<String, Set<String>>> listRoles;
private final List<Pair<PrestoPrincipal, Set<RoleGrant>>> listRoleGrants;

Expand All @@ -538,9 +513,7 @@ public Recording(
@JsonProperty("partitionNames") List<Pair<HiveTableName, Optional<List<String>>>> partitionNames,
@JsonProperty("partitionNamesByParts") List<Pair<PartitionFilter, Optional<List<String>>>> partitionNamesByParts,
@JsonProperty("partitionsByNames") List<Pair<Set<HivePartitionName>, Map<String, Optional<Partition>>>> partitionsByNames,
@JsonProperty("roles") List<Pair<String, Set<String>>> roles,
@JsonProperty("databasePrivileges") List<Pair<UserDatabaseKey, Set<HivePrivilegeInfo>>> databasePrivileges,
@JsonProperty("tablePrivileges") List<Pair<UserTableKey, Set<HivePrivilegeInfo>>> tablePrivileges,
@JsonProperty("listTablePrivileges") List<Pair<UserTableKey, Set<HivePrivilegeInfo>>> listTablePrivileges,
@JsonProperty("listRoles") List<Pair<String, Set<String>>> listRoles,
@JsonProperty("listRoleGrants") List<Pair<PrestoPrincipal, Set<RoleGrant>>> listRoleGrants)
{
Expand All @@ -556,9 +529,7 @@ public Recording(
this.partitionNames = partitionNames;
this.partitionNamesByParts = partitionNamesByParts;
this.partitionsByNames = partitionsByNames;
this.roles = roles;
this.databasePrivileges = databasePrivileges;
this.tablePrivileges = tablePrivileges;
this.listTablePrivileges = listTablePrivileges;
this.listRoles = listRoles;
this.listRoleGrants = listRoleGrants;
}
Expand Down Expand Up @@ -636,21 +607,9 @@ public List<Pair<Set<HivePartitionName>, Map<String, Optional<Partition>>>> getP
}

@JsonProperty
public List<Pair<String, Set<String>>> getRoles()
{
return roles;
}

@JsonProperty
public List<Pair<UserDatabaseKey, Set<HivePrivilegeInfo>>> getDatabasePrivileges()
{
return databasePrivileges;
}

@JsonProperty
public List<Pair<UserTableKey, Set<HivePrivilegeInfo>>> getTablePrivileges()
public List<Pair<UserTableKey, Set<HivePrivilegeInfo>>> getListTablePrivileges()
{
return tablePrivileges;
return listTablePrivileges;
}

@JsonProperty
Expand Down
Loading

0 comments on commit cb2da7b

Please sign in to comment.