Skip to content

Commit

Permalink
Test AuthorizeReturnObject in Reactive
Browse files Browse the repository at this point in the history
  • Loading branch information
jzheaux committed Sep 12, 2024
1 parent 91053c5 commit 86ef0b6
Showing 1 changed file with 222 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@

import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
Expand All @@ -34,12 +38,18 @@
import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.access.PermissionEvaluator;
import org.springframework.security.access.expression.method.DefaultMethodSecurityExpressionHandler;
import org.springframework.security.access.expression.method.MethodSecurityExpressionHandler;
import org.springframework.security.access.hierarchicalroles.RoleHierarchy;
import org.springframework.security.access.hierarchicalroles.RoleHierarchyImpl;
import org.springframework.security.access.prepost.PostAuthorize;
import org.springframework.security.access.prepost.PostFilter;
import org.springframework.security.access.prepost.PreAuthorize;
import org.springframework.security.access.prepost.PreFilter;
import org.springframework.security.authorization.AuthorizationDeniedException;
import org.springframework.security.authorization.method.AuthorizationAdvisorProxyFactory;
import org.springframework.security.authorization.method.AuthorizeReturnObject;
import org.springframework.security.authorization.method.PrePostTemplateDefaults;
import org.springframework.security.config.Customizer;
import org.springframework.security.config.test.SpringTestContext;
import org.springframework.security.config.test.SpringTestContextExtension;
import org.springframework.security.core.annotation.AnnotationTemplateExpressionDefaults;
Expand All @@ -49,6 +59,7 @@

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatNoException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.given;
Expand Down Expand Up @@ -320,6 +331,84 @@ public void methodWhenPostFilterMetaAnnotationThenFilters(Class<?> config) {
.containsExactly("dave");
}

@Test
@WithMockUser(authorities = "airplane:read")
public void findByIdWhenAuthorizedResultThenAuthorizes() {
this.spring.register(AuthorizeResultConfig.class).autowire();
FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
Flight flight = flights.findById("1").block();
assertThatNoException().isThrownBy(flight::getAltitude);
assertThatNoException().isThrownBy(flight::getSeats);
}

@Test
@WithMockUser(authorities = "seating:read")
public void findByIdWhenUnauthorizedResultThenDenies() {
this.spring.register(AuthorizeResultConfig.class).autowire();
FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
Flight flight = flights.findById("1").block();
assertThatNoException().isThrownBy(flight::getSeats);
assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(() -> flight.getAltitude().block());
}

@Test
@WithMockUser(authorities = "seating:read")
public void findAllWhenUnauthorizedResultThenDenies() {
this.spring.register(AuthorizeResultConfig.class).autowire();
FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
flights.findAll().collectList().block().forEach((flight) -> {
assertThatNoException().isThrownBy(flight::getSeats);
assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(() -> flight.getAltitude().block());
});
}

@Test
public void removeWhenAuthorizedResultThenRemoves() {
this.spring.register(AuthorizeResultConfig.class).autowire();
FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
flights.remove("1");
}

@Test
@WithMockUser(authorities = "airplane:read")
public void findAllWhenPostFilterThenFilters() {
this.spring.register(AuthorizeResultConfig.class).autowire();
FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
flights.findAll()
.collectList()
.block()
.forEach((flight) -> assertThat(flight.getPassengers().collectList().block())
.extracting((p) -> p.getName().block())
.doesNotContain("Kevin Mitnick"));
}

@Test
@WithMockUser(authorities = "airplane:read")
public void findAllWhenPreFilterThenFilters() {
this.spring.register(AuthorizeResultConfig.class).autowire();
FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
flights.findAll().collectList().block().forEach((flight) -> {
flight.board(Flux.just("John")).block();
assertThat(flight.getPassengers().collectList().block()).extracting((p) -> p.getName().block())
.doesNotContain("John");
flight.board(Flux.just("John Doe")).block();
assertThat(flight.getPassengers().collectList().block()).extracting((p) -> p.getName().block())
.contains("John Doe");
});
}

@Test
@WithMockUser(authorities = "seating:read")
public void findAllWhenNestedPreAuthorizeThenAuthorizes() {
this.spring.register(AuthorizeResultConfig.class).autowire();
FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
flights.findAll().collectList().block().forEach((flight) -> {
List<Passenger> passengers = flight.getPassengers().collectList().block();
passengers.forEach((passenger) -> assertThatExceptionOfType(AccessDeniedException.class)
.isThrownBy(() -> passenger.getName().block()));
});
}

@Configuration
@EnableReactiveMethodSecurity
static class MethodSecurityServiceEnabledConfig {
Expand Down Expand Up @@ -484,4 +573,137 @@ static class EntityClass {

}

@EnableReactiveMethodSecurity
@Configuration
public static class AuthorizeResultConfig {

@Bean
@Role(BeanDefinition.ROLE_INFRASTRUCTURE)
static Customizer<AuthorizationAdvisorProxyFactory> skipValueTypes() {
return (f) -> f.setTargetVisitor(AuthorizationAdvisorProxyFactory.TargetVisitor.defaultsSkipValueTypes());
}

@Bean
FlightRepository flights() {
FlightRepository flights = new FlightRepository();
Flight one = new Flight("1", 35000d, 35);
one.board(Flux.just("Marie Curie", "Kevin Mitnick", "Ada Lovelace")).block();
flights.save(one).block();
Flight two = new Flight("2", 32000d, 72);
two.board(Flux.just("Albert Einstein")).block();
flights.save(two).block();
return flights;
}

@Bean
static MethodSecurityExpressionHandler expressionHandler() {
RoleHierarchy hierarchy = RoleHierarchyImpl.withRolePrefix("")
.role("airplane:read")
.implies("seating:read")
.build();
DefaultMethodSecurityExpressionHandler expressionHandler = new DefaultMethodSecurityExpressionHandler();
expressionHandler.setRoleHierarchy(hierarchy);
return expressionHandler;
}

@Bean
Authz authz() {
return new Authz();
}

public static class Authz {

public Mono<Boolean> isNotKevinMitnick(Passenger passenger) {
return passenger.getName().map((n) -> !"Kevin Mitnick".equals(n));
}

}

}

@AuthorizeReturnObject
static class FlightRepository {

private final Map<String, Flight> flights = new ConcurrentHashMap<>();

Flux<Flight> findAll() {
return Flux.fromIterable(this.flights.values());
}

Mono<Flight> findById(String id) {
return Mono.just(this.flights.get(id));
}

Mono<Flight> save(Flight flight) {
this.flights.put(flight.getId(), flight);
return Mono.just(flight);
}

Mono<Void> remove(String id) {
this.flights.remove(id);
return Mono.empty();
}

}

@AuthorizeReturnObject
static class Flight {

private final String id;

private final Double altitude;

private final Integer seats;

private final List<Passenger> passengers = new ArrayList<>();

Flight(String id, Double altitude, Integer seats) {
this.id = id;
this.altitude = altitude;
this.seats = seats;
}

String getId() {
return this.id;
}

@PreAuthorize("hasAuthority('airplane:read')")
Mono<Double> getAltitude() {
return Mono.just(this.altitude);
}

@PreAuthorize("hasAuthority('seating:read')")
Mono<Integer> getSeats() {
return Mono.just(this.seats);
}

@PostAuthorize("hasAuthority('seating:read')")
@PostFilter("@authz.isNotKevinMitnick(filterObject)")
Flux<Passenger> getPassengers() {
return Flux.fromIterable(this.passengers);
}

@PreAuthorize("hasAuthority('seating:read')")
@PreFilter("filterObject.contains(' ')")
Mono<Void> board(Flux<String> passengers) {
return passengers.doOnNext((passenger) -> this.passengers.add(new Passenger(passenger))).then(Mono.empty());
}

}

public static class Passenger {

String name;

public Passenger(String name) {
this.name = name;
}

@PreAuthorize("hasAuthority('airplane:read')")
public Mono<String> getName() {
return Mono.just(this.name);
}

}

}

0 comments on commit 86ef0b6

Please sign in to comment.