diff --git a/pkg/adidns/formats.go b/pkg/adidns/formats.go index ad60362..0150b04 100644 --- a/pkg/adidns/formats.go +++ b/pkg/adidns/formats.go @@ -105,3 +105,17 @@ func MSTimeToUnixTimestamp(msTime uint64) int64 { return unixTimestamp } + +func GetCurrentMSTime() uint32 { + baseTime := time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC) + + currentTime := time.Now().UTC() + + duration := currentTime.Sub(baseTime) + + targetTime := duration.Hours() + + msTime := uint32(targetTime) + uint32(3234576) + + return msTime +} diff --git a/pkg/adidns/types.go b/pkg/adidns/types.go index ea837e8..678ebc9 100644 --- a/pkg/adidns/types.go +++ b/pkg/adidns/types.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "io/ioutil" + "net" "reflect" "strings" "time" @@ -57,6 +58,16 @@ var DnsRecordTypes map[uint16]string = map[uint16]string{ 0xFF02: "WINSR", } +func FindRecordType(rTypeStr string) uint16 { + for key, val := range DnsRecordTypes { + if rTypeStr == val { + return uint16(key) + } + } + + return 0 +} + type DcPromoFlag struct { Value uint32 Description string @@ -124,6 +135,21 @@ type DNSRecord struct { Data []byte } +func MakeDNSRecord(rec FriendlyRecord, recType uint16, ttl uint32) DNSRecord { + serial := uint32(1) + msTime := GetCurrentMSTime() + data := rec.Encode() + + return DNSRecord{ + uint16(len(data)), recType, + 0x05, 0xF0, + 0x0000, + serial, ttl, + 0x00000000, + msTime, data, + } +} + type DNSProperty struct { DataLength uint32 NameLength uint32 @@ -134,7 +160,57 @@ type DNSProperty struct { Name uint8 } -func (prop *DNSProperty) Format(timeFormat string) string { +func MakeProp(id uint32, data []byte) DNSProperty { + return DNSProperty{ + uint32(len(data)), + 1, 0, 1, + id, data, + 0, + } +} + +func (prop *DNSProperty) ExportFormat() any { + var propDataArr [8]byte + copy(propDataArr[:], prop.Data) + propVal := binary.LittleEndian.Uint64(propDataArr[:]) + + switch prop.Id { + case 0x01, 0x02, 0x10, 0x20, 0x40, 0x83: + // DSPROPERTY_ZONE_TYPE + // DSPROPERTY_ZONE_ALLOW_UPDATE + // DSPROPERTY_ZONE_NOREFRESH_INTERVAL + // DSPROPERTY_ZONE_REFRESH_INTERVAL + // DSPROPERTY_ZONE_AGING_STATE + // DSPROPERTY_ZONE_DCPROMO_CONVERT + return propVal + case 0x00000008: + unixTimestamp := MSTimeToUnixTimestamp(propVal) + return unixTimestamp + case 0x00000012: + // DSPROPERTY_ZONE_AGING_ENABLED_TIME + msTime := propVal * 3600 + unixTimestamp := MSTimeToUnixTimestamp(msTime) + return unixTimestamp + case 0x00000080: + // DSPROPERTY_ZONE_DELETED_FROM_HOSTNAME + return string(prop.Data[:]) + case 0x00000090, 0x00000091, 0x00000092: + // DSPROPERTY_ZONE_SCAVENGING_SERVERS_DA + // DSPROPERTY_ZONE_MASTER_SERVERS_DA + // DSPROPERTY_ZONE_AUTO_NS_SERVERS_DA + return ParseAddrArray(prop.Data) + case 0x00000082, 0x00000011: + // DSPROPERTY_ZONE_SCAVENGING_SERVERS + // DSPROPERTY_ZONE_AUTO_NS_SERVERS + return ParseIP4Array(prop.Data) + default: + // DSPROPERTY_ZONE_NODE_DBFLAGS + // Or other unknown codes + return prop.Data + } +} + +func (prop *DNSProperty) PrintFormat(timeFormat string) string { var propDataArr [8]byte copy(propDataArr[:], prop.Data) propVal := binary.LittleEndian.Uint64(propDataArr[:]) @@ -194,11 +270,9 @@ func (prop *DNSProperty) Format(timeFormat string) string { } else { return "Not specified" } - - //return hex.EncodeToString(prop.Data) case 0x00000080: // DSPROPERTY_ZONE_DELETED_FROM_HOSTNAME - return string(propVal) + return string(prop.Data[:]) case 0x00000040: // DSPROPERTY_ZONE_AGING_STATE if propVal == 1 { @@ -334,12 +408,14 @@ func (d *DNSRecord) UnixTimestamp() int64 { // DNS_RPC_NAME parser func ParseRpcName(buf *bytes.Reader) (string, error) { var nameLen uint8 - if err := binary.Read(buf, binary.LittleEndian, &nameLen); err != nil { + + nameLen, err := buf.ReadByte() + if err != nil { return "", err } nameBuf := make([]byte, nameLen) - if _, err := io.ReadFull(buf, nameBuf); err != nil { + if err := binary.Read(buf, binary.LittleEndian, &nameBuf); err != nil { return "", err } @@ -348,27 +424,35 @@ func ParseRpcName(buf *bytes.Reader) (string, error) { func ParseRpcNameSingle(data []byte) (string, error) { buf := bytes.NewReader(data) - return ParseCountName(buf) + return ParseRpcName(buf) +} + +func EncodeRpcName(buf *bytes.Buffer, name string) { + buf.WriteByte(byte(len(name))) + buf.Write([]byte(name)) } // DNS_COUNT_NAME parser func ParseCountName(buf *bytes.Reader) (string, error) { - var rawNameLen uint8 var labelCnt uint8 var labLen uint8 + var err error - if err := binary.Read(buf, binary.LittleEndian, &rawNameLen); err != nil { + _, err = buf.ReadByte() + if err != nil { return "", err } - if err := binary.Read(buf, binary.LittleEndian, &labelCnt); err != nil { + labelCnt, err = buf.ReadByte() + if err != nil { return "", err } labels := make([]string, labelCnt) for cnt := uint8(0); cnt < labelCnt; cnt += 1 { - if err := binary.Read(buf, binary.LittleEndian, &labLen); err != nil { + labLen, err = buf.ReadByte() + if err != nil { return "", err } @@ -386,6 +470,36 @@ func ParseCountName(buf *bytes.Reader) (string, error) { return strings.Join(labels, "."), nil } +func EncodeCountName(buf *bytes.Buffer, name string) error { + labels := strings.Split(name, ".") + + rawNameLen := uint8(len(name) + 2) + if err := binary.Write(buf, binary.LittleEndian, rawNameLen); err != nil { + return err + } + + labelCnt := uint8(len(labels)) + if err := binary.Write(buf, binary.LittleEndian, labelCnt); err != nil { + return err + } + + for _, label := range labels { + labLen := uint8(len(label)) + if err := binary.Write(buf, binary.LittleEndian, labLen); err != nil { + return err + } + if _, err := buf.WriteString(label); err != nil { + return err + } + } + + if err := buf.WriteByte(0); err != nil { + return err + } + + return nil +} + func ParseCountNameSingle(data []byte) (string, error) { buf := bytes.NewReader(data) return ParseCountName(buf) @@ -448,30 +562,29 @@ func (p *DNSProperty) Decode(data []byte) error { // {Reference} MS-DNSP 2.2.2.2.4 DNS_RPC_RECORD_DATA // IP addresses (v4 or v6) are stored using their string representations -// Interface/structure to hold the parsed record fields +// Interface to hold the parsed record fields type FriendlyRecord interface { // Parses a record from its byte array in the Data field of the - // DNSRecord attribute + // DNSRecord AD attribute Parse([]byte) -} -type RecordContainer struct { - Name string - Contents FriendlyRecord + // Encode the record into a byte array to be used in the Data field + // of the DNSRecord AD attribute + Encode() []byte } +// Using a bit of reflection so that +// I don't have to manually implement a DumpField +// method on every type type Field struct { Name any Value any } -// Using a bit of reflection so that -// I don't have to manually implement a DumpField -// method on every type -func (rc RecordContainer) DumpFields() []Field { +func DumpRecordFields(fr FriendlyRecord) []Field { result := make([]Field, 0) - v := reflect.ValueOf(rc.Contents).Elem() + v := reflect.ValueOf(fr).Elem() for i := 0; i < v.NumField(); i++ { result = append(result, Field{v.Type().Field(i).Name, v.Field(i).Interface()}) } @@ -491,6 +604,14 @@ func (rnn *RecordNodeName) Parse(data []byte) { } } +func (rnn *RecordNodeName) Encode() []byte { + buf := new(bytes.Buffer) + + EncodeCountName(buf, rnn.NameNode) + + return buf.Bytes() +} + // 2.2.2.2.4.6 DNS_RPC_RECORD_STRING type RecordString struct { StrData []string @@ -510,6 +631,17 @@ func (rs *RecordString) Parse(data []byte) { rs.StrData = result } +func (rs *RecordString) Encode() []byte { + data := make([]byte, 0) + + for _, val := range rs.StrData { + data = append(data, byte(len(val))) + data = append(data, []byte(val)...) + } + + return data +} + // 2.2.2.2.4.7 DNS_RPC_RECORD_MAIL_ERROR type RecordMailError struct { MailBX string @@ -530,6 +662,16 @@ func (rs *RecordMailError) Parse(data []byte) { } } +func (rs *RecordMailError) Encode() []byte { + buf := new(bytes.Buffer) + + EncodeCountName(buf, rs.MailBX) + + EncodeCountName(buf, rs.ErrorMailBX) + + return buf.Bytes() +} + // 2.2.2.2.4.8 DNS_RPC_RECORD_NAME_PREFERENCE type RecordNamePreference struct { Preference uint16 @@ -544,6 +686,16 @@ func (rnp *RecordNamePreference) Parse(data []byte) { } } +func (rnp *RecordNamePreference) Encode() []byte { + buf := new(bytes.Buffer) + + binary.Write(buf, binary.BigEndian, rnp.Preference) + + EncodeCountName(buf, rnp.Exchange) + + return buf.Bytes() +} + type NSRecord = RecordNodeName type MDRecord = RecordNodeName type MFRecord = RecordNodeName @@ -571,6 +723,9 @@ type RTRecord = RecordNamePreference type ZERORecord struct{} func (zr *ZERORecord) Parse(data []byte) {} +func (zr *ZERORecord) Encode() []byte { + return []byte{} +} // 2.2.2.2.4.1 DNS_RPC_RECORD_A type ARecord struct { @@ -581,6 +736,11 @@ func (v4r *ARecord) Parse(data []byte) { v4r.Address = ParseIP(data) } +func (v4r *ARecord) Encode() []byte { + ip := net.ParseIP(v4r.Address) + return []byte(net.IP.To4(ip)) +} + // 2.2.2.2.4.16 DNS_RPC_RECORD_AAAA type AAAARecord struct { Address string // Parsed from a [16]byte @@ -590,6 +750,11 @@ func (v6r *AAAARecord) Parse(data []byte) { v6r.Address = ParseIP(data) } +func (v6r *AAAARecord) Encode() []byte { + ip := net.ParseIP(v6r.Address) + return []byte(net.IP.To16(ip)) +} + // 2.2.2.2.4.3 DNS_RPC_RECORD_SOA type SOARecord struct { Serial uint32 @@ -620,6 +785,22 @@ func (r *SOARecord) Parse(data []byte) { } } +func (r *SOARecord) Encode() []byte { + buf := new(bytes.Buffer) + + fields := []uint32{r.Serial, r.Refresh, r.Retry, r.Expire, r.MinimumTTL} + for _, field := range fields { + binary.Write(buf, binary.BigEndian, field) + } + + names := []string{r.NamePrimaryServer, r.ZoneAdminEmail} + for _, name := range names { + EncodeCountName(buf, name) + } + + return buf.Bytes() +} + // 2.2.2.2.4.4 DNS_RPC_RECORD_NULL type NULLRecord struct { Data []byte @@ -629,6 +810,10 @@ func (r *NULLRecord) Parse(data []byte) { r.Data = data } +func (r *NULLRecord) Encode() []byte { + return r.Data +} + // 2.2.2.2.4.5 DNS_RPC_RECORD_WKS type WKSRecord struct { Address string @@ -642,6 +827,16 @@ func (r *WKSRecord) Parse(data []byte) { r.BitMask = data[5:] } +func (r *WKSRecord) Encode() []byte { + ip := net.ParseIP(r.Address) + + data := []byte(net.IP.To4(ip)) + data = append(data, byte(r.Protocol)) + data = append(data, r.BitMask...) + + return data +} + // 2.2.2.2.4.9 DNS_RPC_RECORD_SIG type SIGRecord struct { TypeCovered uint16 @@ -670,10 +865,30 @@ func (r *SIGRecord) Parse(data []byte) { r.NameSigner = parsedName } - sigInfo, err := ioutil.ReadAll(buf) + sigInfo, _ := ioutil.ReadAll(buf) r.SignatureInfo = sigInfo } +func (r *SIGRecord) Encode() []byte { + buf := new(bytes.Buffer) + + binary.Write(buf, binary.BigEndian, r.TypeCovered) + + buf.WriteByte(r.Algorithm) + buf.WriteByte(r.Labels) + + binary.Write(buf, binary.BigEndian, r.OriginalTTL) + binary.Write(buf, binary.BigEndian, r.SigExpiration) + binary.Write(buf, binary.BigEndian, r.SigInception) + binary.Write(buf, binary.BigEndian, r.KeyTag) + + EncodeCountName(buf, r.NameSigner) + + buf.Write(r.SignatureInfo) + + return buf.Bytes() +} + // 2.2.2.2.4.13 DNS_RPC_RECORD_KEY type KEYRecord struct { Flags uint16 @@ -689,6 +904,28 @@ func (r *KEYRecord) Parse(data []byte) { r.Key = data[4:] } +func (r *KEYRecord) Encode() []byte { + buf := new(bytes.Buffer) + + if err := binary.Write(buf, binary.BigEndian, r.Flags); err != nil { + return []byte{} + } + + if err := buf.WriteByte(r.Protocol); err != nil { + return []byte{} + } + + if err := buf.WriteByte(r.Algorithm); err != nil { + return []byte{} + } + + if _, err := buf.Write(r.Key); err != nil { + return []byte{} + } + + return buf.Bytes() +} + // 2.2.2.2.4.17 DNS_RPC_RECORD_NXT type NXTRecord struct { NumRecordTypes uint16 @@ -697,13 +934,12 @@ type NXTRecord struct { } func (r *NXTRecord) Parse(data []byte) { + // TODO: Fix NXT parsing // This type does not seem to be following MS spec properly. // I'll just ignore it for the moment and hope to figure it out later. - r.NumRecordTypes = binary.LittleEndian.Uint16(data[:2]) - r.NextName = "" - /* + r.NumRecordTypes = binary.LittleEndian.Uint16(data[:2]) r.TypeWords = make([]uint16, r.NumRecordTypes) offset := 2 @@ -719,6 +955,11 @@ func (r *NXTRecord) Parse(data []byte) { */ } +func (r *NXTRecord) Encode() []byte { + // TODO: Fix NXT parsing + return []byte{} +} + // 2.2.2.2.4.18 DNS_RPC_RECORD_SRV type SRVRecord struct { Priority uint16 @@ -738,6 +979,20 @@ func (r *SRVRecord) Parse(data []byte) { } } +func (r *SRVRecord) Encode() []byte { + buf := new(bytes.Buffer) + + binary.Write(buf, binary.BigEndian, r.Priority) + + binary.Write(buf, binary.BigEndian, r.Weight) + + binary.Write(buf, binary.BigEndian, r.Port) + + EncodeCountName(buf, r.NameTarget) + + return buf.Bytes() +} + // 2.2.2.2.4.19 DNS_RPC_RECORD_ATMA type ATMARecord struct { Format uint8 @@ -750,6 +1005,16 @@ func (r *ATMARecord) Parse(data []byte) { r.Address = string(data[1:]) } +func (r *ATMARecord) Encode() []byte { + buf := new(bytes.Buffer) + + buf.WriteByte(r.Format) + + buf.WriteString(r.Address) + + return buf.Bytes() +} + // 2.2.2.2.4.20 DNS_RPC_RECORD_NAPTR type NAPTRRecord struct { Order uint16 @@ -789,6 +1054,24 @@ func (r *NAPTRRecord) Parse(data []byte) { } } +func (r *NAPTRRecord) Encode() []byte { + buf := new(bytes.Buffer) + + binary.Write(buf, binary.BigEndian, r.Order) + + binary.Write(buf, binary.BigEndian, r.Preference) + + EncodeRpcName(buf, r.Flags) + + EncodeRpcName(buf, r.Service) + + EncodeRpcName(buf, r.Substitution) + + EncodeCountName(buf, r.Replacement) + + return buf.Bytes() +} + // 2.2.2.2.4.12 DNS_RPC_RECORD_DS type DSRecord struct { KeyTag uint16 @@ -804,6 +1087,20 @@ func (r *DSRecord) Parse(data []byte) { r.Digest = data[4:] } +func (r *DSRecord) Encode() []byte { + buf := new(bytes.Buffer) + + binary.Write(buf, binary.BigEndian, r.KeyTag) + + buf.WriteByte(r.Algorithm) + + buf.WriteByte(r.DigestType) + + buf.Write(r.Digest) + + return buf.Bytes() +} + // 2.2.2.2.4.10 DNS_RPC_RECORD_RRSIG type RRSIGRecord = SIGRecord @@ -820,7 +1117,17 @@ func (r *NSECRecord) Parse(data []byte) { r.NameSigner = parsedName } - binary.Read(buf, binary.LittleEndian, &r.NSECBitmap) + r.NSECBitmap, _ = ioutil.ReadAll(buf) +} + +func (r *NSECRecord) Encode() []byte { + buf := new(bytes.Buffer) + + EncodeCountName(buf, r.NameSigner) + + binary.Write(buf, binary.LittleEndian, r.NSECBitmap) + + return buf.Bytes() } // 2.2.2.2.4.15 DNS_RPC_RECORD_DNSKEY @@ -838,6 +1145,20 @@ func (r *DNSKEYRecord) Parse(data []byte) { r.Key = data[4:] } +func (r *DNSKEYRecord) Encode() []byte { + buf := new(bytes.Buffer) + + binary.Write(buf, binary.BigEndian, r.Flags) + + buf.WriteByte(r.Protocol) + + buf.WriteByte(r.Algorithm) + + buf.Write(r.Key) + + return buf.Bytes() +} + // 2.2.2.2.4.14 DNS_RPC_RECORD_DHCID type DHCIDRecord struct { Digest []byte @@ -847,6 +1168,10 @@ func (r *DHCIDRecord) Parse(data []byte) { r.Digest = data } +func (r *DHCIDRecord) Encode() []byte { + return r.Digest +} + // 2.2.2.2.4.24 DNS_RPC_RECORD_NSEC3 type NSEC3Record struct { Algorithm uint8 @@ -870,6 +1195,28 @@ func (r *NSEC3Record) Parse(data []byte) { r.Bitmaps = data[6+int(r.SaltLength)+int(r.HashLength):] } +func (r *NSEC3Record) Encode() []byte { + buf := new(bytes.Buffer) + + buf.WriteByte(r.Algorithm) + + buf.WriteByte(r.Flags) + + binary.Write(buf, binary.BigEndian, r.Iterations) + + buf.WriteByte(r.SaltLength) + + buf.WriteByte(r.HashLength) + + buf.Write(r.Salt) + + buf.Write(r.NextHashedOwnerName) + + buf.Write(r.Bitmaps) + + return buf.Bytes() +} + // 2.2.2.2.4.25 DNS_RPC_RECORD_NSEC3PARAM type NSEC3PARAMRecord struct { Algorithm uint8 @@ -887,6 +1234,22 @@ func (r *NSEC3PARAMRecord) Parse(data []byte) { r.Salt = data[5 : 5+int(r.SaltLength)] } +func (r *NSEC3PARAMRecord) Encode() []byte { + buf := new(bytes.Buffer) + + buf.WriteByte(r.Algorithm) + + buf.WriteByte(r.Flags) + + binary.Write(buf, binary.BigEndian, r.Iterations) + + buf.WriteByte(r.SaltLength) + + buf.Write(r.Salt) + + return buf.Bytes() +} + // 2.2.2.2.4.26 DNS_RPC_RECORD_TLSA type TLSARecord struct { CertificateUsage uint8 @@ -902,33 +1265,69 @@ func (r *TLSARecord) Parse(data []byte) { r.CertificateAssociationData = data[3:] } +func (r *TLSARecord) Encode() []byte { + buf := new(bytes.Buffer) + + buf.WriteByte(r.CertificateUsage) + + buf.WriteByte(r.Selector) + + buf.WriteByte(r.MatchingType) + + buf.Write(r.CertificateAssociationData) + + return buf.Bytes() +} + // 2.2.2.2.4.21 DNS_RPC_RECORD_WINS type WINSRecord struct { MappingFlag uint32 LookupTimeout uint32 CacheTimeout uint32 - WinsServers [4]uint32 + WinsSrvCount uint32 + WinsServers []uint32 } func (r *WINSRecord) Parse(data []byte) { r.MappingFlag = binary.BigEndian.Uint32(data[:4]) r.LookupTimeout = binary.BigEndian.Uint32(data[4:8]) r.CacheTimeout = binary.BigEndian.Uint32(data[8:12]) - for i := 0; i < 4; i++ { - r.WinsServers[i] = binary.BigEndian.Uint32(data[12+i*4 : 16+i*4]) + r.WinsSrvCount = binary.BigEndian.Uint32(data[12:16]) + + for i := uint32(0); i < r.WinsSrvCount; i++ { + addr := binary.BigEndian.Uint32(data[16+i*4 : 20+i*4]) + r.WinsServers = append(r.WinsServers, addr) } } +func (r *WINSRecord) Encode() []byte { + buf := new(bytes.Buffer) + + binary.Write(buf, binary.BigEndian, r.MappingFlag) + + binary.Write(buf, binary.BigEndian, r.LookupTimeout) + + binary.Write(buf, binary.BigEndian, r.CacheTimeout) + + binary.Write(buf, binary.BigEndian, r.WinsSrvCount) + + for i := uint32(0); i < r.WinsSrvCount; i++ { + binary.Write(buf, binary.BigEndian, r.WinsServers[i]) + } + + return buf.Bytes() +} + // 2.2.2.2.4.22 DNS_RPC_RECORD_WINSR type WINSRRecord struct { - Mapping uint32 + MappingFlag uint32 LookupTimeout uint32 CacheTimeout uint32 NameResultDomain string } func (r *WINSRRecord) Parse(data []byte) { - r.Mapping = binary.BigEndian.Uint32(data[:4]) + r.MappingFlag = binary.BigEndian.Uint32(data[:4]) r.LookupTimeout = binary.BigEndian.Uint32(data[4:8]) r.CacheTimeout = binary.BigEndian.Uint32(data[8:12]) @@ -937,3 +1336,17 @@ func (r *WINSRRecord) Parse(data []byte) { r.NameResultDomain = parsedName } } + +func (r *WINSRRecord) Encode() []byte { + buf := new(bytes.Buffer) + + binary.Write(buf, binary.BigEndian, r.MappingFlag) + + binary.Write(buf, binary.BigEndian, r.LookupTimeout) + + binary.Write(buf, binary.BigEndian, r.CacheTimeout) + + EncodeCountName(buf, r.NameResultDomain) + + return buf.Bytes() +} diff --git a/pkg/ldaputils/actions.go b/pkg/ldaputils/actions.go index 97c0e57..a677303 100644 --- a/pkg/ldaputils/actions.go +++ b/pkg/ldaputils/actions.go @@ -457,8 +457,15 @@ func (lc *LDAPConn) AddUser(objectName string, parentDN string) error { return lc.Conn.Add(addRequest) } -func (lc *LDAPConn) AddADIDNSZone(objectName string, props []adidns.DNSProperty) error { - addRequest := ldap.NewAddRequest("DC="+objectName+",CN=MicrosoftDNS,DC=DomainDNSZones,"+lc.RootDN, nil) +func (lc *LDAPConn) AddADIDNSZone(objectName string, props []adidns.DNSProperty, isForest bool) (string, error) { + zoneContainer := "DomainDnsZones" + if isForest { + zoneContainer = "ForestDnsZones" + } + + zoneDN := fmt.Sprintf("DC=%s,CN=MicrosoftDNS,DC=%s,%s", objectName, zoneContainer, lc.RootDN) + + addRequest := ldap.NewAddRequest(zoneDN, nil) addRequest.Attribute("objectClass", []string{"top", "dnsZone"}) addRequest.Attribute("cn", []string{"Zone"}) addRequest.Attribute("name", []string{objectName}) @@ -473,7 +480,7 @@ func (lc *LDAPConn) AddADIDNSZone(objectName string, props []adidns.DNSProperty) addRequest.Attribute("dNSProperty", dNSPropertyList) - return lc.Conn.Add(addRequest) + return zoneDN, lc.Conn.Add(addRequest) } func (lc *LDAPConn) GetADIDNSZones(name string, isForest bool) ([]adidns.DNSZone, error) { @@ -572,11 +579,12 @@ func (lc *LDAPConn) GetADIDNSNodes(zoneDN string) ([]adidns.DNSNode, error) { return nodes, nil } -func (lc *LDAPConn) AddADIDNSNode(objectName string, records []adidns.DNSRecord) error { - addRequest := ldap.NewAddRequest("DC="+objectName+",CN=MicrosoftDNS,DC=DomainDNSZones,"+lc.RootDN, nil) +func (lc *LDAPConn) AddADIDNSNode(nodeName string, zoneDN string, records []adidns.DNSRecord) (string, error) { + nodeDN := fmt.Sprintf("DC=%s,%s", nodeName, zoneDN) + + addRequest := ldap.NewAddRequest(nodeDN, nil) addRequest.Attribute("objectClass", []string{"top", "dnsNode"}) - addRequest.Attribute("cn", []string{"Zone"}) - addRequest.Attribute("name", []string{objectName}) + addRequest.Attribute("name", []string{nodeName}) var dNSRecordList []string for _, record := range records { @@ -586,9 +594,47 @@ func (lc *LDAPConn) AddADIDNSNode(objectName string, records []adidns.DNSRecord) } } - addRequest.Attribute("dnsRecord", dNSRecordList) + if len(dNSRecordList) > 0 { + addRequest.Attribute("dnsRecord", dNSRecordList) + } + + return nodeDN, lc.Conn.Add(addRequest) +} - return lc.Conn.Add(addRequest) +func (lc *LDAPConn) AddADIDNSRecords(nodeDN string, records []adidns.DNSRecord) error { + modifyRequest := ldap.NewModifyRequest(nodeDN, nil) + + var dNSRecordList []string + for _, record := range records { + encodedProp, err := record.Encode() + if err == nil { + dNSRecordList = append(dNSRecordList, string(encodedProp)) + } + } + + if len(dNSRecordList) > 0 { + modifyRequest.Add("dnsRecord", dNSRecordList) + } + + return lc.Conn.Modify(modifyRequest) +} + +func (lc *LDAPConn) ReplaceADIDNSRecords(nodeDN string, records []adidns.DNSRecord) error { + modifyRequest := ldap.NewModifyRequest(nodeDN, nil) + + var dNSRecordList []string + for _, record := range records { + encodedProp, err := record.Encode() + if err == nil { + dNSRecordList = append(dNSRecordList, string(encodedProp)) + } + } + + if len(dNSRecordList) > 0 { + modifyRequest.Replace("dnsRecord", dNSRecordList) + } + + return lc.Conn.Modify(modifyRequest) } // Attributes diff --git a/tui/ace.go b/tui/ace.go index 02980f6..b1c39ad 100644 --- a/tui/ace.go +++ b/tui/ace.go @@ -128,7 +128,8 @@ func removeAce(aceIdx int) { err = lc.ModifyDACL(object, string(newSd)) if err == nil { - go updateDaclEntries() + go app.QueueUpdateDraw(updateDaclEntries) + updateLog("ACE deleted for object '"+object+"'", "green") if aceIdx > 0 { @@ -236,7 +237,7 @@ func createOrUpdateAce(aceIdx int, newAllowOrDeny bool, newACEFlags int, newMask err = lc.ModifyDACL(object, string(newSd)) if err == nil { - go updateDaclEntries() + go app.QueueUpdateDraw(updateDaclEntries) updateLog("DACL updated successfully!", "green") // Update selection @@ -839,9 +840,12 @@ func loadAceEditorForm(aceIdx int) { app.SetRoot(appPanel, true).SetFocus(daclEntriesPanel) } }) + assignButtonTheme(updateBtn) + cancelBtn := tview.NewButton("Go Back").SetSelectedFunc(func() { app.SetRoot(appPanel, true).SetFocus(daclEntriesPanel) }) + assignButtonTheme(cancelBtn) currentAceTable := tview.NewTable(). SetBorders(false). diff --git a/tui/dacl.go b/tui/dacl.go index df25c3e..28f13cf 100644 --- a/tui/dacl.go +++ b/tui/dacl.go @@ -217,7 +217,7 @@ func initDaclPage(includeCurSchema bool) { daclPage.SetInputCapture(daclPageKeyHandler) objectNameInputDacl.SetDoneFunc(func(tcell.Key) { updateLog("Fetching DACL for '"+objectNameInputDacl.GetText()+"'", "yellow") - go updateDaclEntries() + go app.QueueUpdateDraw(updateDaclEntries) }) } @@ -367,8 +367,6 @@ func updateDaclEntries() { } else { updateLog(fmt.Sprint(err), "red") } - - app.Draw() } func daclRotateFocus() { @@ -465,7 +463,7 @@ func loadChangeOwnerForm() { updateLog("Owner for '"+object+"' changed to '"+newOwner+"'", "green") - go updateDaclEntries() + go app.QueueUpdateDraw(updateDaclEntries) } else { updateLog(fmt.Sprint(err), "red") } @@ -542,8 +540,7 @@ func loadChangeControlFlagsForm() { if err == nil { updateLog("Control flags updated for '"+object+"'", "green") - - go updateDaclEntries() + go app.QueueUpdateDraw(updateDaclEntries) } else { updateLog(fmt.Sprint(err), "red") } diff --git a/tui/dns.go b/tui/dns.go index eb261fc..565bb46 100644 --- a/tui/dns.go +++ b/tui/dns.go @@ -43,7 +43,7 @@ var forestZones []adidns.DNSZone var zoneCache = make(map[string]adidns.DNSZone, 0) var nodeCache = make(map[string]adidns.DNSNode, 0) -var recordCache = make(map[string][]adidns.RecordContainer, 0) +var recordCache = make(map[string][]adidns.FriendlyRecord, 0) func getParentZone(objectDN string) (adidns.DNSZone, error) { objectDNParts := strings.Split(objectDN, ",") @@ -77,7 +77,7 @@ func exportADIDNSToFile(currentNode *tview.TreeNode, outputFilename string) { zoneProps := make(map[string]any, 0) for _, prop := range zone.Props { propName := adidns.FindPropName(prop.Id) - zoneProps[propName] = prop.Data + zoneProps[propName] = prop.ExportFormat() } exportMap[objectDN] = map[string]any{ @@ -95,9 +95,8 @@ func exportADIDNSToFile(currentNode *tview.TreeNode, outputFilename string) { for idx, rec := range records { recordType := node.Records[idx].PrintType() recordsObj = append(recordsObj, map[string]any{ - "Type": recordType, - "Name": rec.Name, - "Contents": rec.Contents, + "Type": recordType, + "Value": rec, }) } @@ -112,15 +111,28 @@ func exportADIDNSToFile(currentNode *tview.TreeNode, outputFilename string) { // to include in the export _, alreadyExported := exportMap[parentZone.DN] if !alreadyExported { + parentZoneProps := make(map[string]any, 0) + for _, prop := range parentZone.Props { + propName := adidns.FindPropName(prop.Id) + parentZoneProps[propName] = prop.ExportFormat() + } + exportMap[parentZone.DN] = map[string]any{ - "Zone": parentZone, + "Zone": map[string]any{ + "Name": parentZone.Name, + "DN": parentZone.DN, + "Props": parentZoneProps, + }, "Nodes": nodesMap, } } parentZone := (exportMap[parentZone.DN]).(map[string]any) parentZoneNodes := parentZone["Nodes"].(map[string]any) - parentZoneNodes[node.DN] = recordsObj + parentZoneNodes[node.DN] = map[string]any{ + "Name": node.Name, + "Records": recordsObj, + } } } } @@ -138,125 +150,171 @@ func exportADIDNSToFile(currentNode *tview.TreeNode, outputFilename string) { } } -func showZoneOrNodeDetails(objectDN string) { - zone, ok := zoneCache[objectDN] - if ok { - dnsSidePanel.SetTitle("dnsZone Properties") - dnsSidePanel.SwitchToPage("zone-props") - - propsMap := make(map[uint32]adidns.DNSProperty, 0) - for _, prop := range zone.Props { - propsMap[prop.Id] = prop - } +func showZoneDetails(zone *adidns.DNSZone) { + dnsSidePanel.SetTitle("Zone Properties") + dnsSidePanel.SwitchToPage("zone-props") - dnsZoneProps.SetCell(0, 0, tview.NewTableCell("Id").SetSelectable(false)) - dnsZoneProps.SetCell(0, 1, tview.NewTableCell("Description").SetSelectable(false)) - dnsZoneProps.SetCell(0, 2, tview.NewTableCell("Value").SetSelectable(false)) + propsMap := make(map[uint32]adidns.DNSProperty, 0) + for _, prop := range zone.Props { + propsMap[prop.Id] = prop + } - idx := 1 - for _, prop := range adidns.DnsPropertyIds { - dnsZoneProps.SetCell(idx, 0, tview.NewTableCell(fmt.Sprint(prop.Id))) - dnsZoneProps.SetCell(idx, 1, tview.NewTableCell(prop.Name)) + dnsZoneProps.SetCell(0, 0, tview.NewTableCell("Id").SetSelectable(false)) + dnsZoneProps.SetCell(0, 1, tview.NewTableCell("Description").SetSelectable(false)) + dnsZoneProps.SetCell(0, 2, tview.NewTableCell("Value").SetSelectable(false)) - mappedProp, ok := propsMap[prop.Id] - if ok { - mappedPropStr := fmt.Sprintf("%v", mappedProp.Data) - if FormatAttrs { - mappedPropStr = mappedProp.Format(TimeFormat) - } + idx := 1 + for _, prop := range adidns.DnsPropertyIds { + dnsZoneProps.SetCell(idx, 0, tview.NewTableCell(fmt.Sprint(prop.Id))) + dnsZoneProps.SetCell(idx, 1, tview.NewTableCell(prop.Name)) - if Colors { - color, change := adidns.GetPropCellColor(mappedProp.Id, mappedPropStr) - if change { - mappedPropStr = fmt.Sprintf("[%s]%s[c]", color, mappedPropStr) - } - } + mappedProp, ok := propsMap[prop.Id] + if ok { + mappedPropStr := fmt.Sprintf("%v", mappedProp.Data) + if FormatAttrs { + mappedPropStr = mappedProp.PrintFormat(TimeFormat) + } - dnsZoneProps.SetCell(idx, 2, tview.NewTableCell(mappedPropStr)) - } else { - notSpecifiedVal := "Not specified" - if Colors { - notSpecifiedVal = fmt.Sprintf("[gray]%s[c]", notSpecifiedVal) + if Colors { + color, change := adidns.GetPropCellColor(mappedProp.Id, mappedPropStr) + if change { + mappedPropStr = fmt.Sprintf("[%s]%s[c]", color, mappedPropStr) } + } - dnsZoneProps.SetCell(idx, 2, tview.NewTableCell(notSpecifiedVal)) + dnsZoneProps.SetCell(idx, 2, tview.NewTableCell(mappedPropStr)) + } else { + notSpecifiedVal := "Not specified" + if Colors { + notSpecifiedVal = fmt.Sprintf("[gray]%s[c]", notSpecifiedVal) } - idx += 1 + + dnsZoneProps.SetCell(idx, 2, tview.NewTableCell(notSpecifiedVal)) } + idx += 1 + } +} - return +type recordRef struct { + nodeDN string + idx int +} + +func reloadADIDNSZone(currentNode *tview.TreeNode) { + objectDN := currentNode.GetReference().(string) + + updateLog("Fetching nodes for zone '"+objectDN+"'...", "yellow") + + numLoadedNodes := loadZoneNodes(currentNode) + + if numLoadedNodes >= 0 { + updateLog(fmt.Sprintf("Loaded %d nodes (%s)", numLoadedNodes, objectDN), "green") } - node, ok := nodeCache[objectDN] - if ok { - parsedRecords, _ := recordCache[objectDN] - parentZone, err := getParentZone(objectDN) - if err == nil { - dnsSidePanel.SetTitle(fmt.Sprintf("dnsNode Records (%s)", parentZone.Name)) - } else { - dnsSidePanel.SetTitle("dnsNode Records") - } + if len(currentNode.GetChildren()) != 0 && !currentNode.IsExpanded() { + currentNode.SetExpanded(true) + } +} - dnsSidePanel.SwitchToPage("node-records") +func reloadADIDNSNode(currentNode *tview.TreeNode) { + objectDN := currentNode.GetReference().(string) - rootNode := tview.NewTreeNode(node.Name) - dnsNodeRecords.SetRoot(rootNode) + node, err := lc.GetADIDNSNode(objectDN) + nodeCache[node.DN] = node - for idx, record := range node.Records { - unixTimestamp := record.UnixTimestamp() - timeObj := time.Unix(unixTimestamp, 0) + if err == nil { + updateLog(fmt.Sprintf("Loaded node '%s'", node.DN), "green") + } else { + updateLog(fmt.Sprint(err), "red") + } - formattedTime := fmt.Sprintf("%d", unixTimestamp) - timeDistance := time.Since(timeObj) - if FormatAttrs { - if unixTimestamp != -1 { - formattedTime = timeObj.Format(TimeFormat) + storeNodeRecords(node) + showDetails(node.DN) +} + +func showNodeDetails(node *adidns.DNSNode, records []adidns.FriendlyRecord, targetTree *tview.TreeView) { + rootNode := tview.NewTreeNode(node.Name) + + for idx, record := range node.Records { + unixTimestamp := record.UnixTimestamp() + timeObj := time.Unix(unixTimestamp, 0) + + formattedTime := fmt.Sprintf("%d", unixTimestamp) + timeDistance := time.Since(timeObj) + if FormatAttrs { + if unixTimestamp != -1 { + formattedTime = timeObj.Format(TimeFormat) + } else { + formattedTime = "static" + } + } + + if Colors { + daysDiff := timeDistance.Hours() / 24 + color := "gray" + if unixTimestamp != -1 { + if daysDiff <= 7 { + color = "green" + } else if daysDiff <= 90 { + color = "yellow" } else { - formattedTime = "static" + color = "red" } } - if Colors { - daysDiff := timeDistance.Hours() / 24 - color := "gray" - if unixTimestamp != -1 { - if daysDiff <= 7 { - color = "green" - } else if daysDiff <= 90 { - color = "yellow" - } else { - color = "red" - } - } + formattedTime = fmt.Sprintf("[%s]%s[c]", color, formattedTime) + } - formattedTime = fmt.Sprintf("[%s]%s[c]", color, formattedTime) - } + recordName := fmt.Sprintf( + "%s [TTL=%d] (%s)", + record.PrintType(), + record.TTLSeconds, + formattedTime, + ) - nodeName := fmt.Sprintf( - "%s [TTL=%d] (%s)", - record.PrintType(), - record.TTLSeconds, - formattedTime, - ) - - recordTreeNode := tview.NewTreeNode(nodeName). - SetSelectable(true) - - parsedRecord := parsedRecords[idx] - recordFields := parsedRecord.DumpFields() - for _, field := range recordFields { - fieldName := tview.Escape(fmt.Sprintf("%s=%v", field.Name, field.Value)) - fieldTreeNode := tview.NewTreeNode(fieldName) - recordTreeNode.AddChild(fieldTreeNode) - } + recordTreeNode := tview.NewTreeNode(recordName). + SetReference(recordRef{node.DN, idx}) + + parsedRecord := records[idx] + recordFields := adidns.DumpRecordFields(parsedRecord) + for idx, field := range recordFields { + fieldName := tview.Escape(fmt.Sprintf("%s=%v", field.Name, field.Value)) + fieldTreeNode := tview.NewTreeNode(fieldName).SetReference(idx) + recordTreeNode.AddChild(fieldTreeNode) + } + + rootNode.AddChild(recordTreeNode) + } + + targetTree.SetRoot(rootNode) + go func() { + app.Draw() + }() +} + +func showDetails(objectDN string) { + zone, ok := zoneCache[objectDN] + if ok { + showZoneDetails(&zone) + } - rootNode.AddChild(recordTreeNode) + node, ok := nodeCache[objectDN] + if ok { + parsedRecords, _ := recordCache[objectDN] + parentZone, err := getParentZone(objectDN) + if err == nil { + dnsSidePanel.SetTitle(fmt.Sprintf("Records (%s)", parentZone.Name)) + } else { + dnsSidePanel.SetTitle("Records") } + dnsSidePanel.SwitchToPage("node-records") + + showNodeDetails(&node, parsedRecords, dnsNodeRecords) } } func storeNodeRecords(node adidns.DNSNode) { - records := make([]adidns.RecordContainer, 0) + records := make([]adidns.FriendlyRecord, 0) var fRec adidns.FriendlyRecord for _, record := range node.Records { @@ -349,12 +407,7 @@ func storeNodeRecords(node adidns.DNSNode) { fRec.Parse(record.Data) - container := adidns.RecordContainer{ - node.Name, - fRec, - } - - records = append(records, container) + records = append(records, fRec) } recordCache[node.DN] = records @@ -408,29 +461,23 @@ func initADIDNSPage() { dnsQueryPanel = tview.NewInputField() dnsQueryPanel. SetPlaceholder("Type a DNS zone or leave it blank and hit enter to query all zones"). - SetPlaceholderStyle(placeholderStyle). - SetPlaceholderTextColor(placeholderTextColor). - SetFieldBackgroundColor(fieldBackgroundColor). SetTitle("Zone Search"). SetBorder(true) + assignInputFieldTheme(dnsQueryPanel) dnsNodeFilter = tview.NewInputField() dnsNodeFilter. SetPlaceholder("Regex for dnsNode name"). - SetPlaceholderStyle(placeholderStyle). - SetPlaceholderTextColor(placeholderTextColor). - SetFieldBackgroundColor(fieldBackgroundColor). SetTitle("dnsNode Filter"). SetBorder(true) + assignInputFieldTheme(dnsNodeFilter) dnsZoneFilter = tview.NewInputField() dnsZoneFilter. SetPlaceholder("Regex for dnsZone name"). - SetPlaceholderStyle(placeholderStyle). - SetPlaceholderTextColor(placeholderTextColor). - SetFieldBackgroundColor(fieldBackgroundColor). SetTitle("dnsZone Filter"). SetBorder(true) + assignInputFieldTheme(dnsZoneFilter) dnsZoneProps = tview.NewTable(). SetSelectable(true, true). @@ -440,7 +487,7 @@ func initADIDNSPage() { dnsTreePanel = tview.NewTreeView() dnsTreePanel. - SetTitle("Search Results"). + SetTitle("Zones & Nodes"). SetBorder(true) dnsTreePanel.SetChangedFunc(func(objNode *tview.TreeNode) { @@ -452,7 +499,7 @@ func initADIDNSPage() { } nodeDN := objNodeRef.(string) - showZoneOrNodeDetails(nodeDN) + showDetails(nodeDN) }) dnsZoneFilter.SetChangedFunc(func(text string) { @@ -471,43 +518,19 @@ func initADIDNSPage() { return event } - objectDN := currentNode.GetReference().(string) + level := currentNode.GetLevel() switch event.Rune() { case 'r', 'R': - if currentNode == dnsTreePanel.GetRoot() { - return nil - } - - go func() { - level := currentNode.GetLevel() - if level == 1 { - updateLog("Fetching nodes for zone '"+objectDN+"'...", "yellow") - - numLoadedNodes := loadZoneNodes(currentNode) - - if numLoadedNodes >= 0 { - updateLog(fmt.Sprintf("Loaded %d nodes (%s)", numLoadedNodes, objectDN), "green") - } - - if len(currentNode.GetChildren()) != 0 && !currentNode.IsExpanded() { - currentNode.SetExpanded(true) - } + go app.QueueUpdateDraw(func() { + if level == 0 { + go queryDnsZones(dnsQueryPanel.GetText()) + } else if level == 1 { + reloadADIDNSZone(currentNode) } else if level == 2 { - node, err := lc.GetADIDNSNode(objectDN) - - if err == nil { - updateLog(fmt.Sprintf("Loaded node '%s'", node.DN), "green") - } else { - updateLog(fmt.Sprint(err), "red") - } - - storeNodeRecords(node) - showZoneOrNodeDetails(node.DN) + reloadADIDNSNode(currentNode) } - - app.Draw() - }() + }) return nil } @@ -532,9 +555,24 @@ func initADIDNSPage() { } return nil case tcell.KeyDelete: - if currentNode.GetReference() != nil { - openDeleteObjectForm(currentNode, nil) + if currentNode.GetReference() == nil { + return nil } + + openDeleteObjectForm(currentNode, func() { + level := currentNode.GetLevel() + if level == 1 { + go queryDnsZones(dnsQueryPanel.GetText()) + } else if level == 2 { + pathToCurrent := dnsTreePanel.GetPath(currentNode) + if len(pathToCurrent) > 1 { + parentNode := pathToCurrent[len(pathToCurrent)-2] + reloadADIDNSZone(parentNode) + } + } + }) + + return nil case tcell.KeyCtrlS: unixTimestamp := time.Now().UnixMilli() outputFilename := fmt.Sprintf("%d_dns.json", unixTimestamp) @@ -586,6 +624,10 @@ func dnsPageKeyHandler(event *tcell.EventKey) *tcell.EventKey { } func rebuildDnsTree(rootNode *tview.TreeNode) int { + if rootNode == nil { + return 0 + } + expandedZones := make(map[string]bool) childrenZones := rootNode.GetChildren() for _, child := range childrenZones { @@ -654,27 +696,25 @@ func rebuildDnsTree(rootNode *tview.TreeNode) int { return totalNodes } -func dnsQueryDoneHandler(key tcell.Key) { +func queryDnsZones(targetZone string) { + dnsRunControl.Lock() + if dnsRunning { + dnsRunControl.Unlock() + updateLog("Another query is still running...", "yellow") + return + } + dnsRunning = true + dnsRunControl.Unlock() + clear(nodeCache) clear(zoneCache) clear(domainZones) clear(forestZones) clear(recordCache) - go func() { - dnsRunControl.Lock() - if dnsRunning { - dnsRunControl.Unlock() - updateLog("Another query is still running...", "yellow") - return - } - dnsRunning = true - dnsRunControl.Unlock() - + app.QueueUpdateDraw(func() { updateLog("Querying ADIDNS zones...", "yellow") - targetZone := dnsQueryPanel.GetText() - domainZones, _ = lc.GetADIDNSZones(targetZone, false) forestZones, _ = lc.GetADIDNSZones(targetZone, true) @@ -682,7 +722,6 @@ func dnsQueryDoneHandler(key tcell.Key) { if totalZones == 0 { updateLog("No ADIDNS zones found", "red") rootNode.ClearChildren() - app.Draw() dnsRunControl.Lock() dnsRunning = false @@ -702,13 +741,15 @@ func dnsQueryDoneHandler(key tcell.Key) { updateLog(fmt.Sprintf("Found %d ADIDNS zones and %d nodes", totalZones, totalNodes), "green") app.SetFocus(dnsTreePanel) + }) - app.Draw() + dnsRunControl.Lock() + dnsRunning = false + dnsRunControl.Unlock() +} - dnsRunControl.Lock() - dnsRunning = false - dnsRunControl.Unlock() - }() +func dnsQueryDoneHandler(key tcell.Key) { + go queryDnsZones(dnsQueryPanel.GetText()) } func dnsRotateFocus() { @@ -718,6 +759,10 @@ func dnsRotateFocus() { case dnsTreePanel: app.SetFocus(dnsQueryPanel) case dnsQueryPanel: + app.SetFocus(dnsNodeFilter) + case dnsNodeFilter: + app.SetFocus(dnsZoneFilter) + case dnsZoneFilter: app.SetFocus(dnsZoneProps) case dnsZoneProps: app.SetFocus(dnsTreePanel) diff --git a/tui/explorer.go b/tui/explorer.go index 7b7e614..e6287da 100644 --- a/tui/explorer.go +++ b/tui/explorer.go @@ -103,7 +103,7 @@ func initExplorerPage() { func expandTreeNode(node *tview.TreeNode) { if !node.IsExpanded() { if len(node.GetChildren()) == 0 { - go func() { + go app.QueueUpdateDraw(func() { updateLog("Loading children ("+node.GetReference().(string)+")", "yellow") loadChildren(node) @@ -115,8 +115,7 @@ func expandTreeNode(node *tview.TreeNode) { } else { updateLog("Node "+node.GetReference().(string)+" has no children", "green") } - app.Draw() - }() + }) } else { node.SetExpanded(true) } @@ -171,14 +170,6 @@ func findEntryInChildren(dn string, parent *tview.TreeNode) int { return -1 } -func handleEscapeToTree(event *tcell.EventKey) *tcell.EventKey { - if event.Key() == tcell.KeyEscape { - app.SetRoot(appPanel, true).SetFocus(treePanel) - return nil - } - return event -} - func exportCacheToFile(currentNode *tview.TreeNode, cache *EntryCache, outputFilename string) { exportMap := make(map[string]*ldap.Entry) currentNode.Walk(func(node, parent *tview.TreeNode) bool { @@ -235,7 +226,7 @@ func openUpdateUacForm(node *tview.TreeNode, cache *EntryCache, done func()) { SetButtonBackgroundColor(formButtonBackgroundColor). SetButtonTextColor(formButtonTextColor). SetButtonActivatedStyle(formButtonActivatedStyle) - updateUacForm.SetInputCapture(handleEscapeToTree) + updateUacForm.SetInputCapture(handleEscape(treePanel)) updateUacForm.SetItemPadding(0) var checkboxState int = 0 @@ -314,7 +305,7 @@ func openCreateObjectForm(node *tview.TreeNode, done func()) { SetButtonBackgroundColor(formButtonBackgroundColor). SetButtonTextColor(formButtonTextColor). SetButtonActivatedStyle(formButtonActivatedStyle). - SetInputCapture(handleEscapeToTree) + SetInputCapture(handleEscape(treePanel)) createObjectForm. AddButton("Go Back", func() { @@ -402,7 +393,7 @@ func openAddMemberToGroupForm(groupDN string) { SetButtonBackgroundColor(formButtonBackgroundColor). SetButtonTextColor(formButtonTextColor). SetButtonActivatedStyle(formButtonActivatedStyle). - SetInputCapture(handleEscapeToTree) + SetInputCapture(handleEscape(treePanel)) addMemberForm. AddButton("Go Back", func() { @@ -441,7 +432,7 @@ func treePanelKeyHandler(event *tcell.EventKey) *tcell.EventKey { switch event.Rune() { case 'r', 'R': - go func() { + go app.QueueUpdateDraw(func() { updateLog("Reloading node "+baseDN, "yellow") explorerCache.Delete(baseDN) @@ -451,9 +442,7 @@ func treePanelKeyHandler(event *tcell.EventKey) *tcell.EventKey { loadChildren(currentNode) updateLog("Node "+baseDN+" reloaded", "green") - - app.Draw() - }() + }) return event } @@ -531,10 +520,10 @@ func treePanelKeyHandler(event *tcell.EventKey) *tcell.EventKey { } func treePanelChangeHandler(node *tview.TreeNode) { - go func() { + go app.QueueUpdateDraw(func() { // TODO: Implement cancellation reloadExplorerAttrsPanel(node, CacheEntries) - }() + }) } func explorerRotateFocus() { @@ -577,7 +566,7 @@ func openPasswordChangeForm(node *tview.TreeNode) { SetButtonBackgroundColor(formButtonBackgroundColor). SetButtonTextColor(formButtonTextColor). SetButtonActivatedStyle(formButtonActivatedStyle) - changePasswordForm.SetInputCapture(handleEscapeToTree) + changePasswordForm.SetInputCapture(handleEscape(treePanel)) app.SetRoot(changePasswordForm, true).SetFocus(changePasswordForm) } @@ -611,7 +600,7 @@ func openMoveObjectForm(node *tview.TreeNode, done func(string)) { }) moveObjectForm.SetTitle("Move Object").SetBorder(true) - moveObjectForm.SetInputCapture(handleEscapeToTree) + moveObjectForm.SetInputCapture(handleEscape(treePanel)) moveObjectForm. SetButtonBackgroundColor(formButtonBackgroundColor). SetButtonTextColor(formButtonTextColor). diff --git a/tui/gpo.go b/tui/gpo.go index 6e9dcc6..0574235 100644 --- a/tui/gpo.go +++ b/tui/gpo.go @@ -210,117 +210,117 @@ func updateGPOEntries() { gpoListPanel.Clear() gpoPath.Clear() - gpoListPanel.SetCell(0, 0, tview.NewTableCell("Name").SetSelectable(false)) - gpoListPanel.SetCell(0, 1, tview.NewTableCell("Created").SetSelectable(false)) - gpoListPanel.SetCell(0, 2, tview.NewTableCell("Changed").SetSelectable(false)) - gpoListPanel.SetCell(0, 3, tview.NewTableCell("GUID").SetSelectable(false)) + app.QueueUpdateDraw(func() { + gpoListPanel.SetCell(0, 0, tview.NewTableCell("Name").SetSelectable(false)) + gpoListPanel.SetCell(0, 1, tview.NewTableCell("Created").SetSelectable(false)) + gpoListPanel.SetCell(0, 2, tview.NewTableCell("Changed").SetSelectable(false)) + gpoListPanel.SetCell(0, 3, tview.NewTableCell("GUID").SetSelectable(false)) - // Load all gpLinks - updateLog("Querying all gpLinks", "yellow") - gpLinkObjs, err := lc.Query(lc.RootDN, "(gpLink=*)", ldap.ScopeWholeSubtree, false) + // Load all gpLinks + updateLog("Querying all gpLinks", "yellow") - for _, gpLinkObj := range gpLinkObjs { - gpLinkVals := gpLinkObj.GetAttributeValue("gPLink") + gpLinkObjs, err := lc.Query(lc.RootDN, "(gpLink=*)", ldap.ScopeWholeSubtree, false) - links, _ := ParseGPLinks(gpLinkVals, gpLinkObj.DN) + for _, gpLinkObj := range gpLinkObjs { + gpLinkVals := gpLinkObj.GetAttributeValue("gPLink") - for _, link := range links { - gpLinks[link.GUID] = append(gpLinks[link.GUID], link) - containerLinks[link.Target] = append(containerLinks[link.Target], link.GUID) + links, _ := ParseGPLinks(gpLinkVals, gpLinkObj.DN) + + for _, link := range links { + gpLinks[link.GUID] = append(gpLinks[link.GUID], link) + containerLinks[link.Target] = append(containerLinks[link.Target], link.GUID) + } } - } - updateLog("gpLinks loaded successfully", "green") + updateLog("gpLinks loaded successfully", "green") - // Load all GPOs from corresponding links - gpoQuery := "(objectClass=groupPolicyContainer)" - gpoTarget = gpoTargetInput.GetText() + // Load all GPOs from corresponding links + gpoQuery := "(objectClass=groupPolicyContainer)" + gpoTarget = gpoTargetInput.GetText() - gpoTargetDN := gpoTarget - if gpoTarget != "" { - gpoTargetQuery := fmt.Sprintf("(distinguishedName=%s)", ldap.EscapeFilter(gpoTarget)) - if !strings.Contains(gpoTarget, "=") { - gpoTargetQuery = fmt.Sprintf("(cn=%s)", ldap.EscapeFilter(gpoTarget)) - } + gpoTargetDN := gpoTarget + if gpoTarget != "" { + gpoTargetQuery := fmt.Sprintf("(distinguishedName=%s)", ldap.EscapeFilter(gpoTarget)) + if !strings.Contains(gpoTarget, "=") { + gpoTargetQuery = fmt.Sprintf("(cn=%s)", ldap.EscapeFilter(gpoTarget)) + } - entries, err := lc.Query(lc.RootDN, gpoTargetQuery, ldap.ScopeWholeSubtree, false) + entries, err := lc.Query(lc.RootDN, gpoTargetQuery, ldap.ScopeWholeSubtree, false) - updateLog("Querying for '"+gpoTargetQuery+"'", "yellow") - if err != nil { - updateLog(fmt.Sprint(err), "red") - return - } + updateLog("Querying for '"+gpoTargetQuery+"'", "yellow") + if err != nil { + updateLog(fmt.Sprint(err), "red") + return + } - if len(entries) > 0 { - updateLog("GPO target found ("+entries[0].DN+")", "green") - gpoTargetDN = entries[0].DN - } else { - updateLog("GPO target not found", "red") - app.Draw() - return + if len(entries) > 0 { + updateLog("GPO target found ("+entries[0].DN+")", "green") + gpoTargetDN = entries[0].DN + } else { + updateLog("GPO target not found", "red") + return + } } - } - var applicableGPOs []string + var applicableGPOs []string - dnParts := strings.Split(gpoTargetDN, ",") - for idx := len(dnParts) - 1; idx >= 0; idx -= 1 { - candidateDN := strings.Join(dnParts[idx:], ",") + dnParts := strings.Split(gpoTargetDN, ",") + for idx := len(dnParts) - 1; idx >= 0; idx -= 1 { + candidateDN := strings.Join(dnParts[idx:], ",") - candidateGuids, ok := containerLinks[candidateDN] - if ok { - applicableGPOs = append(applicableGPOs, candidateGuids...) + candidateGuids, ok := containerLinks[candidateDN] + if ok { + applicableGPOs = append(applicableGPOs, candidateGuids...) + } } - } - gpoQuerySuffix := "" - if len(applicableGPOs) > 0 { - gpoQuerySuffix = "name=" + ldap.EscapeFilter(applicableGPOs[0]) - for _, gpoGuid := range applicableGPOs[1:] { - gpoQuerySuffix = "(|(" + gpoQuerySuffix + ")(name=" + ldap.EscapeFilter(gpoGuid) + "))" + gpoQuerySuffix := "" + if len(applicableGPOs) > 0 { + gpoQuerySuffix = "name=" + ldap.EscapeFilter(applicableGPOs[0]) + for _, gpoGuid := range applicableGPOs[1:] { + gpoQuerySuffix = "(|(" + gpoQuerySuffix + ")(name=" + ldap.EscapeFilter(gpoGuid) + "))" + } } - } - - if gpoQuerySuffix != "" { - gpoQuery = "(&(" + gpoQuery + ")(" + gpoQuerySuffix + "))" - } - updateLog("Searching applicable GPOs...", "yellow") + if gpoQuerySuffix != "" { + gpoQuery = "(&(" + gpoQuery + ")(" + gpoQuerySuffix + "))" + } - entries, err := lc.Query(lc.RootDN, gpoQuery, ldap.ScopeWholeSubtree, false) - if err != nil { - updateLog(fmt.Sprint(err), "red") - return - } + updateLog("Searching applicable GPOs...", "yellow") - if len(entries) > 0 { - updateLog("GPOs query completed ("+strconv.Itoa(len(entries))+" GPOs found)", "green") - } else { - updateLog("No applicable GPOs found", "red") - } + entries, err := lc.Query(lc.RootDN, gpoQuery, ldap.ScopeWholeSubtree, false) + if err != nil { + updateLog(fmt.Sprint(err), "red") + return + } - for idx, entry := range entries { - gpoGuid := entry.GetAttributeValue("cn") - gpEntry[gpoGuid] = entry + if len(entries) > 0 { + updateLog("GPOs query completed ("+strconv.Itoa(len(entries))+" GPOs found)", "green") + } else { + updateLog("No applicable GPOs found", "red") + } - gpoName := entry.GetAttributeValue("displayName") + for idx, entry := range entries { + gpoGuid := entry.GetAttributeValue("cn") + gpEntry[gpoGuid] = entry - gpoCreated := entry.GetAttributeValue("whenCreated") - gpoChanged := entry.GetAttributeValue("whenChanged") + gpoName := entry.GetAttributeValue("displayName") - gpoListPanel.SetCellSimple(idx+1, 0, gpoName) - gpoListPanel.SetCellSimple(idx+1, 1, ldaputils.FormatLDAPTime(gpoCreated, TimeFormat)) - gpoListPanel.SetCellSimple(idx+1, 2, ldaputils.FormatLDAPTime(gpoChanged, TimeFormat)) - gpoListPanel.SetCellSimple(idx+1, 3, gpoGuid) - } + gpoCreated := entry.GetAttributeValue("whenCreated") + gpoChanged := entry.GetAttributeValue("whenChanged") - if len(entries) > 0 { - gpoListPanel.SetTitle("Applied GPOs (" + strconv.Itoa(len(entries)) + ")") - gpoListPanel.Select(1, 0) + gpoListPanel.SetCellSimple(idx+1, 0, gpoName) + gpoListPanel.SetCellSimple(idx+1, 1, ldaputils.FormatLDAPTime(gpoCreated, TimeFormat)) + gpoListPanel.SetCellSimple(idx+1, 2, ldaputils.FormatLDAPTime(gpoChanged, TimeFormat)) + gpoListPanel.SetCellSimple(idx+1, 3, gpoGuid) + } - app.SetFocus(gpoListPanel) - } + if len(entries) > 0 { + gpoListPanel.SetTitle("Applied GPOs (" + strconv.Itoa(len(entries)) + ")") + gpoListPanel.Select(1, 0) - app.Draw() + app.SetFocus(gpoListPanel) + } + }) } func exportCurrentGpos() { diff --git a/tui/main.go b/tui/main.go index 362ac0a..49ede94 100644 --- a/tui/main.go +++ b/tui/main.go @@ -169,8 +169,9 @@ func upgradeStartTLS() { } func reconnectLdap() { - // TODO: Check possible race conditions - go setupLDAPConn() + go app.QueueUpdateDraw(func() { + setupLDAPConn() + }) } func openConfigForm() { @@ -513,17 +514,15 @@ func setupTimeFormat(f string) string { } func updateStateBox(target *tview.TextView, control bool) { - go func() { - app.QueueUpdateDraw(func() { - if control { - target.SetText("ON") - target.SetTextColor(tcell.GetColor("green")) - } else { - target.SetText("OFF") - target.SetTextColor(tcell.GetColor("red")) - } - }) - }() + go app.QueueUpdateDraw(func() { + if control { + target.SetText("ON") + target.SetTextColor(tcell.GetColor("green")) + } else { + target.SetText("OFF") + target.SetTextColor(tcell.GetColor("red")) + } + }) } func updateLog(msg string, color string) { diff --git a/tui/search.go b/tui/search.go index 2dbcc72..14485c3 100644 --- a/tui/search.go +++ b/tui/search.go @@ -278,38 +278,38 @@ func searchQueryDoneHandler(key tcell.Key) { childNode, ok := searchLoadedDNs[partialDN] if !ok { - if i == 0 { - // Leaf node - nodeName = entryName - childNode = tview.NewTreeNode(nodeName). - SetReference(entry.DN). - SetExpanded(false). - SetSelectable(true) - - if Colors { - color, changed := ldaputils.GetEntryColor(entry) - if changed { - childNode.SetColor(color) + app.QueueUpdateDraw(func() { + if i == 0 { + // Leaf node + nodeName = entryName + childNode = tview.NewTreeNode(nodeName). + SetReference(entry.DN). + SetExpanded(false). + SetSelectable(true) + + if Colors { + color, changed := ldaputils.GetEntryColor(entry) + if changed { + childNode.SetColor(color) + } } - } - currentNode.AddChild(childNode) - - if firstLeaf { - searchTreePanel.SetCurrentNode(childNode) - firstLeaf = false - } + currentNode.AddChild(childNode) - searchCache.Add(entry.DN, entry) - } else { - // Non-leaf node - nodeName = components[i] - childNode = tview.NewTreeNode(nodeName). - SetExpanded(true). - SetSelectable(true) - currentNode.AddChild(childNode) - } + if firstLeaf { + searchTreePanel.SetCurrentNode(childNode) + firstLeaf = false + } - app.Draw() + searchCache.Add(entry.DN, entry) + } else { + // Non-leaf node + nodeName = components[i] + childNode = tview.NewTreeNode(nodeName). + SetExpanded(true). + SetSelectable(true) + currentNode.AddChild(childNode) + } + }) searchLoadedDNs[partialDN] = childNode } @@ -318,9 +318,9 @@ func searchQueryDoneHandler(key tcell.Key) { } } - updateLog("Query completed ("+strconv.Itoa(len(entries))+" objects found)", "green") - - app.Draw() + app.QueueUpdateDraw(func() { + updateLog("Query completed ("+strconv.Itoa(len(entries))+" objects found)", "green") + }) runControl.Lock() running = false diff --git a/tui/tree.go b/tui/tree.go index d02fa6c..5ddc48e 100644 --- a/tui/tree.go +++ b/tui/tree.go @@ -190,7 +190,7 @@ func handleAttrsKeyCtrlE(currentNode *tview.TreeNode, attrsPanel *tview.Table, c SetButtonBackgroundColor(formButtonBackgroundColor). SetButtonTextColor(formButtonTextColor). SetButtonActivatedStyle(formButtonActivatedStyle) - writeAttrValsForm.SetInputCapture(handleEscapeToTree) + writeAttrValsForm.SetInputCapture(handleEscape(treePanel)) writeAttrValsForm.SetTitle("Attribute Editor").SetBorder(true) app.SetRoot(writeAttrValsForm, true).SetFocus(writeAttrValsForm) } @@ -230,7 +230,7 @@ func handleAttrsKeyCtrlN(currentNode *tview.TreeNode, attrsPanel *tview.Table, c SetButtonBackgroundColor(formButtonBackgroundColor). SetButtonTextColor(formButtonTextColor). SetButtonActivatedStyle(formButtonActivatedStyle) - createAttrForm.SetInputCapture(handleEscapeToTree) + createAttrForm.SetInputCapture(handleEscape(treePanel)) baseDN := currentNode.GetReference().(string)