Skip to content

Commit

Permalink
Fix #167: more unicode issues
Browse files Browse the repository at this point in the history
* Modularize the CSV handling into an object that's unicode-aware.  This not only fixes a file mode bug, and does catching of unicode issues, but it also makes us ready for py3 where the CSV module actually handles unicode strings.
* NOTE: because emails cannot contain non-ascii chars, the stray files don't need encoding on input or output.
* Make the LDAP attribute formatters fully unicode aware.  Before they didn't realize that the format strings were themselves unicode, so they were re-encoding the results of formatting.
  • Loading branch information
adobeDan committed Jun 6, 2017
1 parent 8715b69 commit 6d0b988
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 90 deletions.
17 changes: 10 additions & 7 deletions user_sync/connector/directory_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@
import user_sync.config
import user_sync.connector.helper
import user_sync.error
import user_sync.helper
import user_sync.identity_type

from user_sync.helper import CSVAdapter

def connector_metadata():
metadata = {
Expand Down Expand Up @@ -127,10 +126,11 @@ def get_column_name(key):
recognized_column_names += extended_attributes

line_read = 0
rows = user_sync.helper.iter_csv_rows(file_path,
delimiter=options['delimiter'],
recognized_column_names=recognized_column_names,
logger=logger)
rows = CSVAdapter.read_csv_rows(file_path,
recognized_column_names=recognized_column_names,
logger=logger,
encoding=self.encoding,
delimiter=options['delimiter'])
for row in rows:
line_read += 1
email = self.get_column_value(row, email_column_name)
Expand Down Expand Up @@ -199,4 +199,7 @@ def get_column_value(self, row, column_name):
:type column_name: str
"""
value = row.get(column_name)
return value.decode(self.encoding) if value else None
if not value:
return None
else:
return value.decode(self.encoding)
13 changes: 8 additions & 5 deletions user_sync/connector/directory_ldap.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ class LDAPValueFormatter(object):

def __init__(self, string_format):
"""
:type string_format: str
:type string_format: unicode
"""
if string_format is None:
attribute_names = []
Expand All @@ -402,7 +402,7 @@ def get_attribute_names(self):
def generate_value(self, record):
"""
:type record: dict
:rtype (str, str)
:rtype (unicode, unicode)
"""
result = None
attribute_name = None
Expand All @@ -415,17 +415,20 @@ def generate_value(self, record):
break
values[attribute_name] = value
if values is not None:
result = self.string_format.format(**values).decode(self.encoding)
result = self.string_format.format(**values)
return result, attribute_name

@classmethod
def get_attribute_value(cls, attributes, attribute_name):
"""
:type attributes: dict
:type attribute_name: str
:type attribute_name: unicode
"""
if attribute_name in attributes:
attribute_value = attributes[attribute_name]
if len(attribute_value) > 0:
return attribute_value[0].decode(cls.encoding)
try:
return attribute_value[0].decode(cls.encoding)
except UnicodeError as e:
raise AssertionException("Encoding error in value of attribute '%s': %s" % (attribute_name, e))
return None
122 changes: 76 additions & 46 deletions user_sync/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,7 @@
import datetime
import os

import user_sync.error


def open_file(name, mode, buffering=-1):
"""
:type name: str
:type mode: str
:type buffering: int
"""
try:
return open(str(name), mode, buffering)
except IOError as e:
raise user_sync.error.AssertionException(str(e))
from user_sync.error import AssertionException


def normalize_string(string_value):
Expand All @@ -46,43 +34,85 @@ def normalize_string(string_value):
return string_value.strip().lower() if string_value is not None else None


def guess_delimiter_from_filename(filename):
class CSVAdapter:
"""
:type filename
:rtype str
Read and write CSV files to and from lists of dictionaries
"""
_base_name, extension = os.path.os.path.splitext(filename)
normalized_extension = normalize_string(extension)
if normalized_extension == '.csv':
return ','
if normalized_extension == '.tsv':
@staticmethod
def open_csv_file(name, mode, encoding=None):
"""
:type name: str
:type mode: str
:type encoding: str, but ignored in py2
:rtype file
"""
try:
if mode == 'r':
return open(str(name), 'rb', buffering=1)
elif mode == 'w':
return open(str(name), 'wb')
else:
raise ValueError("File mode (%s) must be 'r' or 'w'" % mode)
except IOError as e:
raise AssertionException("Can't open file '%s': %s" % (name, e))

@staticmethod
def guess_delimiter_from_filename(filename):
"""
:type filename
:rtype str
"""
_base_name, extension = os.path.splitext(filename)
normalized_extension = normalize_string(extension)
if normalized_extension == '.csv':
return ','
if normalized_extension == '.tsv':
return '\t'
return '\t'
return '\t'


def iter_csv_rows(file_path, delimiter=None, recognized_column_names=None, logger=None):
"""
:type file_path: str
:type delimiter: str
:type recognized_column_names: list(str)
:type logger: logging.Logger
"""
with open_file(file_path, 'r', 1) as input_file:
if delimiter is None:
delimiter = guess_delimiter_from_filename(file_path)
reader = csv.DictReader(input_file, delimiter=delimiter)

if recognized_column_names is not None:
unrecognized_column_names = [column_name for column_name in reader.fieldnames
if column_name not in recognized_column_names]
if len(unrecognized_column_names) > 0 and logger is not None:
logger.warn("In file '%s': unrecognized column names: %s", file_path, unrecognized_column_names)

for row in reader:
yield row


class JobStats(object):
@classmethod
def read_csv_rows(cls, file_path, recognized_column_names=None, logger=None, encoding=None, delimiter=None):
"""
:type file_path: str
:type recognized_column_names: list(str)
:type logger: logging.Logger
:type encoding: str
:type delimiter: str
"""
with cls.open_csv_file(file_path, 'r', encoding) as input_file:
if delimiter is None:
delimiter = cls.guess_delimiter_from_filename(file_path)
try:
reader = csv.DictReader(input_file, delimiter=delimiter)
if recognized_column_names is not None:
unrecognized_column_names = [column_name for column_name in reader.fieldnames
if column_name not in recognized_column_names]
if len(unrecognized_column_names) > 0 and logger is not None:
logger.warn("In file '%s': unrecognized column names: %s", file_path, unrecognized_column_names)
for row in reader:
yield row
except UnicodeError as e:
raise AssertionException("Encoding error in file '%s': %s" % (file_path, e))

@classmethod
def write_csv_rows(cls, file_path, field_names, rows, encoding=None, delimiter=None):
"""
:type file_path: str
:type field_names: list(str)
:type rows: list(dict)
:type encoding: str
:type delimiter: str
"""
with cls.open_csv_file(file_path, 'w', encoding=encoding) as output_file:
if delimiter is None:
delimiter = cls.guess_delimiter_from_filename(file_path)
writer = csv.DictWriter(output_file, fieldnames=field_names, delimiter=delimiter)
writer.writeheader()
for row in rows:
writer.writerow(row)


class JobStats:
line_left_count = 10
line_width = 60

Expand Down
63 changes: 31 additions & 32 deletions user_sync/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,19 @@
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import csv
import logging

import user_sync.connector.umapi
import user_sync.error
import user_sync.helper
import user_sync.identity_type
from user_sync.helper import normalize_string, CSVAdapter, JobStats

GROUP_NAME_DELIMITER = '::'
PRIMARY_UMAPI_NAME = None
Expand Down Expand Up @@ -149,7 +149,7 @@ def run(self, directory_groups, directory_connector, umapi_connectors):
self.prepare_umapi_infos()

if directory_connector is not None:
load_directory_stats = user_sync.helper.JobStats("Load from Directory", divider="-")
load_directory_stats = JobStats("Load from Directory", divider="-")
load_directory_stats.log_start(logger)
self.read_desired_user_groups(directory_groups, directory_connector)
load_directory_stats.log_end(logger)
Expand All @@ -158,7 +158,7 @@ def run(self, directory_groups, directory_connector, umapi_connectors):
# no directory users to sync with
should_sync_umapi_users = False

umapi_stats = user_sync.helper.JobStats("Sync Umapi", divider="-")
umapi_stats = JobStats("Sync Umapi", divider="-")
umapi_stats.log_start(logger)
if should_sync_umapi_users:
self.process_umapi_users(umapi_connectors)
Expand Down Expand Up @@ -791,7 +791,7 @@ def normalize_groups(group_names):
result = set()
if group_names is not None:
for group_name in group_names:
normalized_group_name = user_sync.helper.normalize_string(group_name)
normalized_group_name = normalize_string(group_name)
result.add(normalized_group_name)
return result

Expand Down Expand Up @@ -881,9 +881,9 @@ def get_user_key(self, id_type, username, domain, email=None):
:return: string "id_type,username,domain" (or None)
"""
id_type = user_sync.identity_type.parse_identity_type(id_type)
email = user_sync.helper.normalize_string(email) if email else None
username = user_sync.helper.normalize_string(username) or email
domain = user_sync.helper.normalize_string(domain)
email = normalize_string(email) if email else None
username = normalize_string(username) or email
domain = normalize_string(domain)

if not id_type:
return None
Expand Down Expand Up @@ -917,13 +917,13 @@ def read_stray_key_map(self, file_path, delimiter=None):
user_column_name = 'username'
domain_column_name = 'domain'
ummapi_name_column_name = 'umapi'
rows = user_sync.helper.iter_csv_rows(file_path,
delimiter=delimiter,
recognized_column_names=[
id_type_column_name, user_column_name, domain_column_name,
ummapi_name_column_name,
],
logger=self.logger)
rows = CSVAdapter.read_csv_rows(file_path,
recognized_column_names=[
id_type_column_name, user_column_name, domain_column_name,
ummapi_name_column_name,
],
logger=self.logger,
delimiter=delimiter)
for row in rows:
umapi_name = row.get(ummapi_name_column_name) or PRIMARY_UMAPI_NAME
id_type = row.get(id_type_column_name)
Expand Down Expand Up @@ -952,26 +952,25 @@ def write_stray_key_map(self):
# figure out if we should include a umapi column
secondary_count = 0
fieldnames = ['type', 'username', 'domain']
rows = []
# count the secondaries, and if there are any add the name as a column
for umapi_name in self.stray_key_map:
if umapi_name != PRIMARY_UMAPI_NAME and self.get_stray_keys(umapi_name):
if not secondary_count:
fieldnames.append('umapi')
secondary_count += 1
with open(file_path, 'wb') as output_file:
delimiter = user_sync.helper.guess_delimiter_from_filename(file_path)
writer = csv.DictWriter(output_file, fieldnames=fieldnames, delimiter=delimiter)
writer.writeheader()
# None sorts before strings, so sorting the keys in the map
# puts the primary umapi first in the output, which is handy
for umapi_name in sorted(self.stray_key_map.keys()):
for user_key in self.get_stray_keys(umapi_name):
id_type, username, domain = self.parse_user_key(user_key)
umapi = umapi_name if umapi_name else ""
if secondary_count:
row_dict = {'type': id_type, 'username': username, 'domain': domain, 'umapi': umapi}
else:
row_dict = {'type': id_type, 'username': username, 'domain': domain}
writer.writerow(row_dict)
# None sorts before strings, so sorting the keys in the map
# puts the primary umapi first in the output, which is handy
for umapi_name in sorted(self.stray_key_map.keys()):
for user_key in self.get_stray_keys(umapi_name):
id_type, username, domain = self.parse_user_key(user_key)
umapi = umapi_name if umapi_name else ""
if secondary_count:
row_dict = {'type': id_type, 'username': username, 'domain': domain, 'umapi': umapi}
else:
row_dict = {'type': id_type, 'username': username, 'domain': domain}
rows.append(row_dict)
CSVAdapter.write_csv_rows(file_path, fieldnames, rows)
user_count = len(self.stray_key_map.get(PRIMARY_UMAPI_NAME, []))
user_plural = "" if user_count == 1 else "s"
if secondary_count > 0:
Expand Down Expand Up @@ -1116,7 +1115,7 @@ def add_mapped_group(self, group):
"""
:type group: str
"""
normalized_group_name = user_sync.helper.normalize_string(group)
normalized_group_name = normalize_string(group)
self.mapped_groups.add(normalized_group_name)

def get_mapped_groups(self):
Expand All @@ -1141,7 +1140,7 @@ def add_desired_group_for(self, user_key, group):
if desired_groups is None:
self.desired_groups_by_user_key[user_key] = desired_groups = set()
if group is not None:
normalized_group_name = user_sync.helper.normalize_string(group)
normalized_group_name = normalize_string(group)
desired_groups.add(normalized_group_name)

def add_umapi_user(self, user_key, user):
Expand Down

0 comments on commit 6d0b988

Please sign in to comment.