Skip to content

Commit

Permalink
Expose reunitarize_sigma to I/O interfaces.
Browse files Browse the repository at this point in the history
  • Loading branch information
SaltyChiang committed Dec 18, 2024
1 parent 7d075aa commit 0901006
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 17 deletions.
12 changes: 6 additions & 6 deletions pyquda_io/_field_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def gaugeProject(gauge: numpy.ndarray):
pass


def gaugeReunitarize(gauge: numpy.ndarray, reunitarize_sigma: bool = True):
def gaugeReunitarize(gauge: numpy.ndarray, reunitarize_sigma: float):
gauge = numpy.ascontiguousarray(gauge.transpose(5, 6, 0, 1, 2, 3, 4))
row0_abs = numpy.linalg.norm(gauge[0], axis=0)
gauge[0] /= row0_abs
Expand All @@ -182,7 +182,7 @@ def gaugeReunitarize(gauge: numpy.ndarray, reunitarize_sigma: bool = True):
row1_abs = numpy.linalg.norm(gauge[1], axis=0)
gauge[1] /= row1_abs
row2 = numpy.cross(gauge[0], gauge[1], axis=0).conjugate()
if reunitarize_sigma:
if reunitarize_sigma > 0:
assert (
MPI.COMM_WORLD.allreduce(
numpy.sqrt(
Expand All @@ -193,13 +193,13 @@ def gaugeReunitarize(gauge: numpy.ndarray, reunitarize_sigma: bool = True):
).max(),
MPI.MAX,
)
< 2e-7 # sqrt(Nc) * fp32 machine epsilon
< reunitarize_sigma
)
gauge[2] = row2
return gauge.transpose(2, 3, 4, 5, 6, 0, 1)


def gaugeReunitarizeReconstruct12(gauge: numpy.ndarray, reunitarize_sigma: bool = True):
def gaugeReunitarizeReconstruct12(gauge: numpy.ndarray, reunitarize_sigma: float):
"""gauge shape (Nd, Lt, Lz, Ly, Lx, Nc - 1, Nc)"""
gauge_ = gauge.transpose(5, 6, 0, 1, 2, 3, 4)
gauge = numpy.empty((Nc, *gauge_.shape[1:]), "<c16")
Expand All @@ -211,13 +211,13 @@ def gaugeReunitarizeReconstruct12(gauge: numpy.ndarray, reunitarize_sigma: bool
row1_abs = numpy.linalg.norm(gauge[1], axis=0)
gauge[1] /= row1_abs
row2 = numpy.cross(gauge[0], gauge[1], axis=0).conjugate()
if reunitarize_sigma:
if reunitarize_sigma > 0:
assert (
MPI.COMM_WORLD.allreduce(
numpy.sqrt((1 - row0_abs) ** 2 + numpy.abs(row0_row1) ** 2 + (1 - row1_abs) ** 2).max(),
MPI.MAX,
)
< 2e-7 # sqrt(Nc) * fp32 machine epsilon
< reunitarize_sigma
)
gauge[2] = row2
return gauge.transpose(2, 3, 4, 5, 6, 0, 1)
Expand Down
4 changes: 2 additions & 2 deletions pyquda_io/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def checksum_qio(latt_size: List[int], grid_size: List[int], data):
return sum29, sum31


def readQIOGauge(filename: str, grid_size: List[int], checksum: bool = True):
def readQIOGauge(filename: str, grid_size: List[int], checksum: bool = True, reunitarize_sigma: float = 2e-7):
from .lime import Lime

lime = Lime(filename)
Expand Down Expand Up @@ -68,7 +68,7 @@ def readQIOGauge(filename: str, grid_size: List[int], checksum: bool = True):
), f"Bad checksum for {filename}"
gauge = gauge.transpose(4, 0, 1, 2, 3, 5, 6).astype("<c16")
if precision == 4:
gauge = gaugeReunitarize(gauge)
gauge = gaugeReunitarize(gauge, reunitarize_sigma) # 2e-7: Nc**0.5 * 1.1920929e-07
return latt_size, gauge


Expand Down
4 changes: 2 additions & 2 deletions pyquda_io/milc.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def checksum_qio(latt_size: List[int], grid_size: List[int], data):
return sum29, sum31


def readGauge(filename: str, grid_size: List[int], checksum: bool = True, reunitarize_sigma: bool = True):
def readGauge(filename: str, grid_size: List[int], checksum: bool = True, reunitarize_sigma: float = 2e-7):
filename = path.expanduser(path.expandvars(filename))
with open(filename, "rb") as f:
magic = f.read(4)
Expand All @@ -79,7 +79,7 @@ def readGauge(filename: str, grid_size: List[int], checksum: bool = True, reunit
sum31,
), f"Bad checksum for {filename}"
gauge = gauge.transpose(4, 0, 1, 2, 3, 5, 6).astype("<c16")
gauge = gaugeReunitarize(gauge, reunitarize_sigma)
gauge = gaugeReunitarize(gauge, reunitarize_sigma) # 2e-7: Nc**0.5 * 1.1920929e-07
return latt_size, gauge


Expand Down
6 changes: 3 additions & 3 deletions pyquda_io/nersc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def readGauge(
checksum: bool = True,
plaquette: bool = True,
link_trace: bool = True,
reunitarize_sigma: bool = True,
reunitarize_sigma: float = 2e-7,
):
filename = path.expanduser(path.expandvars(filename))
header: Dict[str, str] = {}
Expand Down Expand Up @@ -63,15 +63,15 @@ def readGauge(
assert checksum_nersc(gauge.reshape(-1)) == int(header["CHECKSUM"], 16), f"Bad checksum for {filename}"
gauge = gauge.transpose(4, 0, 1, 2, 3, 5, 6).astype("<c16")
if float_nbytes == 4:
gauge = gaugeReunitarize(gauge, reunitarize_sigma)
gauge = gaugeReunitarize(gauge, reunitarize_sigma) # 2e-7: Nc**0.5 * 1.1920929e-07
elif header["DATATYPE"] == "4D_SU3_GAUGE":
gauge = readMPIFile(filename, dtype, offset, (Lt, Lz, Ly, Lx, Nd, Nc - 1, Nc), (3, 2, 1, 0), grid_size)
gauge = gauge.astype(f"<c{2 * float_nbytes}")
if checksum:
assert checksum_nersc(gauge.reshape(-1)) == int(header["CHECKSUM"], 16), f"Bad checksum for {filename}"
gauge = gauge.transpose(4, 0, 1, 2, 3, 5, 6).astype("<c16")
if float_nbytes == 4:
gauge = gaugeReunitarizeReconstruct12(gauge, reunitarize_sigma)
gauge = gaugeReunitarizeReconstruct12(gauge, reunitarize_sigma) # 2e-7: Nc**0.5 * 1.1920929e-07
elif float_nbytes == 8:
gauge = gaugeReconstruct12(gauge)
else:
Expand Down
8 changes: 4 additions & 4 deletions pyquda_utils/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ def rotateToDeGrandRossi(propagator: LatticePropagator):
)


def readChromaQIOGauge(filename: str, checksum: bool = True):
def readChromaQIOGauge(filename: str, checksum: bool = True, reunitarize_sigma: float = 2e-7):
from pyquda import getGridSize
from pyquda_io.chroma import readQIOGauge as read

latt_size, gauge_raw = read(filename, getGridSize(), checksum)
latt_size, gauge_raw = read(filename, getGridSize(), checksum, reunitarize_sigma)
return LatticeGauge(LatticeInfo(latt_size), evenodd(gauge_raw, [1, 2, 3, 4]))


Expand All @@ -96,7 +96,7 @@ def readChromaQIOPropagator(filename: str, checksum: bool = True):
return LatticeStaggeredPropagator(LatticeInfo(latt_size), evenodd(propagator_raw, [0, 1, 2, 3]))


def readMILCGauge(filename: str, checksum: bool = True, reunitarize_sigma: bool = True):
def readMILCGauge(filename: str, checksum: bool = True, reunitarize_sigma: float = 2e-7):
from pyquda import getGridSize
from pyquda_io.milc import readGauge as read

Expand Down Expand Up @@ -244,7 +244,7 @@ def readNERSCGauge(
checksum: bool = True,
plaquette: bool = True,
link_trace: bool = True,
reunitarize_sigma: bool = True,
reunitarize_sigma: float = 2e-7,
):
from pyquda import getGridSize
from pyquda_io.nersc import readGauge as read
Expand Down

0 comments on commit 0901006

Please sign in to comment.