diff --git a/setup.py b/setup.py index 3bf462d..666c7a9 100644 --- a/setup.py +++ b/setup.py @@ -2,22 +2,33 @@ from Cython.Build import cythonize import numpy as np import platform +import sys numpy_include = np.get_include() extra_compile_args = [] extra_link_args = [] -if platform.machine().startswith('arm'): - if platform.architecture()[0] == '32bit': - extra_compile_args.extend(["-march=armv7-a", "-mfpu=neon"]) - extra_link_args.extend(["-march=armv7-a", "-mfpu=neon"]) - else: # 64-bit ARM - extra_compile_args.extend(["-march=armv8-a+simd"]) - extra_link_args.extend(["-march=armv8-a+simd"]) -elif platform.machine() in ["x86_64", "AMD64"]: - extra_compile_args.extend(["-march=native", "-mpopcnt"]) - extra_link_args.extend(["-march=native", "-mpopcnt"]) +if platform.system() == "Darwin": + if platform.machine() == "arm64": + extra_compile_args.extend(["-arch", "arm64", "-O3", "-ffast-math"]) + extra_link_args.extend(["-arch", "arm64"]) + else: + extra_compile_args.extend(["-arch", "x86_64", "-O3", "-ffast-math"]) + extra_link_args.extend(["-arch", "x86_64"]) +elif platform.system() == "Windows": + extra_compile_args.extend(["/O2"]) +else: # Linux and others + if platform.machine().startswith("arm"): + if platform.architecture()[0] == "32bit": + extra_compile_args.extend(["-march=armv7-a", "-mfpu=neon"]) + extra_link_args.extend(["-march=armv7-a", "-mfpu=neon"]) + else: # 64-bit ARM + extra_compile_args.extend(["-march=armv8-a"]) + extra_link_args.extend(["-march=armv8-a"]) + elif platform.machine() in ["x86_64", "AMD64"]: + extra_compile_args.extend(["-march=native", "-mpopcnt"]) + extra_link_args.extend(["-march=native", "-mpopcnt"]) extra_compile_args.extend(["-O3", "-ffast-math"]) @@ -37,16 +48,16 @@ extra_link_args=extra_link_args, ), Extension( - "wordllama.algorithms.kmeans_helpers", - ["wordllama/algorithms/kmeans_helpers.pyx"], + "wordllama.algorithms.deduplicate_helpers", + ["wordllama/algorithms/deduplicate_helpers.pyx"], include_dirs=[numpy_include], define_macros=[("NPY_NO_DEPRECATED_API", "NPY_1_7_API_VERSION")], extra_compile_args=extra_compile_args, extra_link_args=extra_link_args, ), Extension( - "wordllama.algorithms.deduplicate_helpers", - ["wordllama/algorithms/deduplicate_helpers.pyx"], + "wordllama.algorithms.kmeans_helpers", + ["wordllama/algorithms/kmeans_helpers.pyx"], include_dirs=[numpy_include], define_macros=[("NPY_NO_DEPRECATED_API", "NPY_1_7_API_VERSION")], extra_compile_args=extra_compile_args, @@ -58,10 +69,13 @@ name="Text Processing Tools", ext_modules=cythonize( extensions, - compiler_directives={"language_level": "3", "boundscheck": False, "wraparound": False}, - annotate=True + compiler_directives={ + "language_level": "3", + "boundscheck": False, + "wraparound": False, + }, + annotate=True, ), zip_safe=False, install_requires=["numpy"], ) -