Skip to content

Commit

Permalink
Fix host uniqueness mysql (#525)
Browse files Browse the repository at this point in the history
  • Loading branch information
vinnybod authored Jan 10, 2023
1 parent a86df09 commit aae6aba
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 21 deletions.
47 changes: 34 additions & 13 deletions empire/server/core/db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import sqlite3

from sqlalchemy import create_engine, event, text
from sqlalchemy import UniqueConstraint, create_engine, event, text
from sqlalchemy.engine import Engine
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import sessionmaker
Expand Down Expand Up @@ -43,11 +43,16 @@ def try_create_engine(engine_url: str, *args, **kwargs) -> Engine:
return engine


database_config = empire_config.database
def reset_db():
SessionLocal.close_all()
Base.metadata.drop_all(engine)
if use == "sqlite":
os.unlink(database_config.location)


database_config = empire_config.database
use = os.environ.get("DATABASE_USE", database_config.use)
database_config.use = use

database_config = database_config[use.lower()]

if use == "mysql":
Expand All @@ -66,24 +71,40 @@ def try_create_engine(engine_url: str, *args, **kwargs) -> Engine:
f"sqlite:///{location}",
connect_args={
"check_same_thread": False,
# "timeout": 3000
},
echo=False,
)

SessionLocal = sessionmaker(bind=engine)
models.Host.__table_args__ = (
UniqueConstraint(
models.Host.name, models.Host.internal_ip, name="host_unique_idx"
),
)

SessionLocal = sessionmaker(bind=engine)
Base.metadata.create_all(engine)


def reset_db():
SessionLocal.close_all()
Base.metadata.drop_all(engine)
if use == "sqlite":
os.unlink(database_config.location)


with SessionLocal.begin() as db:
if use == "mysql":
database_name = database_config.database_name
query = text(
"""
SELECT * FROM information_schema.statistics
WHERE table_schema = :schema
AND table_name = :table
AND index_name = :index
"""
)
result = engine.execute(
query, schema=database_name, table="hosts", index="host_unique_idx"
)
if not result.fetchone():
db.execute(
"""
CREATE UNIQUE INDEX host_unique_idx ON hosts ((md5(concat(name, internal_ip))))
"""
)

# When Empire starts up for the first time, it will create the database and create
# these default records.
if len(db.query(models.User).all()) == 0:
Expand Down
7 changes: 2 additions & 5 deletions empire/server/core/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
String,
Table,
Text,
UniqueConstraint,
func,
)
from sqlalchemy.dialects import mysql
Expand Down Expand Up @@ -99,10 +98,8 @@ def __repr__(self):
class Host(Base):
__tablename__ = "hosts"
id = Column(Integer, Sequence("host_id_seq"), primary_key=True)
name = Column(String(255), nullable=False)
internal_ip = Column(String(255))

UniqueConstraint(name, internal_ip)
name = Column(Text, nullable=False)
internal_ip = Column(Text)


class Agent(Base):
Expand Down
28 changes: 26 additions & 2 deletions empire/test/test_agents.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,30 @@
from datetime import datetime, timedelta, timezone

import pytest
from sqlalchemy.exc import IntegrityError


@pytest.fixture(scope="module", autouse=True)
def agent(db, models):
def host(db, models):
hosts = db.query(models.Host).all()
if len(hosts) == 0:
host = models.Host(name="default_host", internal_ip="127.0.0.1")
else:
host = hosts[0]

yield host


@pytest.fixture(scope="module", autouse=True)
def agent(db, models, host):
agents = db.query(models.Agent).all()
if len(agents) == 0:
agent = models.Agent(
name="TEST123",
session_id="TEST123",
delay=60,
jitter=0.1,
internal_ip=host.internal_ip,
external_ip="1.1.1.1",
session_key="qwerty",
nonce="nonce",
Expand All @@ -41,6 +48,7 @@ def agent(db, models):
session_id="SECOND",
delay=60,
jitter=0.1,
internal_ip=host.internal_ip,
external_ip="1.1.1.1",
session_key="qwerty",
nonce="nonce",
Expand All @@ -64,6 +72,7 @@ def agent(db, models):
session_id="archived",
delay=60,
jitter=0.1,
internal_ip=host.internal_ip,
external_ip="1.1.1.1",
session_key="qwerty",
nonce="nonce",
Expand All @@ -87,6 +96,7 @@ def agent(db, models):
session_id="STALE",
delay=1,
jitter=0.1,
internal_ip=host.internal_ip,
external_ip="1.1.1.1",
session_key="qwerty",
nonce="nonce",
Expand Down Expand Up @@ -150,9 +160,23 @@ def test_stale_expression(empire_config):
assert len(not_stale) == 3


def test_large_internal_ip_works(db, agent):
def test_large_internal_ip_works(db, agent, host):
agent1 = agent[0]

agent1.internal_ip = "192.168.1.75 fe90::51e7:5dc7:be5d:b22e 3600:1900:7bb0:90d0:4d3c:2cd6:3fe:883b 5600:1900:3aa0:80d1:18a4:4431:5023:eef7 6600:1500:1aa0:20d0:fd69:26ff:5c4c:8d27 2900:2700:4aa0:80d0::47 192.168.214.1 fe90::a24c:82de:578b:8626 192.168.245.1 fe00::f321:a1e:18d3:ab9"

db.flush()

host.internal_ip = agent1.internal_ip

db.flush()


def test_duplicate_host(db, models, host):
with pytest.raises(IntegrityError):
host2 = models.Host(name=host.name, internal_ip=host.internal_ip)

db.add(host2)
db.flush()

db.rollback()
1 change: 0 additions & 1 deletion empire/test/test_server_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ database:
keyword_obfuscation:
- Invoke-Empire
- Invoke-Mimikatz
# TODO VR: Rename or remove
# an IP white list to ONLY accept clients from
# format is "192.168.1.1,192.168.1.10-192.168.1.100,10.0.0.0/8"
ip-whitelist: ""
Expand Down

0 comments on commit aae6aba

Please sign in to comment.