-
Notifications
You must be signed in to change notification settings - Fork 3
/
part12.py
93 lines (77 loc) · 2.5 KB
/
part12.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""https://adventofcode.com/2021/day/16"""
from dataclasses import dataclass
from numpy import product
@dataclass
class Packet:
version: int
type_id: int
literal: int
length: int
sub_packets: list
def parse_packet(bits, depth=0):
version = int(bits[0:3], 2)
type_id = int(bits[3:6], 2)
position = 6
# literal packet
if type_id == 4: # literal
literal, is_last = 0, False
while not is_last:
is_last = bits[position] == "0"
literal = (literal << 4) | int(bits[position + 1 : position + 5], 2)
position += 5
return Packet(version, type_id, literal, position, [])
# operator packet
length_type_id = int(bits[position], 2)
position += 1
if length_type_id == 0: # total length in bits
total_length = int(bits[position : position + 15], 2)
position += 15
length = 0
sub_packets = []
while length != total_length:
sub = parse_packet(bits[position : position + total_length - length], depth + 1)
sub_packets.append(sub)
length += sub.length
position += sub.length
return Packet(version, type_id, 0, position, sub_packets)
# number of sub-packets immediately contained
number = int(bits[position : position + 11], 2)
position += 11
sub_packets = []
for _ in range(number):
sub = parse_packet(bits[position:], depth + 1)
sub_packets.append(sub)
position += sub.length
return Packet(version, type_id, 0, position, sub_packets)
def add_versions(packets):
total = 0
for packet in packets:
total += packet.version
total += add_versions(packet.sub_packets)
return total
def compute(p):
if p.type_id == 4:
return p.literal
values = [compute(x) for x in p.sub_packets]
if p.type_id == 0:
return sum(values)
elif p.type_id == 1:
return product(values)
elif p.type_id == 2:
return min(values)
elif p.type_id == 3:
return max(values)
elif p.type_id == 5:
return int(values[0] > values[1])
elif p.type_id == 6:
return int(values[0] < values[1])
elif p.type_id == 7:
return int(values[0] == values[1])
raise ValueError(f"Unknown type_id {p.type_id}")
data = open("day-16/input", "r", encoding="utf-8").read().strip()
bits = bin(int(data, 16))[2:].zfill(len(data) * 4)
packets = parse_packet(bits)
# part 1
print(add_versions([packets]))
# part2
print(compute(packets))