Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ExecuTorch support #9879

Open
sicong-li-arm opened this issue Dec 19, 2024 · 1 comment
Open

Add ExecuTorch support #9879

sicong-li-arm opened this issue Dec 19, 2024 · 1 comment
Assignees

Comments

@sicong-li-arm
Copy link

sicong-li-arm commented Dec 19, 2024

🚀 The feature, motivation and pitch

Hi! I'm trying to deploy a GNN built with PyG on edge devices (mobile), and I'd like it to be supported with ExecuTorch.

Alternatives

No response

Additional context

Hi! I'm trying to deploy a GNN built with PyG (following this blog) on edge devices (mobile), and I've been exploring ExecuTorch.

However, it seems that PyG can't even be lowered to aten. And I suspect it has something to do with TorchDynamo not supporting the pyg.data.Data structure. This is because when I tried to export to ONNX using dynamo, I saw similar errors.

I attached the full code and full error messages at the end:

aten_dialect = export(simulator, example_inputs) # Error occurred on this line
graph = pyg.data.Data(...) # Error occurred on this line in the user code

I'm using PyTorch==2.5.0, PyG==2.6.1, torch_scatter=2.1.2

export_gnn_executorch.py.txt:

# -*- coding: utf-8 -*-
import torch
print(f"PyTorch has version {torch.__version__} with cuda {torch.version.cuda}")
import numpy as np
import torch_geometric as pyg

"""## GNN helper layers
"""

import math
import torch_scatter

class MLP(torch.nn.Module):
    """Multi-Layer perceptron"""
    def __init__(self, input_size, hidden_size, output_size, layers, layernorm=True):
        super().__init__()
        self.layers = torch.nn.ModuleList()
        for i in range(layers):
            self.layers.append(torch.nn.Linear(
                input_size if i == 0 else hidden_size,
                output_size if i == layers - 1 else hidden_size,
            ))
            if i != layers - 1:
                self.layers.append(torch.nn.ReLU())
        if layernorm:
            self.layers.append(torch.nn.LayerNorm(output_size))
        self.reset_parameters()

    def reset_parameters(self):
        for layer in self.layers:
            if isinstance(layer, torch.nn.Linear):
                layer.weight.data.normal_(0, 1 / math.sqrt(layer.in_features))
                layer.bias.data.fill_(0)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

def find_connectivity(positions, radius):
    """ Find all edges connecting to all nodes within the radius

    Args:
        positions (Tensor): [N_particles, N_Dim] Most recent positions
        radius (float): Radius
    
    Return:
        edge_index: [2, X], containing all X edges in the graph in the form of (source node index, target node index).
                    Node index is the same as the indices in the node positions
    """
    squared_norm = torch.sum(positions*positions, 1) # [N_particles]
    squared_norm = torch.reshape(squared_norm, [-1, 1]) # [N_particles, 1]
    distance_tensor = squared_norm - 2*torch.matmul(positions, torch.transpose(positions, 0, 1)) + torch.transpose(squared_norm, 0, 1) # [N_particles, N_particles] Pair-wise square distance matrix
    # Find index pairs where the distance is less-than or equal to the radius
    # equivalent to torch.where (but as_tuple=false)
    edge_index = torch.nonzero(torch.less_equal(distance_tensor, radius * radius), as_tuple=False)
    # Expected shape: [2, X]
    return edge_index.T


def preprocess(particle_type, position_seq, metadata):
    """Preprocess a trajectory and construct the graph
    particle_type: [N, dtype=int64], particle type
    position_seq: [N, Sequence length, dim, dtype=float32], position sequence
    metadata: dict, meta data
    """
    # calculate the velocities of particles
    recent_position = position_seq[:, -1]
    velocity_seq = position_seq[:, 1:] - position_seq[:, :-1]

    # construct the graph based on the distances between particles
    # edge_index = pyg.nn.radius_graph(recent_position, metadata["default_connectivity_radius"], loop=True, max_num_neighbors=n_particle)
    edge_index = find_connectivity(recent_position, metadata["default_connectivity_radius"])

    # node-level features: velocity, distance to the boundary
    boundary = torch.tensor(metadata["bounds"])
    distance_to_lower_boundary = recent_position - boundary[:, 0]
    distance_to_upper_boundary = boundary[:, 1] - recent_position
    distance_to_boundary = torch.cat((distance_to_lower_boundary, distance_to_upper_boundary), dim=-1)
    distance_to_boundary = torch.clip(distance_to_boundary / metadata["default_connectivity_radius"], -1.0, 1.0)

    # edge-level features: displacement, distance
    dim = recent_position.size(-1)
    edge_displacement = (torch.gather(recent_position, dim=0, index=edge_index[0].unsqueeze(-1).expand(-1, dim)) -
                   torch.gather(recent_position, dim=0, index=edge_index[1].unsqueeze(-1).expand(-1, dim)))
    edge_displacement /= metadata["default_connectivity_radius"]
    edge_distance = torch.norm(edge_displacement, dim=-1, keepdim=True)

    # return the graph with features
    graph = pyg.data.Data(
        x=particle_type,
        edge_index=edge_index,
        edge_attr=torch.cat((edge_displacement, edge_distance), dim=-1),
        y=None, # Ground truth for training
        pos=torch.cat((velocity_seq.reshape(velocity_seq.size(0), -1), distance_to_boundary), dim=-1),
        recent_position=recent_position,
        recent_velocity=velocity_seq[:, -1]
    )
    return graph

class BuildGraph(torch.nn.Module):
    """Preprocessing unit. Build graphs from positions"""
    def __init__(self, metadata):
        super().__init__()
        self.metadata = metadata

    def forward(self, particle_type, position_sequence):
        graph = preprocess(particle_type, position_sequence, self.metadata)
        return graph

class Postprocess(torch.nn.Module):
    """Preprocessing unit. Build graphs from positions"""
    def __init__(self, metadata):
        super().__init__()
        self.metadata = metadata

    def forward(self, graph, acceleration):
        acceleration = acceleration * torch.sqrt(torch.tensor(self.metadata["acc_std"]) ** 2) + torch.tensor(self.metadata["acc_mean"])

        recent_position = graph.recent_position
        recent_velocity = graph.recent_velocity
        new_velocity = recent_velocity + acceleration
        new_position = recent_position + new_velocity

        return new_position

class InteractionNetwork(pyg.nn.MessagePassing):
    """Interaction Network as proposed in this paper:
    https://proceedings.neurips.cc/paper/2016/hash/3147da8ab4a0437c15ef51a5cc7f2dc4-Abstract.html"""
    def __init__(self, hidden_size, layers):
        super().__init__()
        self.lin_edge = MLP(hidden_size * 3, hidden_size, hidden_size, layers)
        self.lin_node = MLP(hidden_size * 2, hidden_size, hidden_size, layers)

    def forward(self, x, edge_index, edge_feature):
        edge_out, aggr = self.propagate(edge_index, x=(x, x), edge_feature=edge_feature)
        node_out = self.lin_node(torch.cat((x, aggr), dim=-1))
        edge_out = edge_feature + edge_out
        node_out = x + node_out
        return node_out, edge_out

    def message(self, x_i, x_j, edge_feature):
        x = torch.cat((x_i, x_j, edge_feature), dim=-1)
        x = self.lin_edge(x)
        return x

    def aggregate(self, inputs, index, dim_size=None):
        out = torch_scatter.scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, reduce="sum")
        return (inputs, out)

"""### The GNN
"""
class LearnedSimulatorFull(torch.nn.Module):
    """Graph Network-based Simulators(GNS) full pipeline (with preprocessor and postprocessor)"""
    def __init__(
        self,
        metadata,
        hidden_size=128,
        n_mp_layers=10, # number of GNN layers
        num_particle_types=9,
        particle_type_dim=16, # embedding dimension of particle types
        dim=2, # dimension of the world, typical 2D or 3D
        window_size=5, # the model looks into W frames before the frame to be predicted
    ):
        super().__init__()
        self.window_size = window_size
        self.embed_type = torch.nn.Embedding(num_particle_types, particle_type_dim)
        self.node_in = MLP(particle_type_dim + dim * (window_size + 2), hidden_size, hidden_size, 3)
        self.edge_in = MLP(dim + 1, hidden_size, hidden_size, 3)
        self.node_out = MLP(hidden_size, hidden_size, dim, 3, layernorm=False)
        self.n_mp_layers = n_mp_layers
        self.layers = torch.nn.ModuleList([InteractionNetwork(
            hidden_size, 3
        ) for _ in range(n_mp_layers)])
        self.build_graph = BuildGraph(metadata=metadata)
        self.postproc = Postprocess(metadata=metadata)

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.embed_type.weight)

    def forward(self, particle_type, position_sequence):
        # pre-processing: graph building
        data = self.build_graph(particle_type, position_sequence)
        # node feature: combine categorial feature data.x and contiguous feature data.pos.
        embedded = self.embed_type(data.x)
        node_feature = torch.cat((embedded, data.pos), dim=-1)
        node_feature = self.node_in(node_feature)
        edge_feature = self.edge_in(data.edge_attr)
        # stack of GNN layers
        for i in range(self.n_mp_layers):
            node_feature, edge_feature = self.layers[i](node_feature, data.edge_index, edge_feature=edge_feature)
        # post-processing
        norm_acceleration = self.node_out(node_feature)
        # Post_processing
        new_pos = self.postproc(data, norm_acceleration)
        return new_pos

N_PARTICLES = 858
N_DIM = 2
metadata = {"bounds": [[0.1, 0.9], [0.1, 0.9]], "sequence_length": 1000, "default_connectivity_radius": 0.015, "dim": 2, "dt": 0.0025, "vel_mean": [-4.906372733478189e-06, -0.0003581614249505887], "vel_std": [0.0018492343327724738, 0.0018154400863548657], "acc_mean": [-1.3758095862050814e-08, 1.114232425851392e-07], "acc_std": [0.0001279824304831018, 0.0001388316140032424]}

simulator = LearnedSimulatorFull(metadata=metadata)
device = next(simulator.parameters()).device

window_size = simulator.window_size + 1

# Prepare example inputs
particle_type = np.zeros((N_PARTICLES), dtype=np.int64) 
particle_type = torch.from_numpy(particle_type)
position_sequence = np.random.random((N_PARTICLES, window_size, N_DIM)).astype(np.float32)
position_sequence = torch.from_numpy(position_sequence)
example_inputs = (particle_type, position_sequence)

with torch.no_grad():
    simulator.eval()
    acc = simulator(*example_inputs).cpu()

# ============== Export to Executorch (float)
import torch
import torch_scatter
from torch.export import export
from executorch.exir import to_edge

# 1. Lower to aten
aten_dialect = export(simulator, example_inputs)

# 2. to_edge: Make optimizations for Edge devices
edge_program = to_edge(aten_dialect)

# 3. to_executorch: Convert the graph to an ExecuTorch program
executorch_program = edge_program.to_executorch()

# 4. Save the compiled .pte program
with open("physics_gnn_2d.pte", "wb") as file:
    file.write(executorch_program.buffer)

export_gnn_executorch_error.txt:

PyTorch has version 2.5.0+cu124 with cuda 12.4
/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/onnxscript/converter.py:820: FutureWarning: 'onnxscript.values.Op.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instea
d.                                                                                                                                                                                                                                                                          param_schemas = callee.param_schemas()
/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/onnxscript/converter.py:820: FutureWarning: 'onnxscript.values.OnnxFunction.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signatu
re' instead.                                                                                                                                                                                                                                                                param_schemas = callee.param_schemas()
Traceback (most recent call last):
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py", line 227, in <module>
    aten_dialect = export(simulator, example_inputs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/export/__init__.py", line 270, in export
    return _export(
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/export/_trace.py", line 1017, in wrapper
    raise e
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/export/_trace.py", line 990, in wrapper
    ep = fn(*args, **kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/export/exported_program.py", line 114, in wrapper
    return fn(*args, **kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/export/_trace.py", line 1880, in _export
    export_artifact = export_func(  # type: ignore[operator]
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/export/_trace.py", line 1224, in _strict_export
    return _strict_export_lower_to_aten_ir(
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/export/_trace.py", line 1252, in _strict_export_lower_to_aten_ir
    gm_torch_level = _export_to_torch_ir(
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/export/_trace.py", line 560, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1432, in inner
    result_traced = opt_f(*args, **kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
    return fn(*args, **kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1269, in __call__
    return self._torchdynamo_orig_callable(
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 526, in __call__
    return _compile(
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 924, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 666, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
    return function(*args, **kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 699, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
    transformations(instructions, code_options)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 219, in _fn
    return fn(*args, **kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 634, in transform
    tracer.run()
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2796, in run
    super().run()
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
    return inner_fn(self, inst)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1602, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 442, in call_function
    return tx.inline_user_function_return(
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
    tracer.run()
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
    return inner_fn(self, inst)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1680, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 385, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
    tracer.run()
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
    return inner_fn(self, inst)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1602, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
    tracer.run()
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
    return inner_fn(self, inst)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1692, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/user_defined.py", line 496, in call_function
    var.call_method(tx, "__init__", args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/user_defined.py", line 788, in call_method
    return UserMethodVariable(method, self, source=source).call_function(
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 385, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
    tracer.run()
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
    return inner_fn(self, inst)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1692, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py", line 1024, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py", line 195, in call_method
    ).call_function(tx, [self.objvar] + args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
    tracer.run()
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
    return inner_fn(self, inst)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1602, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py", line 1024, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py", line 195, in call_method
    ).call_function(tx, [self.objvar] + args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
    tracer.run()
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
    return inner_fn(self, inst)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1803, in STORE_SUBSCR
    result = obj.call_method(self, "__setitem__", [key, val], {})
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py", line 1082, in call_method
    return super().call_method(tx, name, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/base.py", line 343, in call_method
    unimplemented(f"call_method {self} {name} {args} {kwargs}")
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 297, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: call_method GetAttrVariable(UserDefinedObjectVariable(Data), __dict__) __setitem__ [ConstantVariable(), UserDefinedClassVariable()] {}

from user code:
   File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py", line 185, in forward
    data = self.build_graph(particle_type, position_sequence)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py", line 108, in forward
    graph = preprocess(particle_type, position_sequence, self.metadata)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py", line 90, in preprocess
    graph = pyg.data.Data(
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch_geometric/data/data.py", line 530, in __init__
    super().__init__(tensor_attr_cls=DataTensorAttr)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch_geometric/data/feature_store.py", line 278, in __init__
    super().__init__()
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch_geometric/data/graph_store.py", line 111, in __init__
    self.__dict__['_edge_attr_cls'] = edge_attr_cls or EdgeAttr

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


export_gnn_executorch_error_detailed.txt:

V1219 16:33:18.212000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:1234] skipping: _wrapped_call_impl (reason: in skipfiles, file: /home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/nn/modules/module.py)
V1219 16:33:18.212000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:1234] skipping: _call_impl (reason: in skipfiles, file: /home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/nn/modules/module.py)
V1219 16:33:18.215000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:864] [0/0] torchdynamo start compiling forward /home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py:183, stack (elided 4 frames):
V1219 16:33:18.215000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:864] [0/0]   File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py", line 227, in <module>
V1219 16:33:18.215000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:864] [0/0]     aten_dialect = export(simulator, example_inputs)
V1219 16:33:18.215000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:864] [0/0]   File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/export/__init__.py", line 270, in export
V1219 16:33:18.215000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:864] [0/0]     return _export(
V1219 16:33:18.215000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:864] [0/0]   File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/export/_trace.py", line 990, in wrapper
V1219 16:33:18.215000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:864] [0/0]     ep = fn(*args, **kwargs)
V1219 16:33:18.215000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:864] [0/0]   File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/export/exported_program.py", line 114, in wrapper
V1219 16:33:18.215000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:864] [0/0]     return fn(*args, **kwargs)
V1219 16:33:18.215000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:864] [0/0]   File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/export/_trace.py", line 1880, in _export
V1219 16:33:18.215000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:864] [0/0]     export_artifact = export_func(  # type: ignore[operator]
V1219 16:33:18.215000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:864] [0/0]   File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/export/_trace.py", line 1224, in _strict_export
V1219 16:33:18.215000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:864] [0/0]     return _strict_export_lower_to_aten_ir(
V1219 16:33:18.215000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:864] [0/0]   File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/export/_trace.py", line 1252, in _strict_export_lower_to_aten_ir
V1219 16:33:18.215000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:864] [0/0]     gm_torch_level = _export_to_torch_ir(
V1219 16:33:18.215000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:864] [0/0]   File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/export/_trace.py", line 560, in _export_to_torch_ir
V1219 16:33:18.215000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:864] [0/0]     gm_torch_level, _ = torch._dynamo.export(
V1219 16:33:18.215000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:864] [0/0]   File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1432, in inner
V1219 16:33:18.215000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:864] [0/0]     result_traced = opt_f(*args, **kwargs)
V1219 16:33:18.215000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:864] [0/0]   File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
V1219 16:33:18.215000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:864] [0/0]     return self._call_impl(*args, **kwargs)
V1219 16:33:18.215000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:864] [0/0]   File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
V1219 16:33:18.215000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:864] [0/0]     return forward_call(*args, **kwargs)
V1219 16:33:18.215000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:864] [0/0]   File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
V1219 16:33:18.215000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:864] [0/0]     return fn(*args, **kwargs)
V1219 16:33:18.215000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:864] [0/0]   File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
V1219 16:33:18.215000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:864] [0/0]     return self._call_impl(*args, **kwargs)
V1219 16:33:18.215000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:864] [0/0] 
I1219 16:33:18.215000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/utils.py:859] [0/0] ChromiumEventLogger initialized with id d4706065-cfc9-4744-8922-8c3f00f352a2
I1219 16:33:18.231000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/logging.py:57] [0/0] Step 1: torchdynamo start tracing forward /home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py:183
V1219 16:33:18.231000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:2498] [0/0] create_env
/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/onnxscript/converter.py:820: FutureWarning: 'onnxscript.values.Op.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
  param_schemas = callee.param_schemas()
/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/onnxscript/converter.py:820: FutureWarning: 'onnxscript.values.OnnxFunction.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
  param_schemas = callee.param_schemas()
V1219 16:33:18.585000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/output_graph.py:2076] [0/0] create_graph_input L_particle_type_ L['particle_type']
V1219 16:33:18.586000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py:2702] [0/0] wrap_to_fake L['particle_type'] (858,) StatefulSymbolicContext(dynamic_sizes=[<DimDynamic.DYNAMIC: 0>], dynamic_strides=[<DimDynamic.INFER_STRIDE: 4>], constraint_sizes=[StrictMinMaxConstraint(warn_only=False, vr=VR[858, 858])], constraint_strides=[None], view_base_context=None, tensor_source=LocalSource(local_name='particle_type', cell_or_freevar=False), shape_env_to_source_to_symbol_cache={}) <class 'torch.Tensor'>
V1219 16:33:18.587000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/output_graph.py:2076] [0/0] create_graph_input L_position_sequence_ L['position_sequence']
V1219 16:33:18.587000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py:2702] [0/0] wrap_to_fake L['position_sequence'] (858, 6, 2) StatefulSymbolicContext(dynamic_sizes=[<DimDynamic.DYNAMIC: 0>, <DimDynamic.DYNAMIC: 0>, <DimDynamic.DYNAMIC: 0>], dynamic_strides=[<DimDynamic.INFER_STRIDE: 4>, <DimDynamic.INFER_STRIDE: 4>, <DimDynamic.INFER_STRIDE: 4>], constraint_sizes=[StrictMinMaxConstraint(warn_only=False, vr=VR[858, 858]), StrictMinMaxConstraint(warn_only=False, vr=VR[6, 6]), StrictMinMaxConstraint(warn_only=False, vr=VR[2, 2])], constraint_strides=[None, None, None], view_base_context=None, tensor_source=LocalSource(local_name='position_sequence', cell_or_freevar=False), shape_env_to_source_to_symbol_cache={}) <class 'torch.Tensor'>
V1219 16:33:18.588000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py:185 in forward (LearnedSimulatorFull.forward)
V1219 16:33:18.588000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]             data = self.build_graph(particle_type, position_sequence)
V1219 16:33:18.591000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST self []
V1219 16:33:18.591000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR build_graph [NNModuleVariable()]
V1219 16:33:18.592000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST particle_type [NNModuleVariable()]
V1219 16:33:18.592000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST position_sequence [NNModuleVariable(), TensorVariable()]
V1219 16:33:18.592000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION 2 [NNModuleVariable(), TensorVariable(), TensorVariable()]
V1219 16:33:18.593000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:3099] [0/0] INLINING <code object _call_impl at 0x78f885d08f50, file "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1740>, inlined according trace_rules.lookup MOD_INLINELIST
V1219 16:33:18.593000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/nn/modules/module.py:1741 in _call_impl (Module._call_impl) (inline depth: 1)
V1219 16:33:18.593000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]             forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)
V1219 16:33:18.612000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_GLOBAL torch []
V1219 16:33:18.613000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR _C [PythonModuleVariable(<module 'torch' from '/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/__init__.py'>)]
V1219 16:33:18.613000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR _get_tracing_state [PythonModuleVariable(<module 'torch._C' from '/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_C.cpython-310-x86_64-linux-gnu.so'>)]
V1219 16:33:18.613000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION 0 [TorchInGraphFunctionVariable(<built-in method _get_tracing_state of PyCapsule object at 0x78f8d131b000>)]
V1219 16:33:18.613000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE POP_JUMP_IF_FALSE 16 [ConstantVariable()]
V1219 16:33:18.613000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_DEREF self []
V1219 16:33:18.614000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR forward [NNModuleVariable()]
V1219 16:33:18.614000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE STORE_DEREF forward_call [UserMethodVariable()]
V1219 16:33:18.614000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/nn/modules/module.py:1744 in _call_impl (Module._call_impl) (inline depth: 1)
V1219 16:33:18.614000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]             if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
V1219 16:33:18.614000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_DEREF self []
V1219 16:33:18.614000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR _backward_hooks [NNModuleVariable()]
V1219 16:33:18.614000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE POP_JUMP_IF_TRUE 76 [ConstDictVariable()]
V1219 16:33:18.614000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_DEREF self []
V1219 16:33:18.614000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR _backward_pre_hooks [NNModuleVariable()]
V1219 16:33:18.614000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE POP_JUMP_IF_TRUE 76 [ConstDictVariable()]
V1219 16:33:18.614000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_DEREF self []
V1219 16:33:18.614000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR _forward_hooks [NNModuleVariable()]
V1219 16:33:18.615000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE POP_JUMP_IF_TRUE 76 [ConstDictVariable()]
V1219 16:33:18.615000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_DEREF self []
V1219 16:33:18.615000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR _forward_pre_hooks [NNModuleVariable()]
V1219 16:33:18.615000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE POP_JUMP_IF_TRUE 76 [ConstDictVariable()]
V1219 16:33:18.615000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/nn/modules/module.py:1745 in _call_impl (Module._call_impl) (inline depth: 1)
V1219 16:33:18.615000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]                     or _global_backward_pre_hooks or _global_backward_hooks
V1219 16:33:18.615000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_GLOBAL _global_backward_pre_hooks []
V1219 16:33:18.615000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/nn/modules/module.py:1744 in _call_impl (Module._call_impl) (inline depth: 1)
V1219 16:33:18.615000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]             if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
V1219 16:33:18.615000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE POP_JUMP_IF_TRUE 76 [ConstDictVariable()]
V1219 16:33:18.615000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/nn/modules/module.py:1745 in _call_impl (Module._call_impl) (inline depth: 1)
V1219 16:33:18.615000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]                     or _global_backward_pre_hooks or _global_backward_hooks
V1219 16:33:18.615000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_GLOBAL _global_backward_hooks []
V1219 16:33:18.616000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/nn/modules/module.py:1744 in _call_impl (Module._call_impl) (inline depth: 1)
V1219 16:33:18.616000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]             if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
V1219 16:33:18.616000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE POP_JUMP_IF_TRUE 76 [ConstDictVariable()]
V1219 16:33:18.616000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/nn/modules/module.py:1746 in _call_impl (Module._call_impl) (inline depth: 1)
V1219 16:33:18.616000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]                     or _global_forward_hooks or _global_forward_pre_hooks):
V1219 16:33:18.616000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_GLOBAL _global_forward_hooks []
V1219 16:33:18.616000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/nn/modules/module.py:1744 in _call_impl (Module._call_impl) (inline depth: 1)
V1219 16:33:18.616000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]             if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
V1219 16:33:18.616000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE POP_JUMP_IF_TRUE 76 [ConstDictVariable()]
V1219 16:33:18.616000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/nn/modules/module.py:1746 in _call_impl (Module._call_impl) (inline depth: 1)
V1219 16:33:18.616000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]                     or _global_forward_hooks or _global_forward_pre_hooks):
V1219 16:33:18.616000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_GLOBAL _global_forward_pre_hooks []
V1219 16:33:18.616000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/nn/modules/module.py:1744 in _call_impl (Module._call_impl) (inline depth: 1)
V1219 16:33:18.616000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]             if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
V1219 16:33:18.616000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE POP_JUMP_IF_TRUE 76 [ConstDictVariable()]
V1219 16:33:18.616000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/nn/modules/module.py:1747 in _call_impl (Module._call_impl) (inline depth: 1)
V1219 16:33:18.616000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]                 return forward_call(*args, **kwargs)
V1219 16:33:18.616000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_DEREF forward_call []
V1219 16:33:18.617000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_DEREF args [UserMethodVariable()]
V1219 16:33:18.617000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BUILD_MAP 0 [UserMethodVariable(), TupleVariable(length=2)]
V1219 16:33:18.617000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_DEREF kwargs [UserMethodVariable(), TupleVariable(length=2), ConstDictVariable()]
V1219 16:33:18.617000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE DICT_MERGE 1 [UserMethodVariable(), TupleVariable(length=2), ConstDictVariable(), ConstDictVariable()]
V1219 16:33:18.617000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION_EX 1 [UserMethodVariable(), TupleVariable(length=2), ConstDictVariable()]
V1219 16:33:18.617000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:3099] [0/0] INLINING <code object forward at 0x78f8d25a8190, file "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py", line 107>, inlined according trace_rules.lookup inlined by default
V1219 16:33:18.617000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py:108 in forward (BuildGraph.forward) (inline depth: 2)
V1219 16:33:18.617000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]             graph = preprocess(particle_type, position_sequence, self.metadata)
V1219 16:33:18.617000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_GLOBAL preprocess []
V1219 16:33:18.617000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST particle_type [UserFunctionVariable()]
V1219 16:33:18.618000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST position_sequence [UserFunctionVariable(), TensorVariable()]
V1219 16:33:18.618000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST self [UserFunctionVariable(), TensorVariable(), TensorVariable()]
V1219 16:33:18.618000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR metadata [UserFunctionVariable(), TensorVariable(), TensorVariable(), NNModuleVariable()]
V1219 16:33:18.618000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION 3 [UserFunctionVariable(), TensorVariable(), TensorVariable(), ConstDictVariable()]
V1219 16:33:18.618000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:3099] [0/0] INLINING <code object preprocess at 0x78f8d25a8030, file "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py", line 61>, inlined according trace_rules.lookup inlined by default
V1219 16:33:18.619000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py:68 in preprocess (preprocess) (inline depth: 3)
V1219 16:33:18.619000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]         recent_position = position_seq[:, -1]
V1219 16:33:18.619000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST position_seq []
V1219 16:33:18.619000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST None [TensorVariable()]
V1219 16:33:18.619000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST None [TensorVariable(), ConstantVariable()]
V1219 16:33:18.619000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BUILD_SLICE 2 [TensorVariable(), ConstantVariable(), ConstantVariable()]
V1219 16:33:18.619000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST -1 [TensorVariable(), SliceVariable()]
V1219 16:33:18.619000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BUILD_TUPLE 2 [TensorVariable(), SliceVariable(), ConstantVariable()]
V1219 16:33:18.619000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BINARY_SUBSCR None [TensorVariable(), TupleVariable(length=2)]
V1219 16:33:18.621000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE STORE_FAST recent_position [TensorVariable()]
V1219 16:33:18.621000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py:69 in preprocess (preprocess) (inline depth: 3)
V1219 16:33:18.621000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]         velocity_seq = position_seq[:, 1:] - position_seq[:, :-1]
V1219 16:33:18.621000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST position_seq []
V1219 16:33:18.621000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST None [TensorVariable()]
V1219 16:33:18.622000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST None [TensorVariable(), ConstantVariable()]
V1219 16:33:18.622000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BUILD_SLICE 2 [TensorVariable(), ConstantVariable(), ConstantVariable()]
V1219 16:33:18.622000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST 1 [TensorVariable(), SliceVariable()]
V1219 16:33:18.622000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST None [TensorVariable(), SliceVariable(), ConstantVariable()]
V1219 16:33:18.622000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BUILD_SLICE 2 [TensorVariable(), SliceVariable(), ConstantVariable(), ConstantVariable()]
V1219 16:33:18.622000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BUILD_TUPLE 2 [TensorVariable(), SliceVariable(), SliceVariable()]
V1219 16:33:18.622000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BINARY_SUBSCR None [TensorVariable(), TupleVariable(length=2)]
V1219 16:33:18.623000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST position_seq [TensorVariable()]
V1219 16:33:18.623000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST None [TensorVariable(), TensorVariable()]
V1219 16:33:18.623000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST None [TensorVariable(), TensorVariable(), ConstantVariable()]
V1219 16:33:18.623000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BUILD_SLICE 2 [TensorVariable(), TensorVariable(), ConstantVariable(), ConstantVariable()]
V1219 16:33:18.623000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST None [TensorVariable(), TensorVariable(), SliceVariable()]
V1219 16:33:18.624000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST -1 [TensorVariable(), TensorVariable(), SliceVariable(), ConstantVariable()]
V1219 16:33:18.624000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BUILD_SLICE 2 [TensorVariable(), TensorVariable(), SliceVariable(), ConstantVariable(), ConstantVariable()]
V1219 16:33:18.624000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BUILD_TUPLE 2 [TensorVariable(), TensorVariable(), SliceVariable(), SliceVariable()]
V1219 16:33:18.624000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BINARY_SUBSCR None [TensorVariable(), TensorVariable(), TupleVariable(length=2)]
V1219 16:33:18.625000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BINARY_SUBTRACT None [TensorVariable(), TensorVariable()]
V1219 16:33:18.626000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE STORE_FAST velocity_seq [TensorVariable()]
V1219 16:33:18.626000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py:73 in preprocess (preprocess) (inline depth: 3)
V1219 16:33:18.626000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]         edge_index = find_connectivity(recent_position, metadata["default_connectivity_radius"])
V1219 16:33:18.626000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_GLOBAL find_connectivity []
V1219 16:33:18.626000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST recent_position [UserFunctionVariable()]
V1219 16:33:18.626000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST metadata [UserFunctionVariable(), TensorVariable()]
V1219 16:33:18.627000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST default_connectivity_radius [UserFunctionVariable(), TensorVariable(), ConstDictVariable()]
V1219 16:33:18.627000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BINARY_SUBSCR None [UserFunctionVariable(), TensorVariable(), ConstDictVariable(), ConstantVariable()]
V1219 16:33:18.627000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION 2 [UserFunctionVariable(), TensorVariable(), LazyVariableTracker()]
V1219 16:33:18.627000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:3099] [0/0] INLINING <code object find_connectivity at 0x78f8d257fec0, file "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py", line 40>, inlined according trace_rules.lookup inlined by default
V1219 16:33:18.627000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py:51 in find_connectivity (find_connectivity) (inline depth: 4)
V1219 16:33:18.627000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]         squared_norm = torch.sum(positions*positions, 1) # [N_particles]
V1219 16:33:18.627000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_GLOBAL torch []
V1219 16:33:18.627000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR sum [PythonModuleVariable(<module 'torch' from '/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/__init__.py'>)]
V1219 16:33:18.628000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST positions [TorchInGraphFunctionVariable(<built-in method sum of type object at 0x78f8d10bf1c0>)]
V1219 16:33:18.628000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST positions [TorchInGraphFunctionVariable(<built-in method sum of type object at 0x78f8d10bf1c0>), TensorVariable()]
V1219 16:33:18.628000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BINARY_MULTIPLY None [TorchInGraphFunctionVariable(<built-in method sum of type object at 0x78f8d10bf1c0>), TensorVariable(), TensorVariable()]
V1219 16:33:18.629000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST 1 [TorchInGraphFunctionVariable(<built-in method sum of type object at 0x78f8d10bf1c0>), TensorVariable()]
V1219 16:33:18.629000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION 2 [TorchInGraphFunctionVariable(<built-in method sum of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstantVariable()]
V1219 16:33:18.630000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE STORE_FAST squared_norm [TensorVariable()]
V1219 16:33:18.630000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py:52 in find_connectivity (find_connectivity) (inline depth: 4)
V1219 16:33:18.630000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]         squared_norm = torch.reshape(squared_norm, [-1, 1]) # [N_particles, 1]
V1219 16:33:18.630000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_GLOBAL torch []
V1219 16:33:18.630000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR reshape [PythonModuleVariable(<module 'torch' from '/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/__init__.py'>)]
V1219 16:33:18.630000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST squared_norm [TorchInGraphFunctionVariable(<built-in method reshape of type object at 0x78f8d10bf1c0>)]
V1219 16:33:18.630000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST -1 [TorchInGraphFunctionVariable(<built-in method reshape of type object at 0x78f8d10bf1c0>), TensorVariable()]
V1219 16:33:18.630000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST 1 [TorchInGraphFunctionVariable(<built-in method reshape of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstantVariable()]
V1219 16:33:18.630000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BUILD_LIST 2 [TorchInGraphFunctionVariable(<built-in method reshape of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstantVariable(), ConstantVariable()]
V1219 16:33:18.630000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION 2 [TorchInGraphFunctionVariable(<built-in method reshape of type object at 0x78f8d10bf1c0>), TensorVariable(), ListVariable(length=2)]
V1219 16:33:18.631000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE STORE_FAST squared_norm [TensorVariable()]
V1219 16:33:18.632000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py:53 in find_connectivity (find_connectivity) (inline depth: 4)
V1219 16:33:18.632000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]         distance_tensor = squared_norm - 2*torch.matmul(positions, torch.transpose(positions, 0, 1)) + torch.transpose(squared_norm, 0, 1) # [N_particles, N_particles] Pair-wise square distance matrix
V1219 16:33:18.632000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST squared_norm []
V1219 16:33:18.632000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST 2 [TensorVariable()]
V1219 16:33:18.632000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_GLOBAL torch [TensorVariable(), ConstantVariable()]
V1219 16:33:18.632000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR matmul [TensorVariable(), ConstantVariable(), PythonModuleVariable(<module 'torch' from '/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/__init__.py'>)]
V1219 16:33:18.632000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST positions [TensorVariable(), ConstantVariable(), TorchInGraphFunctionVariable(<built-in method matmul of type object at 0x78f8d10bf1c0>)]
V1219 16:33:18.632000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_GLOBAL torch [TensorVariable(), ConstantVariable(), TorchInGraphFunctionVariable(<built-in method matmul of type object at 0x78f8d10bf1c0>), TensorVariable()]
V1219 16:33:18.632000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR transpose [TensorVariable(), ConstantVariable(), TorchInGraphFunctionVariable(<built-in method matmul of type object at 0x78f8d10bf1c0>), TensorVariable(), PythonModuleVariable(<module 'torch' from '/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/__init__.py'>)]
V1219 16:33:18.632000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST positions [TensorVariable(), ConstantVariable(), TorchInGraphFunctionVariable(<built-in method matmul of type object at 0x78f8d10bf1c0>), TensorVariable(), TorchInGraphFunctionVariable(<built-in method transpose of type object at 0x78f8d10bf1c0>)]
V1219 16:33:18.633000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST 0 [TensorVariable(), ConstantVariable(), TorchInGraphFunctionVariable(<built-in method matmul of type object at 0x78f8d10bf1c0>), TensorVariable(), TorchInGraphFunctionVariable(<built-in method transpose of type object at 0x78f8d10bf1c0>), TensorVariable()]
V1219 16:33:18.633000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST 1 [TensorVariable(), ConstantVariable(), TorchInGraphFunctionVariable(<built-in method matmul of type object at 0x78f8d10bf1c0>), TensorVariable(), TorchInGraphFunctionVariable(<built-in method transpose of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstantVariable()]
V1219 16:33:18.633000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION 3 [TensorVariable(), ConstantVariable(), TorchInGraphFunctionVariable(<built-in method matmul of type object at 0x78f8d10bf1c0>), TensorVariable(), TorchInGraphFunctionVariable(<built-in method transpose of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstantVariable(), ConstantVariable()]
V1219 16:33:18.634000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION 2 [TensorVariable(), ConstantVariable(), TorchInGraphFunctionVariable(<built-in method matmul of type object at 0x78f8d10bf1c0>), TensorVariable(), TensorVariable()]
V1219 16:33:18.635000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BINARY_MULTIPLY None [TensorVariable(), ConstantVariable(), TensorVariable()]
V1219 16:33:18.635000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BINARY_SUBTRACT None [TensorVariable(), TensorVariable()]
V1219 16:33:18.636000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_GLOBAL torch [TensorVariable()]
V1219 16:33:18.636000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR transpose [TensorVariable(), PythonModuleVariable(<module 'torch' from '/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/__init__.py'>)]
V1219 16:33:18.636000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST squared_norm [TensorVariable(), TorchInGraphFunctionVariable(<built-in method transpose of type object at 0x78f8d10bf1c0>)]
V1219 16:33:18.637000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST 0 [TensorVariable(), TorchInGraphFunctionVariable(<built-in method transpose of type object at 0x78f8d10bf1c0>), TensorVariable()]
V1219 16:33:18.637000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST 1 [TensorVariable(), TorchInGraphFunctionVariable(<built-in method transpose of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstantVariable()]
V1219 16:33:18.637000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION 3 [TensorVariable(), TorchInGraphFunctionVariable(<built-in method transpose of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstantVariable(), ConstantVariable()]
V1219 16:33:18.637000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BINARY_ADD None [TensorVariable(), TensorVariable()]
V1219 16:33:18.638000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE STORE_FAST distance_tensor [TensorVariable()]
V1219 16:33:18.638000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py:56 in find_connectivity (find_connectivity) (inline depth: 4)
V1219 16:33:18.638000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]         edge_index = torch.nonzero(torch.less_equal(distance_tensor, radius * radius), as_tuple=False)
V1219 16:33:18.639000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_GLOBAL torch []
V1219 16:33:18.639000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR nonzero [PythonModuleVariable(<module 'torch' from '/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/__init__.py'>)]
V1219 16:33:18.639000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_GLOBAL torch [TorchInGraphFunctionVariable(<built-in method nonzero of type object at 0x78f8d10bf1c0>)]
V1219 16:33:18.639000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR less_equal [TorchInGraphFunctionVariable(<built-in method nonzero of type object at 0x78f8d10bf1c0>), PythonModuleVariable(<module 'torch' from '/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/__init__.py'>)]
V1219 16:33:18.639000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST distance_tensor [TorchInGraphFunctionVariable(<built-in method nonzero of type object at 0x78f8d10bf1c0>), TorchInGraphFunctionVariable(<built-in method less_equal of type object at 0x78f8d10bf1c0>)]
V1219 16:33:18.639000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST radius [TorchInGraphFunctionVariable(<built-in method nonzero of type object at 0x78f8d10bf1c0>), TorchInGraphFunctionVariable(<built-in method less_equal of type object at 0x78f8d10bf1c0>), TensorVariable()]
V1219 16:33:18.639000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST radius [TorchInGraphFunctionVariable(<built-in method nonzero of type object at 0x78f8d10bf1c0>), TorchInGraphFunctionVariable(<built-in method less_equal of type object at 0x78f8d10bf1c0>), TensorVariable(), LazyVariableTracker()]
V1219 16:33:18.639000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BINARY_MULTIPLY None [TorchInGraphFunctionVariable(<built-in method nonzero of type object at 0x78f8d10bf1c0>), TorchInGraphFunctionVariable(<built-in method less_equal of type object at 0x78f8d10bf1c0>), TensorVariable(), LazyVariableTracker(), LazyVariableTracker()]
V1219 16:33:18.640000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION 2 [TorchInGraphFunctionVariable(<built-in method nonzero of type object at 0x78f8d10bf1c0>), TorchInGraphFunctionVariable(<built-in method less_equal of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstantVariable()]
V1219 16:33:18.641000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST False [TorchInGraphFunctionVariable(<built-in method nonzero of type object at 0x78f8d10bf1c0>), TensorVariable()]
V1219 16:33:18.641000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST ('as_tuple',) [TorchInGraphFunctionVariable(<built-in method nonzero of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstantVariable()]
V1219 16:33:18.641000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION_KW 2 [TorchInGraphFunctionVariable(<built-in method nonzero of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstantVariable(), TupleVariable(length=1)]
I1219 16:33:18.642000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:3317] [0/0] create_unbacked_symint u0 [-int_oo, int_oo] at export_gnn_executorch.py:56 in find_connectivity (_subclasses/fake_impls.py:426 in nonzero)
V1219 16:33:18.643000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:4734] [0/0] _update_var_to_range u0 = VR[0, 736164] (update)
I1219 16:33:18.643000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5481] [0/0] constrain_symbol_range u0 [0, 736164]
V1219 16:33:18.659000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5358] [0/0] runtime_assert u0 >= 0 == True [statically known]
V1219 16:33:18.661000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5358] [0/0] runtime_assert u0 >= 0 == True [statically known]
V1219 16:33:18.662000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5201] [0/0] eval Eq(u0, 0) == False [statically known]
V1219 16:33:18.663000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5358] [0/0] runtime_assert u0 >= 0 == True [statically known]
I1219 16:33:18.663000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:604] [0/0] compute_unbacked_bindings [u0]
V1219 16:33:18.664000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE STORE_FAST edge_index [TensorVariable()]
V1219 16:33:18.664000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py:58 in find_connectivity (find_connectivity) (inline depth: 4)
V1219 16:33:18.664000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]         return edge_index.T
V1219 16:33:18.664000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST edge_index []
V1219 16:33:18.664000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR T [TensorVariable()]
V1219 16:33:18.665000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE RETURN_VALUE None [TensorVariable()]
V1219 16:33:18.665000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:3164] [0/0] DONE INLINING <code object find_connectivity at 0x78f8d257fec0, file "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py", line 40>
V1219 16:33:18.665000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE STORE_FAST edge_index [TensorVariable()]
V1219 16:33:18.665000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py:76 in preprocess (preprocess) (inline depth: 3)
V1219 16:33:18.665000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]         boundary = torch.tensor(metadata["bounds"])
V1219 16:33:18.665000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_GLOBAL torch []
V1219 16:33:18.665000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR tensor [PythonModuleVariable(<module 'torch' from '/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/__init__.py'>)]
V1219 16:33:18.666000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST metadata [TorchInGraphFunctionVariable(<built-in method tensor of type object at 0x78f8d10bf1c0>)]
V1219 16:33:18.666000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST bounds [TorchInGraphFunctionVariable(<built-in method tensor of type object at 0x78f8d10bf1c0>), ConstDictVariable()]
V1219 16:33:18.666000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BINARY_SUBSCR None [TorchInGraphFunctionVariable(<built-in method tensor of type object at 0x78f8d10bf1c0>), ConstDictVariable(), ConstantVariable()]
V1219 16:33:18.666000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION 1 [TorchInGraphFunctionVariable(<built-in method tensor of type object at 0x78f8d10bf1c0>), LazyVariableTracker()]
V1219 16:33:18.667000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE STORE_FAST boundary [TensorVariable()]
V1219 16:33:18.667000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py:77 in preprocess (preprocess) (inline depth: 3)
V1219 16:33:18.667000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]         distance_to_lower_boundary = recent_position - boundary[:, 0]
V1219 16:33:18.667000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST recent_position []
V1219 16:33:18.667000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST boundary [TensorVariable()]
V1219 16:33:18.667000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST None [TensorVariable(), TensorVariable()]
V1219 16:33:18.668000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST None [TensorVariable(), TensorVariable(), ConstantVariable()]
V1219 16:33:18.668000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BUILD_SLICE 2 [TensorVariable(), TensorVariable(), ConstantVariable(), ConstantVariable()]
V1219 16:33:18.668000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST 0 [TensorVariable(), TensorVariable(), SliceVariable()]
V1219 16:33:18.668000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BUILD_TUPLE 2 [TensorVariable(), TensorVariable(), SliceVariable(), ConstantVariable()]
V1219 16:33:18.668000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BINARY_SUBSCR None [TensorVariable(), TensorVariable(), TupleVariable(length=2)]
V1219 16:33:18.669000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BINARY_SUBTRACT None [TensorVariable(), TensorVariable()]
V1219 16:33:18.670000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE STORE_FAST distance_to_lower_boundary [TensorVariable()]
V1219 16:33:18.670000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py:78 in preprocess (preprocess) (inline depth: 3)
V1219 16:33:18.670000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]         distance_to_upper_boundary = boundary[:, 1] - recent_position
V1219 16:33:18.670000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST boundary []
V1219 16:33:18.670000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST None [TensorVariable()]
V1219 16:33:18.670000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST None [TensorVariable(), ConstantVariable()]
V1219 16:33:18.670000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BUILD_SLICE 2 [TensorVariable(), ConstantVariable(), ConstantVariable()]
V1219 16:33:18.671000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST 1 [TensorVariable(), SliceVariable()]
V1219 16:33:18.671000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BUILD_TUPLE 2 [TensorVariable(), SliceVariable(), ConstantVariable()]
V1219 16:33:18.671000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BINARY_SUBSCR None [TensorVariable(), TupleVariable(length=2)]
V1219 16:33:18.672000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST recent_position [TensorVariable()]
V1219 16:33:18.672000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BINARY_SUBTRACT None [TensorVariable(), TensorVariable()]
V1219 16:33:18.672000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE STORE_FAST distance_to_upper_boundary [TensorVariable()]
V1219 16:33:18.673000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py:79 in preprocess (preprocess) (inline depth: 3)
V1219 16:33:18.673000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]         distance_to_boundary = torch.cat((distance_to_lower_boundary, distance_to_upper_boundary), dim=-1)
V1219 16:33:18.673000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_GLOBAL torch []
V1219 16:33:18.673000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR cat [PythonModuleVariable(<module 'torch' from '/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/__init__.py'>)]
V1219 16:33:18.673000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST distance_to_lower_boundary [TorchInGraphFunctionVariable(<built-in method cat of type object at 0x78f8d10bf1c0>)]
V1219 16:33:18.673000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST distance_to_upper_boundary [TorchInGraphFunctionVariable(<built-in method cat of type object at 0x78f8d10bf1c0>), TensorVariable()]
V1219 16:33:18.673000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BUILD_TUPLE 2 [TorchInGraphFunctionVariable(<built-in method cat of type object at 0x78f8d10bf1c0>), TensorVariable(), TensorVariable()]
V1219 16:33:18.673000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST -1 [TorchInGraphFunctionVariable(<built-in method cat of type object at 0x78f8d10bf1c0>), TupleVariable(length=2)]
V1219 16:33:18.673000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST ('dim',) [TorchInGraphFunctionVariable(<built-in method cat of type object at 0x78f8d10bf1c0>), TupleVariable(length=2), ConstantVariable()]
V1219 16:33:18.673000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION_KW 2 [TorchInGraphFunctionVariable(<built-in method cat of type object at 0x78f8d10bf1c0>), TupleVariable(length=2), ConstantVariable(), TupleVariable(length=1)]
V1219 16:33:18.674000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE STORE_FAST distance_to_boundary [TensorVariable()]
V1219 16:33:18.675000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py:80 in preprocess (preprocess) (inline depth: 3)
V1219 16:33:18.675000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]         distance_to_boundary = torch.clip(distance_to_boundary / metadata["default_connectivity_radius"], -1.0, 1.0)
V1219 16:33:18.675000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_GLOBAL torch []
V1219 16:33:18.675000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR clip [PythonModuleVariable(<module 'torch' from '/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/__init__.py'>)]
V1219 16:33:18.675000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST distance_to_boundary [TorchInGraphFunctionVariable(<built-in method clip of type object at 0x78f8d10bf1c0>)]
V1219 16:33:18.675000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST metadata [TorchInGraphFunctionVariable(<built-in method clip of type object at 0x78f8d10bf1c0>), TensorVariable()]
V1219 16:33:18.675000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST default_connectivity_radius [TorchInGraphFunctionVariable(<built-in method clip of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstDictVariable()]
V1219 16:33:18.675000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BINARY_SUBSCR None [TorchInGraphFunctionVariable(<built-in method clip of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstDictVariable(), ConstantVariable()]
V1219 16:33:18.675000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BINARY_TRUE_DIVIDE None [TorchInGraphFunctionVariable(<built-in method clip of type object at 0x78f8d10bf1c0>), TensorVariable(), LazyVariableTracker()]
V1219 16:33:18.676000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST -1.0 [TorchInGraphFunctionVariable(<built-in method clip of type object at 0x78f8d10bf1c0>), TensorVariable()]
V1219 16:33:18.676000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST 1.0 [TorchInGraphFunctionVariable(<built-in method clip of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstantVariable()]
V1219 16:33:18.676000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION 3 [TorchInGraphFunctionVariable(<built-in method clip of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstantVariable(), ConstantVariable()]
V1219 16:33:18.678000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE STORE_FAST distance_to_boundary [TensorVariable()]
V1219 16:33:18.678000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py:83 in preprocess (preprocess) (inline depth: 3)
V1219 16:33:18.678000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]         dim = recent_position.size(-1)
V1219 16:33:18.678000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST recent_position []
V1219 16:33:18.678000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR size [TensorVariable()]
V1219 16:33:18.678000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST -1 [GetAttrVariable()]
V1219 16:33:18.679000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION 1 [GetAttrVariable(), ConstantVariable()]
V1219 16:33:18.679000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE STORE_FAST dim [ConstantVariable()]
V1219 16:33:18.679000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py:84 in preprocess (preprocess) (inline depth: 3)
V1219 16:33:18.679000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]         edge_displacement = (torch.gather(recent_position, dim=0, index=edge_index[0].unsqueeze(-1).expand(-1, dim)) -
V1219 16:33:18.679000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_GLOBAL torch []
V1219 16:33:18.679000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR gather [PythonModuleVariable(<module 'torch' from '/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/__init__.py'>)]
V1219 16:33:18.679000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST recent_position [TorchInGraphFunctionVariable(<built-in method gather of type object at 0x78f8d10bf1c0>)]
V1219 16:33:18.679000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST 0 [TorchInGraphFunctionVariable(<built-in method gather of type object at 0x78f8d10bf1c0>), TensorVariable()]
V1219 16:33:18.679000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST edge_index [TorchInGraphFunctionVariable(<built-in method gather of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstantVariable()]
V1219 16:33:18.679000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST 0 [TorchInGraphFunctionVariable(<built-in method gather of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstantVariable(), TensorVariable()]
V1219 16:33:18.679000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BINARY_SUBSCR None [TorchInGraphFunctionVariable(<built-in method gather of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstantVariable(), TensorVariable(), ConstantVariable()]
V1219 16:33:18.680000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR unsqueeze [TorchInGraphFunctionVariable(<built-in method gather of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstantVariable(), TensorVariable()]
V1219 16:33:18.680000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST -1 [TorchInGraphFunctionVariable(<built-in method gather of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstantVariable(), GetAttrVariable()]
V1219 16:33:18.680000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION 1 [TorchInGraphFunctionVariable(<built-in method gather of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstantVariable(), GetAttrVariable(), ConstantVariable()]
V1219 16:33:18.682000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5201] [0/0] eval Eq(u0, 1) == False [statically known]
V1219 16:33:18.682000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5358] [0/0] runtime_assert True == True [statically known]
V1219 16:33:18.683000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5201] [0/0] eval False == False [statically known]
V1219 16:33:18.683000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR expand [TorchInGraphFunctionVariable(<built-in method gather of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstantVariable(), TensorVariable()]
V1219 16:33:18.683000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST -1 [TorchInGraphFunctionVariable(<built-in method gather of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstantVariable(), GetAttrVariable()]
V1219 16:33:18.683000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST dim [TorchInGraphFunctionVariable(<built-in method gather of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstantVariable(), GetAttrVariable(), ConstantVariable()]
V1219 16:33:18.683000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION 2 [TorchInGraphFunctionVariable(<built-in method gather of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstantVariable(), GetAttrVariable(), ConstantVariable(), ConstantVariable()]
V1219 16:33:18.685000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5201] [0/0] eval Eq(u0, -1) == False [statically known]
V1219 16:33:18.686000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5358] [0/0] runtime_assert u0 >= 0 == True [statically known]
V1219 16:33:18.686000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5358] [0/0] runtime_assert True == True [statically known]
V1219 16:33:18.687000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST ('dim', 'index') [TorchInGraphFunctionVariable(<built-in method gather of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstantVariable(), TensorVariable()]
V1219 16:33:18.687000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION_KW 3 [TorchInGraphFunctionVariable(<built-in method gather of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstantVariable(), TensorVariable(), TupleVariable(length=2)]
V1219 16:33:18.688000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5201] [0/0] eval Eq(2*u0, 0) == False [statically known]
V1219 16:33:18.689000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5358] [0/0] runtime_assert u0 >= 0 == True [statically known]
V1219 16:33:18.689000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py:85 in preprocess (preprocess) (inline depth: 3)
V1219 16:33:18.689000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]                        torch.gather(recent_position, dim=0, index=edge_index[1].unsqueeze(-1).expand(-1, dim)))
V1219 16:33:18.689000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_GLOBAL torch [TensorVariable()]
V1219 16:33:18.690000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR gather [TensorVariable(), PythonModuleVariable(<module 'torch' from '/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/__init__.py'>)]
V1219 16:33:18.690000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST recent_position [TensorVariable(), TorchInGraphFunctionVariable(<built-in method gather of type object at 0x78f8d10bf1c0>)]
V1219 16:33:18.690000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST 0 [TensorVariable(), TorchInGraphFunctionVariable(<built-in method gather of type object at 0x78f8d10bf1c0>), TensorVariable()]
V1219 16:33:18.690000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST edge_index [TensorVariable(), TorchInGraphFunctionVariable(<built-in method gather of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstantVariable()]
V1219 16:33:18.690000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST 1 [TensorVariable(), TorchInGraphFunctionVariable(<built-in method gather of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstantVariable(), TensorVariable()]
V1219 16:33:18.690000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BINARY_SUBSCR None [TensorVariable(), TorchInGraphFunctionVariable(<built-in method gather of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstantVariable(), TensorVariable(), ConstantVariable()]
V1219 16:33:18.691000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR unsqueeze [TensorVariable(), TorchInGraphFunctionVariable(<built-in method gather of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstantVariable(), TensorVariable()]
V1219 16:33:18.691000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST -1 [TensorVariable(), TorchInGraphFunctionVariable(<built-in method gather of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstantVariable(), GetAttrVariable()]
V1219 16:33:18.691000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION 1 [TensorVariable(), TorchInGraphFunctionVariable(<built-in method gather of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstantVariable(), GetAttrVariable(), ConstantVariable()]
V1219 16:33:18.691000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5358] [0/0] runtime_assert True == True [statically known]
V1219 16:33:18.692000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR expand [TensorVariable(), TorchInGraphFunctionVariable(<built-in method gather of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstantVariable(), TensorVariable()]
V1219 16:33:18.692000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST -1 [TensorVariable(), TorchInGraphFunctionVariable(<built-in method gather of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstantVariable(), GetAttrVariable()]
V1219 16:33:18.692000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST dim [TensorVariable(), TorchInGraphFunctionVariable(<built-in method gather of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstantVariable(), GetAttrVariable(), ConstantVariable()]
V1219 16:33:18.692000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION 2 [TensorVariable(), TorchInGraphFunctionVariable(<built-in method gather of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstantVariable(), GetAttrVariable(), ConstantVariable(), ConstantVariable()]
V1219 16:33:18.693000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5358] [0/0] runtime_assert u0 >= 0 == True [statically known]
V1219 16:33:18.693000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5358] [0/0] runtime_assert True == True [statically known]
V1219 16:33:18.694000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST ('dim', 'index') [TensorVariable(), TorchInGraphFunctionVariable(<built-in method gather of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstantVariable(), TensorVariable()]
V1219 16:33:18.694000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION_KW 3 [TensorVariable(), TorchInGraphFunctionVariable(<built-in method gather of type object at 0x78f8d10bf1c0>), TensorVariable(), ConstantVariable(), TensorVariable(), TupleVariable(length=2)]
V1219 16:33:18.694000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5358] [0/0] runtime_assert u0 >= 0 == True [statically known]
V1219 16:33:18.695000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py:84 in preprocess (preprocess) (inline depth: 3)
V1219 16:33:18.695000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]         edge_displacement = (torch.gather(recent_position, dim=0, index=edge_index[0].unsqueeze(-1).expand(-1, dim)) -
V1219 16:33:18.695000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BINARY_SUBTRACT None [TensorVariable(), TensorVariable()]
V1219 16:33:18.695000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5358] [0/0] runtime_assert True == True [statically known]
V1219 16:33:18.696000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5358] [0/0] runtime_assert True == True [statically known]
V1219 16:33:18.696000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5201] [0/0] eval True == True [statically known]
V1219 16:33:18.696000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5358] [0/0] runtime_assert u0 >= 0 == True [statically known]
V1219 16:33:18.697000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE STORE_FAST edge_displacement [TensorVariable()]
V1219 16:33:18.697000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py:86 in preprocess (preprocess) (inline depth: 3)
V1219 16:33:18.697000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]         edge_displacement /= metadata["default_connectivity_radius"]
V1219 16:33:18.697000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST edge_displacement []
V1219 16:33:18.697000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST metadata [TensorVariable()]
V1219 16:33:18.697000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST default_connectivity_radius [TensorVariable(), ConstDictVariable()]
V1219 16:33:18.697000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BINARY_SUBSCR None [TensorVariable(), ConstDictVariable(), ConstantVariable()]
V1219 16:33:18.697000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE INPLACE_TRUE_DIVIDE None [TensorVariable(), LazyVariableTracker()]
V1219 16:33:18.698000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE STORE_FAST edge_displacement [TensorVariable()]
V1219 16:33:18.698000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py:87 in preprocess (preprocess) (inline depth: 3)
V1219 16:33:18.698000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]         edge_distance = torch.norm(edge_displacement, dim=-1, keepdim=True)
V1219 16:33:18.698000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_GLOBAL torch []
V1219 16:33:18.698000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR norm [PythonModuleVariable(<module 'torch' from '/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/__init__.py'>)]
V1219 16:33:18.698000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST edge_displacement [TorchInGraphFunctionVariable(<function norm at 0x78f87368b9a0>)]
V1219 16:33:18.698000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST -1 [TorchInGraphFunctionVariable(<function norm at 0x78f87368b9a0>), TensorVariable()]
V1219 16:33:18.698000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST True [TorchInGraphFunctionVariable(<function norm at 0x78f87368b9a0>), TensorVariable(), ConstantVariable()]
V1219 16:33:18.698000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST ('dim', 'keepdim') [TorchInGraphFunctionVariable(<function norm at 0x78f87368b9a0>), TensorVariable(), ConstantVariable(), ConstantVariable()]
V1219 16:33:18.698000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION_KW 3 [TorchInGraphFunctionVariable(<function norm at 0x78f87368b9a0>), TensorVariable(), ConstantVariable(), ConstantVariable(), TupleVariable(length=2)]
V1219 16:33:18.700000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5201] [0/0] eval u0 < 0 == False [statically known]
V1219 16:33:18.700000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5358] [0/0] runtime_assert True == True [statically known]
V1219 16:33:18.700000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5358] [0/0] runtime_assert True == True [statically known]
V1219 16:33:18.701000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5358] [0/0] runtime_assert u0 >= 0 == True [statically known]
V1219 16:33:18.701000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5358] [0/0] runtime_assert u0 >= 0 == True [statically known]
V1219 16:33:18.703000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5358] [0/0] runtime_assert u0 >= 0 == True [statically known]
V1219 16:33:18.704000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5358] [0/0] runtime_assert u0 >= 0 == True [statically known]
V1219 16:33:18.704000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5358] [0/0] runtime_assert True == True [statically known]
V1219 16:33:18.705000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5201] [0/0] eval True == True [statically known]
V1219 16:33:18.706000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5201] [0/0] eval Ne(u0, 1) == True [statically known]
V1219 16:33:18.707000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5358] [0/0] runtime_assert u0 >= 0 == True [statically known]
V1219 16:33:18.707000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5358] [0/0] runtime_assert u0 >= 0 == True [statically known]
V1219 16:33:18.708000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE STORE_FAST edge_distance [TensorVariable()]
V1219 16:33:18.708000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py:90 in preprocess (preprocess) (inline depth: 3)
V1219 16:33:18.708000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]         graph = pyg.data.Data(
V1219 16:33:18.708000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_GLOBAL pyg []
V1219 16:33:18.708000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR data [PythonModuleVariable(<module 'torch_geometric' from '/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch_geometric/__init__.py'>)]
V1219 16:33:18.708000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR Data [PythonModuleVariable(<module 'torch_geometric.data' from '/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch_geometric/data/__init__.py'>)]
V1219 16:33:18.709000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py:91 in preprocess (preprocess) (inline depth: 3)
V1219 16:33:18.709000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]             x=particle_type,
V1219 16:33:18.709000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST particle_type [UserDefinedClassVariable()]
V1219 16:33:18.709000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py:92 in preprocess (preprocess) (inline depth: 3)
V1219 16:33:18.709000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]             edge_index=edge_index,
V1219 16:33:18.709000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST edge_index [UserDefinedClassVariable(), TensorVariable()]
V1219 16:33:18.709000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py:93 in preprocess (preprocess) (inline depth: 3)
V1219 16:33:18.709000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]             edge_attr=torch.cat((edge_displacement, edge_distance), dim=-1),
V1219 16:33:18.709000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_GLOBAL torch [UserDefinedClassVariable(), TensorVariable(), TensorVariable()]
V1219 16:33:18.709000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR cat [UserDefinedClassVariable(), TensorVariable(), TensorVariable(), PythonModuleVariable(<module 'torch' from '/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/__init__.py'>)]
V1219 16:33:18.709000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST edge_displacement [UserDefinedClassVariable(), TensorVariable(), TensorVariable(), TorchInGraphFunctionVariable(<built-in method cat of type object at 0x78f8d10bf1c0>)]
V1219 16:33:18.709000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST edge_distance [UserDefinedClassVariable(), TensorVariable(), TensorVariable(), TorchInGraphFunctionVariable(<built-in method cat of type object at 0x78f8d10bf1c0>), TensorVariable()]
V1219 16:33:18.709000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BUILD_TUPLE 2 [UserDefinedClassVariable(), TensorVariable(), TensorVariable(), TorchInGraphFunctionVariable(<built-in method cat of type object at 0x78f8d10bf1c0>), TensorVariable(), TensorVariable()]
V1219 16:33:18.709000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST -1 [UserDefinedClassVariable(), TensorVariable(), TensorVariable(), TorchInGraphFunctionVariable(<built-in method cat of type object at 0x78f8d10bf1c0>), TupleVariable(length=2)]
V1219 16:33:18.710000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST ('dim',) [UserDefinedClassVariable(), TensorVariable(), TensorVariable(), TorchInGraphFunctionVariable(<built-in method cat of type object at 0x78f8d10bf1c0>), TupleVariable(length=2), ConstantVariable()]
V1219 16:33:18.710000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION_KW 2 [UserDefinedClassVariable(), TensorVariable(), TensorVariable(), TorchInGraphFunctionVariable(<built-in method cat of type object at 0x78f8d10bf1c0>), TupleVariable(length=2), ConstantVariable(), TupleVariable(length=1)]
V1219 16:33:18.710000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5358] [0/0] runtime_assert True == True [statically known]
V1219 16:33:18.710000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5358] [0/0] runtime_assert True == True [statically known]
V1219 16:33:18.711000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5358] [0/0] runtime_assert u0 >= 0 == True [statically known]
V1219 16:33:18.712000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5358] [0/0] runtime_assert u0 >= 0 == True [statically known]
V1219 16:33:18.713000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5358] [0/0] runtime_assert u0 >= 0 == True [statically known]
V1219 16:33:18.714000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5358] [0/0] runtime_assert u0 >= 0 == True [statically known]
V1219 16:33:18.714000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5358] [0/0] runtime_assert u0 >= 0 == True [statically known]
V1219 16:33:18.715000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py:94 in preprocess (preprocess) (inline depth: 3)
V1219 16:33:18.715000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]             y=None, # Ground truth for training
V1219 16:33:18.715000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST None [UserDefinedClassVariable(), TensorVariable(), TensorVariable(), TensorVariable()]
V1219 16:33:18.715000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py:95 in preprocess (preprocess) (inline depth: 3)
V1219 16:33:18.715000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]             pos=torch.cat((velocity_seq.reshape(velocity_seq.size(0), -1), distance_to_boundary), dim=-1),
V1219 16:33:18.715000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_GLOBAL torch [UserDefinedClassVariable(), TensorVariable(), TensorVariable(), TensorVariable(), ConstantVariable()]
V1219 16:33:18.715000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR cat [UserDefinedClassVariable(), TensorVariable(), TensorVariable(), TensorVariable(), ConstantVariable(), PythonModuleVariable(<module 'torch' from '/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/__init__.py'>)]
V1219 16:33:18.715000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST velocity_seq [UserDefinedClassVariable(), TensorVariable(), TensorVariable(), TensorVariable(), ConstantVariable(), TorchInGraphFunctionVariable(<built-in method cat of type object at 0x78f8d10bf1c0>)]
V1219 16:33:18.715000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR reshape [UserDefinedClassVariable(), TensorVariable(), TensorVariable(), TensorVariable(), ConstantVariable(), TorchInGraphFunctionVariable(<built-in method cat of type object at 0x78f8d10bf1c0>), TensorVariable()]
V1219 16:33:18.715000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST velocity_seq [UserDefinedClassVariable(), TensorVariable(), TensorVariable(), TensorVariable(), ConstantVariable(), TorchInGraphFunctionVariable(<built-in method cat of type object at 0x78f8d10bf1c0>), GetAttrVariable()]
V1219 16:33:18.715000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR size [UserDefinedClassVariable(), TensorVariable(), TensorVariable(), TensorVariable(), ConstantVariable(), TorchInGraphFunctionVariable(<built-in method cat of type object at 0x78f8d10bf1c0>), GetAttrVariable(), TensorVariable()]
V1219 16:33:18.716000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST 0 [UserDefinedClassVariable(), TensorVariable(), TensorVariable(), TensorVariable(), ConstantVariable(), TorchInGraphFunctionVariable(<built-in method cat of type object at 0x78f8d10bf1c0>), GetAttrVariable(), GetAttrVariable()]
V1219 16:33:18.716000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION 1 [UserDefinedClassVariable(), TensorVariable(), TensorVariable(), TensorVariable(), ConstantVariable(), TorchInGraphFunctionVariable(<built-in method cat of type object at 0x78f8d10bf1c0>), GetAttrVariable(), GetAttrVariable(), ConstantVariable()]
V1219 16:33:18.716000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST -1 [UserDefinedClassVariable(), TensorVariable(), TensorVariable(), TensorVariable(), ConstantVariable(), TorchInGraphFunctionVariable(<built-in method cat of type object at 0x78f8d10bf1c0>), GetAttrVariable(), ConstantVariable()]
V1219 16:33:18.716000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION 2 [UserDefinedClassVariable(), TensorVariable(), TensorVariable(), TensorVariable(), ConstantVariable(), TorchInGraphFunctionVariable(<built-in method cat of type object at 0x78f8d10bf1c0>), GetAttrVariable(), ConstantVariable(), ConstantVariable()]
V1219 16:33:18.717000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST distance_to_boundary [UserDefinedClassVariable(), TensorVariable(), TensorVariable(), TensorVariable(), ConstantVariable(), TorchInGraphFunctionVariable(<built-in method cat of type object at 0x78f8d10bf1c0>), TensorVariable()]
V1219 16:33:18.717000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BUILD_TUPLE 2 [UserDefinedClassVariable(), TensorVariable(), TensorVariable(), TensorVariable(), ConstantVariable(), TorchInGraphFunctionVariable(<built-in method cat of type object at 0x78f8d10bf1c0>), TensorVariable(), TensorVariable()]
V1219 16:33:18.717000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST -1 [UserDefinedClassVariable(), TensorVariable(), TensorVariable(), TensorVariable(), ConstantVariable(), TorchInGraphFunctionVariable(<built-in method cat of type object at 0x78f8d10bf1c0>), TupleVariable(length=2)]
V1219 16:33:18.717000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST ('dim',) [UserDefinedClassVariable(), TensorVariable(), TensorVariable(), TensorVariable(), ConstantVariable(), TorchInGraphFunctionVariable(<built-in method cat of type object at 0x78f8d10bf1c0>), TupleVariable(length=2), ConstantVariable()]
V1219 16:33:18.717000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION_KW 2 [UserDefinedClassVariable(), TensorVariable(), TensorVariable(), TensorVariable(), ConstantVariable(), TorchInGraphFunctionVariable(<built-in method cat of type object at 0x78f8d10bf1c0>), TupleVariable(length=2), ConstantVariable(), TupleVariable(length=1)]
V1219 16:33:18.718000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py:96 in preprocess (preprocess) (inline depth: 3)
V1219 16:33:18.718000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]             recent_position=recent_position,
V1219 16:33:18.718000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST recent_position [UserDefinedClassVariable(), TensorVariable(), TensorVariable(), TensorVariable(), ConstantVariable(), TensorVariable()]
V1219 16:33:18.718000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py:97 in preprocess (preprocess) (inline depth: 3)
V1219 16:33:18.718000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]             recent_velocity=velocity_seq[:, -1]
V1219 16:33:18.718000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST velocity_seq [UserDefinedClassVariable(), TensorVariable(), TensorVariable(), TensorVariable(), ConstantVariable(), TensorVariable(), TensorVariable()]
V1219 16:33:18.718000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST None [UserDefinedClassVariable(), TensorVariable(), TensorVariable(), TensorVariable(), ConstantVariable(), TensorVariable(), TensorVariable(), TensorVariable()]
V1219 16:33:18.718000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST None [UserDefinedClassVariable(), TensorVariable(), TensorVariable(), TensorVariable(), ConstantVariable(), TensorVariable(), TensorVariable(), TensorVariable(), ConstantVariable()]
V1219 16:33:18.718000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BUILD_SLICE 2 [UserDefinedClassVariable(), TensorVariable(), TensorVariable(), TensorVariable(), ConstantVariable(), TensorVariable(), TensorVariable(), TensorVariable(), ConstantVariable(), ConstantVariable()]
V1219 16:33:18.718000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST -1 [UserDefinedClassVariable(), TensorVariable(), TensorVariable(), TensorVariable(), ConstantVariable(), TensorVariable(), TensorVariable(), TensorVariable(), SliceVariable()]
V1219 16:33:18.719000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BUILD_TUPLE 2 [UserDefinedClassVariable(), TensorVariable(), TensorVariable(), TensorVariable(), ConstantVariable(), TensorVariable(), TensorVariable(), TensorVariable(), SliceVariable(), ConstantVariable()]
V1219 16:33:18.719000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE BINARY_SUBSCR None [UserDefinedClassVariable(), TensorVariable(), TensorVariable(), TensorVariable(), ConstantVariable(), TensorVariable(), TensorVariable(), TensorVariable(), TupleVariable(length=2)]
V1219 16:33:18.720000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py:90 in preprocess (preprocess) (inline depth: 3)
V1219 16:33:18.720000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]         graph = pyg.data.Data(
V1219 16:33:18.720000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST ('x', 'edge_index', 'edge_attr', 'y', 'pos', 'recent_position', 'recent_velocity') [UserDefinedClassVariable(), TensorVariable(), TensorVariable(), TensorVariable(), ConstantVariable(), TensorVariable(), TensorVariable(), TensorVariable()]
V1219 16:33:18.720000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION_KW 7 [UserDefinedClassVariable(), TensorVariable(), TensorVariable(), TensorVariable(), ConstantVariable(), TensorVariable(), TensorVariable(), TensorVariable(), TupleVariable(length=7)]
V1219 16:33:18.721000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:3099] [0/0] INLINING <code object __init__ at 0x78f782f0c190, file "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch_geometric/data/data.py", line 518>, inlined according trace_rules.lookup inlined by default
V1219 16:33:18.721000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch_geometric/data/data.py:530 in __init__ (Data.__init__) (inline depth: 4)
V1219 16:33:18.721000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]             super().__init__(tensor_attr_cls=DataTensorAttr)
V1219 16:33:18.732000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_GLOBAL super []
V1219 16:33:18.733000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_DEREF __class__ [BuiltinVariable()]
V1219 16:33:18.733000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST self [BuiltinVariable(), UserDefinedClassVariable()]
V1219 16:33:18.733000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION 2 [BuiltinVariable(), UserDefinedClassVariable(), UserDefinedObjectVariable(Data)]
V1219 16:33:18.733000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR __init__ [SuperVariable()]
V1219 16:33:18.733000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_GLOBAL DataTensorAttr [GetAttrVariable()]
V1219 16:33:18.733000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST ('tensor_attr_cls',) [GetAttrVariable(), UserDefinedClassVariable()]
V1219 16:33:18.733000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION_KW 1 [GetAttrVariable(), UserDefinedClassVariable(), TupleVariable(length=1)]
V1219 16:33:18.734000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:3099] [0/0] INLINING <code object __init__ at 0x78f7839df940, file "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch_geometric/data/feature_store.py", line 277>, inlined according trace_rules.lookup inlined by default
V1219 16:33:18.734000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch_geometric/data/feature_store.py:278 in __init__ (FeatureStore.__init__) (inline depth: 5)
V1219 16:33:18.734000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]             super().__init__()
V1219 16:33:18.737000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_GLOBAL super []
V1219 16:33:18.737000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_DEREF __class__ [BuiltinVariable()]
V1219 16:33:18.737000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST self [BuiltinVariable(), UserDefinedClassVariable()]
V1219 16:33:18.737000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION 2 [BuiltinVariable(), UserDefinedClassVariable(), UserDefinedObjectVariable(Data)]
V1219 16:33:18.737000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR __init__ [SuperVariable()]
V1219 16:33:18.738000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION 0 [GetAttrVariable()]
V1219 16:33:18.738000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:3099] [0/0] INLINING <code object __init__ at 0x78f7839fa290, file "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch_geometric/data/graph_store.py", line 109>, inlined according trace_rules.lookup inlined by default
V1219 16:33:18.738000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch_geometric/data/graph_store.py:110 in __init__ (GraphStore.__init__) (inline depth: 6)
V1219 16:33:18.738000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]             super().__init__()
V1219 16:33:18.741000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_GLOBAL super []
V1219 16:33:18.741000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_DEREF __class__ [BuiltinVariable()]
V1219 16:33:18.741000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST self [BuiltinVariable(), UserDefinedClassVariable()]
V1219 16:33:18.741000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION 2 [BuiltinVariable(), UserDefinedClassVariable(), UserDefinedObjectVariable(Data)]
V1219 16:33:18.741000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR __init__ [SuperVariable()]
V1219 16:33:18.741000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION 0 [GetAttrVariable()]
V1219 16:33:18.741000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE POP_TOP None [LambdaVariable()]
V1219 16:33:18.741000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source] TRACE starts_line /home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch_geometric/data/graph_store.py:111 in __init__ (GraphStore.__init__) (inline depth: 6)
V1219 16:33:18.741000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:865] [0/0] [__trace_source]             self.__dict__['_edge_attr_cls'] = edge_attr_cls or EdgeAttr
V1219 16:33:18.741000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST edge_attr_cls []
V1219 16:33:18.741000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE JUMP_IF_TRUE_OR_POP 16 [ConstantVariable()]
V1219 16:33:18.742000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_GLOBAL EdgeAttr []
V1219 16:33:18.742000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_FAST self [UserDefinedClassVariable()]
V1219 16:33:18.742000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_ATTR __dict__ [UserDefinedClassVariable(), UserDefinedObjectVariable(Data)]
V1219 16:33:18.742000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE LOAD_CONST _edge_attr_cls [UserDefinedClassVariable(), GetAttrVariable()]
V1219 16:33:18.742000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:888] [0/0] [__trace_bytecode] TRACE STORE_SUBSCR None [UserDefinedClassVariable(), GetAttrVariable(), ConstantVariable()]
V1219 16:33:18.742000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:904] [0/0] empty checkpoint
V1219 16:33:18.742000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:3153] [0/0] FAILED INLINING <code object __init__ at 0x78f7839fa290, file "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch_geometric/data/graph_store.py", line 109>
V1219 16:33:18.742000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:904] [0/0] empty checkpoint
V1219 16:33:18.742000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:3153] [0/0] FAILED INLINING <code object __init__ at 0x78f7839df940, file "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch_geometric/data/feature_store.py", line 277>
V1219 16:33:18.742000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:904] [0/0] empty checkpoint
V1219 16:33:18.742000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:3153] [0/0] FAILED INLINING <code object __init__ at 0x78f782f0c190, file "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch_geometric/data/data.py", line 518>
V1219 16:33:18.742000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:904] [0/0] empty checkpoint
V1219 16:33:18.743000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:3153] [0/0] FAILED INLINING <code object preprocess at 0x78f8d25a8030, file "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py", line 61>
V1219 16:33:18.743000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:904] [0/0] empty checkpoint
V1219 16:33:18.743000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:3153] [0/0] FAILED INLINING <code object forward at 0x78f8d25a8190, file "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py", line 107>
V1219 16:33:18.743000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:904] [0/0] empty checkpoint
V1219 16:33:18.743000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:3153] [0/0] FAILED INLINING <code object _call_impl at 0x78f885d08f50, file "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1740>
V1219 16:33:18.743000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:904] [0/0] empty checkpoint
PyTorch has version 2.5.0+cu124 with cuda 12.4
-1 0
-1 0
Traceback (most recent call last):
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py", line 227, in <module>
    aten_dialect = export(simulator, example_inputs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/export/__init__.py", line 270, in export
    return _export(
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/export/_trace.py", line 1017, in wrapper
    raise e
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/export/_trace.py", line 990, in wrapper
    ep = fn(*args, **kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/export/exported_program.py", line 114, in wrapper
    return fn(*args, **kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/export/_trace.py", line 1880, in _export
    export_artifact = export_func(  # type: ignore[operator]
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/export/_trace.py", line 1224, in _strict_export
    return _strict_export_lower_to_aten_ir(
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/export/_trace.py", line 1252, in _strict_export_lower_to_aten_ir
    gm_torch_level = _export_to_torch_ir(
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/export/_trace.py", line 560, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1432, in inner
    result_traced = opt_f(*args, **kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
    return fn(*args, **kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1269, in __call__
    return self._torchdynamo_orig_callable(
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 526, in __call__
    return _compile(
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 924, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 666, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
    return function(*args, **kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 699, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
    transformations(instructions, code_options)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 219, in _fn
    return fn(*args, **kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 634, in transform
    tracer.run()
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2796, in run
    super().run()
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
    return inner_fn(self, inst)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1602, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 442, in call_function
    return tx.inline_user_function_return(
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
    tracer.run()
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
    return inner_fn(self, inst)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1680, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 385, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
    tracer.run()
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
    return inner_fn(self, inst)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1602, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
    tracer.run()
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
    return inner_fn(self, inst)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1692, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/user_defined.py", line 496, in call_function
    var.call_method(tx, "__init__", args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/user_defined.py", line 788, in call_method
    return UserMethodVariable(method, self, source=source).call_function(
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 385, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
    tracer.run()
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
    return inner_fn(self, inst)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1692, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py", line 1024, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py", line 195, in call_method
    ).call_function(tx, [self.objvar] + args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
    tracer.run()
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
    return inner_fn(self, inst)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1602, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py", line 1024, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py", line 195, in call_method
    ).call_function(tx, [self.objvar] + args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
    tracer.run()
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
    return inner_fn(self, inst)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1803, in STORE_SUBSCR
    result = obj.call_method(self, "__setitem__", [key, val], {})
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py", line 1082, in call_method
    return super().call_method(tx, name, args, kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/variables/base.py", line 343, in call_method
    unimplemented(f"call_method {self} {name} {args} {kwargs}")
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 297, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: call_method GetAttrVariable(UserDefinedObjectVariable(Data), __dict__) __setitem__ [ConstantVariable(), UserDefinedClassVariable()] {}

from user code:
   File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py", line 185, in forward
    data = self.build_graph(particle_type, position_sequence)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py", line 108, in forward
    graph = preprocess(particle_type, position_sequence, self.metadata)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/export_gnn_executorch.py", line 90, in preprocess
    graph = pyg.data.Data(
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch_geometric/data/data.py", line 530, in __init__
    super().__init__(tensor_attr_cls=DataTensorAttr)
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch_geometric/data/feature_store.py", line 278, in __init__
    super().__init__()
  File "/home/sicli01/Projects/FluidML/gnn-physics-pytorch/gnn_env/lib/python3.10/site-packages/torch_geometric/data/graph_store.py", line 111, in __init__
    self.__dict__['_edge_attr_cls'] = edge_attr_cls or EdgeAttr

I1219 16:33:18.755000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/utils.py:399] TorchDynamo compilation metrics:
I1219 16:33:18.755000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/utils.py:399] Function                  Runtimes (s)
I1219 16:33:18.755000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/utils.py:399] ----------------------  --------------
I1219 16:33:18.755000 2552752 gnn_env/lib/python3.10/site-packages/torch/_dynamo/utils.py:399] _compile.compile_inner               0
V1219 16:33:18.755000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats constrain_symbol_range: CacheInfo(hits=18, misses=2, maxsize=None, currsize=2)
V1219 16:33:18.755000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats evaluate_expr: CacheInfo(hits=39, misses=9, maxsize=256, currsize=9)
V1219 16:33:18.755000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats _simplify_floor_div: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V1219 16:33:18.755000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats _maybe_guard_rel: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V1219 16:33:18.755000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats _find: CacheInfo(hits=20, misses=1, maxsize=None, currsize=1)
V1219 16:33:18.755000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats has_hint: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V1219 16:33:18.755000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats size_hint: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V1219 16:33:18.755000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats simplify: CacheInfo(hits=4, misses=9, maxsize=None, currsize=9)
V1219 16:33:18.755000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats _update_divisible: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V1219 16:33:18.755000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats replace: CacheInfo(hits=1208, misses=30, maxsize=None, currsize=30)
V1219 16:33:18.756000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats _maybe_evaluate_static: CacheInfo(hits=49, misses=13, maxsize=None, currsize=13)
V1219 16:33:18.756000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats get_implications: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V1219 16:33:18.756000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats get_axioms: CacheInfo(hits=10, misses=3, maxsize=None, currsize=3)
V1219 16:33:18.756000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats safe_expand: CacheInfo(hits=182, misses=30, maxsize=256, currsize=30)
V1219 16:33:18.756000 2552752 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats uninteresting_files: CacheInfo(hits=4, misses=1, maxsize=None, currsize=1)
I1219 16:33:19.704000 2552920 gnn_env/lib/python3.10/site-packages/torch/_dynamo/utils.py:399] TorchDynamo compilation metrics:
I1219 16:33:19.704000 2552920 gnn_env/lib/python3.10/site-packages/torch/_dynamo/utils.py:399] Function    Runtimes (s)
I1219 16:33:19.704000 2552920 gnn_env/lib/python3.10/site-packages/torch/_dynamo/utils.py:399] ----------  --------------
V1219 16:33:19.704000 2552920 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats constrain_symbol_range: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V1219 16:33:19.704000 2552920 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats evaluate_expr: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V1219 16:33:19.704000 2552920 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats _simplify_floor_div: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V1219 16:33:19.704000 2552920 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats _maybe_guard_rel: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V1219 16:33:19.705000 2552920 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats _find: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V1219 16:33:19.705000 2552920 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats has_hint: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V1219 16:33:19.705000 2552920 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats size_hint: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V1219 16:33:19.705000 2552920 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats simplify: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V1219 16:33:19.705000 2552920 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats _update_divisible: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V1219 16:33:19.705000 2552920 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats replace: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V1219 16:33:19.705000 2552920 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats _maybe_evaluate_static: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V1219 16:33:19.705000 2552920 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats get_implications: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V1219 16:33:19.705000 2552920 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats get_axioms: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V1219 16:33:19.705000 2552920 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats safe_expand: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V1219 16:33:19.705000 2552920 gnn_env/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats uninteresting_files: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)

@akihironitta

@akihironitta akihironitta self-assigned this Dec 19, 2024
@akihironitta
Copy link
Member

@sicong-li-arm Thanks for creating this issue! I was able to reproduce the issue with your script. I'll have a deeper look :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants