diff --git a/infinigen/tools/export.py b/infinigen/tools/export.py index 8debf69ff..4db14c172 100644 --- a/infinigen/tools/export.py +++ b/infinigen/tools/export.py @@ -485,37 +485,95 @@ def bake_pass(obj, dest: Path, img_size, bake_type, export_usd): def bake_metal( obj, dest, img_size, export_usd -): # metal baking is not really set up for node graphs w/ 2 mixed BSDFs. - metal_map_mats = [] +): + # If at least one material has both a BSDF and non-zero Metallic value, then bake + should_bake = False + + # (Root node, From Socket, To Socket) + links_removed = [] + links_added = [] + for slot in obj.material_slots: mat = slot.material - if mat is None or not mat.use_nodes: + if mat is None: + logging.warn("No material on mesh, skipping...") + continue + if not mat.use_nodes: + logging.warn("Material has no nodes, skipping...") continue + nodes = mat.node_tree.nodes - if nodes.get("Principled BSDF") and nodes.get("Material Output"): + principled_bsdf_node = None + root_node = None + logging.info(f"{mat.name} has {len(nodes)} nodes: {nodes}") + for node in nodes: + if node.type != "GROUP": + continue + + for subnode in node.node_tree.nodes: + logging.info(f" [{subnode.type}] {subnode.name} {subnode.bl_idname}") + if subnode.type == "BSDF_PRINCIPLED": + logging.debug(f" BSDF_PRINCIPLED: {subnode.inputs}") + principled_bsdf_node = subnode + root_node = node + + if nodes.get("Principled BSDF"): principled_bsdf_node = nodes["Principled BSDF"] - outputNode = nodes["Material Output"] - else: + root_node = mat + elif not principled_bsdf_node: + logging.warn("No Principled BSDF, skipping...") + continue + elif "Metallic" not in principled_bsdf_node.inputs: + logging.warn("No Metallic input, skipping...") continue - links = mat.node_tree.links + # Here, we've found the proper BSDF and Metallic input. Set up the scene graph + # for baking. + + outputSoc = principled_bsdf_node.outputs[0].links[0].to_socket + + # Remove the BSDF link to Output first + l = principled_bsdf_node.outputs[0].links[0] + from_socket, to_socket = l.from_socket, l.to_socket + logging.debug(f"Removing link: {from_socket.name} => {to_socket.name}") + root_node.node_tree.links.remove(l) + links_removed.append((root_node, from_socket, to_socket)) + + # Get metallic value + metallic_input = principled_bsdf_node.inputs["Metallic"] + metallic_val = metallic_input.default_value + logging.info(f"Metallic value: {metallic_val}") - if len(principled_bsdf_node.inputs["Metallic"].links) != 0: - link = principled_bsdf_node.inputs["Metallic"].links[0] - from_socket = link.from_socket - links.remove(link) - links.new(outputNode.inputs[0], from_socket) - metal_map_mats.append(mat) + if metallic_val > 0: + should_bake = True - if len(metal_map_mats) != 0: + # Make a color input matching the metallic value + col = root_node.node_tree.nodes.new("ShaderNodeRGB") + col.outputs[0].default_value = (metallic_val, metallic_val, metallic_val, 1.0) + new_link = root_node.node_tree.links.new(col.outputs[0], metallic_input) + links_added.append((root_node, col.outputs[0], metallic_input)) + logging.debug(f"Linking {col.outputs[0].name} to {metallic_input.name}({metallic_input.bl_idname}): {new_link}") + + # Link the color to output + new_link = root_node.node_tree.links.new(col.outputs[0], outputSoc) + links_added.append((root_node, col.outputs[0], outputSoc)) + logging.debug(f"Linking {col.outputs[0].name} to {outputSoc.name}({outputSoc.bl_idname}): {new_link}") + + # After setting up all materials, bake if applicable + if should_bake: bake_pass(obj, dest, img_size, "METAL", export_usd) - for mat in metal_map_mats: - nodes = mat.node_tree.nodes - outputNode = nodes["Material Output"] - principled_bsdf_node = nodes["Principled BSDF"] - links.remove(outputNode.inputs[0].links[0]) - links.new(outputNode.inputs[0], principled_bsdf_node.outputs[0]) + # After baking, undo the temporary changes to the scene graph + for n, from_soc, to_soc in links_added: + logging.debug(f"Removing added link:\t{n.name}: {from_soc.name} => {to_soc.name}") + for l in n.node_tree.links: + if l.from_socket == from_soc and l.to_socket == to_soc: + n.node_tree.links.remove(l) + logging.debug(f"Removed link:\t{n.name}: {from_soc.name} => {to_soc.name}") + + for n, from_soc, to_soc in links_removed: + logging.debug(f"Adding back link:\t{n.name}: {from_soc.name} => {to_soc.name}") + n.node_tree.links.new(from_soc, to_soc) def bake_normals(obj, dest, img_size, export_usd):