Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: +rfc197 example #968

Merged
merged 1 commit into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions examples/reverse_engineering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import asyncio
import shutil
from pathlib import Path

import typer

from metagpt.actions.rebuild_class_view import RebuildClassView
from metagpt.actions.rebuild_sequence_view import RebuildSequenceView
from metagpt.context import Context
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.utils.git_repository import GitRepository
from metagpt.utils.project_repo import ProjectRepo

app = typer.Typer(add_completion=False, pretty_exceptions_show_locals=False)


@app.command("", help="Python project reverse engineering.")
def startup(
project_root: str = typer.Argument(
default="",
help="Specify the root directory of the existing project for reverse engineering.",
),
output_dir: str = typer.Option(default="", help="Specify the output directory path for reverse engineering."),
):
package_root = Path(project_root)
if not package_root.exists():
raise FileNotFoundError(f"{project_root} not exists")
if not _is_python_package_root(package_root):
raise FileNotFoundError(f'There are no "*.py" files under "{project_root}".')
init_file = package_root / "__init__.py" # used by pyreverse
init_file_exists = init_file.exists()
if not init_file_exists:
init_file.touch()

if not output_dir:
output_dir = package_root / "../reverse_engineering_output"
logger.info(f"output dir:{output_dir}")
try:
asyncio.run(reverse_engineering(package_root, Path(output_dir)))
finally:
if not init_file_exists:
init_file.unlink(missing_ok=True)
tmp_dir = package_root / "__dot__"
if tmp_dir.exists():
shutil.rmtree(tmp_dir, ignore_errors=True)


def _is_python_package_root(package_root: Path) -> bool:
for file_path in package_root.iterdir():
if file_path.is_file():
if file_path.suffix == ".py":
return True
return False


async def reverse_engineering(package_root: Path, output_dir: Path):
ctx = Context()
ctx.git_repo = GitRepository(output_dir)
ctx.repo = ProjectRepo(ctx.git_repo)
action = RebuildClassView(name="ReverseEngineering", i_context=str(package_root), llm=LLM(), context=ctx)
await action.run()

action = RebuildSequenceView(name="ReverseEngineering", llm=LLM(), context=ctx)
await action.run()


if __name__ == "__main__":
app()
2 changes: 1 addition & 1 deletion metagpt/actions/rebuild_class_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ async def _create_mermaid_class_views(self) -> str:
path = self.context.git_repo.workdir / DATA_API_DESIGN_FILE_REPO
path.mkdir(parents=True, exist_ok=True)
pathname = path / self.context.git_repo.workdir.name
filename = str(pathname.with_suffix(".mmd"))
filename = str(pathname.with_suffix(".class_diagram.mmd"))
async with aiofiles.open(filename, mode="w", encoding="utf-8") as writer:
content = "classDiagram\n"
logger.debug(content)
Expand Down
75 changes: 33 additions & 42 deletions metagpt/actions/rebuild_sequence_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import re
from datetime import datetime
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Set

from pydantic import BaseModel
from tenacity import retry, stop_after_attempt, wait_random_exponential
Expand Down Expand Up @@ -125,7 +125,7 @@ async def _rebuild_main_sequence_view(self, entry: SPO):
if prefix in r.subject:
classes.append(r)
await self._rebuild_use_case(r.subject)
participants = set()
participants = await self._search_participants(split_namespace(entry.subject)[0])
class_details = []
class_views = []
for c in classes:
Expand Down Expand Up @@ -171,7 +171,8 @@ async def _rebuild_main_sequence_view(self, entry: SPO):
sequence_view = rsp.removeprefix("```mermaid").removesuffix("```")
rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
for r in rows:
await self.graph_db.delete(subject=r.subject, predicate=r.predicate, object_=r.object_)
if r.predicate == GraphKeyword.HAS_SEQUENCE_VIEW:
await self.graph_db.delete(subject=r.subject, predicate=r.predicate, object_=r.object_)
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view
)
Expand All @@ -184,7 +185,7 @@ async def _rebuild_main_sequence_view(self, entry: SPO):
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(c.subject)
)
await self.graph_db.save()
await self._save_sequence_view(subject=entry.subject, content=sequence_view)

async def _merge_sequence_view(self, entry: SPO) -> bool:
"""
Expand Down Expand Up @@ -267,38 +268,6 @@ async def _rebuild_use_case(self, ns_class_name: str):
prompt_blocks.append(block)
prompt = "\n---\n".join(prompt_blocks)

# class _UseCase(BaseModel):
# description: str = Field(default="...", description="Describes about what the use case to do")
# inputs: List[str] = Field(default=["input name 1", "input name 2"],
# description="Lists the input names of the use case from external sources")
# outputs: List[str] = Field(default=["output name 1", "output name 2"],
# description="Lists the output names of the use case to external sources")
# actors: List[str] = Field(default=["actor name 1", "actor name 2"],
# description="Lists the participant actors of the use case")
# steps: List[str] = Field(default=["Step 1", "Step 2"],
# description="Lists the steps about how the use case works step by step")
# reason: str = Field(default="Because ...",
# description="Explaining under what circumstances would the external system execute this use case.")
#
#
# class _UseCaseList(BaseModel):
# description: str = Field(default="...",
# description="A summary explains what the whole source code want to do")
# use_cases: List[_UseCase] = Field(default=[
# {
# "description": "Describes about what the use case to do",
# "inputs": ["input name 1", "input name 2"],
# "outputs": ["output name 1", "output name 2"],
# "actors": ["actor name 1", "actor name 2"],
# "steps": ["Step 1", "Step 2"],
# "reason": "Because ..."
# }
# ], description="List all use cases.")
# relationship: List[str] = Field(default=["use case 1 ..."],
# description="Lists all the descriptions of relationship among these use cases")

# rsp = await ActionNode.from_pydantic(_UseCaseList).fill(context=prompt, llm=self.llm)

rsp = await self.llm.aask(
msg=prompt,
system_msgs=[
Expand Down Expand Up @@ -327,7 +296,6 @@ async def _rebuild_use_case(self, ns_class_name: str):
await self.graph_db.insert(
subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_USE_CASE, object_=detail.model_dump_json()
)
await self.graph_db.save()

@retry(
wait=wait_random_exponential(min=1, max=20),
Expand All @@ -347,7 +315,6 @@ async def _rebuild_sequence_view(self, ns_class_name: str):
use_case_markdown = await self._get_class_use_cases(ns_class_name)
if not use_case_markdown: # external class
await self.graph_db.insert(subject=ns_class_name, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_="")
await self.graph_db.save()
return
block = f"## Use Cases\n{use_case_markdown}"
prompts_blocks.append(block)
Expand Down Expand Up @@ -382,7 +349,6 @@ async def _rebuild_sequence_view(self, ns_class_name: str):
await self.graph_db.insert(
subject=ns_class_name, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view
)
await self.graph_db.save()

async def _get_participants(self, ns_class_name: str) -> List[str]:
"""
Expand Down Expand Up @@ -574,14 +540,12 @@ async def _merge_participant(self, entry: SPO, class_name: str):
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=concat_namespace("?", class_name)
)
await self.graph_db.save()
return
if len(participants) > 1:
for r in participants:
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(r.subject)
)
await self.graph_db.save()
return

participant = participants[0]
Expand Down Expand Up @@ -619,4 +583,31 @@ async def _merge_participant(self, entry: SPO, class_name: str):
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(participant.subject)
)
await self.graph_db.save()
await self._save_sequence_view(subject=entry.subject, content=sequence_view)

async def _save_sequence_view(self, subject: str, content: str):
pattern = re.compile(r"[^a-zA-Z0-9]")
name = re.sub(pattern, "_", subject)
filename = Path(name).with_suffix(".sequence_diagram.mmd")
await self.context.repo.resources.data_api_design.save(filename=str(filename), content=content)

async def _search_participants(self, filename: str) -> Set:
content = await self._get_source_code(filename)

rsp = await self.llm.aask(
msg=content,
system_msgs=[
"You are a tool for listing all class names used in a source file.",
"Return a markdown JSON object with: "
'- a "class_names" key containing the list of class names used in the file; '
'- a "reasons" key lists all reason objects, each object containing a "class_name" key for class name, a "reference" key explaining the line where the class has been used.',
],
)

class _Data(BaseModel):
class_names: List[str]
reasons: List

json_blocks = parse_json_code_block(rsp)
data = _Data.model_validate_json(json_blocks[0])
return set(data.class_names)
13 changes: 10 additions & 3 deletions metagpt/repo_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,14 +722,19 @@ async def rebuild_class_views(self, path: str | Path = None):
path = Path(path)
if not path.exists():
return
init_file = path / "__init__.py"
if not init_file.exists():
raise ValueError("Failed to import module __init__ with error:No module named __init__.")
command = f"pyreverse {str(path)} -o dot"
result = subprocess.run(command, shell=True, check=True, cwd=str(path))
output_dir = path / "__dot__"
output_dir.mkdir(parents=True, exist_ok=True)
result = subprocess.run(command, shell=True, check=True, cwd=str(output_dir))
if result.returncode != 0:
raise ValueError(f"{result}")
class_view_pathname = path / "classes.dot"
class_view_pathname = output_dir / "classes.dot"
class_views = await self._parse_classes(class_view_pathname)
relationship_views = await self._parse_class_relationships(class_view_pathname)
packages_pathname = path / "packages.dot"
packages_pathname = output_dir / "packages.dot"
class_views, relationship_views, package_root = RepoParser._repair_namespaces(
class_views=class_views, relationship_views=relationship_views, path=path
)
Expand Down Expand Up @@ -975,6 +980,8 @@ def _repair_ns(package: str, mappings: Dict[str, str]) -> str:
file_ns = file_ns[0:ix]
continue
break
if file_ns == "":
return ""
internal_ns = package[ix + 1 :]
ns = mappings[file_ns] + ":" + internal_ns.replace(".", ":")
return ns
Expand Down
1 change: 0 additions & 1 deletion tests/metagpt/actions/test_rebuild_class_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from metagpt.llm import LLM


@pytest.mark.skip
@pytest.mark.asyncio
async def test_rebuild(context):
action = RebuildClassView(
Expand Down
2 changes: 2 additions & 0 deletions tests/metagpt/actions/test_rebuild_sequence_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ async def test_rebuild(context, mocker):
context=context,
)
await action.run()
rows = await action.graph_db.select()
assert rows
assert context.repo.docs.graph_repo.changed_files


Expand Down
Loading