diff --git a/enterprise/admin/__init__.py b/enterprise/admin/__init__.py
index f2dd11d9c6..f27b7f8763 100644
--- a/enterprise/admin/__init__.py
+++ b/enterprise/admin/__init__.py
@@ -41,30 +41,14 @@
)
from enterprise.api_client.lms import CourseApiClient, EnrollmentApiClient
from enterprise.config.models import UpdateRoleAssignmentsWithCustomersConfig
-from enterprise.models import (
- AdminNotification,
- AdminNotificationFilter,
- AdminNotificationRead,
- ChatGPTResponse,
- EnrollmentNotificationEmailTemplate,
- EnterpriseCatalogQuery,
- EnterpriseCourseEnrollment,
- EnterpriseCustomer,
- EnterpriseCustomerBrandingConfiguration,
- EnterpriseCustomerCatalog,
- EnterpriseCustomerIdentityProvider,
- EnterpriseCustomerInviteKey,
- EnterpriseCustomerReportingConfiguration,
- EnterpriseCustomerType,
- EnterpriseCustomerUser,
- EnterpriseFeatureUserRoleAssignment,
- PendingEnrollment,
- PendingEnterpriseCustomerAdminUser,
- PendingEnterpriseCustomerUser,
- SystemWideEnterpriseUserRoleAssignment,
-)
-from enterprise.utils import discovery_query_url, get_all_field_names, get_default_catalog_content_filter
+from enterprise import models as ent_models
+from enterprise.utils import (
+ discovery_query_url,
+ localized_utcnow,
+ get_all_field_names,
+ get_default_catalog_content_filter
+)
try:
from enterprise.api_client.enterprise_catalog import EnterpriseCatalogApiClient
except ImportError:
@@ -88,7 +72,7 @@ class EnterpriseCustomerBrandingConfigurationInline(admin.StackedInline):
https://docs.djangoproject.com/en/1.8/ref/contrib/admin/#django.contrib.admin.StackedInline
"""
- model = EnterpriseCustomerBrandingConfiguration
+ model = ent_models.EnterpriseCustomerBrandingConfiguration
can_delete = False
@@ -100,7 +84,7 @@ class EnterpriseCustomerIdentityProviderInline(admin.StackedInline):
https://docs.djangoproject.com/en/1.8/ref/contrib/admin/#django.contrib.admin.StackedInline
"""
- model = EnterpriseCustomerIdentityProvider
+ model = ent_models.EnterpriseCustomerIdentityProvider
form = EnterpriseCustomerIdentityProviderAdminForm
extra = 0
@@ -112,7 +96,7 @@ class EnterpriseCustomerCatalogInline(admin.TabularInline):
https://docs.djangoproject.com/en/1.8/ref/contrib/admin/#django.contrib.admin.StackedInline
"""
- model = EnterpriseCustomerCatalog
+ model = ent_models.EnterpriseCustomerCatalog
form = EnterpriseCustomerCatalogAdminForm
extra = 0
can_delete = False
@@ -128,7 +112,7 @@ class PendingEnterpriseCustomerAdminUserInline(admin.TabularInline):
Django admin inline model for PendingEnterpriseCustomerAdminUser.
"""
- model = PendingEnterpriseCustomerAdminUser
+ model = ent_models.PendingEnterpriseCustomerAdminUser
extra = 0
fieldsets = (
(None, {
@@ -149,14 +133,14 @@ def get_admin_registration_url(self, obj):
return format_html('{0}'.format(obj.admin_registration_url))
-@admin.register(EnterpriseCustomerType)
+@admin.register(ent_models.EnterpriseCustomerType)
class EnterpriseCustomerTypeAdmin(admin.ModelAdmin):
"""
Django admin model for EnterpriseCustomerType.
"""
class Meta:
- model = EnterpriseCustomerType
+ model = ent_models.EnterpriseCustomerType
fields = (
'name',
@@ -166,7 +150,7 @@ class Meta:
search_fields = ('name', )
-@admin.register(EnterpriseCustomer)
+@admin.register(ent_models.EnterpriseCustomer)
class EnterpriseCustomerAdmin(DjangoObjectActions, SimpleHistoryAdmin):
"""
Django admin model for EnterpriseCustomer.
@@ -252,7 +236,7 @@ class EnterpriseCustomerAdmin(DjangoObjectActions, SimpleHistoryAdmin):
form = EnterpriseCustomerAdminForm
class Meta:
- model = EnterpriseCustomer
+ model = ent_models.EnterpriseCustomer
def get_search_results(self, request, queryset, search_term):
original_queryset = queryset
@@ -395,14 +379,14 @@ def get_urls(self):
return customer_urls + super().get_urls()
-@admin.register(EnterpriseCustomerUser)
+@admin.register(ent_models.EnterpriseCustomerUser)
class EnterpriseCustomerUserAdmin(admin.ModelAdmin):
"""
Django admin model for EnterpriseCustomerUser.
"""
class Meta:
- model = EnterpriseCustomerUser
+ model = ent_models.EnterpriseCustomerUser
fields = (
'user_id',
@@ -444,7 +428,7 @@ def get_search_results(self, request, queryset, search_term):
use_distinct = False
if search_term:
- queryset = EnterpriseCustomerUser.objects.filter(
+ queryset = ent_models.EnterpriseCustomerUser.objects.filter(
user_id__in=User.objects.filter(
Q(email__icontains=search_term) | Q(username__icontains=search_term)
)
@@ -493,7 +477,9 @@ def _get_enterprise_course_enrollments(self, enterprise_customer_user):
enterprise_customer_user: The instance of EnterpriseCustomerUser
being rendered with this admin form.
"""
- enrollments = EnterpriseCourseEnrollment.objects.filter(enterprise_customer_user=enterprise_customer_user)
+ enrollments = ent_models.EnterpriseCourseEnrollment.objects.filter(
+ enterprise_customer_user=enterprise_customer_user
+ )
return [enrollment.course_id for enrollment in enrollments]
def _get_all_enrollments(self, enterprise_customer_user):
@@ -538,14 +524,14 @@ def get_enrolled_course_string(self, course_ids):
)
-@admin.register(PendingEnterpriseCustomerUser)
+@admin.register(ent_models.PendingEnterpriseCustomerUser)
class PendingEnterpriseCustomerUserAdmin(admin.ModelAdmin):
"""
Django admin model for PendingEnterpriseCustomerUser
"""
class Meta:
- model = PendingEnterpriseCustomerUser
+ model = ent_models.PendingEnterpriseCustomerUser
fields = (
'user_email',
@@ -560,14 +546,14 @@ class Meta:
)
-@admin.register(PendingEnterpriseCustomerAdminUser)
+@admin.register(ent_models.PendingEnterpriseCustomerAdminUser)
class PendingEnterpriseCustomerAdminUserAdmin(admin.ModelAdmin):
"""
Django admin model for PendingEnterpriseCustomerAdminUser
"""
class Meta:
- model = PendingEnterpriseCustomerAdminUser
+ model = ent_models.PendingEnterpriseCustomerAdminUser
fields = (
'user_email',
@@ -609,7 +595,7 @@ def get_admin_registration_url(self, obj):
return format_html('{0}'.format(obj.admin_registration_url))
-@admin.register(EnrollmentNotificationEmailTemplate)
+@admin.register(ent_models.EnrollmentNotificationEmailTemplate)
class EnrollmentNotificationEmailTemplateAdmin(DjangoObjectActions, admin.ModelAdmin):
"""
Django admin for EnrollmentNotificationEmailTemplate model
@@ -617,7 +603,7 @@ class EnrollmentNotificationEmailTemplateAdmin(DjangoObjectActions, admin.ModelA
change_actions = ("preview_as_course", "preview_as_program")
class Meta:
- model = EnrollmentNotificationEmailTemplate
+ model = ent_models.EnrollmentNotificationEmailTemplate
def get_urls(self):
"""
@@ -667,14 +653,14 @@ def preview_as_program(self, request, obj):
preview_as_program.label = _("Preview (program)")
-@admin.register(EnterpriseCourseEnrollment)
+@admin.register(ent_models.EnterpriseCourseEnrollment)
class EnterpriseCourseEnrollmentAdmin(admin.ModelAdmin):
"""
Django admin model for EnterpriseCourseEnrollment
"""
class Meta:
- model = EnterpriseCourseEnrollment
+ model = ent_models.EnterpriseCourseEnrollment
readonly_fields = (
'enterprise_customer_user',
@@ -735,14 +721,14 @@ def get_urls(self):
return custom_urls + super().get_urls()
-@admin.register(PendingEnrollment)
+@admin.register(ent_models.PendingEnrollment)
class PendingEnrollmentAdmin(admin.ModelAdmin):
"""
Django admin model for PendingEnrollment
"""
class Meta:
- model = PendingEnrollment
+ model = ent_models.PendingEnrollment
readonly_fields = (
'user',
@@ -771,14 +757,14 @@ def has_delete_permission(self, request, obj=None):
return False
-@admin.register(EnterpriseCatalogQuery)
+@admin.register(ent_models.EnterpriseCatalogQuery)
class EnterpriseCatalogQueryAdmin(admin.ModelAdmin):
"""
Django admin model for EnterpriseCatalogQuery.
"""
class Meta:
- model = EnterpriseCatalogQuery
+ model = ent_models.EnterpriseCatalogQuery
def get_urls(self):
"""
@@ -818,7 +804,7 @@ def has_delete_permission(self, request, obj=None):
readonly_fields = ('discovery_query_url', 'uuid')
-@admin.register(EnterpriseCustomerCatalog)
+@admin.register(ent_models.EnterpriseCustomerCatalog)
class EnterpriseCustomerCatalogAdmin(admin.ModelAdmin):
"""
Django admin model for EnterpriseCustomerCatalog.
@@ -827,7 +813,7 @@ class EnterpriseCustomerCatalogAdmin(admin.ModelAdmin):
actions = [refresh_catalog]
class Meta:
- model = EnterpriseCustomerCatalog
+ model = ent_models.EnterpriseCustomerCatalog
class Media:
js = ('enterprise/admin/enterprise_customer_catalog.js',)
@@ -895,7 +881,7 @@ def get_actions(self, request):
return actions
-@admin.register(EnterpriseCustomerReportingConfiguration)
+@admin.register(ent_models.EnterpriseCustomerReportingConfiguration)
class EnterpriseCustomerReportingConfigurationAdmin(admin.ModelAdmin):
"""
Django admin model for EnterpriseCustomerReportingConfiguration.
@@ -918,7 +904,7 @@ class EnterpriseCustomerReportingConfigurationAdmin(admin.ModelAdmin):
form = EnterpriseCustomerReportingConfigAdminForm
class Meta:
- model = EnterpriseCustomerReportingConfiguration
+ model = ent_models.EnterpriseCustomerReportingConfiguration
def get_fields(self, request, obj=None):
"""
@@ -966,7 +952,7 @@ def count(self): # pylint: disable=invalid-overridden-method
return self._whole_table_count
-@admin.register(SystemWideEnterpriseUserRoleAssignment)
+@admin.register(ent_models.SystemWideEnterpriseUserRoleAssignment)
class SystemWideEnterpriseUserRoleAssignmentAdmin(UserRoleAssignmentAdmin):
"""
Django admin model for SystemWideEnterpriseUserRoleAssignment.
@@ -992,10 +978,10 @@ class SystemWideEnterpriseUserRoleAssignmentAdmin(UserRoleAssignmentAdmin):
form = SystemWideEnterpriseUserRoleAssignmentForm
class Meta:
- model = SystemWideEnterpriseUserRoleAssignment
+ model = ent_models.SystemWideEnterpriseUserRoleAssignment
-@admin.register(EnterpriseFeatureUserRoleAssignment)
+@admin.register(ent_models.EnterpriseFeatureUserRoleAssignment)
class EnterpriseFeatureUserRoleAssignmentAdmin(UserRoleAssignmentAdmin):
"""
Django admin model for EnterpriseFeatureUserRoleAssignment.
@@ -1004,45 +990,45 @@ class EnterpriseFeatureUserRoleAssignmentAdmin(UserRoleAssignmentAdmin):
form = EnterpriseFeatureUserRoleAssignmentForm
class Meta:
- model = EnterpriseFeatureUserRoleAssignment
+ model = ent_models.EnterpriseFeatureUserRoleAssignment
admin.site.register(UpdateRoleAssignmentsWithCustomersConfig, ConfigurationModelAdmin)
-@admin.register(AdminNotificationRead)
+@admin.register(ent_models.AdminNotificationRead)
class AdminNotificationReadAdmin(admin.ModelAdmin):
"""
Django admin for AdminNotificationRead model.
"""
- model = AdminNotificationRead
+ model = ent_models.AdminNotificationRead
list_display = ('id', 'enterprise_customer_user', 'admin_notification', 'is_read', 'created', 'modified')
-@admin.register(AdminNotification)
+@admin.register(ent_models.AdminNotification)
class AdminNotificationAdmin(admin.ModelAdmin):
"""
Django admin for AdminNotification model.
"""
- model = AdminNotification
+ model = ent_models.AdminNotification
form = AdminNotificationForm
list_display = ('id', 'title', 'text', 'is_active', 'start_date', 'expiration_date', 'created', 'modified')
filter_horizontal = ('admin_notification_filter',)
-@admin.register(AdminNotificationFilter)
+@admin.register(ent_models.AdminNotificationFilter)
class AdminNotificationFilterAdmin(admin.ModelAdmin):
"""
- Django admin for AdminNotificationFilter model.
+ Django admin for models.AdminNotificationFilter model.
"""
- model = AdminNotificationFilter
+ model = ent_models.AdminNotificationFilter
list_display = ('id', 'filter', 'created', 'modified')
-@admin.register(EnterpriseCustomerInviteKey)
+@admin.register(ent_models.EnterpriseCustomerInviteKey)
class EnterpriseCustomerInviteKeyAdmin(admin.ModelAdmin):
"""
Django admin model for EnterpriseCustomerInviteKey.
@@ -1074,7 +1060,7 @@ class EnterpriseCustomerInviteKeyAdmin(admin.ModelAdmin):
)
class Meta:
- model = EnterpriseCustomerInviteKey
+ model = ent_models.EnterpriseCustomerInviteKey
def get_readonly_fields(self, request, obj=None):
readonly_fields = super().get_readonly_fields(request, obj=obj)
@@ -1085,12 +1071,36 @@ def get_readonly_fields(self, request, obj=None):
return readonly_fields
-@admin.register(ChatGPTResponse)
+@admin.register(ent_models.ChatGPTResponse)
class ChatGPTResponseAdmin(admin.ModelAdmin):
"""
Django admin for ChatGPTResponse model.
"""
- model = ChatGPTResponse
+ model = ent_models.ChatGPTResponse
list_display = ('uuid', 'enterprise_customer', 'prompt_hash', )
readonly_fields = ('prompt', 'response', 'prompt_hash', )
+
+
+@admin.register(ent_models.EnterpriseCustomerSsoConfiguration)
+class EnterpriseCustomerSsoConfigurationAdmin(DjangoObjectActions, admin.ModelAdmin):
+ """
+ Django admin for models.EnterpriseCustomerSsoConfigurationAdmin model.
+ """
+
+ model = ent_models.ChatGPTResponse
+ list_display = ('uuid', 'enterprise_customer', 'active', 'identity_provider', 'created', 'configured_at')
+ change_actions = ['mark_configured']
+
+ @admin.action(
+ description="Allows for marking a config as configured. This is useful for testing while the SSO" \
+ "orchestrator is under constructions.",
+ )
+ def mark_configured(self, request, obj):
+ """
+ Object tool handler method - marks the config as configured.
+ """
+ obj.configured_at = localized_utcnow()
+ obj.save()
+
+ mark_configured.label = "Mark as Configured"
diff --git a/enterprise/api/v1/views/enterprise_customer_sso_configuration.py b/enterprise/api/v1/views/enterprise_customer_sso_configuration.py
index a093fc7fce..3354d65b5f 100644
--- a/enterprise/api/v1/views/enterprise_customer_sso_configuration.py
+++ b/enterprise/api/v1/views/enterprise_customer_sso_configuration.py
@@ -2,6 +2,8 @@
Views for the ``enterprise-customer-sso-configuration`` API endpoint.
"""
+from xml.etree.ElementTree import fromstring
+import requests
from edx_rbac.decorators import permission_required
from rest_framework import permissions, viewsets
from rest_framework.decorators import action
@@ -41,6 +43,18 @@ class EnterpriseCustomerInactiveException(Exception):
"""
+class SsoConfigurationApiError(requests.exceptions.RequestException):
+ """
+ Exception raised when the Sso configuration api encounters an error while fetching provider metadata.
+ """
+
+
+class EntityIdNotFoundError(Exception):
+ """
+ Exception raised by the SSO configuration api when it fails to fetch a customer IDP's entity ID from the metadata.
+ """
+
+
def check_user_part_of_customer(user, enterprise_customer):
"""
Checks if a user is in an enterprise customer.
@@ -67,6 +81,28 @@ def fetch_configuration_record(kwargs):
return EnterpriseCustomerSsoConfiguration.all_objects.filter(pk=kwargs.get('configuration_uuid'))
+def get_metadata_xml_from_url(url):
+ """
+ Gets the metadata xml from the given url.
+ """
+ response = requests.get(url)
+ if response.status_code >= 300:
+ raise SsoConfigurationApiError(f'Error fetching metadata xml from provided url: {url}')
+ return response.text
+
+
+def fetch_entity_id_from_metadata_xml(metadata_xml):
+ """
+ Fetches the entity id from the metadata xml.
+ """
+ root = fromstring(metadata_xml)
+ if entity_id := root.get('entityID'):
+ return entity_id
+ if entity_descriptor_child := root.find('EntityDescriptor'):
+ return entity_descriptor_child.get('entityID')
+ raise EntityIdNotFoundError('Could not find entity ID in metadata xml')
+
+
class EnterpriseCustomerSsoConfigurationViewSet(viewsets.ModelViewSet):
"""
API views for the ``EnterpriseCustomerSsoConfiguration`` model.
@@ -177,6 +213,26 @@ def create(self, request, *args, **kwargs):
request_data['enterprise_customer'] = enterprise_customer
else:
return Response({'error': BAD_CUSTOMER_ERROR}, status=HTTP_400_BAD_REQUEST)
+
+ # Parse the request data to see if the metadata url or xml has changed and update the entity id if so
+ sso_config_metadata_xml = None
+ if request_metadata_url := request_data.get('metadata_url'):
+ # If the metadata url has changed, we need to update the metadata xml
+ try:
+ sso_config_metadata_xml = get_metadata_xml_from_url(request_metadata_url)
+ except SsoConfigurationApiError as e:
+ LOGGER.error(f'{CONFIG_UPDATE_ERROR}{e}')
+ return Response({'error': f'{CONFIG_UPDATE_ERROR} {e}'}, status=HTTP_400_BAD_REQUEST)
+ request_data['metadata_xml'] = sso_config_metadata_xml
+ if sso_config_metadata_xml or (sso_config_metadata_xml := request_data.get('metadata_xml')):
+ try:
+ entity_id = fetch_entity_id_from_metadata_xml(sso_config_metadata_xml)
+ except (EntityIdNotFoundError) as e:
+ LOGGER.error(f'{CONFIG_UPDATE_ERROR}{e}')
+ return Response({'error': f'{CONFIG_UPDATE_ERROR} {e}'}, status=HTTP_400_BAD_REQUEST)
+
+ request_data['entity_id'] = entity_id
+
try:
new_record = EnterpriseCustomerSsoConfiguration.objects.create(**request_data)
except TypeError as e:
@@ -206,8 +262,32 @@ def update(self, request, *args, **kwargs):
except EnterpriseCustomerInactiveException:
return Response(status=HTTP_403_FORBIDDEN)
+ # Parse the request data to see if the metadata url or xml has changed and update the entity id if so
+ request_data = request.data.dict()
+ sso_config_metadata_xml = None
+ if request_metadata_url := request_data.get('metadata_url'):
+ sso_config_metadata_url = sso_configuration_record.first().metadata_url
+ if request_metadata_url != sso_config_metadata_url:
+ # If the metadata url has changed, we need to update the metadata xml
+ try:
+ sso_config_metadata_xml = get_metadata_xml_from_url(request_metadata_url)
+ except SsoConfigurationApiError as e:
+ LOGGER.error(f'{CONFIG_UPDATE_ERROR} {e}')
+ return Response({'error': f'{CONFIG_UPDATE_ERROR} {e}'}, status=HTTP_400_BAD_REQUEST)
+ request_data['metadata_xml'] = sso_config_metadata_xml
+ if request_metadata_xml := request_data.get('metadata_xml'):
+ if request_metadata_xml != sso_configuration_record.first().metadata_xml:
+ sso_config_metadata_xml = request_data.get('metadata_xml')
+ if sso_config_metadata_xml:
+ try:
+ entity_id = fetch_entity_id_from_metadata_xml(sso_config_metadata_xml)
+ request_data['entity_id'] = entity_id
+ except (EntityIdNotFoundError) as e:
+ LOGGER.error(f'{CONFIG_UPDATE_ERROR}{e}')
+ return Response({'error': f'{CONFIG_UPDATE_ERROR} {e}'}, status=HTTP_400_BAD_REQUEST)
+
# If the request includes a customer uuid, ensure the new customer is valid
- if new_customer := request.data.dict().get('enterprise_customer'):
+ if new_customer := request_data.get('enterprise_customer'):
try:
enterprise_customer = EnterpriseCustomer.objects.get(uuid=new_customer)
except EnterpriseCustomer.DoesNotExist:
@@ -219,7 +299,7 @@ def update(self, request, *args, **kwargs):
return Response(status=HTTP_403_FORBIDDEN)
try:
with transaction.atomic():
- sso_configuration_record.update(**request.data.dict())
+ sso_configuration_record.update(**request_data)
sso_configuration_record.first().submit_for_configuration(updating_existing_record=True)
except (TypeError, FieldDoesNotExist, ValidationError) as e:
LOGGER.error(f'{CONFIG_UPDATE_ERROR}{e}')
diff --git a/enterprise/migrations/0185_auto_20230920_2041.py b/enterprise/migrations/0185_auto_20230920_2041.py
new file mode 100644
index 0000000000..ade20fdbb2
--- /dev/null
+++ b/enterprise/migrations/0185_auto_20230920_2041.py
@@ -0,0 +1,33 @@
+# Generated by Django 3.2.20 on 2023-09-20 20:41
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('enterprise', '0184_auto_20230914_2057'),
+ ]
+
+ operations = [
+ migrations.AlterField(
+ model_name='enterprisecustomerssoconfiguration',
+ name='entity_id',
+ field=models.CharField(blank=True, help_text='The entity id of the identity provider.', max_length=255, null=True),
+ ),
+ migrations.AlterField(
+ model_name='enterprisecustomerssoconfiguration',
+ name='metadata_url',
+ field=models.CharField(blank=True, help_text='The metadata url of the identity provider.', max_length=255, null=True),
+ ),
+ migrations.AlterField(
+ model_name='historicalenterprisecustomerssoconfiguration',
+ name='entity_id',
+ field=models.CharField(blank=True, help_text='The entity id of the identity provider.', max_length=255, null=True),
+ ),
+ migrations.AlterField(
+ model_name='historicalenterprisecustomerssoconfiguration',
+ name='metadata_url',
+ field=models.CharField(blank=True, help_text='The metadata url of the identity provider.', max_length=255, null=True),
+ ),
+ ]
diff --git a/enterprise/models.py b/enterprise/models.py
index 59ddc0bf69..c6ac072d03 100644
--- a/enterprise/models.py
+++ b/enterprise/models.py
@@ -3782,8 +3782,8 @@ class Meta:
)
metadata_url = models.CharField(
- blank=False,
- null=False,
+ blank=True,
+ null=True,
max_length=255,
help_text=_(
"The metadata url of the identity provider."
@@ -3799,8 +3799,8 @@ class Meta:
)
entity_id = models.CharField(
- blank=False,
- null=False,
+ blank=True,
+ null=True,
max_length=255,
help_text=_(
"The entity id of the identity provider."
@@ -3987,7 +3987,12 @@ def is_pending_configuration(self):
"""
Returns True if the configuration has been submitted but not completed configuration.
"""
- return self.submitted_at and not self.configured_at
+ if self.submitted_at:
+ if not self.configured_at:
+ return True
+ if self.submitted_at > self.configured_at:
+ return True
+ return False
def submit_for_configuration(self, updating_existing_record=False):
"""
@@ -4003,14 +4008,14 @@ def submit_for_configuration(self, updating_existing_record=False):
)
is_sap = False
sap_data = {}
+ config_data = {}
if self.identity_provider == self.SAP_SUCCESS_FACTORS:
for field in self.sap_config_fields:
sap_data[utils.camelCase(field)] = getattr(self, field)
is_sap = True
-
- config_data = {}
- for field in self.base_saml_config_fields:
- config_data[utils.camelCase(field)] = getattr(self, field)
+ else:
+ for field in self.base_saml_config_fields:
+ config_data[utils.camelCase(field)] = getattr(self, field)
EnterpriseSSOOrchestratorApiClient().configure_sso_orchestration_record(
config_data=config_data,
diff --git a/tests/test_enterprise/api/test_views.py b/tests/test_enterprise/api/test_views.py
index 89d410155c..3df4f98253 100644
--- a/tests/test_enterprise/api/test_views.py
+++ b/tests/test_enterprise/api/test_views.py
@@ -7489,10 +7489,19 @@ def test_sso_configuration_list_customer_filtering_while_staff(self):
# -------------------------- create test suite --------------------------
@responses.activate
- def test_sso_configuration_create_x(self):
+ def test_sso_configuration_create(self):
"""
Test expected response when successfully creating a new sso configuration.
"""
+ xml_metadata = """
+
+
+ """
+ responses.add(
+ responses.GET,
+ "https://example.com/metadata.xml",
+ body=xml_metadata,
+ )
responses.add(
responses.POST,
urljoin(get_sso_orchestrator_api_base_url(), get_sso_orchestrator_configure_path()),
@@ -7562,13 +7571,129 @@ def test_sso_configuration_create_bad_data_format(self):
response = self.post_new_sso_configuration(data)
assert "somewhackyvalue" in response.json()['error']
+ def test_sso_configuration_create_bad_xml_url(self):
+ """
+ Test expected response when creating a new sso configuration with a bad xml url.
+ """
+ responses.add(
+ responses.GET,
+ "https://example.com/metadata.xml",
+ json={'error': 'some error'},
+ status=400,
+ )
+ data = {
+ "metadata_url": "https://example.com/metadata.xml",
+ "enterprise_customer": str(self.enterprise_customer.uuid),
+ "identity_provider": "cornerstone"
+ }
+ self.set_jwt_cookie(ENTERPRISE_ADMIN_ROLE, self.enterprise_customer.uuid)
+ config_pk = uuid.uuid4()
+ EnterpriseCustomerSsoConfigurationFactory(
+ uuid=config_pk,
+ enterprise_customer=self.enterprise_customer,
+ )
+ response = self.update_sso_configuration(config_pk, data)
+ assert response.status_code == 400
+ assert "Error fetching metadata xml" in response.json()['error']
+
+ @responses.activate
+ def test_sso_configuration_create_bad_xml_content(self):
+ """
+ Test expected response when creating a new sso configuration with an xml string that doesn't contain an entity
+ id.
+ """
+ xml_metadata = """
+
+
+ """
+ data = {
+ "metadata_url": "https://example.com/metadata.xml",
+ "enterprise_customer": str(self.enterprise_customer.uuid),
+ "identity_provider": "cornerstone"
+ }
+ self.set_jwt_cookie(ENTERPRISE_ADMIN_ROLE, self.enterprise_customer.uuid)
+ config_pk = uuid.uuid4()
+ EnterpriseCustomerSsoConfigurationFactory(
+ uuid=config_pk,
+ enterprise_customer=self.enterprise_customer,
+ )
+ data = {
+ "metadata_xml": xml_metadata,
+ }
+ response = self.update_sso_configuration(config_pk, data)
+ assert response.status_code == 400
+ assert "Could not find entity ID in metadata xml" in response.json()['error']
+
# -------------------------- update test suite --------------------------
+ @responses.activate
+ def test_sso_configurations_update_bad_xml_content(self):
+ """
+ Test the expected response when updating an sso configuration with an xml string that doesn't contain an entity
+ id.
+ """
+ xml_metadata = """
+
+
+ """
+ responses.add(
+ responses.GET,
+ "https://example.com/metadata.xml",
+ body=xml_metadata,
+ )
+
+ self.set_jwt_cookie(ENTERPRISE_ADMIN_ROLE, self.enterprise_customer.uuid)
+ config_pk = uuid.uuid4()
+ EnterpriseCustomerSsoConfigurationFactory(
+ uuid=config_pk,
+ enterprise_customer=self.enterprise_customer,
+ )
+ data = {
+ "metadata_url": "https://example.com/metadata.xml",
+ }
+ response = self.update_sso_configuration(config_pk, data)
+ assert response.status_code == 400
+
+ @responses.activate
+ def test_sso_configurations_update_bad_xml_url(self):
+ """
+ Test the expected response when updating an sso configuration with a bad xml url.
+ """
+ responses.add(
+ responses.GET,
+ "https://example.com/metadata.xml",
+ json={'error': 'some error'},
+ status=400,
+ )
+
+ self.set_jwt_cookie(ENTERPRISE_ADMIN_ROLE, self.enterprise_customer.uuid)
+ config_pk = uuid.uuid4()
+ EnterpriseCustomerSsoConfigurationFactory(
+ uuid=config_pk,
+ enterprise_customer=self.enterprise_customer,
+ )
+ data = {
+ "metadata_url": "https://example.com/metadata.xml",
+ }
+ response = self.update_sso_configuration(config_pk, data)
+ assert response.status_code == 400
+ assert "Error fetching metadata xml" in response.json()['error']
+
@responses.activate
def test_sso_configurations_update_submitted_config(self):
"""
Test the expected response when updating an sso configuration that's already been submitted for configuration.
"""
+ xml_metadata = """
+
+
+ """
+ responses.add(
+ responses.GET,
+ "https://example.com/metadata.xml",
+ body=xml_metadata,
+ )
+
self.set_jwt_cookie(ENTERPRISE_ADMIN_ROLE, self.enterprise_customer.uuid)
config_pk = uuid.uuid4()
enterprise_sso_orchestration_config = EnterpriseCustomerSsoConfigurationFactory(
@@ -7592,7 +7717,7 @@ def test_sso_configurations_update_submitted_config(self):
enterprise_sso_orchestration_config.save()
response = self.update_sso_configuration(config_pk, data)
assert response.status_code == 200
- sent_body_params = json.loads(responses.calls[0].request.body)
+ sent_body_params = json.loads(responses.calls[2].request.body)
assert sent_body_params['requestIdentifier'] == str(config_pk)
@responses.activate
@@ -7600,6 +7725,16 @@ def test_sso_configuration_update(self):
"""
Test expected response when successfully updating an existing sso configuration.
"""
+ xml_metadata = """
+
+
+ """
+ responses.add(
+ responses.GET,
+ "https://example.com/metadata.xml",
+ body=xml_metadata,
+ )
+
responses.add(
responses.POST,
urljoin(get_sso_orchestrator_api_base_url(), get_sso_orchestrator_configure_path()),