From 2c0d3db48e39117052c5e6d36ede893e7c4007f2 Mon Sep 17 00:00:00 2001 From: Haibao Tang Date: Tue, 18 Jun 2024 15:06:03 -0700 Subject: [PATCH] Add numpy as np --- goatools/nt_utils.py | 36 +++++++++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/goatools/nt_utils.py b/goatools/nt_utils.py index 9cfe3323..e25045ad 100644 --- a/goatools/nt_utils.py +++ b/goatools/nt_utils.py @@ -7,11 +7,13 @@ import datetime import collections as cx + def get_dict_w_id2nts(ids, id2nts, flds, dflt_null=""): """Return a new dict of namedtuples by combining "dicts" of namedtuples or objects.""" assert len(ids) == len(set(ids)), "NOT ALL IDs ARE UNIQUE: {IDs}".format(IDs=ids) assert len(flds) == len(set(flds)), "DUPLICATE FIELDS: {IDs}".format( - IDs=cx.Counter(flds).most_common()) + IDs=cx.Counter(flds).most_common() + ) usr_id_nt = [] # 1. Instantiate namedtuple object ntobj = cx.namedtuple("Nt", " ".join(flds)) @@ -23,6 +25,7 @@ def get_dict_w_id2nts(ids, id2nts, flds, dflt_null=""): usr_id_nt.append((item_id, ntobj._make(vals))) return cx.OrderedDict(usr_id_nt) + def get_list_w_id2nts(ids, id2nts, flds, dflt_null=""): """Return a new list of namedtuples by combining "dicts" of namedtuples or objects.""" combined_nt_list = [] @@ -36,41 +39,53 @@ def get_list_w_id2nts(ids, id2nts, flds, dflt_null=""): combined_nt_list.append(ntobj._make(vals)) return combined_nt_list + def combine_nt_lists(lists, flds, dflt_null=""): """Return a new list of namedtuples by zipping "lists" of namedtuples or objects.""" combined_nt_list = [] # Check that all lists are the same length lens = [len(lst) for lst in lists] - assert len(set(lens)) == 1, \ - "LIST LENGTHS MUST BE EQUAL: {Ls}".format(Ls=" ".join(str(l) for l in lens)) + assert len(set(lens)) == 1, "LIST LENGTHS MUST BE EQUAL: {Ls}".format( + Ls=" ".join(str(l) for l in lens) + ) # 1. Instantiate namedtuple object ntobj = cx.namedtuple("Nt", " ".join(flds)) # 2. Loop through zipped list for lst0_lstn in zip(*lists): # 2a. Combine various namedtuples into a single namedtuple - combined_nt_list.append(ntobj._make(_combine_nt_vals(lst0_lstn, flds, dflt_null))) + combined_nt_list.append( + ntobj._make(_combine_nt_vals(lst0_lstn, flds, dflt_null)) + ) return combined_nt_list + def wr_py_nts(fout_py, nts, docstring=None, varname="nts"): """Save namedtuples into a Python module.""" if nts: - with open(fout_py, 'w') as prt: + with open(fout_py, "w") as prt: prt.write('"""{DOCSTRING}"""\n\n'.format(DOCSTRING=docstring)) prt.write("# Created: {DATE}\n".format(DATE=str(datetime.date.today()))) prt_nts(prt, nts, varname) - sys.stdout.write(" {N:7,} items WROTE: {PY}\n".format(N=len(nts), PY=fout_py)) + sys.stdout.write( + " {N:7,} items WROTE: {PY}\n".format(N=len(nts), PY=fout_py) + ) -def prt_nts(prt, nts, varname, spc=' '): + +def prt_nts(prt, nts, varname, spc=" "): """Print namedtuples into a Python module.""" first_nt = nts[0] nt_name = type(first_nt).__name__ prt.write("import collections as cx\n\n") + prt.write("import numpy as np\n\n") prt.write("NT_FIELDS = [\n") for fld in first_nt._fields: prt.write('{SPC}"{F}",\n'.format(SPC=spc, F=fld)) prt.write("]\n\n") - prt.write('{NtName} = cx.namedtuple("{NtName}", " ".join(NT_FIELDS))\n\n'.format( - NtName=nt_name)) + prt.write( + '{NtName} = cx.namedtuple("{NtName}", " ".join(NT_FIELDS))\n\n'.format( + NtName=nt_name + ) + ) prt.write("# {N:,} items\n".format(N=len(nts))) prt.write("# pylint: disable=line-too-long\n") prt.write("{VARNAME} = [\n".format(VARNAME=varname)) @@ -78,6 +93,7 @@ def prt_nts(prt, nts, varname, spc=' '): prt.write("{SPC}{NT},\n".format(SPC=spc, NT=ntup)) prt.write("]\n") + def get_unique_fields(fld_lists): """Get unique namedtuple fields, despite potential duplicates in lists of fields.""" flds = [] @@ -93,6 +109,7 @@ def get_unique_fields(fld_lists): assert len(flds) == len(fld_set) return flds + # -- Internal methods ---------------------------------------------------------------- def _combine_nt_vals(lst0_lstn, flds, dflt_null): """Given a list of lists of nts, return a single namedtuple.""" @@ -110,4 +127,5 @@ def _combine_nt_vals(lst0_lstn, flds, dflt_null): vals.append(dflt_null) return vals + # Copyright (C) 2016-2018, DV Klopfenstein, H Tang. All rights reserved.