diff --git a/build.zig b/build.zig index f3f2cad..a6a97ef 100644 --- a/build.zig +++ b/build.zig @@ -1,7 +1,7 @@ const Builder = std.build.Builder; const std = @import("std"); pub fn build(b: *Builder) void { - const native_opt = b.option(bool, "native", "if many cpu, turn this on"); + const native_opt = b.option(bool, "native", "if other people exist, turn this off"); const option_libc = (b.option(bool, "libc", "build with libc?")) orelse false; const is_native = native_opt orelse true; @@ -17,21 +17,32 @@ pub fn build(b: *Builder) void { }); } - const mode = b.standardReleaseOptions(); + const optimize = b.standardOptimizeOption(.{}); // this exports both a library and a binary - const exe = b.addExecutable("zigdig", "src/main.zig"); - exe.setTarget(target); - exe.setBuildMode(mode); + const exe = b.addExecutable(.{ + .name = "zigdig", + .root_source_file = .{ .path = "src/main.zig" }, + .target = target, + .optimize = optimize, + }); if (option_libc) exe.linkLibC(); + exe.install(); + + const exe_tinyhost = b.addExecutable(.{ + .name = "zigdig-tiny", + .root_source_file = .{ .path = "src/main_tinyhost.zig" }, + .target = target, + .optimize = optimize, + }); + if (option_libc) exe.linkLibC(); + exe_tinyhost.install(); - const lib = b.addStaticLibrary("zigdig", "src/lib.zig"); - lib.setTarget(target); - lib.setBuildMode(mode); - - var lib_tests = b.addTest("src/main.zig"); - lib_tests.setBuildMode(mode); + var lib_tests = b.addTest(.{ + .root_source_file = .{ .path = "src/main.zig" }, + .optimize = optimize, + }); const test_step = b.step("test", "Run library tests"); test_step.dependOn(&lib_tests.step); @@ -40,12 +51,9 @@ pub fn build(b: *Builder) void { const run_step = b.step("run", "Run example binary"); run_step.dependOn(&run_cmd.step); - b.default_step.dependOn(&lib.step); - b.default_step.dependOn(&exe.step); - - lib.addPackagePath("dns", "src/lib.zig"); - exe.addPackagePath("dns", "src/lib.zig"); - - b.installArtifact(lib); + b.addModule(.{ + .name = "dns", + .source_file = .{ .path = "src/lib.zig" }, + }); b.installArtifact(exe); } diff --git a/src/enums.zig b/src/enums.zig index b2007f8..864f742 100644 --- a/src/enums.zig +++ b/src/enums.zig @@ -97,6 +97,17 @@ pub const ResourceClass = enum(u16) { HS = 4, WILDCARD = 255, + pub fn readFrom(reader: anytype) !@This() { + const resource_class_int = try reader.readIntBig(u16); + return std.meta.intToEnum(@This(), resource_class_int) catch |err| { + logger.err( + "unknown resource class {d}, got {s}", + .{ resource_class_int, @errorName(err) }, + ); + return err; + }; + } + pub fn writeTo(self: @This(), writer: anytype) !usize { try writer.writeIntBig(u16, @enumToInt(self)); return 16 / 8; diff --git a/src/helpers.zig b/src/helpers.zig index fdc5fd7..ce9b453 100644 --- a/src/helpers.zig +++ b/src/helpers.zig @@ -2,16 +2,30 @@ const std = @import("std"); const dns = @import("lib.zig"); /// Print a slice of DNSResource to stderr. -fn printList(packet: *dns.Packet, allocator: std.mem.Allocator, writer: anytype, resource_list: []dns.Resource) !void { +fn printList( + name_pool: *dns.NamePool, + writer: anytype, + resource_list: []dns.Resource, +) !void { // TODO the formatting here is not good... try writer.print(";;name\t\t\trrtype\tclass\tttl\trdata\n", .{}); for (resource_list) |resource| { - var resource_data = try dns.ResourceData.fromOpaque(packet, resource.typ, resource.opaque_rdata, allocator); - defer resource_data.deinit(allocator); + const resource_data = try dns.ResourceData.fromOpaque( + resource.typ, + resource.opaque_rdata.?, + .{ + .name_provider = .{ .full = name_pool }, + .allocator = name_pool.allocator, + }, + ); + defer switch (resource_data) { + .TXT => resource_data.deinit(name_pool.allocator), + else => {}, // managed a layer above + }; - try writer.print("{s}\t\t{s}\t{s}\t{d}\t{any}\n", .{ - resource.name, + try writer.print("{?}\t\t{s}\t{s}\t{d}\t{any}\n", .{ + resource.name.?, @tagName(resource.typ), @tagName(resource.class), resource.ttl, @@ -22,15 +36,25 @@ fn printList(packet: *dns.Packet, allocator: std.mem.Allocator, writer: anytype, try writer.print("\n", .{}); } -/// Print a packet to stderr. -pub fn printAsZoneFile(packet: *dns.Packet, allocator: std.mem.Allocator, writer: anytype) !void { - try writer.print("id: {}, opcode: {}, rcode: {}\n", .{ - packet.header.id, +/// Print a packet in the format of a "zone file". +/// +/// This will deserialize resourcedata in the resource sections, so +/// a NamePool instance is required. +/// +/// This helper method will NOT free the memory created by name allocation, +/// you should do this manually in a defer block calling NamePool.deinitWithNames. +pub fn printAsZoneFile( + packet: *dns.Packet, + name_pool: *dns.NamePool, + writer: anytype, +) !void { + try writer.print(";; opcode: {}, status: {}, id: {}\n", .{ packet.header.opcode, packet.header.response_code, + packet.header.id, }); - try writer.print("qd: {}, an: {}, ns: {}, ar: {}\n\n", .{ + try writer.print(";; QUERY: {}, ANSWER: {}, AUTHORITY: {}, ADDITIONAL: {}\n\n", .{ packet.header.question_length, packet.header.answer_length, packet.header.nameserver_length, @@ -38,11 +62,11 @@ pub fn printAsZoneFile(packet: *dns.Packet, allocator: std.mem.Allocator, writer }); if (packet.header.question_length > 0) { - try writer.print(";;-- question --\n", .{}); + try writer.print(";; QUESTION SECTION:\n", .{}); try writer.print(";;name\ttype\tclass\n", .{}); for (packet.questions) |question| { - try writer.print(";{s}\t{s}\t{s}\n", .{ + try writer.print(";{?}\t{s}\t{s}\n", .{ question.name, @tagName(question.typ), @tagName(question.class), @@ -53,22 +77,22 @@ pub fn printAsZoneFile(packet: *dns.Packet, allocator: std.mem.Allocator, writer } if (packet.header.answer_length > 0) { - try writer.print(";; -- answer --\n", .{}); - try printList(packet, allocator, writer, packet.answers); + try writer.print(";; ANSWER SECTION:\n", .{}); + try printList(name_pool, writer, packet.answers); } else { try writer.print(";; no answer\n", .{}); } if (packet.header.nameserver_length > 0) { - try writer.print(";; -- authority --\n", .{}); - try printList(packet, allocator, writer, packet.nameservers); + try writer.print(";; AUTHORITY SECTION:\n", .{}); + try printList(name_pool, writer, packet.nameservers); } else { try writer.print(";; no authority\n\n", .{}); } if (packet.header.additional_length > 0) { - try writer.print(";; -- additional --\n", .{}); - try printList(packet, allocator, writer, packet.additionals); + try writer.print(";; ADDITIONAL SECTION:\n", .{}); + try printList(name_pool, writer, packet.additionals); } else { try writer.print(";; no additional\n\n", .{}); } @@ -116,10 +140,18 @@ pub const DNSConnection = struct { ); } - pub fn receivePacket( + /// Deserializes and allocates an *entire* DNS packet. + /// + /// This function is not encouraged if you only wish to get A/AAAA + /// records for a domain name through the system DNS resolver, as this + /// allocates all the data of the packet. Use `receiveTrustedAddresses` + /// for such. + pub fn receiveFullPacket( self: Self, packet_allocator: std.mem.Allocator, + /// Maximum size for the incoming UDP datagram comptime max_incoming_message_size: usize, + options: dns.ParserOptions, ) !dns.IncomingPacket { var packet_buffer: [max_incoming_message_size]u8 = undefined; const read_bytes = try self.socket.read(&packet_buffer); @@ -127,10 +159,78 @@ pub const DNSConnection = struct { logger.debug("read {d} bytes", .{read_bytes}); var stream = std.io.FixedBufferStream([]const u8){ .buffer = packet_bytes, .pos = 0 }; - return try dns.Packet.readFrom(stream.reader(), packet_allocator); + return parseFullPacket(stream.reader(), packet_allocator, options); } }; +pub fn parseFullPacket( + reader: anytype, + // TODO separate allocator and options.allocator + allocator: std.mem.Allocator, + options: dns.ParserOptions, +) !dns.IncomingPacket { + if (options.allocator == null) { + @panic("parseFullPacket requires options.allocator to be set"); + } + + var packet = try allocator.create(dns.Packet); + packet.extra_names = null; + errdefer allocator.destroy(packet); + var incoming_packet = dns.IncomingPacket{ + .allocator = allocator, + .packet = packet, + }; + + var ctx = dns.ParserContext{}; + var parser = dns.parser(reader, &ctx, options); + + var builtin_name_pool = dns.NamePool.init(allocator); + defer builtin_name_pool.deinit(); + + var name_pool = if (options.name_pool) |name_pool| name_pool else &builtin_name_pool; + + var questions = std.ArrayList(dns.Question).init(allocator); + defer questions.deinit(); + + var answers = std.ArrayList(dns.Resource).init(allocator); + defer answers.deinit(); + + var nameservers = std.ArrayList(dns.Resource).init(allocator); + defer nameservers.deinit(); + + var additionals = std.ArrayList(dns.Resource).init(allocator); + defer additionals.deinit(); + + while (try parser.next()) |part| { + switch (part) { + .header => |header| packet.header = header, + .question => |question_with_raw_names| { + var question = try name_pool.transmuteResource(question_with_raw_names); + try questions.append(question); + }, + .end_question => packet.questions = try questions.toOwnedSlice(), + .answer, .nameserver, .additional => |raw_resource| { + // since we give it an allocator, we don't receive rdata + // sections + + var resource = try name_pool.transmuteResource(raw_resource); + try (switch (part) { + .answer => answers, + .nameserver => nameservers, + .additional => additionals, + else => unreachable, + }).append(resource); + }, + .end_answer => packet.answers = try answers.toOwnedSlice(), + .end_nameserver => packet.nameservers = try nameservers.toOwnedSlice(), + .end_additional => packet.additionals = try additionals.toOwnedSlice(), + .answer_rdata, .nameserver_rdata, .additional_rdata => unreachable, + } + } + + return incoming_packet; +} + const logger = std.log.scoped(.dns_helpers); /// Open a socket to a random DNS resolver declared in the systems' @@ -215,14 +315,103 @@ const AddressList = struct { } }; -/// A very simple getAddressList that sets up the DNS connection and extracts -/// the A records. +const ReceiveTrustedAddressesOptions = struct { + max_incoming_message_size: usize = 4096, + requested_packet_header: ?dns.Header = null, + //resource_resolution_options: dns.ResourceResolutionOptions = .{}, +}; + +/// This is an optimized deserializer that is only interested in A and AAAA +/// answers, returning a list of std.net.Address. /// -/// This function does not implement the "happy eyeballs" algorithm. -pub fn getAddressList(incoming_name: []const u8, allocator: std.mem.Allocator) !AddressList { - var name_buffer: [128][]const u8 = undefined; - const name = try dns.Name.fromString(incoming_name, &name_buffer); +/// This function trusts the DNS connection to be returning answers related +/// to the given domain sent through DNSConnection.sendPacket. +/// +/// This, however, does not allocate the packet. It is very memory efficient +/// in that regard. +pub fn receiveTrustedAddresses( + allocator: std.mem.Allocator, + connection: *const DNSConnection, + /// Options to receive message and deserialize it + comptime options: ReceiveTrustedAddressesOptions, +) ![]std.net.Address { + var packet_buffer: [options.max_incoming_message_size]u8 = undefined; + const read_bytes = try connection.socket.read(&packet_buffer); + const packet_bytes = packet_buffer[0..read_bytes]; + logger.debug("read {d} bytes", .{read_bytes}); + + var stream = std.io.FixedBufferStream([]const u8){ + .buffer = packet_bytes, + .pos = 0, + }; + + var ctx = dns.ParserContext{}; + + var parser = dns.parser(stream.reader(), &ctx, .{}); + + var addrs = std.ArrayList(std.net.Address).init(allocator); + errdefer addrs.deinit(); + + var current_resource: ?dns.Resource = null; + + while (try parser.next()) |part| { + switch (part) { + .header => |header| { + if (options.requested_packet_header) |given_header| { + if (given_header.id != header.id) + return error.InvalidReply; + } + + if (!header.is_response) return error.InvalidResponse; + + switch (header.response_code) { + .NoError => {}, + .FormatError => return error.ServerFormatError, // bug in implementation caught by server? + .ServerFailure => return error.ServerFailure, + .NameError => return error.ServerNameError, + .NotImplemented => return error.ServerNotImplemented, + .Refused => return error.ServerRefused, + } + }, + .answer => |raw_resource| { + current_resource = raw_resource; + }, + + .answer_rdata => |rdata| { + // TODO parser.reader() + var reader = parser.wrapper_reader.reader(); + defer current_resource = null; + var maybe_addr = switch (current_resource.?.typ) { + .A => blk: { + var ip4addr: [4]u8 = undefined; + _ = try reader.read(&ip4addr); + break :blk std.net.Address.initIp4(ip4addr, 0); + }, + .AAAA => blk: { + var ip6_addr: [16]u8 = undefined; + _ = try reader.read(&ip6_addr); + break :blk std.net.Address.initIp6(ip6_addr, 0, 0, 0); + }, + else => blk: { + try reader.skipBytes(rdata.size, .{}); + break :blk null; + }, + }; + + if (maybe_addr) |addr| try addrs.append(addr); + }, + else => {}, + } + } + + return try addrs.toOwnedSlice(); +} +fn fetchTrustedAddresses( + allocator: std.mem.Allocator, + name: dns.Name, + qtype: dns.ResourceType, +) ![]std.net.Address { var packet = dns.Packet{ .header = .{ .id = dns.helpers.randomHeaderId(), @@ -230,10 +419,12 @@ pub fn getAddressList(incoming_name: []const u8, allocator: std.mem.Allocator) ! .wanted_recursion = true, .question_length = 1, }, + + // TODO test if we can put more than one question and itll reply with all .questions = &[_]dns.Question{ .{ .name = name, - .typ = .A, + .typ = qtype, .class = .IN, }, }, @@ -245,44 +436,32 @@ pub fn getAddressList(incoming_name: []const u8, allocator: std.mem.Allocator) ! const conn = try dns.helpers.connectToSystemResolver(); defer conn.close(); - logger.info("selected nameserver: {}\n", .{conn.address}); - + logger.debug("selected nameserver: {}", .{conn.address}); try conn.sendPacket(packet); + return try receiveTrustedAddresses(allocator, &conn, .{}); +} - const reply = try conn.receivePacket(allocator, 4096); - defer reply.deinit(); - - const reply_packet = reply.packet; - - if (packet.header.id != reply_packet.header.id) return error.InvalidReply; - if (!reply_packet.header.is_response) return error.InvalidResponse; - - switch (reply_packet.header.response_code) { - .NoError => {}, - .FormatError => return error.ServerFormatError, // bug in implementation caught by server? - .ServerFailure => return error.ServerFailure, - .NameError => return error.ServerNameError, - .NotImplemented => return error.ServerNotImplemented, - .Refused => return error.ServerRefused, - } +/// A very simple getAddressList that sets up the DNS connection and extracts +/// the A records. +/// +/// This function does not implement the "happy eyeballs" algorithm. +pub fn getAddressList(incoming_name: []const u8, allocator: std.mem.Allocator) !AddressList { + var name_buffer: [128][]const u8 = undefined; + const name = try dns.Name.fromString(incoming_name, &name_buffer); - var list = std.ArrayList(std.net.Address).init(allocator); - defer list.deinit(); + var final_list = std.ArrayList(std.net.Address).init(allocator); + defer final_list.deinit(); - for (reply_packet.answers) |resource| { - var resource_data = try dns.ResourceData.fromOpaque( - reply_packet, - resource.typ, - resource.opaque_rdata, - allocator, - ); - defer resource_data.deinit(allocator); + var addrs_v4 = try fetchTrustedAddresses(allocator, name, .A); + defer allocator.free(addrs_v4); + for (addrs_v4) |addr| try final_list.append(addr); - try list.append(resource_data.A); - } + var addrs_v6 = try fetchTrustedAddresses(allocator, name, .AAAA); + defer allocator.free(addrs_v6); + for (addrs_v6) |addr| try final_list.append(addr); return AddressList{ .allocator = allocator, - .addrs = list.toOwnedSlice(), + .addrs = try final_list.toOwnedSlice(), }; } diff --git a/src/lib.zig b/src/lib.zig index cf1ee40..0519063 100644 --- a/src/lib.zig +++ b/src/lib.zig @@ -1,6 +1,13 @@ pub const ResourceType = @import("enums.zig").ResourceType; pub const ResourceClass = @import("enums.zig").ResourceClass; -pub const Name = @import("name.zig").Name; + +pub const names = @import("name.zig"); +pub const FullName = names.FullName; +pub const RawName = names.RawName; +pub const Name = names.Name; +pub const LabelComponent = names.LabelComponent; +pub const NamePool = names.NamePool; + const pkt = @import("packet.zig"); pub const Packet = pkt.Packet; pub const ResponseCode = pkt.ResponseCode; @@ -8,6 +15,14 @@ pub const OpCode = pkt.OpCode; pub const IncomingPacket = pkt.IncomingPacket; pub const Question = pkt.Question; pub const Resource = pkt.Resource; +pub const Header = pkt.Header; + +pub const parserlib = @import("parser.zig"); +pub const parser = parserlib.parser; +pub const Parser = parserlib.Parser; +pub const ParserOptions = parserlib.ParserOptions; +pub const ParserContext = parserlib.ParserContext; + pub const helpers = @import("helpers.zig"); const resource_data = @import("resource_data.zig"); diff --git a/src/main.zig b/src/main.zig index ea84224..d586e6f 100644 --- a/src/main.zig +++ b/src/main.zig @@ -3,7 +3,26 @@ const dns = @import("lib.zig"); const logger = std.log.scoped(.zigdig_main); +pub const std_options = struct { + pub const log_level = .debug; + pub const logFn = logfn; +}; + +pub var current_log_level: std.log.Level = .info; + +fn logfn( + comptime message_level: std.log.Level, + comptime scope: @Type(.EnumLiteral), + comptime format: []const u8, + args: anytype, +) void { + if (@enumToInt(message_level) <= @enumToInt(@import("root").current_log_level)) { + std.log.defaultLog(message_level, scope, format, args); + } +} + pub fn main() !void { + if (std.mem.eql(u8, std.os.getenv("DEBUG") orelse "", "1")) current_log_level = .debug; var gpa = std.heap.GeneralPurposeAllocator(.{}){}; defer { _ = gpa.deinit(); @@ -60,20 +79,31 @@ pub fn main() !void { logger.info("selected nameserver: {}\n", .{conn.address}); const stdout = std.io.getStdOut(); - try dns.helpers.printAsZoneFile(&packet, allocator, stdout.writer()); + try dns.helpers.printAsZoneFile(&packet, undefined, stdout.writer()); try conn.sendPacket(packet); - const reply = try conn.receivePacket(allocator, 4096); - defer reply.deinit(); + // as we need Names inside the NamePool to live beyond the + // ReplyPacket, we must take ownership of them and deinit ourselves + // + // This is required to parse names inside printAsZoneFile + var name_pool = dns.NamePool.init(allocator); + defer name_pool.deinitWithNames(); + + const reply = try conn.receiveFullPacket( + allocator, + 4096, + .{ .allocator = allocator, .name_pool = &name_pool }, + ); + defer reply.deinit(.{ .names = false }); const reply_packet = reply.packet; - logger.info("reply: {}", .{reply_packet}); + logger.debug("reply: {}", .{reply_packet}); try std.testing.expectEqual(packet.header.id, reply_packet.header.id); try std.testing.expect(reply_packet.header.is_response); - try dns.helpers.printAsZoneFile(reply_packet, allocator, stdout.writer()); + try dns.helpers.printAsZoneFile(reply_packet, &name_pool, stdout.writer()); } test "awooga" { diff --git a/src/main_tinyhost.zig b/src/main_tinyhost.zig new file mode 100644 index 0000000..b4b2194 --- /dev/null +++ b/src/main_tinyhost.zig @@ -0,0 +1,47 @@ +const std = @import("std"); +const dns = @import("lib.zig"); + +const logger = std.log.scoped(.zigdig_main); +pub const std_options = struct { + pub const log_level = .debug; + pub const logFn = logfn; +}; + +pub var current_log_level: std.log.Level = .info; + +fn logfn( + comptime message_level: std.log.Level, + comptime scope: @Type(.EnumLiteral), + comptime format: []const u8, + args: anytype, +) void { + if (@enumToInt(message_level) <= @enumToInt(@import("root").current_log_level)) { + std.log.defaultLog(message_level, scope, format, args); + } +} + +pub fn main() !void { + if (std.mem.eql(u8, std.os.getenv("DEBUG") orelse "", "1")) current_log_level = .debug; + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + defer { + _ = gpa.deinit(); + } + const allocator = gpa.allocator(); + + var args_it = std.process.args(); + _ = args_it.skip(); + + const name_string = (args_it.next() orelse { + logger.warn("no name provided", .{}); + return error.InvalidArgs; + }); + + var addrs = try dns.helpers.getAddressList(name_string, allocator); + defer addrs.deinit(); + + var stdout = std.io.getStdOut().writer(); + + for (addrs.addrs) |addr| { + try stdout.print("{s} has address {any}\n", .{ name_string, addr }); + } +} diff --git a/src/name.zig b/src/name.zig index 0265bda..0647c7b 100644 --- a/src/name.zig +++ b/src/name.zig @@ -1,4 +1,220 @@ const std = @import("std"); +const dns = @import("lib.zig"); + +const logger = std.log.scoped(.dns_name); + +pub const LabelComponent = union(enum) { + Full: []const u8, + /// Holds the first offset component of that pointer. + /// + /// You still have to read a byte for the second component and assemble + /// it into the final packet offset. + Pointer: u16, + Null: void, +}; + +pub const RawName = struct { + // TODO rename this to components + labels: []LabelComponent, + + /// Represents the index of that name in its packet's body. + /// + /// **This is an internal field for DNS name pointer resolution.** + packet_index: ?usize = null, +}; + +const ReadNameOptions = struct { + max_label_count: usize = 128, + is_rdata: bool = false, +}; + +pub const Name = union(enum) { + raw: RawName, + full: FullName, + + const Self = @This(); + + pub fn deinit(self: Self, allocator: std.mem.Allocator) void { + switch (self) { + .raw => |raw| { + for (raw.labels) |label| switch (label) { + .Full => |data| allocator.free(data), + else => {}, + }; + + allocator.free(raw.labels); + }, + .full => |full| { + for (full.labels) |label| allocator.free(label); + allocator.free(full.labels); + }, + } + } + + /// Caller owns returned memory. + pub fn readFrom( + reader: anytype, + options: dns.ParserOptions, + ) !?Self { + const current_byte_index = reader.context.ctx.current_byte_count; + + if (options.allocator) |allocator| { + var components = std.ArrayList(LabelComponent).init(allocator); + defer components.deinit(); + + var is_raw: bool = false; + + while (true) { + if (components.items.len > options.max_label_size) + return error.Overflow; + + const component = (try Self.readLabelComponent(reader, allocator)).?; + logger.debug("read name: component {}", .{component}); + try components.append(component); + switch (component) { + .Null => break, + .Pointer => { + is_raw = true; + break; + }, + else => {}, + } + } + + return if (is_raw) .{ .raw = .{ + .labels = try components.toOwnedSlice(), + .packet_index = current_byte_index, + } } else .{ + .full = try FullName.fromAssumedComponents( + allocator, + components.items, + current_byte_index, + ), + }; + } else { + // skip the name in the reader + var name_index: usize = 0; + + while (true) { + if (name_index > options.max_label_size) + return error.Overflow; + + var maybe_component = try Self.readLabelComponent(reader, null); + if (maybe_component) |component| switch (component) { + .Null, .Pointer => break, + else => {}, + }; + } + + return null; + } + } + + /// Deserialize a LabelComponent, which can be: + /// - a pointer + /// - a full label ([]const u8) + /// - a null octet + fn readLabelComponent( + reader: anytype, + maybe_allocator: ?std.mem.Allocator, + ) !?LabelComponent { + // pointers, in the binary representation of a byte, are as follows + // 1 1 B B B B B B | B B B B B B B B + // they are two bytes length, but to identify one, you check if the + // first two bits are 1 and 1 respectively. + // + // then you read the rest, and turn it into an offset (without the + // starting bits!!!) + // + // to prevent inefficiencies, we just read a single byte, see if it + // has the starting bits, and then we chop it off, merging with the + // next byte. pointer offsets are 14 bits long + // + // when it isn't a pointer, its a length for a given label, and that + // length can only be a single byte. + // + // if the length is 0, its a null octet + logger.debug( + "reading label component at {d} bytes", + .{reader.context.ctx.current_byte_count}, + ); + var possible_length = try reader.readIntBig(u8); + if (possible_length == 0) return LabelComponent{ .Null = {} }; + + // RFC1035: + // since the label must begin with two zero bits because + // labels are restricted to 63 octets or less. + + var bit1 = (possible_length & (1 << 7)) != 0; + var bit2 = (possible_length & (1 << 6)) != 0; + + if (bit1 and bit2) { + const second_offset_component = try reader.readIntBig(u8); + + // merge them together + var offset: u16 = (possible_length << 7) | second_offset_component; + + // set first two bits of ptr_offset to zero as they're the + // pointer prefix bits (which are always 1, which brings problems) + offset &= ~@as(u16, 1 << 15); + offset &= ~@as(u16, 1 << 14); + + return LabelComponent{ .Pointer = offset }; + } else { + // those must be 0 for a correct label length to be made + std.debug.assert((!bit1) and (!bit2)); + + // the next bytes contain a full label. + if (maybe_allocator) |allocator| { + var label = try allocator.alloc(u8, possible_length); + const read_bytes = try reader.read(label); + if (read_bytes != label.len) logger.err( + "possible_length = {d} read_bytes = {d} label.len = {d}", + .{ possible_length, read_bytes, label.len }, + ); + std.debug.assert(read_bytes == label.len); + return LabelComponent{ .Full = label }; + } else { + logger.debug("read_name: skip {d} bytes as no alloc", .{possible_length}); + try reader.skipBytes(possible_length, .{}); + return null; + } + } + } + + pub fn writeTo(self: Self, writer: anytype) !usize { + return switch (self) { + .raw => unreachable, // must resolve against original packet so that we know the full name + .full => |full| try full.writeTo(writer), + }; + } + pub fn networkSize(self: Self) usize { + return switch (self) { + .raw => unreachable, // must resolve against original packet so that we know the full name + .full => |full| full.networkSize(), + }; + } + + pub fn fromString(domain: []const u8, buffer: [][]const u8) !Self { + return .{ .full = try FullName.fromString(domain, buffer) }; + } + + pub fn format( + self: Self, + comptime f: []const u8, + options: std.fmt.FormatOptions, + writer: anytype, + ) !void { + return switch (self) { + .full => |full| full.format(f, options, writer), + .raw => |raw| for (raw.labels) |component| switch (component) { + .Pointer => |ptr| try std.fmt.format(writer, "(pointer={d}).", .{ptr}), + .Full => |label| try std.fmt.format(writer, "{s}.", .{label}), + .Null => break, + }, + }; + } +}; /// Represents a single DNS domain-name, which is a slice of strings. /// @@ -9,7 +225,7 @@ const std = @import("std"); /// Keep in mind Name's are not singularly deserializeable, as the names /// could be pointers to different bytes in the packet. /// (RFC1035, section 4.1.4 Message Compression) -pub const Name = struct { +pub const FullName = struct { /// The name's labels. labels: [][]const u8, @@ -20,6 +236,29 @@ pub const Name = struct { const Self = @This(); + /// Create a FullName from a []LabelComponent. + /// + /// Assumes that the slice does not end in a pointer. + pub fn fromAssumedComponents( + allocator: std.mem.Allocator, + components: []LabelComponent, + packet_index: ?usize, + ) !Self { + var labels = std.ArrayList([]const u8).init(allocator); + defer labels.deinit(); + + for (components) |component| switch (component) { + .Full => |data| try labels.append(data), + .Pointer => unreachable, + .Null => break, + }; + + return Self{ + .labels = try labels.toOwnedSlice(), + .packet_index = packet_index, + }; + } + /// Only use this if you have manually heap allocated a Name /// through the internal Packet.readName function. /// @@ -50,6 +289,8 @@ pub const Name = struct { var it = std.mem.split(u8, domain, "."); var idx: usize = 0; while (it.next()) |label| { + if (label.len == 0) return error.EmptyLabelInName; + // Is there a better error for this? if (idx > (buffer.len - 1)) return error.Underflow; // buffer too small @@ -94,3 +335,144 @@ pub const Name = struct { } } }; + +const NameList = std.ArrayList(dns.Name); + +pub const NamePool = struct { + allocator: std.mem.Allocator, + held_names: NameList, + + const Self = @This(); + pub fn init(allocator: std.mem.Allocator) Self { + return Self{ + .allocator = allocator, + .held_names = NameList.init(allocator), + }; + } + + pub fn deinit(self: Self) void { + self.held_names.deinit(); + } + + pub fn deinitWithNames(self: Self) void { + for (self.held_names.items) |name| name.deinit(self.allocator); + self.deinit(); + } + + /// Convert dns.RawName or FullName to FullName, applying pointer + /// resolution, and storing the name for future pointers to be resolved. + /// + /// takes ownership of the given name. + pub fn transmuteName(self: *Self, name: dns.Name) !dns.Name { + return switch (name) { + .full => blk: { + try self.held_names.append(name); + break :blk name; + }, + .raw => |raw| blk: { + defer name.deinit(self.allocator); + // this ends in a Pointer, create a new FullName + var resolved_labels = std.ArrayList([]const u8).init(self.allocator); + defer resolved_labels.deinit(); + + for (raw.labels) |raw_component| switch (raw_component) { + .Full => |text| try resolved_labels.append(try self.allocator.dupe(u8, text)), + .Pointer => |packet_offset| { + + // step 1: find out the name we already have + // that contains this pointer + var maybe_referenced_name: ?dns.FullName = null; + for (self.held_names.items) |held_name_from_list| { + const held_name = held_name_from_list.full; + + const packet_index = + if (held_name.packet_index) |idx| + idx + else + continue; + + // calculate end packet offset using length of the + // full name. + + const start_index = packet_index; + var name_length: usize = 0; + for (held_name.labels) |label| + name_length += label.len; + const end_index = packet_index + name_length; + + if (start_index <= packet_offset and packet_offset <= end_index) { + maybe_referenced_name = held_name; + } + } + + if (maybe_referenced_name) |referenced_name| { + var label_cursor: usize = referenced_name.packet_index.?; + var label_index: ?usize = null; + + for (referenced_name.labels) |label, idx| { + // if cursor is in offset's range, select that + // label onwards as our new label + const label_start = label_cursor; + if (label_start <= packet_offset) { + label_index = idx; + } + label_cursor += label.len; + } + + const referenced_labels = referenced_name.labels[label_index.?..]; + + for (referenced_labels) |referenced_label| { + try resolved_labels.append(try self.allocator.dupe(u8, referenced_label)); + } + } else { + logger.warn( + "unknown pointer offset: pointer has offset={d}", + .{packet_offset}, + ); + + for (self.held_names.items) |held_name| { + logger.warn( + "known name: {} at offset {?d}", + .{ held_name, held_name.full.packet_index }, + ); + } + + return error.UnknownPointerOffset; + } + }, + .Null => unreachable, + }; + + const full_name = dns.Name{ .full = dns.FullName{ + .labels = try resolved_labels.toOwnedSlice(), + .packet_index = name.raw.packet_index, + } }; + try self.held_names.append(full_name); + break :blk full_name; + }, + }; + } + + /// given a dns.Question or dns.Resource, resolve pointers and return + /// that same Question or Resource with a FullName inside of it. + /// + /// to be able to do this, ALL questions and resources must be registered + /// in the NamePool. + /// + /// this takes ownership of the given resource. + pub fn transmuteResource(self: *Self, resource: anytype) !@TypeOf(resource) { + switch (@TypeOf(resource)) { + dns.Question => { + var new_question = resource; + new_question.name = try self.transmuteName(resource.name.?); + return new_question; + }, + dns.Resource => { + var new_resource = resource; + new_resource.name = try self.transmuteName(resource.name.?); + return new_resource; + }, + else => @compileError("invalid type to resolve in name pool " ++ @typeName(@TypeOf(resource))), + } + } +}; diff --git a/src/packet.zig b/src/packet.zig index 5405c8e..2d5d2f8 100644 --- a/src/packet.zig +++ b/src/packet.zig @@ -148,14 +148,33 @@ pub const Header = packed struct { }; pub const Question = struct { - name: Name, + name: ?dns.Name, typ: ResourceType, class: ResourceClass, + + const Self = @This(); + + pub fn readFrom(reader: anytype, options: dns.ParserOptions) !Self { + logger.debug( + "reading question at {d} bytes", + .{reader.context.ctx.current_byte_count}, + ); + + var name = try Name.readFrom(reader, options); + var qtype = try reader.readEnum(ResourceType, .Big); + var qclass = try ResourceClass.readFrom(reader); + + return Self{ + .name = name, + .typ = qtype, + .class = qclass, + }; + } }; /// DNS resource pub const Resource = struct { - name: Name, + name: ?dns.Name, typ: ResourceType, class: ResourceClass, @@ -163,18 +182,61 @@ pub const Resource = struct { /// Opaque Resource Data. /// Parsing of the data in this is done by a separate package, dns.rdata - opaque_rdata: dns.ResourceData.Opaque, + opaque_rdata: ?dns.ResourceData.Opaque, + + const Self = @This(); + + /// Extract an RDATA. This only spits out a slice of u8. + /// Parsing of RDATA sections are in the dns.rdata module. + /// + /// Caller owns returned memory. + fn readResourceDataFrom( + reader: anytype, + options: dns.ParserOptions, + ) !?dns.ResourceData.Opaque { + if (options.allocator) |allocator| { + const rdata_length = try reader.readIntBig(u16); + const rdata_index = reader.context.ctx.current_byte_count; + + var opaque_rdata = try allocator.alloc(u8, rdata_length); + const read_bytes = try reader.read(opaque_rdata); + std.debug.assert(read_bytes == opaque_rdata.len); + return .{ + .data = opaque_rdata, + .current_byte_count = rdata_index, + }; + } else { + return null; + } + } + + pub fn readFrom(reader: anytype, options: dns.ParserOptions) !Self { + logger.debug("reading resource at {d} bytes", .{reader.context.ctx.current_byte_count}); + var name = try Name.readFrom(reader, options); + var typ = try ResourceType.readFrom(reader); + var class = try ResourceClass.readFrom(reader); + var ttl = try reader.readIntBig(i32); + var opaque_rdata = try Self.readResourceDataFrom(reader, options); + + return Self{ + .name = name, + .typ = typ, + .class = class, + .ttl = ttl, + .opaque_rdata = opaque_rdata, + }; + } pub fn writeTo(self: @This(), writer: anytype) !usize { - const name_size = try self.name.writeTo(writer); + const name_size = try self.name.?.writeTo(writer); const typ_size = try self.typ.writeTo(writer); const class_size = try self.class.writeTo(writer); const ttl_size = 32 / 8; try writer.writeIntBig(i32, self.ttl); const rdata_prefix_size = 16 / 8; - try writer.writeIntBig(u16, @intCast(u16, self.opaque_rdata.data.len)); - const rdata_size = try writer.write(self.opaque_rdata.data); + try writer.writeIntBig(u16, @intCast(u16, self.opaque_rdata.?.data.len)); + const rdata_size = try writer.write(self.opaque_rdata.?.data); return name_size + typ_size + class_size + ttl_size + rdata_prefix_size + rdata_size; } @@ -184,11 +246,7 @@ const ByteList = std.ArrayList(u8); const StringList = std.ArrayList([]u8); const ManyStringList = std.ArrayList([][]const u8); -pub const DeserializationContext = struct { - current_byte_count: usize = 0, -}; - -const LabelComponent = union(enum) { +pub const LabelComponent = union(enum) { Full: []const u8, /// Holds the first offset component of that pointer. /// @@ -198,44 +256,6 @@ const LabelComponent = union(enum) { Null: void, }; -/// Wrap a Reader with a type that contains a DeserializationContext. -/// -/// Automatically increments the DeserializationContext's current_byte_count -/// on every read(). -/// -/// Useful to hold deserialization state without having to pass an entire -/// parameter around on every single helper function. -pub fn WrapperReader(comptime ReaderType: anytype) type { - return struct { - underlying_reader: ReaderType, - ctx: *DeserializationContext, - - const Self = @This(); - - pub fn init( - underlying_reader: ReaderType, - ctx: *DeserializationContext, - ) Self { - return .{ - .underlying_reader = underlying_reader, - .ctx = ctx, - }; - } - - pub fn read(self: *Self, buffer: []u8) !usize { - const bytes_read = try self.underlying_reader.read(buffer); - self.ctx.current_byte_count += bytes_read; - return bytes_read; - } - - pub const Error = ReaderType.Error || error{OutOfMemory}; - pub const Reader = std.io.Reader(*Self, Error, read); - pub fn reader(self: *Self) Reader { - return Reader{ .context = self }; - } - }; -} - const NameList = std.ArrayList(Name); /// A DNS packet, as specified in RFC1035. @@ -246,6 +266,9 @@ pub const Packet = struct { nameservers: []Resource, additionals: []Resource, + /// Names that are held in RDATA sections are added here. + /// + /// This is an internal field that shouldn't be used by API consumers. extra_names: ?NameList = null, const Self = @This(); @@ -270,7 +293,7 @@ pub const Packet = struct { var question_size: usize = 0; for (self.questions) |question| { - const question_name_size = try question.name.writeTo(writer); + const question_name_size = try question.name.?.writeTo(writer); const question_typ_size = try question.typ.writeTo(writer); const question_class_size = try question.class.writeTo(writer); @@ -447,55 +470,6 @@ pub const Packet = struct { return error.UnknownPointerOffset; } } - - /// Deserialize a LabelComponent, which can be: - /// - a pointer - /// - a full label ([]const u8) - /// - a null octet - fn readLabelComponent( - reader: anytype, - allocator: std.mem.Allocator, - ) !LabelComponent { - // pointers, in the binary representation of a byte, are as follows - // 1 1 B B B B B B | B B B B B B B B - // they are two bytes length, but to identify one, you check if the - // first two bits are 1 and 1 respectively. - // - // then you read the rest, and turn it into an offset (without the - // starting bits!!!) - // - // to prevent inefficiencies, we just read a single byte, see if it - // has the starting bits, and then we chop it off, merging with the - // next byte. pointer offsets are 14 bits long - // - // when it isn't a pointer, its a length for a given label, and that - // length can only be a single byte. - // - // if the length is 0, its a null octet - var possible_length = try reader.readIntBig(u8); - if (possible_length == 0) return LabelComponent{ .Null = {} }; - - // RFC1035: - // since the label must begin with two zero bits because - // labels are restricted to 63 octets or less. - - var bit1 = (possible_length & (1 << 7)) != 0; - var bit2 = (possible_length & (1 << 6)) != 0; - - if (bit1 and bit2) { - return LabelComponent{ .Pointer = possible_length }; - } else { - // those must be 0 for a correct label length to be made - std.debug.assert((!bit1) and (!bit2)); - - // the next bytes contain a full label. - var label = try allocator.alloc(u8, possible_length); - const read_bytes = try reader.read(label); - std.debug.assert(read_bytes == label.len); - return LabelComponent{ .Full = label }; - } - } - const ReadNameOptions = struct { max_label_count: usize = 128, is_rdata: bool = false, @@ -572,111 +546,6 @@ pub const Packet = struct { return name; } - - /// Extract an RDATA. This only spits out a slice of u8. - /// Parsing of RDATA sections are in the dns.rdata module. - /// - /// Caller owns returned memory. - fn readResourceDataFrom( - reader: anytype, - allocator: std.mem.Allocator, - ) !dns.ResourceData.Opaque { - const rdata_length = try reader.readIntBig(u16); - var opaque_rdata = try allocator.alloc(u8, rdata_length); - const rdata_index = reader.context.ctx.current_byte_count; - const read_bytes = try reader.read(opaque_rdata); - std.debug.assert(read_bytes == opaque_rdata.len); - return .{ - .data = opaque_rdata, - .current_byte_count = rdata_index, - }; - } - - fn readResourceListFrom( - self: *Self, - reader: anytype, - allocator: std.mem.Allocator, - resource_count: usize, - ) ![]Resource { - var list = std.ArrayList(Resource).init(allocator); - defer list.deinit(); - - var i: usize = 0; - while (i < resource_count) : (i += 1) { - var name = try self.readName(reader, allocator, .{}); - - var typ = try ResourceType.readFrom(reader); - var class = try reader.readEnum(ResourceClass, .Big); - var ttl = try reader.readIntBig(i32); - var opaque_rdata = try Self.readResourceDataFrom(reader, allocator); - - var resource = Resource{ - .name = name, - .typ = typ, - .class = class, - .ttl = ttl, - .opaque_rdata = opaque_rdata, - }; - - try list.append(resource); - } - - return try list.toOwnedSlice(); - } - - pub fn readFrom( - incoming_reader: anytype, - allocator: std.mem.Allocator, - ) !IncomingPacket { - var packet = try allocator.create(Self); - errdefer allocator.destroy(packet); - packet.extra_names = NameList.init(allocator); - - var ctx = DeserializationContext{}; - const WrapperR = WrapperReader(@TypeOf(incoming_reader)); - var wrapper_reader = WrapperR.init(incoming_reader, &ctx); - var reader = wrapper_reader.reader(); - - packet.header = try Header.readFrom(reader); - - var questions = std.ArrayList(Question).init(allocator); - defer questions.deinit(); - - var i: usize = 0; - while (i < packet.header.question_length) { - var name = try packet.readName(reader, allocator, .{}); - var qtype = try reader.readEnum(ResourceType, .Big); - var qclass = try reader.readEnum(ResourceClass, .Big); - - var question = Question{ - .name = name, - .typ = qtype, - .class = qclass, - }; - - try questions.append(question); - i += 1; - } - - packet.questions = try questions.toOwnedSlice(); - packet.answers = try packet.readResourceListFrom( - reader, - allocator, - packet.header.answer_length, - ); - packet.nameservers = try packet.readResourceListFrom( - reader, - allocator, - packet.header.nameserver_length, - ); - packet.additionals = try packet.readResourceListFrom( - reader, - allocator, - packet.header.additional_length, - ); - - return IncomingPacket{ .packet = packet, .allocator = allocator }; - } }; /// Represents a Packet where all of its data was allocated dynamically @@ -684,32 +553,39 @@ pub const IncomingPacket = struct { allocator: std.mem.Allocator, packet: *Packet, - fn freeResource(self: @This(), resource: Resource) void { - for (resource.name.labels) |label| self.allocator.free(label); - self.allocator.free(resource.name.labels); - self.allocator.free(resource.opaque_rdata.data); + fn freeResource( + self: @This(), + resource: Resource, + options: DeinitOptions, + ) void { + if (options.names) + if (resource.name) |name| name.deinit(self.allocator); + if (resource.opaque_rdata) |opaque_rdata| + self.allocator.free(opaque_rdata.data); } - fn freeResourceList(self: @This(), resource_list: []Resource) void { - for (resource_list) |resource| self.freeResource(resource); + fn freeResourceList( + self: @This(), + resource_list: []Resource, + options: DeinitOptions, + ) void { + for (resource_list) |resource| self.freeResource(resource, options); self.allocator.free(resource_list); } - pub fn deinit(self: @This()) void { - for (self.packet.questions) |question| { - for (question.name.labels) |label| self.allocator.free(label); - self.allocator.free(question.name.labels); - } + pub const DeinitOptions = struct { + names: bool = true, + }; - self.allocator.free(self.packet.questions); - self.freeResourceList(self.packet.answers); - self.freeResourceList(self.packet.nameservers); - self.freeResourceList(self.packet.additionals); + pub fn deinit(self: @This(), options: DeinitOptions) void { + if (options.names) for (self.packet.questions) |question| { + if (question.name) |name| name.deinit(self.allocator); + }; - if (self.packet.extra_names) |list| { - for (list.items) |name| name.deinit(self.allocator); - list.deinit(); - } + self.allocator.free(self.packet.questions); + self.freeResourceList(self.packet.answers, options); + self.freeResourceList(self.packet.nameservers, options); + self.freeResourceList(self.packet.additionals, options); self.allocator.destroy(self.packet); } diff --git a/src/parser.zig b/src/parser.zig new file mode 100644 index 0000000..2178dfc --- /dev/null +++ b/src/parser.zig @@ -0,0 +1,316 @@ +const std = @import("std"); +const dns = @import("lib.zig"); + +const logger = std.log.scoped(.dns_parser); + +pub fn parser( + reader: anytype, + ctx: *ParserContext, + options: dns.ParserOptions, +) Parser(@TypeOf(reader)) { + return Parser(@TypeOf(reader)).init(reader, ctx, options); +} + +fn Output(typ: type) type { + return switch (typ) { + dns.Question => dns.Question, + dns.Resource => dns.Resource, + else => @compileError("invalid input to resolve"), + }; +} + +pub const ResourceResolutionOptions = struct { + max_follow: usize = 32, +}; + +pub const NamePool = struct { + allocator: std.mem.Allocator, + + const Self = @This(); + + pub fn init(allocator: std.mem.Allocator) Self { + return .{ .allocator = allocator }; + } + + fn resolve(raw_data: anytype, options: ResourceResolutionOptions) Output(@TypeOf(raw_data)) { + _ = options; + @compileError("TODO"); + } +}; + +const ParserState = enum { + header, + question, + answer, + nameserver, + additional, + answer_rdata, + nameserver_rdata, + additional_rdata, + done, +}; + +pub const ParserFrame = union(enum) { + header: dns.Header, + + question: dns.Question, + end_question: void, + + answer: dns.Resource, + answer_rdata: dns.parserlib.ResourceDataHolder, + end_answer: void, + + nameserver: dns.Resource, + nameserver_rdata: dns.parserlib.ResourceDataHolder, + end_nameserver: void, + + additional: dns.Resource, + additional_rdata: dns.parserlib.ResourceDataHolder, + end_additional: void, +}; + +pub const ResourceDataHolder = struct { + size: usize, + current_byte_index: usize, + + pub fn skip(self: @This(), reader: anytype) !void { + try reader.skipBytes(self.size, .{}); + } + + pub fn readAllAlloc( + self: @This(), + allocator: std.mem.Allocator, + reader: anytype, + ) !dns.ResourceData.Opaque { + var opaque_rdata = try allocator.alloc(u8, self.size); + const read_bytes = try reader.read(opaque_rdata); + std.debug.assert(read_bytes == opaque_rdata.len); + return .{ + .data = opaque_rdata, + .current_byte_count = self.current_byte_index, + }; + } +}; + +pub const ParserOptions = struct { + /// Give an allocator if you want names to appear properly. + allocator: ?std.mem.Allocator = null, + name_pool: ?*dns.NamePool = null, + + max_label_size: usize = 32, +}; + +pub const ParserContext = struct { + header: ?dns.Header = null, + current_byte_count: usize = 0, + current_counts: struct { + question: usize = 0, + answer: usize = 0, + nameserver: usize = 0, + additional: usize = 0, + } = .{}, +}; + +pub const DeserializationContext = struct { + current_byte_count: usize = 0, +}; + +/// Wrap a Reader with a type that contains a DeserializationContext. +/// +/// Automatically increments the DeserializationContext's current_byte_count +/// on every read(). +/// +/// Useful to hold deserialization state without having to pass an entire +/// parameter around on every single helper function. +pub fn WrapperReader(comptime ReaderType: anytype) type { + return struct { + underlying_reader: ReaderType, + ctx: *ParserContext, + + const Self = @This(); + + pub fn read(self: *Self, buffer: []u8) !usize { + const bytes_read = try self.underlying_reader.read(buffer); + self.ctx.current_byte_count += bytes_read; + logger.debug( + "wrapper reader: read {d} bytes, now at {d}", + .{ bytes_read, self.ctx.current_byte_count }, + ); + return bytes_read; + } + + pub const Error = ReaderType.Error; + pub const Reader = std.io.Reader(*Self, Error, read); + pub fn reader(self: *Self) Reader { + return Reader{ .context = self }; + } + }; +} + +/// Low level parser for DNS packets. +pub fn Parser(comptime ReaderType: type) type { + const WrapperR = WrapperReader(ReaderType); + + return struct { + state: ParserState = .header, + wrapper_reader: WrapperR, + options: ParserOptions, + ctx: *ParserContext, + + const Self = @This(); + + pub fn init(incoming_reader: ReaderType, ctx: *ParserContext, options: ParserOptions) Self { + var self = Self{ + .wrapper_reader = WrapperR{ + .underlying_reader = incoming_reader, + .ctx = ctx, + }, + .options = options, + .ctx = ctx, + }; + + return self; + } + + pub fn next(self: *Self) !?ParserFrame { + // self.state dictates what we *want* from the reader + // at the moment, first state always being header. + logger.debug("next(): enter {}", .{self.state}); + + logger.debug( + "parser reader is at {d} bytes of message", + .{self.wrapper_reader.ctx.current_byte_count}, + ); + + var reader = self.wrapper_reader.reader(); + + switch (self.state) { + .header => { + // since header is constant size, store it + // in our parser state so we know how to continue + const header = try dns.Header.readFrom(reader); + self.ctx.header = header; + self.state = .question; + logger.debug( + "next(): header read ({?}). state is now {}", + .{ self.ctx.header, self.state }, + ); + return ParserFrame{ .header = header }; + }, + .question => { + logger.debug("next(): read {d} out of {d} questions", .{ + self.ctx.current_counts.question, + self.ctx.header.?.question_length, + }); + + self.ctx.current_counts.question += 1; + + if (self.ctx.current_counts.question > self.ctx.header.?.question_length) { + self.state = .answer; + logger.debug("parser: end question, go to resources", .{}); + return ParserFrame{ .end_question = {} }; + } else { + const raw_question = try dns.Question.readFrom(reader, self.options); + return ParserFrame{ .question = raw_question }; + } + }, + .answer, .nameserver, .additional => { + var count_holder = (switch (self.state) { + .answer => &self.ctx.current_counts.answer, + .nameserver => &self.ctx.current_counts.nameserver, + .additional => &self.ctx.current_counts.additional, + else => unreachable, + }); + + const header_count = switch (self.state) { + .answer => self.ctx.header.?.answer_length, + .nameserver => self.ctx.header.?.nameserver_length, + .additional => self.ctx.header.?.additional_length, + else => unreachable, + }; + + logger.debug("next(): read {d} out of {d} resources", .{ + count_holder.*, header_count, + }); + + count_holder.* += 1; + + if (count_holder.* > header_count) { + const old_state = self.state; + self.state = switch (self.state) { + .answer => .nameserver, + .nameserver => .additional, + .additional => .done, + else => unreachable, + }; + + logger.debug( + "end resource list. state transition {} -> {}", + .{ old_state, self.state }, + ); + + return switch (old_state) { + .answer => ParserFrame{ .end_answer = {} }, + .nameserver => ParserFrame{ .end_nameserver = {} }, + .additional => ParserFrame{ .end_additional = {} }, + else => unreachable, + }; + } else { + const raw_resource = try dns.Resource.readFrom(reader, self.options); + + // not at end yet, which means resource_rdata event + // must happen if we don't have allocator + + const old_state = self.state; + + // if we don't have allocator, we emit rdata records + if (self.options.allocator == null) { + self.state = switch (self.state) { + .answer => .answer_rdata, + .nameserver => .nameserver_rdata, + .additional => .additional_rdata, + else => unreachable, + }; + } + + logger.debug("resource from {}: {}", .{ old_state, raw_resource }); + + return switch (old_state) { + .answer => ParserFrame{ .answer = raw_resource }, + .nameserver => ParserFrame{ .nameserver = raw_resource }, + .additional => ParserFrame{ .additional = raw_resource }, + else => unreachable, + }; + } + }, + + .answer_rdata, .nameserver_rdata, .additional_rdata => { + const old_state = self.state; + + self.state = switch (self.state) { + .answer_rdata => .answer, + .nameserver_rdata => .nameserver, + .additional_rdata => .additional, + else => unreachable, + }; + + const rdata_length = try reader.readIntBig(u16); + const rdata_index = reader.context.ctx.current_byte_count; + var rdata = ResourceDataHolder{ + .size = rdata_length, + .current_byte_index = rdata_index, + }; + + return switch (old_state) { + .answer_rdata => ParserFrame{ .answer_rdata = rdata }, + .nameserver_rdata => ParserFrame{ .nameserver_rdata = rdata }, + .additional_rdata => ParserFrame{ .additional_rdata = rdata }, + else => unreachable, + }; + }, + + .done => return null, + } + } + }; +} diff --git a/src/resource_data.zig b/src/resource_data.zig index 289075e..9fc5f38 100644 --- a/src/resource_data.zig +++ b/src/resource_data.zig @@ -8,8 +8,8 @@ const Type = dns.ResourceType; const logger = std.log.scoped(.dns_rdata); pub const SOAData = struct { - mname: dns.Name, - rname: dns.Name, + mname: ?dns.Name, + rname: ?dns.Name, serial: u32, refresh: u32, retry: u32, @@ -19,30 +19,47 @@ pub const SOAData = struct { pub const MXData = struct { preference: u16, - exchange: dns.Name, + exchange: ?dns.Name, }; pub const SRVData = struct { priority: u16, weight: u16, port: u16, - target: dns.Name, + target: ?dns.Name, }; +fn maybeReadResourceName( + reader: anytype, + options: ResourceData.ParseOptions, +) !?dns.Name { + return switch (options.name_provider) { + .none => null, + .raw => |allocator| try dns.Name.readFrom(reader, .{ .allocator = allocator }), + .full => |name_pool| blk: { + var name = try dns.Name.readFrom( + reader, + .{ .allocator = name_pool.allocator }, + ); + break :blk try name_pool.transmuteName(name.?); + }, + }; +} + /// Common representations of DNS' Resource Data. pub const ResourceData = union(Type) { A: std.net.Address, AAAA: std.net.Address, - NS: dns.Name, - MD: dns.Name, - MF: dns.Name, - CNAME: dns.Name, + NS: ?dns.Name, + MD: ?dns.Name, + MF: ?dns.Name, + CNAME: ?dns.Name, SOA: SOAData, - MB: dns.Name, - MG: dns.Name, - MR: dns.Name, + MB: ?dns.Name, + MG: ?dns.Name, + MR: ?dns.Name, // ???? NULL: void, @@ -53,19 +70,19 @@ pub const ResourceData = union(Type) { proto: u8, // how to define bit map? align(8)? }, - PTR: dns.Name, + PTR: ?dns.Name, - // TODO replace by Name? + // TODO replace []const u8 by Name? HINFO: struct { cpu: []const u8, os: []const u8, }, MINFO: struct { - rmailbx: dns.Name, - emailbx: dns.Name, + rmailbx: ?dns.Name, + emailbx: ?dns.Name, }, MX: MXData, - TXT: []const u8, + TXT: ?[]const u8, SRV: SRVData, OPT: void, // EDNS0 is not implemented @@ -106,9 +123,9 @@ pub const ResourceData = union(Type) { switch (self) { .A, .AAAA => |addr| return fmt.format(writer, "{}", .{addr}), - .NS, .MD, .MF, .MB, .MG, .MR, .CNAME, .PTR => |name| return fmt.format(writer, "{}", .{name}), + .NS, .MD, .MF, .MB, .MG, .MR, .CNAME, .PTR => |name| return fmt.format(writer, "{?}", .{name}), - .SOA => |soa| return fmt.format(writer, "{} {} {} {} {} {} {}", .{ + .SOA => |soa| return fmt.format(writer, "{?} {?} {} {} {} {} {}", .{ soa.mname, soa.rname, soa.serial, @@ -118,15 +135,15 @@ pub const ResourceData = union(Type) { soa.minimum, }), - .MX => |mx| return fmt.format(writer, "{} {}", .{ mx.preference, mx.exchange }), - .SRV => |srv| return fmt.format(writer, "{} {} {} {}", .{ + .MX => |mx| return fmt.format(writer, "{} {?}", .{ mx.preference, mx.exchange }), + .SRV => |srv| return fmt.format(writer, "{} {} {} {?}", .{ srv.priority, srv.weight, srv.port, srv.target, }), - .TXT => |text| return fmt.format(writer, "{s}", .{text}), + .TXT => |text| return fmt.format(writer, "{?s}", .{text}), else => return fmt.format(writer, "TODO support {s}", .{@tagName(self)}), } } @@ -139,11 +156,11 @@ pub const ResourceData = union(Type) { }, .AAAA => |addr| try writer.write(&addr.in6.sa.addr), - .NS, .MD, .MF, .MB, .MG, .MR, .CNAME, .PTR => |name| try name.writeTo(writer), + .NS, .MD, .MF, .MB, .MG, .MR, .CNAME, .PTR => |name| try name.?.writeTo(writer), .SOA => |soa_data| blk: { - const mname_size = try soa_data.mname.writeTo(writer); - const rname_size = try soa_data.rname.writeTo(writer); + const mname_size = try soa_data.mname.?.writeTo(writer); + const rname_size = try soa_data.rname.?.writeTo(writer); try writer.writeIntBig(u32, soa_data.serial); try writer.writeIntBig(u32, soa_data.refresh); @@ -156,7 +173,7 @@ pub const ResourceData = union(Type) { .MX => |mxdata| blk: { try writer.writeIntBig(u16, mxdata.preference); - const exchange_size = try mxdata.exchange.writeTo(writer); + const exchange_size = try mxdata.exchange.?.writeTo(writer); break :blk @sizeOf(@TypeOf(mxdata.preference)) + exchange_size; }, @@ -165,26 +182,26 @@ pub const ResourceData = union(Type) { try writer.writeIntBig(u16, srv.weight); try writer.writeIntBig(u16, srv.port); - const target_size = try srv.target.writeTo(writer); + const target_size = try srv.target.?.writeTo(writer); return target_size + (3 * @sizeOf(u16)); }, + // TODO TXT + else => @panic("not implemented"), }; } - /// Only call this if you dynamically created a ResourceData - /// through the fromOpaque() method. pub fn deinit(self: Self, allocator: std.mem.Allocator) void { switch (self) { - // .NS, .MD, .MF, .MB, .MG, .MR, .CNAME, .PTR => |name| name.deinit(allocator), - // .SOA => |soa_data| { - // soa_data.mname.deinit(allocator); - // soa_data.rname.deinit(allocator); - // }, - // .MX => |mxdata| mxdata.exchange.deinit(allocator), - // .SRV => |srv| srv.target.deinit(allocator), - .TXT => |data| allocator.free(data), + .NS, .MD, .MF, .MB, .MG, .MR, .CNAME, .PTR => |maybe_name| if (maybe_name) |name| name.deinit(allocator), + .SOA => |soa_data| { + if (soa_data.mname) |name| name.deinit(allocator); + if (soa_data.rname) |name| name.deinit(allocator); + }, + .MX => |mxdata| if (mxdata.exchange) |name| name.deinit(allocator), + .SRV => |srv| if (srv.target) |name| name.deinit(allocator), + .TXT => |maybe_data| if (maybe_data) |data| allocator.free(data), else => {}, } } @@ -194,18 +211,24 @@ pub const ResourceData = union(Type) { current_byte_count: usize, }; + pub const NameProvider = union(enum) { + none: void, + raw: std.mem.Allocator, + full: *dns.NamePool, + }; + + pub const ParseOptions = struct { + name_provider: NameProvider = NameProvider.none, + allocator: ?std.mem.Allocator = null, + }; + /// Deserialize a given opaque resource data. /// /// Call deinit() with the same allocator. pub fn fromOpaque( - /// Packet the resource data comes from. - /// - /// This is required as resource data may have name pointers - /// that refer to the packet index. - packet: *dns.Packet, - typ: dns.ResourceType, + resource_type: dns.ResourceType, opaque_resource_data: Opaque, - allocator: std.mem.Allocator, + options: ParseOptions, ) !ResourceData { const BufferT = std.io.FixedBufferStream([]const u8); var stream = BufferT{ .buffer = opaque_resource_data.data, .pos = 0 }; @@ -213,17 +236,18 @@ pub const ResourceData = union(Type) { // important to keep track of that rdata's position in the packet // as rdata could point to other rdata. - - var ctx = pkt.DeserializationContext{ + var parser_ctx = dns.ParserContext{ .current_byte_count = opaque_resource_data.current_byte_count, }; - const WrapperR = pkt.WrapperReader(BufferT.Reader); - var wrapper_reader = WrapperR.init(underlying_reader, &ctx); - var reader = wrapper_reader.reader(); - const options = .{ .is_rdata = true }; + const WrapperR = dns.parserlib.WrapperReader(BufferT.Reader); + var wrapper_reader = WrapperR{ + .underlying_reader = underlying_reader, + .ctx = &parser_ctx, + }; + var reader = wrapper_reader.reader(); - var rdata = switch (typ) { + return switch (resource_type) { .A => blk: { var ip4addr: [4]u8 = undefined; _ = try reader.read(&ip4addr); @@ -239,24 +263,24 @@ pub const ResourceData = union(Type) { }; }, - .NS => ResourceData{ .NS = try packet.readName(reader, allocator, options) }, - .CNAME => ResourceData{ .CNAME = try packet.readName(reader, allocator, options) }, - .PTR => ResourceData{ .PTR = try packet.readName(reader, allocator, options) }, - .MD => ResourceData{ .MD = try packet.readName(reader, allocator, options) }, - .MF => ResourceData{ .MF = try packet.readName(reader, allocator, options) }, + .NS => ResourceData{ .NS = try maybeReadResourceName(reader, options) }, + .CNAME => ResourceData{ .CNAME = try maybeReadResourceName(reader, options) }, + .PTR => ResourceData{ .PTR = try maybeReadResourceName(reader, options) }, + .MD => ResourceData{ .MD = try maybeReadResourceName(reader, options) }, + .MF => ResourceData{ .MF = try maybeReadResourceName(reader, options) }, .MX => blk: { break :blk ResourceData{ .MX = MXData{ .preference = try reader.readIntBig(u16), - .exchange = try packet.readName(reader, allocator, options), + .exchange = try maybeReadResourceName(reader, options), }, }; }, .SOA => blk: { - var mname = try packet.readName(reader, allocator, options); - var rname = try packet.readName(reader, allocator, options); + var mname = try maybeReadResourceName(reader, options); + var rname = try maybeReadResourceName(reader, options); var serial = try reader.readIntBig(u32); var refresh = try reader.readIntBig(u32); var retry = try reader.readIntBig(u32); @@ -279,7 +303,7 @@ pub const ResourceData = union(Type) { const priority = try reader.readIntBig(u16); const weight = try reader.readIntBig(u16); const port = try reader.readIntBig(u16); - const target = try packet.readName(reader, allocator, options); + const target = try maybeReadResourceName(reader, options); break :blk ResourceData{ .SRV = .{ .priority = priority, @@ -293,18 +317,21 @@ pub const ResourceData = union(Type) { const length = try reader.readIntBig(u8); if (length > 256) return error.Overflow; - var text = try allocator.alloc(u8, length); - _ = try reader.read(text); + if (options.allocator) |allocator| { + var text = try allocator.alloc(u8, length); + _ = try reader.read(text); - break :blk ResourceData{ .TXT = text }; + break :blk ResourceData{ .TXT = text }; + } else { + try reader.skipBytes(length, .{}); + break :blk ResourceData{ .TXT = null }; + } }, else => { - logger.warn("unexpected rdata: {}\n", .{typ}); - return error.InvalidRData; + logger.warn("unexpected rdata: {}\n", .{resource_type}); + return error.UnknownResourceType; }, }; - - return rdata; } }; diff --git a/src/test.zig b/src/test.zig index 86a97b2..6fb05bc 100644 --- a/src/test.zig +++ b/src/test.zig @@ -9,7 +9,7 @@ const Packet = dns.Packet; test "convert domain string to dns name" { const domain = "www.google.com"; var name_buffer: [3][]const u8 = undefined; - var name = try dns.Name.fromString(domain[0..], &name_buffer); + var name = (try dns.Name.fromString(domain[0..], &name_buffer)).full; std.debug.assert(name.labels.len == 3); try std.testing.expect(std.mem.eql(u8, name.labels[0], "www")); try std.testing.expect(std.mem.eql(u8, name.labels[1], "google")); @@ -77,7 +77,7 @@ fn expectGoogleLabels(actual: [][]const u8) !void { } } -test "deserialization of original google.com/A" { +test "deserialization of original question google.com/A" { var write_buffer: [0x10000]u8 = undefined; var decoded = try decodeBase64(TEST_PKT_QUERY, &write_buffer); @@ -95,13 +95,14 @@ test "deserialization of original google.com/A" { const question = pkt.questions[0]; - try expectGoogleLabels(question.name.labels); - try std.testing.expectEqual(@as(usize, 12), question.name.packet_index.?); + try expectGoogleLabels(question.name.?.full.labels); + try std.testing.expectEqual(@as(usize, 12), question.name.?.full.packet_index.?); try std.testing.expectEqual(question.typ, dns.ResourceType.A); try std.testing.expectEqual(question.class, dns.ResourceClass.IN); } test "deserialization of reply google.com/A" { + std.testing.log_level = .debug; var encode_buffer: [0x10000]u8 = undefined; var decoded = try decodeBase64(TEST_PKT_RESPONSE, &encode_buffer); @@ -117,26 +118,27 @@ test "deserialization of reply google.com/A" { var question = pkt.questions[0]; - try expectGoogleLabels(question.name.labels); + try expectGoogleLabels(question.name.?.full.labels); try testing.expectEqual(dns.ResourceType.A, question.typ); try testing.expectEqual(dns.ResourceClass.IN, question.class); var answer = pkt.answers[0]; - try expectGoogleLabels(answer.name.labels); + try expectGoogleLabels(answer.name.?.full.labels); try testing.expectEqual(dns.ResourceType.A, answer.typ); try testing.expectEqual(dns.ResourceClass.IN, answer.class); try testing.expectEqual(@as(i32, 300), answer.ttl); const resource_data = try dns.ResourceData.fromOpaque( - pkt, .A, - answer.opaque_rdata, - std.testing.allocator, + answer.opaque_rdata.?, + .{}, ); - defer resource_data.deinit(std.testing.allocator); - try testing.expectEqual(dns.ResourceType.A, @as(dns.ResourceType, resource_data)); + try testing.expectEqual( + dns.ResourceType.A, + @as(dns.ResourceType, resource_data), + ); const addr = @ptrCast(*const [4]u8, &resource_data.A.in.sa.addr).*; try testing.expectEqual(@as(u8, 216), addr[0]); @@ -197,11 +199,11 @@ fn serialTest(packet: Packet, write_buffer: []u8) ![]u8 { const FixedStream = std.io.FixedBufferStream([]const u8); fn deserialTest(packet_data: []const u8) !dns.IncomingPacket { var stream = FixedStream{ .buffer = packet_data, .pos = 0 }; - const incoming_packet = try dns.Packet.readFrom( + return try dns.helpers.parseFullPacket( stream.reader(), std.testing.allocator, + .{ .allocator = std.testing.allocator }, ); - return incoming_packet; } test "convert string to dns type" { @@ -239,7 +241,7 @@ test "resources have good sizes" { // name + rr (2) + class (2) + ttl (4) + rdlength (2) try testing.expectEqual( - @as(usize, name.networkSize() + 10 + resource.opaque_rdata.data.len), + @as(usize, name.networkSize() + 10 + resource.opaque_rdata.?.data.len), network_size, ); }