Skip to content

Commit

Permalink
Use country colors for graphs and maps; avoid duplicates (#117)
Browse files Browse the repository at this point in the history
* Use country colors for graphs and maps; avoid duplicates

* Country color tweaks
  • Loading branch information
MichaelMakesGames authored Nov 12, 2023
1 parent b2998d8 commit 00b4f03
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 12 deletions.
23 changes: 13 additions & 10 deletions stellarisdashboard/dashboard_app/graph_ledger.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def update_content(
if not plot_spec:
continue # just in case it's possible to sneak in an invalid ID
start = time.time()
figure_data = get_raw_plot_data_dicts(plot_data, plot_spec)
figure_data = get_raw_plot_data_dicts(game_id, plot_data, plot_spec)
end = time.time()
logger.debug(
f"Prepared figure {plot_spec.title} in {end - start:5.3f} seconds."
Expand Down Expand Up @@ -422,6 +422,7 @@ def _get_game_ids_matching_url(url):


def get_raw_plot_data_dicts(
game_id: str,
plot_data: visualization_data.PlotDataManager,
plot_spec: visualization_data.PlotSpecification,
) -> List[Dict[str, Any]]:
Expand All @@ -434,18 +435,19 @@ def get_raw_plot_data_dicts(
:return:
"""
if plot_spec.style == visualization_data.PlotStyle.line:
return _get_raw_data_for_line_plot(plot_data, plot_spec)
return _get_raw_data_for_line_plot(game_id, plot_data, plot_spec)
elif plot_spec.style in [
visualization_data.PlotStyle.stacked,
visualization_data.PlotStyle.budget,
]:
return _get_raw_data_for_stacked_and_budget_plots(plot_data, plot_spec)
return _get_raw_data_for_stacked_and_budget_plots(game_id, plot_data, plot_spec)
else:
logger.warning(f"Unknown Plot type {plot_spec}")
return []


def _get_raw_data_for_line_plot(
game_id: str,
plot_data: visualization_data.PlotDataManager,
plot_spec: visualization_data.PlotSpecification,
) -> List[Dict[str, Any]]:
Expand All @@ -459,7 +461,7 @@ def _get_raw_data_for_line_plot(
x=x_values,
y=y_values,
name=dict_key_to_legend_label(key),
line={"color": get_country_color(key, 1.0)},
line={"color": get_country_color(game_id, key, 1.0)},
text=get_plot_value_labels(x_values, y_values, key),
hoverinfo="text",
)
Expand All @@ -468,6 +470,7 @@ def _get_raw_data_for_line_plot(


def _get_raw_data_for_stacked_and_budget_plots(
game_id: str,
plot_data: visualization_data.PlotDataManager,
plot_spec: visualization_data.PlotSpecification,
) -> List[Dict[str, Any]]:
Expand Down Expand Up @@ -514,10 +517,10 @@ def _get_raw_data_for_stacked_and_budget_plots(
legendgroup=key, # ensure that budget contributions with mixed signs still behave as a single entry
hoverinfo="text",
mode="lines",
line=dict(width=0.5, color=get_country_color(key, 1.0)),
line=dict(width=0.5, color=get_country_color(game_id, key, 1.0)),
stackgroup=stackgroup,
groupnorm="percent" if normalized else "",
fillcolor=get_country_color(key, 0.5),
fillcolor=get_country_color(game_id, key, 0.5),
text=get_plot_value_labels(x_values, yv, key),
showlegend=i == 0, # only show one legend entry
)
Expand Down Expand Up @@ -574,7 +577,7 @@ def get_galaxy(game_id: str, slider_date: float) -> dcc.Graph:
x=[],
y=[],
text=[],
line=go.scatter.Line(width=0.5, color=get_country_color(country)),
line=go.scatter.Line(width=0.5, color=get_country_color(game_id, country)),
hoverinfo="text",
mode="lines",
showlegend=False,
Expand Down Expand Up @@ -611,7 +614,7 @@ def get_galaxy(game_id: str, slider_date: float) -> dcc.Graph:
marker=dict(color=[], size=4),
name=country,
)
color = get_country_color(country)
color = get_country_color(game_id, country)
country_system_markers[country]["marker"]["color"].append(color)
x, y = nx_graph.nodes[node]["pos"]
country_system_markers[country]["x"].append(x)
Expand Down Expand Up @@ -712,10 +715,10 @@ def get_galaxy(game_id: str, slider_date: float) -> dcc.Graph:
)


def get_country_color(country_name: str, alpha: float = 1.0) -> str:
def get_country_color(game_id: int, country_name: str, alpha: float = 1.0) -> str:
alpha = min(alpha, 1)
alpha = max(alpha, 0)
r, g, b = visualization_data.get_color_vals(country_name)
r, g, b = visualization_data.get_color_vals(game_id, country_name)
color = f"rgba({r},{g},{b},{alpha})"
return color

Expand Down
2 changes: 1 addition & 1 deletion stellarisdashboard/dashboard_app/timelapse_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _timelapse_id(self):
return f"{self.game_id}-{ts}"

def rgb(self, name: str):
r, g, b = get_color_vals(name)
r, g, b = get_color_vals(self.game_id, name)
return r / 255.0, g / 255.0, b / 255.0

def draw_frame(
Expand Down
113 changes: 112 additions & 1 deletion stellarisdashboard/dashboard_app/visualization_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
import dataclasses
import enum
import functools
import itertools
import logging
import pathlib
import random
import re
import time
from collections import defaultdict
from typing import List, Dict, Callable, Tuple, Iterable, Union, Set, Optional, Any
Expand All @@ -14,6 +17,7 @@
from scipy.spatial import Voronoi

from stellarisdashboard import datamodel, config, game_info
from stellarisdashboard.parsing.save_parser import rust_parser

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -92,15 +96,29 @@ def get_current_execution_plot_data(
return _CURRENT_EXECUTION_PLOT_DATA[game_name]


_GAME_COUNTRY_COLORS = {}


def clear_cached_country_colors():
_GAME_COUNTRY_COLORS.clear()


def get_color_vals(
key_str: str, range_min: float = 0.1, range_max: float = 1.0
game_id: str, key_str: str, range_min: float = 0.1, range_max: float = 1.0
) -> Tuple[float, float, float]:
"""Generate RGB values for the given identifier. Some special values (tech categories)
have hardcoded colors to roughly match the game's look and feel.
For unknown identifiers, a random color is generated, with the key_str being applied as a seed to
the random number generator. This makes colors consistent across figures and executions.
"""
if game_id not in _GAME_COUNTRY_COLORS:
country_colors = CountryColors()
country_colors.load(game_id)
_GAME_COUNTRY_COLORS[game_id] = country_colors
else:
country_colors = _GAME_COUNTRY_COLORS[game_id]

if key_str.lower() == "physics":
r, g, b = COLOR_PHYSICS
elif key_str.lower() == "society":
Expand All @@ -113,6 +131,8 @@ def get_color_vals(
r, g, b = 255, 0, 0
elif key_str.endswith("internal_market"):
r, g, b = 0, 0, 255
elif country_colors.has_color_for_name(key_str):
r, g, b = country_colors.get_color_by_name(key_str)
else:
random.seed(key_str)
h = random.uniform(0, 1)
Expand Down Expand Up @@ -1679,3 +1699,94 @@ def _country_display_name(self, country: datamodel.Country) -> str:
if not country.has_met_player():
return GalaxyMapData.UNCLAIMED
return country.rendered_name


_MIN_V = 0.4
_MAX_V = 1.0
_V_SHIFTS = [
0.2,
0.4,
0.6,
]
_GRAYSCALE_S = 0.2


class CountryColors:
def __init__(self):
self.map_colors: Dict[str, tuple] = {}
self.country_name_to_primary_color: Dict[str, tuple] = {}
self._used_hsv = set()

def load(self, game_id: str):
for game_data_dir in reversed(config.CONFIG.game_data_dirs):
colors_path = game_data_dir / "flags/colors.txt"
if colors_path.exists():
self.map_colors = self._parse_map_colors(colors_path)
break

with datamodel.get_db_session(game_id) as session:
for c in session.query(datamodel.Country).order_by("country_id_in_game"):
name = c.rendered_name
# avoid grayscale colors if possible; they can be hard to distinguish
color = (
c.secondary_color
if self._is_grayscale_color(c.primary_color)
and not self._is_grayscale_color(c.secondary_color)
else c.primary_color
)
# try to avoid duplicate colors
# don't bother for non-default-or-fallen-empire countries, so that the "real" countries are more likely to get their color
rgb = self._get_rgb(color, avoid_used=c.is_real_country())
self.country_name_to_primary_color[name] = rgb

def _get_rgb(self, key: str, avoid_used: bool):
r, g, b = self.map_colors.get(key, (0, 0, 0))
h, s, v = colorsys.rgb_to_hsv(r / 255.0, g / 255.0, b / 255.0)
v = min(v, _MAX_V)
v = max(v, _MIN_V)
if avoid_used:
for shift in itertools.chain([0.0], *((v, -v) for v in _V_SHIFTS)):
v_shifted = s + shift
if (
v_shifted >= _MIN_V
and v_shifted <= _MAX_V
and self._round_hsv(h, s, v_shifted) not in self._used_hsv
):
v = v_shifted
break
self._used_hsv.add(self._round_hsv(h, s, v))
return tuple(int(v * 255) for v in colorsys.hsv_to_rgb(h, s, v))

@staticmethod
def _round_hsv(h: float, s: float, v: float):
return round(h, 1), round(s, 1), round(v, 1)

def _is_grayscale_color(self, key: str):
r, g, b = self.map_colors.get(key, (0, 0, 0))
_, s, _ = colorsys.rgb_to_hsv(r / 255.0, g / 255.0, b / 255.0)
return s < _GRAYSCALE_S

def has_color_for_name(self, country_name: str):
return country_name in self.country_name_to_primary_color

def get_color_by_name(self, country_name: str):
return self.country_name_to_primary_color[country_name]

@staticmethod
def _parse_map_colors(path: pathlib.Path):
with open(path, "r") as f:
prepared_str = re.sub("#[^\n]*", "", f.read()) # strip out comments
data = rust_parser.parse_save_from_string(prepared_str)

colors = data.get("colors", {})
map_colors = {}
for key in colors:
space, v1, v2, v3 = colors[key]["map"]
if space == "rgb":
map_colors[key] = (int(v1), int(v2), int(v3))
elif space == "hsv":
rgb = tuple(int(v * 255) for v in colorsys.hsv_to_rgb(v1, v2, v3))
map_colors[key] = rgb
else:
raise RuntimeError(f"Unexpected color space: {space}")
return map_colors
2 changes: 2 additions & 0 deletions stellarisdashboard/datamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,8 @@ class Country(Base):
country_id_in_game = Column(Integer)
first_player_contact_date = Column(Integer)
country_type = Column(String(50))
primary_color = Column(String(80))
secondary_color = Column(String(80))

game = relationship("Game", back_populates="countries")
capital = relationship("Planet", foreign_keys=[capital_planet_id], post_update=True)
Expand Down
19 changes: 19 additions & 0 deletions stellarisdashboard/parsing/rust_parser/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub enum Value<'a> {
Float(f64),
List(Vec<Value<'a>>),
Map(HashMap<&'a str, Value<'a>>),
Color((&'a str, f64, f64, f64)),
}

impl ToPyObject for Value<'_> {
Expand All @@ -49,6 +50,7 @@ impl ToPyObject for Value<'_> {
}
key_vals.into_py_dict(py).to_object(py)
}
Value::Color(color_tuple) => color_tuple.to_object(py),
}
}
}
Expand All @@ -61,6 +63,7 @@ impl Display for Value<'_> {
Value::Float(x) => write!(f, "{}", x),
Value::List(vec) => write!(f, "[{:?}]", vec),
Value::Map(hm) => write!(f, "{:?}", hm),
Value::Color((space, v1, v2, v3)) => write!(f, "{} {{ {} {} {} }}", space, v1, v2, v3),
}
}
}
Expand Down Expand Up @@ -119,6 +122,7 @@ fn parse_value(input: &str) -> IResult<&str, Value> {
context("int", map(parse_int, Value::Int)),
context("float", map(parse_float, Value::Float)),
context("str", map(parse_str, Value::Str)),
context("color", map(parse_color, Value::Color)),
context("unquoted_str", map(parse_unquoted_str, Value::Str)),
context("list", map(parse_list, Value::List)),
context("map", map(parse_map, Value::Map)),
Expand Down Expand Up @@ -255,6 +259,14 @@ fn parse_unquoted_str(input: &str) -> IResult<&str, &str> {
)(input)
}

fn parse_color(input: &str) -> IResult<&str, (&str, f64, f64, f64)> {
let (input, color_space) = preceded(multispace0, alt((tag("rgb"), tag("hsv"))))(input)?;
let (input, _) = preceded(multispace0, tag("{"))(input)?;
let (input, (v1, v2, v3)) = tuple((parse_float, parse_float, parse_float))(input)?;
let (input, _) = preceded(multispace0, tag("}"))(input)?;
Ok((input, (color_space, v1, v2, v3)))
}

#[allow(dead_code)]
pub fn debug_str(input: &str) -> () {
let prefix = &input[..min(input.len(), 150)];
Expand Down Expand Up @@ -741,6 +753,13 @@ mod tests {
("gender", Value::Str("not_set")),
("trait", Value::Str("trait_resilient")),
]))
);

assert_eq!(
parse_file("color = rgb { 1 2 3 }").unwrap(),
Value::Map(HashMap::from([
("color", Value::Color(("rgb", 1.0, 2.0, 3.0)))
]))
)
}
}
12 changes: 12 additions & 0 deletions stellarisdashboard/parsing/timeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import sqlalchemy

from stellarisdashboard import datamodel, game_info, config
from stellarisdashboard.dashboard_app.visualization_data import clear_cached_country_colors

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -415,11 +416,18 @@ def extract_data_from_gamestate(self, dependencies):
continue
country_type = country_data_dict.get("type")
country_name = dump_name(country_data_dict.get("name", "no name"))
flag_colors = country_data_dict.get("flag", {}).get("colors", [])
primary_color = flag_colors[0] if len(flag_colors) >= 1 else "black"
secondary_color = flag_colors[1] if len(flag_colors) >= 2 else primary_color
country_model = (
self._session.query(datamodel.Country)
.filter_by(game=self._db_game, country_id_in_game=country_id)
.one_or_none()
)

if country_model is None or primary_color != country_model.primary_color or secondary_color != country_model.secondary_color:
clear_cached_country_colors()

if country_model is None:
country_model = datamodel.Country(
is_player=(country_id == self._basic_info.player_country_id),
Expand All @@ -428,13 +436,17 @@ def extract_data_from_gamestate(self, dependencies):
game=self._db_game,
country_type=country_type,
country_name=country_name,
primary_color=primary_color,
secondary_color=secondary_color,
)
if country_id == self._basic_info.player_country_id:
country_model.first_player_contact_date = 0
self._session.add(country_model)
if (
country_name != country_model.country_name
or country_type != country_model.country_type
or primary_color != country_model.primary_color
or secondary_color != country_model.secondary_color
):
country_model.country_name = country_name
country_model.country_type = country_type
Expand Down

0 comments on commit 00b4f03

Please sign in to comment.