Skip to content

Commit

Permalink
fixup!: apply review fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
xitij2000 committed Dec 6, 2024
1 parent 8d6d690 commit 1806428
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 25 deletions.
16 changes: 7 additions & 9 deletions openedx/core/djangoapps/agreements/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,10 +248,10 @@ def get_user_agreements(user: User) -> Iterable[UserAgreementRecordData]:
yield UserAgreementRecordData.from_model(agreement_record)


def get_user_agreement_record(
def get_latest_user_agreement_record(
user: User,
agreement_type: str,
agreement_update_timestamp: datetime = None,
agreed_after: datetime = None,
) -> Optional[UserAgreementRecordData]:
"""
Retrieve the user agreement record for the specified user and agreement type.
Expand All @@ -264,9 +264,9 @@ def get_user_agreement_record(
user=user,
agreement_type=agreement_type,
)
if agreement_update_timestamp:
record_query = record_query.filter(timestamp__gte=agreement_update_timestamp)
record = record_query.get()
if agreed_after:
record_query = record_query.filter(timestamp__gte=agreed_after)
record = record_query.latest("timestamp")
return UserAgreementRecordData.from_model(record)
except UserAgreementRecord.DoesNotExist:
return None
Expand All @@ -277,11 +277,9 @@ def create_user_agreement_record(user: User, agreement_type: str) -> UserAgreeme
Creates a user agreement record if one doesn't already exist, or updates existing
record to current timestamp.
"""
record, _ = UserAgreementRecord.objects.update_or_create(
record = UserAgreementRecord.objects.create(
user=user,
agreement_type=agreement_type,
defaults={
"timestamp": datetime.now(),
},
timestamp=datetime.now(),
)
return UserAgreementRecordData.from_model(record)
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Generated by Django 4.2.16 on 2024-11-14 11:47
# Generated by Django 4.2.16 on 2024-12-06 11:34

from django.conf import settings
from django.db import migrations, models
Expand All @@ -21,8 +21,5 @@ class Migration(migrations.Migration):
('timestamp', models.DateTimeField(auto_now_add=True)),
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
],
options={
'unique_together': {('user', 'agreement_type')},
},
),
]
4 changes: 1 addition & 3 deletions openedx/core/djangoapps/agreements/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class UserAgreementRecord(models.Model):
This model stores the agreements a user has accepted or acknowledged.
Each record here represents a user agreeing to the agreement type represented
by `agreement_type`.
by `agreement_type` at a particular time.
.. no_pii:
"""
Expand All @@ -87,5 +87,3 @@ class UserAgreementRecord(models.Model):

class Meta:
app_label = 'agreements'
# A user can only have a single record for a single agreement type.
unique_together = [['user', 'agreement_type']]
4 changes: 2 additions & 2 deletions openedx/core/djangoapps/agreements/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,6 @@ class UserAgreementsSerializer(serializers.Serializer):
"""
Serializer for UserAgreementRecord model
"""
accepted_at = serializers.DateTimeField()
agreement_type = serializers.CharField(read_only=True)
username = serializers.CharField(read_only=True)
agreement_type = serializers.CharField(read_only=True)
accepted_at = serializers.DateTimeField()
6 changes: 3 additions & 3 deletions openedx/core/djangoapps/agreements/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
get_integrity_signatures_for_course,
get_lti_pii_signature,
get_pii_receiving_lti_tools,
get_user_agreement_record,
get_latest_user_agreement_record,
get_user_agreements
)
from ..models import LTIPIITool
Expand Down Expand Up @@ -214,11 +214,11 @@ def test_get_user_agreements(self, ):

def test_get_user_agreement_record(self):
record = create_user_agreement_record(self.user, 'test_type')
result = get_user_agreement_record(self.user, 'test_type')
result = get_latest_user_agreement_record(self.user, 'test_type')

assert result == record

result = get_user_agreement_record(self.user, 'test_type', datetime.now() + timedelta(days=1))
result = get_latest_user_agreement_record(self.user, 'test_type', datetime.now() + timedelta(days=1))

assert result is None

Expand Down
7 changes: 6 additions & 1 deletion openedx/core/djangoapps/agreements/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,11 @@ class UserAgreementsViewTests(APITestCase):

def setUp(self):
self.user = UserFactory(username="testuser", password="password")
self.client.login(username="testuser", password="password")
self.url = reverse('user_agreements', kwargs={'agreement_type': 'sample_agreement'})
self.login()

def login(self):
self.client.login(username="testuser", password="password")

def test_get_user_agreement_record_no_data(self):
response = self.client.get(self.url)
Expand All @@ -326,6 +329,8 @@ def test_post_user_agreement(self):
response = self.client.post(self.url)
assert response.status_code == status.HTTP_201_CREATED

self.login()

response = self.client.get(self.url)
assert response.status_code == status.HTTP_200_OK

Expand Down
8 changes: 5 additions & 3 deletions openedx/core/djangoapps/agreements/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
create_lti_pii_signature,
create_user_agreement_record,
get_integrity_signature,
get_user_agreement_record
get_latest_user_agreement_record
)
from .serializers import IntegritySignatureSerializer, LTIPIISignatureSerializer, UserAgreementsSerializer
from ...lib.api.view_utils import view_auth_classes


def is_user_course_or_global_staff(user, course_id):
Expand Down Expand Up @@ -167,7 +168,8 @@ def post(self, request, course_id):
return Response(data=serializer.data, status=statusStr)


class UserAgreementsView(AuthenticatedAPIView):
@view_auth_classes(is_authenticated=True)
class UserAgreementsView(APIView):
"""
Endpoint for the user agreements API.
"""
Expand Down Expand Up @@ -207,7 +209,7 @@ def get(self, request, agreement_type):
params = UserAgreementsView.QueryFilterForm(request.query_params)
if not params.is_valid():
return Response(status=status.HTTP_400_BAD_REQUEST)
record = get_user_agreement_record(request.user, agreement_type, params.cleaned_data.get('after'))
record = get_latest_user_agreement_record(request.user, agreement_type, params.cleaned_data.get('after'))
if record is None:
return Response(status=status.HTTP_404_NOT_FOUND)
serializer = UserAgreementsSerializer(record)
Expand Down

0 comments on commit 1806428

Please sign in to comment.