diff --git a/homeassistant/components/reddit/sensor.py b/homeassistant/components/reddit/sensor.py index 1b6a960669cd64..3ba43196551028 100644 --- a/homeassistant/components/reddit/sensor.py +++ b/homeassistant/components/reddit/sensor.py @@ -15,6 +15,7 @@ CONF_CLIENT_ID = 'client_id' CONF_CLIENT_SECRET = 'client_secret' +CONF_SORT_BY = 'sort_by' CONF_SUBREDDITS = 'subreddits' ATTR_ID = 'id' @@ -29,6 +30,10 @@ DEFAULT_NAME = 'Reddit' +DOMAIN = 'reddit' + +LIST_TYPES = ['top', 'controversial', 'hot', 'new'] + SCAN_INTERVAL = timedelta(seconds=300) PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ @@ -37,6 +42,8 @@ vol.Required(CONF_USERNAME): cv.string, vol.Required(CONF_PASSWORD): cv.string, vol.Required(CONF_SUBREDDITS): vol.All(cv.ensure_list, [cv.string]), + vol.Optional(CONF_SORT_BY, default='hot'): + vol.All(cv.string, vol.In(LIST_TYPES)), vol.Optional(CONF_MAXIMUM, default=10): cv.positive_int }) @@ -48,6 +55,7 @@ def setup_platform(hass, config, add_entities, discovery_info=None): subreddits = config[CONF_SUBREDDITS] user_agent = '{}_home_assistant_sensor'.format(config[CONF_USERNAME]) limit = config[CONF_MAXIMUM] + sort_by = config[CONF_SORT_BY] try: reddit = praw.Reddit( @@ -63,18 +71,20 @@ def setup_platform(hass, config, add_entities, discovery_info=None): _LOGGER.error("Reddit error %s", err) return - sensors = [RedditSensor(reddit, sub, limit) for sub in subreddits] + sensors = [RedditSensor(reddit, subreddit, limit, sort_by) + for subreddit in subreddits] add_entities(sensors, True) class RedditSensor(Entity): """Representation of a Reddit sensor.""" - def __init__(self, reddit, subreddit: str, limit: int): + def __init__(self, reddit, subreddit: str, limit: int, sort_by: str): """Initialize the Reddit sensor.""" self._reddit = reddit - self._limit = limit self._subreddit = subreddit + self._limit = limit + self._sort_by = sort_by self._subreddit_data = [] @@ -93,7 +103,8 @@ def device_state_attributes(self): """Return the state attributes.""" return { ATTR_SUBREDDIT: self._subreddit, - ATTR_POSTS: self._subreddit_data + ATTR_POSTS: self._subreddit_data, + CONF_SORT_BY: self._sort_by } @property @@ -109,17 +120,19 @@ def update(self): try: subreddit = self._reddit.subreddit(self._subreddit) - - for submission in subreddit.top(limit=self._limit): - self._subreddit_data.append({ - ATTR_ID: submission.id, - ATTR_URL: submission.url, - ATTR_TITLE: submission.title, - ATTR_SCORE: submission.score, - ATTR_COMMENTS_NUMBER: submission.num_comments, - ATTR_CREATED: submission.created, - ATTR_BODY: submission.selftext - }) + if hasattr(subreddit, self._sort_by): + method_to_call = getattr(subreddit, self._sort_by) + + for submission in method_to_call(limit=self._limit): + self._subreddit_data.append({ + ATTR_ID: submission.id, + ATTR_URL: submission.url, + ATTR_TITLE: submission.title, + ATTR_SCORE: submission.score, + ATTR_COMMENTS_NUMBER: submission.num_comments, + ATTR_CREATED: submission.created, + ATTR_BODY: submission.selftext + }) except praw.exceptions.PRAWException as err: _LOGGER.error("Reddit error %s", err) diff --git a/tests/components/reddit/__init__.py b/tests/components/reddit/__init__.py new file mode 100644 index 00000000000000..67e0db82f42339 --- /dev/null +++ b/tests/components/reddit/__init__.py @@ -0,0 +1 @@ +"""Tests for the the Reddit component.""" diff --git a/tests/components/reddit/test_sensor.py b/tests/components/reddit/test_sensor.py new file mode 100644 index 00000000000000..2bb22a0024bf57 --- /dev/null +++ b/tests/components/reddit/test_sensor.py @@ -0,0 +1,175 @@ +"""The tests for the Reddit platform.""" +import copy +import unittest +from unittest.mock import patch + +from homeassistant.components.reddit.sensor import ( + DOMAIN, ATTR_SUBREDDIT, ATTR_POSTS, CONF_SORT_BY, + ATTR_ID, ATTR_URL, ATTR_TITLE, ATTR_SCORE, ATTR_COMMENTS_NUMBER, + ATTR_CREATED, ATTR_BODY) +from homeassistant.const import (CONF_USERNAME, CONF_PASSWORD, CONF_MAXIMUM) +from homeassistant.setup import setup_component + +from tests.common import (get_test_home_assistant, + MockDependency) + + +VALID_CONFIG = { + 'sensor': { + 'platform': DOMAIN, + 'client_id': 'test_client_id', + 'client_secret': 'test_client_secret', + CONF_USERNAME: 'test_username', + CONF_PASSWORD: 'test_password', + 'subreddits': ['worldnews', 'news'], + + } +} + +VALID_LIMITED_CONFIG = { + 'sensor': { + 'platform': DOMAIN, + 'client_id': 'test_client_id', + 'client_secret': 'test_client_secret', + CONF_USERNAME: 'test_username', + CONF_PASSWORD: 'test_password', + 'subreddits': ['worldnews', 'news'], + CONF_MAXIMUM: 1 + } +} + + +INVALID_SORT_BY_CONFIG = { + 'sensor': { + 'platform': DOMAIN, + 'client_id': 'test_client_id', + 'client_secret': 'test_client_secret', + CONF_USERNAME: 'test_username', + CONF_PASSWORD: 'test_password', + 'subreddits': ['worldnews', 'news'], + 'sort_by': 'invalid_sort_by' + } +} + + +class ObjectView(): + """Use dict properties as attributes.""" + + def __init__(self, d): + """Set dict as internal dict.""" + self.__dict__ = d + + +MOCK_RESULTS = { + 'results': [ + ObjectView({ + 'id': 0, + 'url': 'http://example.com/1', + 'title': 'example1', + 'score': '1', + 'num_comments': '1', + 'created': '', + 'selftext': 'example1 selftext' + }), + ObjectView({ + 'id': 1, + 'url': 'http://example.com/2', + 'title': 'example2', + 'score': '2', + 'num_comments': '2', + 'created': '', + 'selftext': 'example2 selftext' + }) + ] +} + +MOCK_RESULTS_LENGTH = len(MOCK_RESULTS['results']) + + +class MockPraw(): + """Mock class for tmdbsimple library.""" + + def __init__(self, client_id: str, client_secret: + str, username: str, password: str, + user_agent: str): + """Add mock data for API return.""" + self._data = MOCK_RESULTS + + def subreddit(self, subreddit: str): + """Return an instance of a sunbreddit.""" + return MockSubreddit(subreddit, self._data) + + +class MockSubreddit(): + """Mock class for a subreddit instance.""" + + def __init__(self, subreddit: str, data): + """Add mock data for API return.""" + self._subreddit = subreddit + self._data = data + + def top(self, limit): + """Return top posts for a subreddit.""" + return self._return_data(limit) + + def controversial(self, limit): + """Return controversial posts for a subreddit.""" + return self._return_data(limit) + + def hot(self, limit): + """Return hot posts for a subreddit.""" + return self._return_data(limit) + + def new(self, limit): + """Return new posts for a subreddit.""" + return self._return_data(limit) + + def _return_data(self, limit): + """Test method to return modified data.""" + data = copy.deepcopy(self._data) + return data['results'][:limit] + + +class TestRedditSetup(unittest.TestCase): + """Test the Reddit platform.""" + + def setUp(self): + """Initialize values for this testcase class.""" + self.hass = get_test_home_assistant() + + def tearDown(self): # pylint: disable=invalid-name + """Stop everything that was started.""" + self.hass.stop() + + @MockDependency('praw') + @patch('praw.Reddit', new=MockPraw) + def test_setup_with_valid_config(self, mock_praw): + """Test the platform setup with movie configuration.""" + setup_component(self.hass, 'sensor', VALID_CONFIG) + + state = self.hass.states.get('sensor.reddit_worldnews') + assert int(state.state) == MOCK_RESULTS_LENGTH + + state = self.hass.states.get('sensor.reddit_news') + assert int(state.state) == MOCK_RESULTS_LENGTH + + assert state.attributes[ATTR_SUBREDDIT] == 'news' + + assert state.attributes[ATTR_POSTS][0] == { + ATTR_ID: 0, + ATTR_URL: 'http://example.com/1', + ATTR_TITLE: 'example1', + ATTR_SCORE: '1', + ATTR_COMMENTS_NUMBER: '1', + ATTR_CREATED: '', + ATTR_BODY: 'example1 selftext' + } + + assert state.attributes[CONF_SORT_BY] == 'hot' + + @MockDependency('praw') + @patch('praw.Reddit', new=MockPraw) + def test_setup_with_invalid_config(self, mock_praw): + """Test the platform setup with invalid movie configuration.""" + setup_component(self.hass, 'sensor', INVALID_SORT_BY_CONFIG) + assert not self.hass.states.get('sensor.reddit_worldnews')