diff --git a/stellarisdashboard/dashboard_app/graph_ledger.py b/stellarisdashboard/dashboard_app/graph_ledger.py index 6fd498d..5467110 100644 --- a/stellarisdashboard/dashboard_app/graph_ledger.py +++ b/stellarisdashboard/dashboard_app/graph_ledger.py @@ -364,7 +364,7 @@ def update_content( if not plot_spec: continue # just in case it's possible to sneak in an invalid ID start = time.time() - figure_data = get_raw_plot_data_dicts(plot_data, plot_spec) + figure_data = get_raw_plot_data_dicts(game_id, plot_data, plot_spec) end = time.time() logger.debug( f"Prepared figure {plot_spec.title} in {end - start:5.3f} seconds." @@ -422,6 +422,7 @@ def _get_game_ids_matching_url(url): def get_raw_plot_data_dicts( + game_id: str, plot_data: visualization_data.PlotDataManager, plot_spec: visualization_data.PlotSpecification, ) -> List[Dict[str, Any]]: @@ -434,18 +435,19 @@ def get_raw_plot_data_dicts( :return: """ if plot_spec.style == visualization_data.PlotStyle.line: - return _get_raw_data_for_line_plot(plot_data, plot_spec) + return _get_raw_data_for_line_plot(game_id, plot_data, plot_spec) elif plot_spec.style in [ visualization_data.PlotStyle.stacked, visualization_data.PlotStyle.budget, ]: - return _get_raw_data_for_stacked_and_budget_plots(plot_data, plot_spec) + return _get_raw_data_for_stacked_and_budget_plots(game_id, plot_data, plot_spec) else: logger.warning(f"Unknown Plot type {plot_spec}") return [] def _get_raw_data_for_line_plot( + game_id: str, plot_data: visualization_data.PlotDataManager, plot_spec: visualization_data.PlotSpecification, ) -> List[Dict[str, Any]]: @@ -459,7 +461,7 @@ def _get_raw_data_for_line_plot( x=x_values, y=y_values, name=dict_key_to_legend_label(key), - line={"color": get_country_color(key, 1.0)}, + line={"color": get_country_color(game_id, key, 1.0)}, text=get_plot_value_labels(x_values, y_values, key), hoverinfo="text", ) @@ -468,6 +470,7 @@ def _get_raw_data_for_line_plot( def _get_raw_data_for_stacked_and_budget_plots( + game_id: str, plot_data: visualization_data.PlotDataManager, plot_spec: visualization_data.PlotSpecification, ) -> List[Dict[str, Any]]: @@ -514,10 +517,10 @@ def _get_raw_data_for_stacked_and_budget_plots( legendgroup=key, # ensure that budget contributions with mixed signs still behave as a single entry hoverinfo="text", mode="lines", - line=dict(width=0.5, color=get_country_color(key, 1.0)), + line=dict(width=0.5, color=get_country_color(game_id, key, 1.0)), stackgroup=stackgroup, groupnorm="percent" if normalized else "", - fillcolor=get_country_color(key, 0.5), + fillcolor=get_country_color(game_id, key, 0.5), text=get_plot_value_labels(x_values, yv, key), showlegend=i == 0, # only show one legend entry ) @@ -574,7 +577,7 @@ def get_galaxy(game_id: str, slider_date: float) -> dcc.Graph: x=[], y=[], text=[], - line=go.scatter.Line(width=0.5, color=get_country_color(country)), + line=go.scatter.Line(width=0.5, color=get_country_color(game_id, country)), hoverinfo="text", mode="lines", showlegend=False, @@ -611,7 +614,7 @@ def get_galaxy(game_id: str, slider_date: float) -> dcc.Graph: marker=dict(color=[], size=4), name=country, ) - color = get_country_color(country) + color = get_country_color(game_id, country) country_system_markers[country]["marker"]["color"].append(color) x, y = nx_graph.nodes[node]["pos"] country_system_markers[country]["x"].append(x) @@ -712,10 +715,10 @@ def get_galaxy(game_id: str, slider_date: float) -> dcc.Graph: ) -def get_country_color(country_name: str, alpha: float = 1.0) -> str: +def get_country_color(game_id: int, country_name: str, alpha: float = 1.0) -> str: alpha = min(alpha, 1) alpha = max(alpha, 0) - r, g, b = visualization_data.get_color_vals(country_name) + r, g, b = visualization_data.get_color_vals(game_id, country_name) color = f"rgba({r},{g},{b},{alpha})" return color diff --git a/stellarisdashboard/dashboard_app/timelapse_exporter.py b/stellarisdashboard/dashboard_app/timelapse_exporter.py index d60b983..dfea848 100644 --- a/stellarisdashboard/dashboard_app/timelapse_exporter.py +++ b/stellarisdashboard/dashboard_app/timelapse_exporter.py @@ -121,7 +121,7 @@ def _timelapse_id(self): return f"{self.game_id}-{ts}" def rgb(self, name: str): - r, g, b = get_color_vals(name) + r, g, b = get_color_vals(self.game_id, name) return r / 255.0, g / 255.0, b / 255.0 def draw_frame( diff --git a/stellarisdashboard/dashboard_app/visualization_data.py b/stellarisdashboard/dashboard_app/visualization_data.py index 706fbcd..6619944 100644 --- a/stellarisdashboard/dashboard_app/visualization_data.py +++ b/stellarisdashboard/dashboard_app/visualization_data.py @@ -3,8 +3,11 @@ import dataclasses import enum import functools +import itertools import logging +import pathlib import random +import re import time from collections import defaultdict from typing import List, Dict, Callable, Tuple, Iterable, Union, Set, Optional, Any @@ -14,6 +17,7 @@ from scipy.spatial import Voronoi from stellarisdashboard import datamodel, config, game_info +from stellarisdashboard.parsing.save_parser import rust_parser logger = logging.getLogger(__name__) @@ -92,8 +96,15 @@ def get_current_execution_plot_data( return _CURRENT_EXECUTION_PLOT_DATA[game_name] +_GAME_COUNTRY_COLORS = {} + + +def clear_cached_country_colors(): + _GAME_COUNTRY_COLORS.clear() + + def get_color_vals( - key_str: str, range_min: float = 0.1, range_max: float = 1.0 + game_id: str, key_str: str, range_min: float = 0.1, range_max: float = 1.0 ) -> Tuple[float, float, float]: """Generate RGB values for the given identifier. Some special values (tech categories) have hardcoded colors to roughly match the game's look and feel. @@ -101,6 +112,13 @@ def get_color_vals( For unknown identifiers, a random color is generated, with the key_str being applied as a seed to the random number generator. This makes colors consistent across figures and executions. """ + if game_id not in _GAME_COUNTRY_COLORS: + country_colors = CountryColors() + country_colors.load(game_id) + _GAME_COUNTRY_COLORS[game_id] = country_colors + else: + country_colors = _GAME_COUNTRY_COLORS[game_id] + if key_str.lower() == "physics": r, g, b = COLOR_PHYSICS elif key_str.lower() == "society": @@ -113,6 +131,8 @@ def get_color_vals( r, g, b = 255, 0, 0 elif key_str.endswith("internal_market"): r, g, b = 0, 0, 255 + elif country_colors.has_color_for_name(key_str): + r, g, b = country_colors.get_color_by_name(key_str) else: random.seed(key_str) h = random.uniform(0, 1) @@ -1679,3 +1699,94 @@ def _country_display_name(self, country: datamodel.Country) -> str: if not country.has_met_player(): return GalaxyMapData.UNCLAIMED return country.rendered_name + + +_MIN_V = 0.4 +_MAX_V = 1.0 +_V_SHIFTS = [ + 0.2, + 0.4, + 0.6, +] +_GRAYSCALE_S = 0.2 + + +class CountryColors: + def __init__(self): + self.map_colors: Dict[str, tuple] = {} + self.country_name_to_primary_color: Dict[str, tuple] = {} + self._used_hsv = set() + + def load(self, game_id: str): + for game_data_dir in reversed(config.CONFIG.game_data_dirs): + colors_path = game_data_dir / "flags/colors.txt" + if colors_path.exists(): + self.map_colors = self._parse_map_colors(colors_path) + break + + with datamodel.get_db_session(game_id) as session: + for c in session.query(datamodel.Country).order_by("country_id_in_game"): + name = c.rendered_name + # avoid grayscale colors if possible; they can be hard to distinguish + color = ( + c.secondary_color + if self._is_grayscale_color(c.primary_color) + and not self._is_grayscale_color(c.secondary_color) + else c.primary_color + ) + # try to avoid duplicate colors + # don't bother for non-default-or-fallen-empire countries, so that the "real" countries are more likely to get their color + rgb = self._get_rgb(color, avoid_used=c.is_real_country()) + self.country_name_to_primary_color[name] = rgb + + def _get_rgb(self, key: str, avoid_used: bool): + r, g, b = self.map_colors.get(key, (0, 0, 0)) + h, s, v = colorsys.rgb_to_hsv(r / 255.0, g / 255.0, b / 255.0) + v = min(v, _MAX_V) + v = max(v, _MIN_V) + if avoid_used: + for shift in itertools.chain([0.0], *((v, -v) for v in _V_SHIFTS)): + v_shifted = s + shift + if ( + v_shifted >= _MIN_V + and v_shifted <= _MAX_V + and self._round_hsv(h, s, v_shifted) not in self._used_hsv + ): + v = v_shifted + break + self._used_hsv.add(self._round_hsv(h, s, v)) + return tuple(int(v * 255) for v in colorsys.hsv_to_rgb(h, s, v)) + + @staticmethod + def _round_hsv(h: float, s: float, v: float): + return round(h, 1), round(s, 1), round(v, 1) + + def _is_grayscale_color(self, key: str): + r, g, b = self.map_colors.get(key, (0, 0, 0)) + _, s, _ = colorsys.rgb_to_hsv(r / 255.0, g / 255.0, b / 255.0) + return s < _GRAYSCALE_S + + def has_color_for_name(self, country_name: str): + return country_name in self.country_name_to_primary_color + + def get_color_by_name(self, country_name: str): + return self.country_name_to_primary_color[country_name] + + @staticmethod + def _parse_map_colors(path: pathlib.Path): + with open(path, "r") as f: + prepared_str = re.sub("#[^\n]*", "", f.read()) # strip out comments + data = rust_parser.parse_save_from_string(prepared_str) + + colors = data.get("colors", {}) + map_colors = {} + for key in colors: + space, v1, v2, v3 = colors[key]["map"] + if space == "rgb": + map_colors[key] = (int(v1), int(v2), int(v3)) + elif space == "hsv": + rgb = tuple(int(v * 255) for v in colorsys.hsv_to_rgb(v1, v2, v3)) + map_colors[key] = rgb + else: + raise RuntimeError(f"Unexpected color space: {space}") + return map_colors diff --git a/stellarisdashboard/datamodel.py b/stellarisdashboard/datamodel.py index ce5c5d0..f7cbbf0 100644 --- a/stellarisdashboard/datamodel.py +++ b/stellarisdashboard/datamodel.py @@ -615,6 +615,8 @@ class Country(Base): country_id_in_game = Column(Integer) first_player_contact_date = Column(Integer) country_type = Column(String(50)) + primary_color = Column(String(80)) + secondary_color = Column(String(80)) game = relationship("Game", back_populates="countries") capital = relationship("Planet", foreign_keys=[capital_planet_id], post_update=True) diff --git a/stellarisdashboard/parsing/rust_parser/src/parser.rs b/stellarisdashboard/parsing/rust_parser/src/parser.rs index 46ba107..ba466b9 100644 --- a/stellarisdashboard/parsing/rust_parser/src/parser.rs +++ b/stellarisdashboard/parsing/rust_parser/src/parser.rs @@ -27,6 +27,7 @@ pub enum Value<'a> { Float(f64), List(Vec>), Map(HashMap<&'a str, Value<'a>>), + Color((&'a str, f64, f64, f64)), } impl ToPyObject for Value<'_> { @@ -49,6 +50,7 @@ impl ToPyObject for Value<'_> { } key_vals.into_py_dict(py).to_object(py) } + Value::Color(color_tuple) => color_tuple.to_object(py), } } } @@ -61,6 +63,7 @@ impl Display for Value<'_> { Value::Float(x) => write!(f, "{}", x), Value::List(vec) => write!(f, "[{:?}]", vec), Value::Map(hm) => write!(f, "{:?}", hm), + Value::Color((space, v1, v2, v3)) => write!(f, "{} {{ {} {} {} }}", space, v1, v2, v3), } } } @@ -119,6 +122,7 @@ fn parse_value(input: &str) -> IResult<&str, Value> { context("int", map(parse_int, Value::Int)), context("float", map(parse_float, Value::Float)), context("str", map(parse_str, Value::Str)), + context("color", map(parse_color, Value::Color)), context("unquoted_str", map(parse_unquoted_str, Value::Str)), context("list", map(parse_list, Value::List)), context("map", map(parse_map, Value::Map)), @@ -255,6 +259,14 @@ fn parse_unquoted_str(input: &str) -> IResult<&str, &str> { )(input) } +fn parse_color(input: &str) -> IResult<&str, (&str, f64, f64, f64)> { + let (input, color_space) = preceded(multispace0, alt((tag("rgb"), tag("hsv"))))(input)?; + let (input, _) = preceded(multispace0, tag("{"))(input)?; + let (input, (v1, v2, v3)) = tuple((parse_float, parse_float, parse_float))(input)?; + let (input, _) = preceded(multispace0, tag("}"))(input)?; + Ok((input, (color_space, v1, v2, v3))) +} + #[allow(dead_code)] pub fn debug_str(input: &str) -> () { let prefix = &input[..min(input.len(), 150)]; @@ -741,6 +753,13 @@ mod tests { ("gender", Value::Str("not_set")), ("trait", Value::Str("trait_resilient")), ])) + ); + + assert_eq!( + parse_file("color = rgb { 1 2 3 }").unwrap(), + Value::Map(HashMap::from([ + ("color", Value::Color(("rgb", 1.0, 2.0, 3.0))) + ])) ) } } \ No newline at end of file diff --git a/stellarisdashboard/parsing/timeline.py b/stellarisdashboard/parsing/timeline.py index 52213d5..1b04176 100644 --- a/stellarisdashboard/parsing/timeline.py +++ b/stellarisdashboard/parsing/timeline.py @@ -12,6 +12,7 @@ import sqlalchemy from stellarisdashboard import datamodel, game_info, config +from stellarisdashboard.dashboard_app.visualization_data import clear_cached_country_colors logger = logging.getLogger(__name__) @@ -415,11 +416,18 @@ def extract_data_from_gamestate(self, dependencies): continue country_type = country_data_dict.get("type") country_name = dump_name(country_data_dict.get("name", "no name")) + flag_colors = country_data_dict.get("flag", {}).get("colors", []) + primary_color = flag_colors[0] if len(flag_colors) >= 1 else "black" + secondary_color = flag_colors[1] if len(flag_colors) >= 2 else primary_color country_model = ( self._session.query(datamodel.Country) .filter_by(game=self._db_game, country_id_in_game=country_id) .one_or_none() ) + + if country_model is None or primary_color != country_model.primary_color or secondary_color != country_model.secondary_color: + clear_cached_country_colors() + if country_model is None: country_model = datamodel.Country( is_player=(country_id == self._basic_info.player_country_id), @@ -428,6 +436,8 @@ def extract_data_from_gamestate(self, dependencies): game=self._db_game, country_type=country_type, country_name=country_name, + primary_color=primary_color, + secondary_color=secondary_color, ) if country_id == self._basic_info.player_country_id: country_model.first_player_contact_date = 0 @@ -435,6 +445,8 @@ def extract_data_from_gamestate(self, dependencies): if ( country_name != country_model.country_name or country_type != country_model.country_type + or primary_color != country_model.primary_color + or secondary_color != country_model.secondary_color ): country_model.country_name = country_name country_model.country_type = country_type