Skip to content

Commit

Permalink
move spam check into save method
Browse files Browse the repository at this point in the history
  • Loading branch information
John Tordoff committed Feb 8, 2024
1 parent 82ee230 commit 250ef11
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 34 deletions.
2 changes: 1 addition & 1 deletion osf/models/spam.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def do_check_spam(self, author, author_email, content, request_headers):
return

request_kwargs = {
'remote_addr': request_headers.get('Remote-Addr') or request_headers['Host'], # for local testing
'remote_addr': request_headers.get('Remote-Addr') or request_headers.get('Host'), # for local testing
'user_agent': request_headers.get('User-Agent'),
'referer': request_headers.get('Referer'),
}
Expand Down
70 changes: 52 additions & 18 deletions osf/models/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from website.project import new_bookmark_collection
from website.util.metrics import OsfSourceTags
from importlib import import_module
from osf.utils.requests import get_headers_from_request

SessionStore = import_module(settings.SESSION_ENGINE).SessionStore

Expand Down Expand Up @@ -1025,8 +1026,14 @@ def save(self, *args, **kwargs):

self.update_is_active()
self.username = self.username.lower().strip() if self.username else None
dirty_fields = set(self.get_dirty_fields(check_relationship=True))
ret = super(OSFUser, self).save(*args, **kwargs)
dirty_fields = self.get_dirty_fields(check_relationship=True)
ret = super(OSFUser, self).save(*args, **kwargs) # must save BEFORE spam check, as user needs guid.
if set(self.SPAM_USER_PROFILE_FIELDS.keys()).intersection(dirty_fields):
request = get_current_request()
headers = get_headers_from_request(request)
self.check_spam(dirty_fields, request_headers=headers)

dirty_fields = set(dirty_fields)
if self.SEARCH_UPDATE_FIELDS.intersection(dirty_fields) and self.is_confirmed:
self.update_search()
self.update_search_nodes_contributors()
Expand Down Expand Up @@ -1859,28 +1866,55 @@ def get_node_comment_timestamps(self, target_id):
return self.comments_viewed_timestamp.get(target_id, default_timestamp)

def _get_spam_content(self, saved_fields=None, **unused_kwargs):
"""
Retrieves content for spam checking from specified fields.
Sometimes from validated serializer data, sometimes from
dirty_fields.
Parameters:
- saved_fields (dict): Fields that have been saved and their values.
- unused_kwargs: Ignored additional keyword arguments.
Returns:
- str: A string containing the spam check contents, joined by spaces.
"""
# Determine which fields to check for spam.
spam_check_fields = set(self.SPAM_USER_PROFILE_FIELDS.keys())
spam_check_source = {}

# Decide the source of the fields to check: either use 'saved_fields' if provided, or the object's attributes.
if saved_fields:
spam_check_source = saved_fields
spam_check_fields = spam_check_fields.intersection(set(saved_fields.keys()))
# Only check fields that are both in 'saved_fields' and 'SPAM_USER_PROFILE_FIELDS'.
spam_check_fields = spam_check_fields.intersection(saved_fields.keys())
else:
spam_check_source = {field: getattr(self, field) for field in spam_check_fields}

spam_check_contents = []
for spam_field in spam_check_fields:
spam_field_content = spam_check_source[spam_field]
if not spam_field_content:
continue
if spam_field in ['schools', 'jobs']:
spam_check_contents.extend(
_get_nested_spam_check_content(spam_check_source, spam_field)
)
else: # Only other currently checked field is social['profileWebsites']
spam_check_contents.extend(
spam_check_source.get('social', dict()).get('profileWebsites', list())
)
return ' '.join(spam_check_contents).strip()
spam_contents = []
for field in spam_check_fields:
# Check if the field's value is present, if not it's from
# dirty fields, so fetch it.
if not spam_check_source[field]:
value = getattr(self, field)
# Special handling for the 'social' field to extract 'profileWebsites'.
if field == 'social':
websites = value.get('profileWebsites', [])
spam_contents.extend(websites)
else:
# Handle nested spam check content for other fields.
nested_contents = _get_nested_spam_check_content(spam_check_source, field)
spam_contents.extend(nested_contents)
else:
# For 'schools' and 'jobs', always extract nested spam check content.
if field in ['schools', 'jobs']:
nested_contents = _get_nested_spam_check_content(spam_check_source, field)
spam_contents.extend(nested_contents)
else:
# Extract 'profileWebsites' directly for the 'social' field.
websites = spam_check_source.get('social', {}).get('profileWebsites', [])
spam_contents.extend(websites)

# Join all collected spam check contents into a single string.
return ' '.join(spam_contents).strip()

def check_spam(self, saved_fields, request_headers):
is_spam = False
Expand Down
2 changes: 1 addition & 1 deletion osf/utils/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def get_headers_from_request(req):
k: v
for k, v in headers.items()
}
headers['Remote-Addr'] = req.remote_addr
headers['Remote-Addr'] = getattr(req, 'remote_addr', None)
return headers


Expand Down
18 changes: 16 additions & 2 deletions osf_tests/test_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@

SessionStore = import_module(django_conf_settings.SESSION_ENGINE).SessionStore

from osf.external.spam import tasks as spam_tasks


pytestmark = pytest.mark.django_db

Expand Down Expand Up @@ -2191,11 +2193,13 @@ def test_validate_social_profile_website_many_different(self):
with open(url_data_path) as url_test_data:
data = json.load(url_test_data)

previous_number_of_domains = NotableDomain.objects.all().count()
fails_at_end = False
for should_pass in data['testsPositive']:
try:
self.user.social = {'profileWebsites': [should_pass]}
self.user.save()
with mock.patch.object(spam_tasks.requests, 'head'):
self.user.save()
assert self.user.social['profileWebsites'] == [should_pass]
except ValidationError:
fails_at_end = True
Expand All @@ -2205,13 +2209,23 @@ def test_validate_social_profile_website_many_different(self):
self.user.social = {'profileWebsites': [should_fail]}
try:
with pytest.raises(ValidationError):
self.user.save()
with mock.patch.object(spam_tasks.requests, 'head'):
self.user.save()
except AssertionError:
fails_at_end = True
print('\"' + should_fail + '\" passed but should have failed while testing that the validator ' + data['testsNegative'][should_fail])
if fails_at_end:
raise

# Not all domains that are permissable are possible to use as spam,
# some are correctly not extracted and not kept in notable domain so spot
# check some, not all, because not all `testsPositive` urls should be in
# NotableDomains
assert NotableDomain.objects.all().count() > previous_number_of_domains
assert NotableDomain.objects.get(domain='definitelyawebsite.com')
assert NotableDomain.objects.get(domain='a.b-c.de')


def test_validate_multiple_profile_websites_valid(self):
self.user.social = {'profileWebsites': ['http://cos.io/', 'http://thebuckstopshere.com', 'http://dinosaurs.com']}
self.user.save()
Expand Down
29 changes: 18 additions & 11 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
QuickFilesNode,
NotableDomain
)
from osf.external.spam import tasks as spam_tasks

from tests.base import (
assert_is_redirect,
Expand Down Expand Up @@ -1140,16 +1141,19 @@ def test_unserialize_social(self):
'twitter': 'howtopizza',
'github': 'frozenpizzacode',
}
self.app.put_json(
url,
payload,
auth=self.user.auth,
)
with mock.patch.object(spam_tasks.requests, 'head'):
resp = self.app.put_json(
url,
payload,
auth=self.user.auth,
)

self.user.reload()
for key, value in payload.items():
assert_equal(self.user.social[key], value)
assert_true(self.user.social['researcherId'] is None)

assert NotableDomain.objects.all()
assert NotableDomain.objects.get(domain='frozen.pizza.com')

# Regression test for help-desk ticket
Expand Down Expand Up @@ -1187,7 +1191,8 @@ def test_unserialize_social_validation_failure(self):
def test_serialize_social_editable(self):
self.user.social['twitter'] = 'howtopizza'
self.user.social['profileWebsites'] = ['http://www.cos.io', 'http://www.osf.io', 'http://www.wordup.com']
self.user.save()
with mock.patch.object(spam_tasks.requests, 'head'):
self.user.save()
url = api_url_for('serialize_social')
res = self.app.get(
url,
Expand All @@ -1202,12 +1207,14 @@ def test_serialize_social_not_editable(self):
user2 = AuthUserFactory()
self.user.social['twitter'] = 'howtopizza'
self.user.social['profileWebsites'] = ['http://www.cos.io', 'http://www.osf.io', 'http://www.wordup.com']
self.user.save()
with mock.patch.object(spam_tasks.requests, 'head'):
self.user.save()
url = api_url_for('serialize_social', uid=self.user._id)
res = self.app.get(
url,
auth=user2.auth,
)
with mock.patch.object(spam_tasks.requests, 'head'):
res = self.app.get(
url,
auth=user2.auth,
)
assert_equal(res.json.get('twitter'), 'howtopizza')
assert_equal(res.json.get('profileWebsites'), ['http://www.cos.io', 'http://www.osf.io', 'http://www.wordup.com'])
assert_true(res.json.get('github') is None)
Expand Down
1 change: 0 additions & 1 deletion website/profile/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,6 @@ def unserialize_social(auth, **kwargs):

try:
user.save()
user.check_spam(saved_fields=None, request_headers=request.headers)
except ValidationError as exc:
raise HTTPError(http_status.HTTP_400_BAD_REQUEST, data=dict(
message_long=exc.messages[0]
Expand Down

0 comments on commit 250ef11

Please sign in to comment.