diff --git a/lms/djangoapps/discussion/rest_api/api.py b/lms/djangoapps/discussion/rest_api/api.py index 552460dba6e3..d9e53ceafb1e 100644 --- a/lms/djangoapps/discussion/rest_api/api.py +++ b/lms/djangoapps/discussion/rest_api/api.py @@ -1465,7 +1465,7 @@ def update_comment(request, comment_id, update_data): return api_comment -def get_thread(request, thread_id, requested_fields=None): +def get_thread(request, thread_id, requested_fields=None, course_id=None): """ Retrieve a thread. @@ -1476,6 +1476,8 @@ def get_thread(request, thread_id, requested_fields=None): thread_id: The id for the thread to retrieve + course_id: the id of the course the threads belongs to + requested_fields: Indicates which additional fields to return for thread. (i.e. ['profile_image']) """ @@ -1489,6 +1491,8 @@ def get_thread(request, thread_id, requested_fields=None): "user_id": str(request.user.id), } ) + if course_id and course_id != cc_thread.course_id: + raise ThreadNotFoundError("Thread not found.") return _serialize_discussion_entities(request, context, [cc_thread], requested_fields, DiscussionEntity.thread)[0] diff --git a/lms/djangoapps/discussion/rest_api/tests/test_api.py b/lms/djangoapps/discussion/rest_api/tests/test_api.py index aab2be06c27a..8204ca0218f0 100644 --- a/lms/djangoapps/discussion/rest_api/tests/test_api.py +++ b/lms/djangoapps/discussion/rest_api/tests/test_api.py @@ -3993,6 +3993,14 @@ def test_group_access(self, role_name, course_is_cohorted, thread_group_state): except ThreadNotFoundError: assert expected_error + def test_course_id_mismatch(self): + """ + Test if the api throws not found exception if course_id from params mismatches course_id in thread + """ + self.register_thread() + get_thread(self.request, self.thread_id, 'different_course_id') + assert ThreadNotFoundError + @mock.patch('lms.djangoapps.discussion.rest_api.api._get_course', mock.Mock()) class CourseTopicsV2Test(ModuleStoreTestCase): diff --git a/lms/djangoapps/discussion/rest_api/views.py b/lms/djangoapps/discussion/rest_api/views.py index 6831765fbf87..daed977e2690 100644 --- a/lms/djangoapps/discussion/rest_api/views.py +++ b/lms/djangoapps/discussion/rest_api/views.py @@ -527,7 +527,8 @@ def retrieve(self, request, thread_id=None): Implements the GET method for thread ID """ requested_fields = request.GET.get('requested_fields') - return Response(get_thread(request, thread_id, requested_fields)) + course_id = request.GET.get('course_id') + return Response(get_thread(request, thread_id, requested_fields, course_id)) def create(self, request): """