Skip to content

Commit

Permalink
Add mypy static type checker (#96)
Browse files Browse the repository at this point in the history
Co-authored-by: NP4567-dev <[email protected]>
  • Loading branch information
samuelrince and NP4567-dev authored Nov 29, 2024
1 parent 97da113 commit 08a3703
Show file tree
Hide file tree
Showing 22 changed files with 459 additions and 454 deletions.
6 changes: 5 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,8 @@ repos:
- repo: https://github.com/Lucas-C/pre-commit-hooks-safety
rev: v1.3.1
hooks:
- id: python-safety-dependencies-check
- id: python-safety-dependencies-check
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.13.0'
hooks:
- id: mypy
4 changes: 2 additions & 2 deletions ecologits/_ecologits.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def init_mistralai_instrumentor() -> None:
logger.warning("MistralAI client v0.*.* will soon no longer be supported by EcoLogits.")
from ecologits.tracers.mistralai_tracer_v0 import MistralAIInstrumentor
else:
from ecologits.tracers.mistralai_tracer_v1 import MistralAIInstrumentor
from ecologits.tracers.mistralai_tracer_v1 import MistralAIInstrumentor # type: ignore[assignment]

instrumentor = MistralAIInstrumentor()
instrumentor.instrument()
Expand Down Expand Up @@ -123,7 +123,7 @@ class _Config:
@staticmethod
def init(
providers: Optional[Union[str, list[str]]] = None,
electricity_mix_zone: Optional[str] = "WOR",
electricity_mix_zone: str = "WOR",
) -> None:
"""
Initialization static method. Will attempt to initialize all providers by default.
Expand Down
14 changes: 7 additions & 7 deletions ecologits/impacts/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,23 @@

class DAG:
def __init__(self) -> None:
self.tasks = {}
self.dependencies = {}
self.__tasks: dict[str, Callable] = {}
self.__dependencies: dict[str, set] = {}

def asset(self, func: Callable) -> Callable:
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
return func(*args, **kwargs)

# Register the task and its dependencies
self.tasks[func.__name__] = func
self.__tasks[func.__name__] = func
func_params = list(func.__annotations__.keys())[:-1] # Ignore return type
self.dependencies[func.__name__] = set(func_params)
self.__dependencies[func.__name__] = set(func_params)

return wrapper

def build_dag(self) -> TopologicalSorter:
return TopologicalSorter(self.dependencies)
return TopologicalSorter(self.__dependencies)

def execute(self, **kwargs: Any) -> dict[str, Any]:
ts = self.build_dag()
Expand All @@ -30,9 +30,9 @@ def execute(self, **kwargs: Any) -> dict[str, Any]:
for task_name in ts.static_order():
if task_name in results: # Skip execution if result already provided
continue
task = self.tasks[task_name]
task = self.__tasks[task_name]
# Collect results from dependencies or use initial params
dep_results = {dep: results.get(dep) for dep in self.dependencies[task_name]}
dep_results = {dep: results.get(dep) for dep in self.__dependencies[task_name]}
# Filter out None values if not all dependencies are met
dep_results = {k: v for k, v in dep_results.items() if v is not None}
results[task_name] = task(**dep_results)
Expand Down
Loading

0 comments on commit 08a3703

Please sign in to comment.