diff --git a/netbox/circuits/graphql/schema.py b/netbox/circuits/graphql/schema.py index f65874239ea..32b73e258e6 100644 --- a/netbox/circuits/graphql/schema.py +++ b/netbox/circuits/graphql/schema.py @@ -1,21 +1,38 @@ import graphene +from circuits import models from netbox.graphql.fields import ObjectField, ObjectListField from .types import * +from utilities.graphql_optimizer import gql_query_optimizer class CircuitsQuery(graphene.ObjectType): circuit = ObjectField(CircuitType) circuit_list = ObjectListField(CircuitType) + def resolve_circuit_list(root, info, **kwargs): + return gql_query_optimizer(models.Circuit.objects.all(), info) + circuit_termination = ObjectField(CircuitTerminationType) circuit_termination_list = ObjectListField(CircuitTerminationType) + def resolve_circuit_termination_list(root, info, **kwargs): + return gql_query_optimizer(models.CircuitTermination.objects.all(), info) + circuit_type = ObjectField(CircuitTypeType) circuit_type_list = ObjectListField(CircuitTypeType) + def resolve_circuit_type_list(root, info, **kwargs): + return gql_query_optimizer(models.CircuitType.objects.all(), info) + provider = ObjectField(ProviderType) provider_list = ObjectListField(ProviderType) + def resolve_provider_list(root, info, **kwargs): + return gql_query_optimizer(models.Provider.objects.all(), info) + provider_network = ObjectField(ProviderNetworkType) provider_network_list = ObjectListField(ProviderNetworkType) + + def resolve_provider_network_list(root, info, **kwargs): + return gql_query_optimizer(models.ProviderNetwork.objects.all(), info) diff --git a/netbox/core/graphql/schema.py b/netbox/core/graphql/schema.py index 201965430a8..876faa44265 100644 --- a/netbox/core/graphql/schema.py +++ b/netbox/core/graphql/schema.py @@ -1,12 +1,20 @@ import graphene +from core import models from netbox.graphql.fields import ObjectField, ObjectListField from .types import * +from utilities.graphql_optimizer import gql_query_optimizer class CoreQuery(graphene.ObjectType): data_file = ObjectField(DataFileType) data_file_list = ObjectListField(DataFileType) + def resolve_data_file_list(root, info, **kwargs): + return gql_query_optimizer(models.DataFile.objects.all(), info) + data_source = ObjectField(DataSourceType) data_source_list = ObjectListField(DataSourceType) + + def resolve_data_source_list(root, info, **kwargs): + return gql_query_optimizer(models.DataSource.objects.all(), info) diff --git a/netbox/dcim/graphql/schema.py b/netbox/dcim/graphql/schema.py index eba3114208e..6d689ac2d8a 100644 --- a/netbox/dcim/graphql/schema.py +++ b/netbox/dcim/graphql/schema.py @@ -2,126 +2,248 @@ from netbox.graphql.fields import ObjectField, ObjectListField from .types import * +from dcim import models from .types import VirtualDeviceContextType +from utilities.graphql_optimizer import gql_query_optimizer class DCIMQuery(graphene.ObjectType): cable = ObjectField(CableType) cable_list = ObjectListField(CableType) + def resolve_cable_list(root, info, **kwargs): + return gql_query_optimizer(models.Cable.objects.all(), info) + console_port = ObjectField(ConsolePortType) console_port_list = ObjectListField(ConsolePortType) + def resolve_console_port_list(root, info, **kwargs): + return gql_query_optimizer(models.ConsolePort.objects.all(), info) + console_port_template = ObjectField(ConsolePortTemplateType) console_port_template_list = ObjectListField(ConsolePortTemplateType) + def resolve_console_port_template_list(root, info, **kwargs): + return gql_query_optimizer(models.ConsolePortTemplate.objects.all(), info) + console_server_port = ObjectField(ConsoleServerPortType) console_server_port_list = ObjectListField(ConsoleServerPortType) + def resolve_console_server_port_list(root, info, **kwargs): + return gql_query_optimizer(models.ConsoleServerPort.objects.all(), info) + console_server_port_template = ObjectField(ConsoleServerPortTemplateType) console_server_port_template_list = ObjectListField(ConsoleServerPortTemplateType) + def resolve_console_server_port_template_list(root, info, **kwargs): + return gql_query_optimizer(models.ConsoleServerPortTemplate.objects.all(), info) + device = ObjectField(DeviceType) device_list = ObjectListField(DeviceType) + def resolve_device_list(root, info, **kwargs): + return gql_query_optimizer(models.Device.objects.all(), info) + device_bay = ObjectField(DeviceBayType) device_bay_list = ObjectListField(DeviceBayType) + def resolve_device_bay_list(root, info, **kwargs): + return gql_query_optimizer(models.DeviceBay.objects.all(), info) + device_bay_template = ObjectField(DeviceBayTemplateType) device_bay_template_list = ObjectListField(DeviceBayTemplateType) + def resolve_device_bay_template_list(root, info, **kwargs): + return gql_query_optimizer(models.DeviceBayTemplate.objects.all(), info) + device_role = ObjectField(DeviceRoleType) device_role_list = ObjectListField(DeviceRoleType) + def resolve_device_role_list(root, info, **kwargs): + return gql_query_optimizer(models.DeviceRole.objects.all(), info) + device_type = ObjectField(DeviceTypeType) device_type_list = ObjectListField(DeviceTypeType) + def resolve_device_type_list(root, info, **kwargs): + return gql_query_optimizer(models.DeviceType.objects.all(), info) + front_port = ObjectField(FrontPortType) front_port_list = ObjectListField(FrontPortType) + def resolve_front_port_list(root, info, **kwargs): + return gql_query_optimizer(models.FrontPort.objects.all(), info) + front_port_template = ObjectField(FrontPortTemplateType) front_port_template_list = ObjectListField(FrontPortTemplateType) + def resolve_front_port_template_list(root, info, **kwargs): + return gql_query_optimizer(models.FrontPortTemplate.objects.all(), info) + interface = ObjectField(InterfaceType) interface_list = ObjectListField(InterfaceType) + def resolve_interface_list(root, info, **kwargs): + return gql_query_optimizer(models.Interface.objects.all(), info) + interface_template = ObjectField(InterfaceTemplateType) interface_template_list = ObjectListField(InterfaceTemplateType) + def resolve_interface_template_list(root, info, **kwargs): + return gql_query_optimizer(models.InterfaceTemplate.objects.all(), info) + inventory_item = ObjectField(InventoryItemType) inventory_item_list = ObjectListField(InventoryItemType) + def resolve_inventory_item_list(root, info, **kwargs): + return gql_query_optimizer(models.InventoryItem.objects.all(), info) + inventory_item_role = ObjectField(InventoryItemRoleType) inventory_item_role_list = ObjectListField(InventoryItemRoleType) + def resolve_inventory_item_role_list(root, info, **kwargs): + return gql_query_optimizer(models.InventoryItemRole.objects.all(), info) + inventory_item_template = ObjectField(InventoryItemTemplateType) inventory_item_template_list = ObjectListField(InventoryItemTemplateType) + def resolve_inventory_item_template_list(root, info, **kwargs): + return gql_query_optimizer(models.InventoryItemTemplate.objects.all(), info) + location = ObjectField(LocationType) location_list = ObjectListField(LocationType) + def resolve_location_list(root, info, **kwargs): + return gql_query_optimizer(models.Location.objects.all(), info) + manufacturer = ObjectField(ManufacturerType) manufacturer_list = ObjectListField(ManufacturerType) + def resolve_manufacturer_list(root, info, **kwargs): + return gql_query_optimizer(models.Manufacturer.objects.all(), info) + module = ObjectField(ModuleType) module_list = ObjectListField(ModuleType) + def resolve_module_list(root, info, **kwargs): + return gql_query_optimizer(models.Module.objects.all(), info) + module_bay = ObjectField(ModuleBayType) module_bay_list = ObjectListField(ModuleBayType) + def resolve_module_bay_list(root, info, **kwargs): + return gql_query_optimizer(models.ModuleBay.objects.all(), info) + module_bay_template = ObjectField(ModuleBayTemplateType) module_bay_template_list = ObjectListField(ModuleBayTemplateType) + def resolve_module_bay_template_list(root, info, **kwargs): + return gql_query_optimizer(models.ModuleBayTemplate.objects.all(), info) + module_type = ObjectField(ModuleTypeType) module_type_list = ObjectListField(ModuleTypeType) + def resolve_module_type_list(root, info, **kwargs): + return gql_query_optimizer(models.ModuleType.objects.all(), info) + platform = ObjectField(PlatformType) platform_list = ObjectListField(PlatformType) + def resolve_platform_list(root, info, **kwargs): + return gql_query_optimizer(models.Platform.objects.all(), info) + power_feed = ObjectField(PowerFeedType) power_feed_list = ObjectListField(PowerFeedType) + def resolve_power_feed_list(root, info, **kwargs): + return gql_query_optimizer(models.PowerFeed.objects.all(), info) + power_outlet = ObjectField(PowerOutletType) power_outlet_list = ObjectListField(PowerOutletType) + def resolve_power_outlet_list(root, info, **kwargs): + return gql_query_optimizer(models.PowerOutlet.objects.all(), info) + power_outlet_template = ObjectField(PowerOutletTemplateType) power_outlet_template_list = ObjectListField(PowerOutletTemplateType) + def resolve_power_outlet_template_list(root, info, **kwargs): + return gql_query_optimizer(models.PowerOutletTemplate.objects.all(), info) + power_panel = ObjectField(PowerPanelType) power_panel_list = ObjectListField(PowerPanelType) + def resolve_power_panel_list(root, info, **kwargs): + return gql_query_optimizer(models.PowerPanel.objects.all(), info) + power_port = ObjectField(PowerPortType) power_port_list = ObjectListField(PowerPortType) + def resolve_power_port_list(root, info, **kwargs): + return gql_query_optimizer(models.PowerPort.objects.all(), info) + power_port_template = ObjectField(PowerPortTemplateType) power_port_template_list = ObjectListField(PowerPortTemplateType) + def resolve_power_port_template_list(root, info, **kwargs): + return gql_query_optimizer(models.PowerPortTemplate.objects.all(), info) + rack = ObjectField(RackType) rack_list = ObjectListField(RackType) + def resolve_rack_list(root, info, **kwargs): + return gql_query_optimizer(models.Rack.objects.all(), info) + rack_reservation = ObjectField(RackReservationType) rack_reservation_list = ObjectListField(RackReservationType) + def resolve_rack_reservation_list(root, info, **kwargs): + return gql_query_optimizer(models.RackReservation.objects.all(), info) + rack_role = ObjectField(RackRoleType) rack_role_list = ObjectListField(RackRoleType) + def resolve_rack_role_list(root, info, **kwargs): + return gql_query_optimizer(models.RackRole.objects.all(), info) + rear_port = ObjectField(RearPortType) rear_port_list = ObjectListField(RearPortType) + def resolve_rear_port_list(root, info, **kwargs): + return gql_query_optimizer(models.RearPort.objects.all(), info) + rear_port_template = ObjectField(RearPortTemplateType) rear_port_template_list = ObjectListField(RearPortTemplateType) + def resolve_rear_port_template_list(root, info, **kwargs): + return gql_query_optimizer(models.RearPortTemplate.objects.all(), info) + region = ObjectField(RegionType) region_list = ObjectListField(RegionType) + def resolve_region_list(root, info, **kwargs): + return gql_query_optimizer(models.Region.objects.all(), info) + site = ObjectField(SiteType) site_list = ObjectListField(SiteType) + def resolve_site_list(root, info, **kwargs): + return gql_query_optimizer(models.Site.objects.all(), info) + site_group = ObjectField(SiteGroupType) site_group_list = ObjectListField(SiteGroupType) + def resolve_site_group_list(root, info, **kwargs): + return gql_query_optimizer(models.SiteGroup.objects.all(), info) + virtual_chassis = ObjectField(VirtualChassisType) virtual_chassis_list = ObjectListField(VirtualChassisType) + def resolve_virtual_chassis_list(root, info, **kwargs): + return gql_query_optimizer(models.VirtualChassis.objects.all(), info) + virtual_device_context = ObjectField(VirtualDeviceContextType) virtual_device_context_list = ObjectListField(VirtualDeviceContextType) + + def resolve_virtual_device_context_list(root, info, **kwargs): + return gql_query_optimizer(models.VirtualDeviceContext.objects.all(), info) diff --git a/netbox/extras/graphql/schema.py b/netbox/extras/graphql/schema.py index 3e116023f7e..c61b0b88cc5 100644 --- a/netbox/extras/graphql/schema.py +++ b/netbox/extras/graphql/schema.py @@ -1,36 +1,68 @@ import graphene +from extras import models from netbox.graphql.fields import ObjectField, ObjectListField from .types import * +from utilities.graphql_optimizer import gql_query_optimizer class ExtrasQuery(graphene.ObjectType): config_context = ObjectField(ConfigContextType) config_context_list = ObjectListField(ConfigContextType) + def resolve_config_context_list(root, info, **kwargs): + return gql_query_optimizer(models.ConfigContext.objects.all(), info) + config_template = ObjectField(ConfigTemplateType) config_template_list = ObjectListField(ConfigTemplateType) + def resolve_config_template_list(root, info, **kwargs): + return gql_query_optimizer(models.ConfigTemplate.objects.all(), info) + custom_field = ObjectField(CustomFieldType) custom_field_list = ObjectListField(CustomFieldType) + def resolve_custom_field_list(root, info, **kwargs): + return gql_query_optimizer(models.CustomField.objects.all(), info) + custom_link = ObjectField(CustomLinkType) custom_link_list = ObjectListField(CustomLinkType) + def resolve_custom_link_list(root, info, **kwargs): + return gql_query_optimizer(models.CustomLink.objects.all(), info) + export_template = ObjectField(ExportTemplateType) export_template_list = ObjectListField(ExportTemplateType) + def resolve_export_template_list(root, info, **kwargs): + return gql_query_optimizer(models.ExportTemplate.objects.all(), info) + image_attachment = ObjectField(ImageAttachmentType) image_attachment_list = ObjectListField(ImageAttachmentType) + def resolve_image_attachment_list(root, info, **kwargs): + return gql_query_optimizer(models.ImageAttachment.objects.all(), info) + saved_filter = ObjectField(SavedFilterType) saved_filter_list = ObjectListField(SavedFilterType) + def resolve_saved_filter_list(root, info, **kwargs): + return gql_query_optimizer(models.SavedFilter.objects.all(), info) + journal_entry = ObjectField(JournalEntryType) journal_entry_list = ObjectListField(JournalEntryType) + def resolve_journal_entry_list(root, info, **kwargs): + return gql_query_optimizer(models.JournalEntry.objects.all(), info) + tag = ObjectField(TagType) tag_list = ObjectListField(TagType) + def resolve_tag_list(root, info, **kwargs): + return gql_query_optimizer(models.Tag.objects.all(), info) + webhook = ObjectField(WebhookType) webhook_list = ObjectListField(WebhookType) + + def resolve_webhook_list(root, info, **kwargs): + return gql_query_optimizer(models.Webhook.objects.all(), info) diff --git a/netbox/ipam/graphql/schema.py b/netbox/ipam/graphql/schema.py index 3f77de7494e..596b5eb7851 100644 --- a/netbox/ipam/graphql/schema.py +++ b/netbox/ipam/graphql/schema.py @@ -1,6 +1,9 @@ import graphene +from ipam import models +from utilities.graphql_optimizer import gql_query_optimizer from netbox.graphql.fields import ObjectField, ObjectListField + from .types import * @@ -8,53 +11,107 @@ class IPAMQuery(graphene.ObjectType): asn = ObjectField(ASNType) asn_list = ObjectListField(ASNType) + def resolve_asn_list(root, info, **kwargs): + return gql_query_optimizer(models.ASN.objects.all(), info) + asn_range = ObjectField(ASNRangeType) asn_range_list = ObjectListField(ASNRangeType) + def resolve_asn_range_list(root, info, **kwargs): + return gql_query_optimizer(models.ASNRange.objects.all(), info) + aggregate = ObjectField(AggregateType) aggregate_list = ObjectListField(AggregateType) + def resolve_aggregate_list(root, info, **kwargs): + return gql_query_optimizer(models.Aggregate.objects.all(), info) + ip_address = ObjectField(IPAddressType) ip_address_list = ObjectListField(IPAddressType) + def resolve_ip_address_list(root, info, **kwargs): + return gql_query_optimizer(models.IPAddress.objects.all(), info) + ip_range = ObjectField(IPRangeType) ip_range_list = ObjectListField(IPRangeType) + def resolve_ip_range_list(root, info, **kwargs): + return gql_query_optimizer(models.IPRange.objects.all(), info) + l2vpn = ObjectField(L2VPNType) l2vpn_list = ObjectListField(L2VPNType) + def resolve_l2vpn_list(root, info, **kwargs): + return gql_query_optimizer(models.L2VPN.objects.all(), info) + l2vpn_termination = ObjectField(L2VPNTerminationType) l2vpn_termination_list = ObjectListField(L2VPNTerminationType) + def resolve_l2vpn_termination_list(root, info, **kwargs): + return gql_query_optimizer(models.L2VPNTermination.objects.all(), info) + prefix = ObjectField(PrefixType) prefix_list = ObjectListField(PrefixType) + def resolve_prefix_list(root, info, **kwargs): + return gql_query_optimizer(models.Prefix.objects.all(), info) + rir = ObjectField(RIRType) rir_list = ObjectListField(RIRType) + def resolve_rir_list(root, info, **kwargs): + return gql_query_optimizer(models.RIR.objects.all(), info) + role = ObjectField(RoleType) role_list = ObjectListField(RoleType) + def resolve_role_list(root, info, **kwargs): + return gql_query_optimizer(models.Role.objects.all(), info) + route_target = ObjectField(RouteTargetType) route_target_list = ObjectListField(RouteTargetType) + def resolve_route_target_list(root, info, **kwargs): + return gql_query_optimizer(models.RouteTarget.objects.all(), info) + service = ObjectField(ServiceType) service_list = ObjectListField(ServiceType) + def resolve_service_list(root, info, **kwargs): + return gql_query_optimizer(models.Service.objects.all(), info) + service_template = ObjectField(ServiceTemplateType) service_template_list = ObjectListField(ServiceTemplateType) + def resolve_service_template_list(root, info, **kwargs): + return gql_query_optimizer(models.ServiceTemplate.objects.all(), info) + fhrp_group = ObjectField(FHRPGroupType) fhrp_group_list = ObjectListField(FHRPGroupType) + def resolve_fhrp_group_list(root, info, **kwargs): + return gql_query_optimizer(models.FHRPGroup.objects.all(), info) + fhrp_group_assignment = ObjectField(FHRPGroupAssignmentType) fhrp_group_assignment_list = ObjectListField(FHRPGroupAssignmentType) + def resolve_fhrp_group_assignment_list(root, info, **kwargs): + return gql_query_optimizer(models.FHRPGroupAssignment.objects.all(), info) + vlan = ObjectField(VLANType) vlan_list = ObjectListField(VLANType) + def resolve_vlan_list(root, info, **kwargs): + return gql_query_optimizer(models.VLAN.objects.all(), info) + vlan_group = ObjectField(VLANGroupType) vlan_group_list = ObjectListField(VLANGroupType) + def resolve_vlan_group_list(root, info, **kwargs): + return gql_query_optimizer(models.VLANGroup.objects.all(), info) + vrf = ObjectField(VRFType) vrf_list = ObjectListField(VRFType) + + def resolve_vrf_list(root, info, **kwargs): + return gql_query_optimizer(models.VRF.objects.all(), info) diff --git a/netbox/netbox/graphql/fields.py b/netbox/netbox/graphql/fields.py index 7c359e82ea3..0f5221b47c8 100644 --- a/netbox/netbox/graphql/fields.py +++ b/netbox/netbox/graphql/fields.py @@ -2,7 +2,6 @@ import graphene from graphene_django import DjangoListField - from .utils import get_graphene_type __all__ = ( @@ -56,10 +55,14 @@ def __init__(self, _type, *args, **kwargs): def list_resolver(django_object_type, resolver, default_manager, root, info, **args): queryset = super(ObjectListField, ObjectListField).list_resolver(django_object_type, resolver, default_manager, root, info, **args) - # Instantiate and apply the FilterSet, if defined + # if there are no filter params then don't need to filter + if not args: + return queryset + filterset_class = django_object_type._meta.filterset_class if filterset_class: - filterset = filterset_class(data=args, queryset=queryset, request=info.context) + filterset = filterset_class(data=args if args else None, queryset=queryset, request=info.context) + if not filterset.is_valid(): return queryset.none() return filterset.qs diff --git a/netbox/tenancy/graphql/schema.py b/netbox/tenancy/graphql/schema.py index de0a1781a74..8c464882096 100644 --- a/netbox/tenancy/graphql/schema.py +++ b/netbox/tenancy/graphql/schema.py @@ -1,24 +1,44 @@ import graphene from netbox.graphql.fields import ObjectField, ObjectListField +from tenancy import models from .types import * +from utilities.graphql_optimizer import gql_query_optimizer class TenancyQuery(graphene.ObjectType): tenant = ObjectField(TenantType) tenant_list = ObjectListField(TenantType) + def resolve_tenant_list(root, info, **kwargs): + return gql_query_optimizer(models.Tenant.objects.all(), info) + tenant_group = ObjectField(TenantGroupType) tenant_group_list = ObjectListField(TenantGroupType) + def resolve_tenant_group_list(root, info, **kwargs): + return gql_query_optimizer(models.TenantGroup.objects.all(), info) + contact = ObjectField(ContactType) contact_list = ObjectListField(ContactType) + def resolve_contact_list(root, info, **kwargs): + return gql_query_optimizer(models.Contact.objects.all(), info) + contact_role = ObjectField(ContactRoleType) contact_role_list = ObjectListField(ContactRoleType) + def resolve_contact_role_list(root, info, **kwargs): + return gql_query_optimizer(models.ContactRole.objects.all(), info) + contact_group = ObjectField(ContactGroupType) contact_group_list = ObjectListField(ContactGroupType) + def resolve_contact_group_list(root, info, **kwargs): + return gql_query_optimizer(models.ContactGroup.objects.all(), info) + contact_assignment = ObjectField(ContactAssignmentType) contact_assignment_list = ObjectListField(ContactAssignmentType) + + def resolve_contact_assignment_list(root, info, **kwargs): + return gql_query_optimizer(models.ContactAssignment.objects.all(), info) diff --git a/netbox/users/graphql/schema.py b/netbox/users/graphql/schema.py index 4a58be128c9..3b04d841834 100644 --- a/netbox/users/graphql/schema.py +++ b/netbox/users/graphql/schema.py @@ -1,12 +1,20 @@ import graphene +from django.contrib.auth.models import Group, User from netbox.graphql.fields import ObjectField, ObjectListField from .types import * +from utilities.graphql_optimizer import gql_query_optimizer class UsersQuery(graphene.ObjectType): group = ObjectField(GroupType) group_list = ObjectListField(GroupType) + def resolve_group_list(root, info, **kwargs): + return gql_query_optimizer(Group.objects.all(), info) + user = ObjectField(UserType) user_list = ObjectListField(UserType) + + def resolve_user_list(root, info, **kwargs): + return gql_query_optimizer(User.objects.all(), info) diff --git a/netbox/utilities/graphql_optimizer.py b/netbox/utilities/graphql_optimizer.py new file mode 100644 index 00000000000..e50a5a44b8b --- /dev/null +++ b/netbox/utilities/graphql_optimizer.py @@ -0,0 +1,249 @@ +import functools + +import graphql +from django.core.exceptions import FieldDoesNotExist +from django.db.models import ForeignKey, Prefetch +from django.db.models.constants import LOOKUP_SEP +from django.db.models.fields.reverse_related import ManyToOneRel +from graphene import InputObjectType +from graphene.types.generic import GenericScalar +from graphene.types.resolver import default_resolver +from graphene_django import DjangoObjectType +from graphql import FieldNode, GraphQLObjectType, GraphQLResolveInfo, GraphQLSchema +from graphql.execution.execute import get_field_def +from graphql.language.ast import FragmentSpreadNode, InlineFragmentNode, VariableNode +from graphql.pyutils import Path +from graphql.type.definition import GraphQLInterfaceType, GraphQLUnionType + + +def gql_query_optimizer(queryset, info, **options): + return QueryOptimizer(info).optimize(queryset) + + +class QueryOptimizer(object): + def __init__(self, info, **options): + self.root_info = info + + def optimize(self, queryset): + info = self.root_info + field_def = get_field_def(info.schema, info.parent_type, info.field_nodes[0]) + + field_names = self._optimize_gql_selections( + self._get_type(field_def), + info.field_nodes[0], + ) + + qs = queryset.prefetch_related(*field_names) + return qs + + def _get_type(self, field_def): + a_type = field_def.type + while hasattr(a_type, "of_type"): + a_type = a_type.of_type + return a_type + + def _get_graphql_schema(self, schema): + if isinstance(schema, GraphQLSchema): + return schema + else: + return schema.graphql_schema + + def _get_possible_types(self, graphql_type): + if isinstance(graphql_type, (GraphQLInterfaceType, GraphQLUnionType)): + graphql_schema = self._get_graphql_schema(self.root_info.schema) + return graphql_schema.get_possible_types(graphql_type) + else: + return (graphql_type,) + + def _get_base_model(self, graphql_types): + models = tuple(t.graphene_type._meta.model for t in graphql_types) + for model in models: + if all(issubclass(m, model) for m in models): + return model + return None + + def handle_inline_fragment(self, selection, schema, possible_types, field_names): + fragment_type_name = selection.type_condition.name.value + graphql_schema = self._get_graphql_schema(schema) + fragment_type = graphql_schema.get_type(fragment_type_name) + fragment_possible_types = self._get_possible_types(fragment_type) + for fragment_possible_type in fragment_possible_types: + fragment_model = fragment_possible_type.graphene_type._meta.model + parent_model = self._get_base_model(possible_types) + if not parent_model: + continue + path_from_parent = fragment_model._meta.get_path_from_parent(parent_model) + select_related_name = LOOKUP_SEP.join(p.join_field.name for p in path_from_parent) + if not select_related_name: + continue + sub_field_names = self._optimize_gql_selections( + fragment_possible_type, + selection, + ) + field_names.append(select_related_name) + return + + def handle_fragment_spread(self, field_names, name, field_type): + fragment = self.root_info.fragments[name] + sub_field_names = self._optimize_gql_selections( + field_type, + fragment, + ) + + def _optimize_gql_selections(self, field_type, field_ast): + field_names = [] + selection_set = field_ast.selection_set + if not selection_set: + return field_names + optimized_fields_by_model = {} + schema = self.root_info.schema + graphql_schema = self._get_graphql_schema(schema) + graphql_type = graphql_schema.get_type(field_type.name) + + possible_types = self._get_possible_types(graphql_type) + for selection in selection_set.selections: + if isinstance(selection, InlineFragmentNode): + self.handle_inline_fragment(selection, schema, possible_types, field_names) + else: + name = selection.name.value + if isinstance(selection, FragmentSpreadNode): + self.handle_fragment_spread(field_names, name, field_type) + else: + for possible_type in possible_types: + selection_field_def = possible_type.fields.get(name) + if not selection_field_def: + continue + + graphene_type = possible_type.graphene_type + model = getattr(graphene_type._meta, "model", None) + if model and name not in optimized_fields_by_model: + field_model = optimized_fields_by_model[name] = model + if field_model == model: + self._optimize_field( + field_names, + model, + selection, + selection_field_def, + possible_type, + ) + return field_names + + def _get_field_info(self, field_names, model, selection, field_def): + name = None + model_field = None + name = self._get_name_from_resolver(field_def.resolve) + if not name and callable(field_def.resolve) and not isinstance(field_def.resolve, functools.partial): + name = selection.name.value + if name: + model_field = self._get_model_field_from_name(model, name) + + return (name, model_field) + + def _optimize_field(self, field_names, model, selection, field_def, parent_type): + name, model_field = self._get_field_info(field_names, model, selection, field_def) + if model_field: + self._optimize_field_by_name(field_names, model, selection, field_def, name, model_field) + + return + + def _optimize_field_by_name(self, field_names, model, selection, field_def, name, model_field): + if model_field.many_to_one or model_field.one_to_one: + sub_field_names = self._optimize_gql_selections( + self._get_type(field_def), + selection, + ) + if name not in field_names: + field_names.append(name) + + for field in sub_field_names: + prefetch_key = f"{name}__{field}" + if prefetch_key not in field_names: + field_names.append(prefetch_key) + + if model_field.one_to_many or model_field.many_to_many: + sub_field_names = self._optimize_gql_selections( + self._get_type(field_def), + selection, + ) + + if isinstance(model_field, ManyToOneRel): + sub_field_names.append(model_field.field.name) + + field_names.append(name) + for field in sub_field_names: + prefetch_key = f"{name}__{field}" + if prefetch_key not in field_names: + field_names.append(prefetch_key) + + return + + def _get_optimization_hints(self, resolver): + return getattr(resolver, "optimization_hints", None) + + def _get_value(self, info, value): + if isinstance(value, VariableNode): + var_name = value.name.value + value = info.variable_values.get(var_name) + return value + elif isinstance(value, InputObjectType): + return value.__dict__ + else: + return GenericScalar.parse_literal(value) + + def _get_name_from_resolver(self, resolver): + optimization_hints = self._get_optimization_hints(resolver) + if optimization_hints: + name_fn = optimization_hints.model_field + if name_fn: + return name_fn() + if self._is_resolver_for_id_field(resolver): + return "id" + elif isinstance(resolver, functools.partial): + resolver_fn = resolver + if resolver_fn.func != default_resolver: + # Some resolvers have the partial function as the second + # argument. + for arg in resolver_fn.args: + if isinstance(arg, (str, functools.partial)): + break + else: + # No suitable instances found, default to first arg + arg = resolver_fn.args[0] + resolver_fn = arg + if isinstance(resolver_fn, functools.partial) and resolver_fn.func == default_resolver: + return resolver_fn.args[0] + if self._is_resolver_for_id_field(resolver_fn): + return "id" + return resolver_fn + + def _is_resolver_for_id_field(self, resolver): + resolve_id = DjangoObjectType.resolve_id + return resolver == resolve_id + + def _get_model_field_from_name(self, model, name): + try: + return model._meta.get_field(name) + except FieldDoesNotExist: + descriptor = model.__dict__.get(name) + if not descriptor: + return None + return getattr(descriptor, "rel", None) or getattr(descriptor, "related", None) # Django < 1.9 + + def _is_foreign_key_id(self, model_field, name): + return isinstance(model_field, ForeignKey) and model_field.name != name and model_field.get_attname() == name + + def _create_resolve_info(self, field_name, field_asts, return_type, parent_type): + return GraphQLResolveInfo( + field_name, + field_asts, + return_type, + parent_type, + Path(None, 0, None), + schema=self.root_info.schema, + fragments=self.root_info.fragments, + root_value=self.root_info.root_value, + operation=self.root_info.operation, + variable_values=self.root_info.variable_values, + context=self.root_info.context, + is_awaitable=self.root_info.is_awaitable, + ) diff --git a/netbox/virtualization/graphql/schema.py b/netbox/virtualization/graphql/schema.py index e22532214c2..88e6aac6460 100644 --- a/netbox/virtualization/graphql/schema.py +++ b/netbox/virtualization/graphql/schema.py @@ -2,20 +2,37 @@ from netbox.graphql.fields import ObjectField, ObjectListField from .types import * +from utilities.graphql_optimizer import gql_query_optimizer +from virtualization import models class VirtualizationQuery(graphene.ObjectType): cluster = ObjectField(ClusterType) cluster_list = ObjectListField(ClusterType) + def resolve_cluster_list(root, info, **kwargs): + return gql_query_optimizer(models.Cluster.objects.all(), info) + cluster_group = ObjectField(ClusterGroupType) cluster_group_list = ObjectListField(ClusterGroupType) + def resolve_cluster_group_list(root, info, **kwargs): + return gql_query_optimizer(models.ClusterGroup.objects.all(), info) + cluster_type = ObjectField(ClusterTypeType) cluster_type_list = ObjectListField(ClusterTypeType) + def resolve_cluster_type_list(root, info, **kwargs): + return gql_query_optimizer(models.ClusterType.objects.all(), info) + virtual_machine = ObjectField(VirtualMachineType) virtual_machine_list = ObjectListField(VirtualMachineType) + def resolve_virtual_machine_list(root, info, **kwargs): + return gql_query_optimizer(models.VirtualMachine.objects.all(), info) + vm_interface = ObjectField(VMInterfaceType) vm_interface_list = ObjectListField(VMInterfaceType) + + def resolve_vm_interface_list(root, info, **kwargs): + return gql_query_optimizer(models.VMInterface.objects.all(), info) diff --git a/netbox/wireless/graphql/schema.py b/netbox/wireless/graphql/schema.py index cd8fd9f5267..e6e46be3f99 100644 --- a/netbox/wireless/graphql/schema.py +++ b/netbox/wireless/graphql/schema.py @@ -2,14 +2,25 @@ from netbox.graphql.fields import ObjectField, ObjectListField from .types import * +from utilities.graphql_optimizer import gql_query_optimizer +from wireless import models class WirelessQuery(graphene.ObjectType): wireless_lan = ObjectField(WirelessLANType) wireless_lan_list = ObjectListField(WirelessLANType) + def resolve_wireless_lan_list(root, info, **kwargs): + return gql_query_optimizer(models.WirelessLAN.objects.all(), info) + wireless_lan_group = ObjectField(WirelessLANGroupType) wireless_lan_group_list = ObjectListField(WirelessLANGroupType) + def resolve_wireless_lan_group_list(root, info, **kwargs): + return gql_query_optimizer(models.WirelessLANGroup.objects.all(), info) + wireless_link = ObjectField(WirelessLinkType) wireless_link_list = ObjectListField(WirelessLinkType) + + def resolve_wireless_link_list(root, info, **kwargs): + return gql_query_optimizer(models.WirelessLink.objects.all(), info)