diff --git a/src/codergpt/explainer/explainer.py b/src/codergpt/explainer/explainer.py index 9c35060..88595a7 100644 --- a/src/codergpt/explainer/explainer.py +++ b/src/codergpt/explainer/explainer.py @@ -27,11 +27,15 @@ def explain( :param classname: The name of the class to explain. Default is None. """ if function: - response = self.chain.invoke({"input": f"Explain the following {language} code: \n\n```\n{code}\n```"}) + response = self.chain.invoke( + {"input": f"Explain the function {function} in the following {language} code: \n\n```\n{code}\n```"} + ) # Pretty print the response print(f"Explanation for '{function}':\n{response.content}") elif classname: - response = self.chain.invoke({"input": f"Explain the following {language} code: \n\n```\n{code}\n```"}) + response = self.chain.invoke( + {"input": f"Explain the class {classname} in the following {language} code: \n\n```\n{code}\n```"} + ) # Pretty print the response print(f"Explanation for '{classname}':\n{response.content}") else: diff --git a/src/codergpt/main.py b/src/codergpt/main.py index caedfb4..8eb26d1 100644 --- a/src/codergpt/main.py +++ b/src/codergpt/main.py @@ -5,13 +5,14 @@ from typing import Optional, Union import yaml +from langchain_anthropic import ChatAnthropicMessages from langchain_core.prompts import ChatPromptTemplate from langchain_google_genai import ChatGoogleGenerativeAI from langchain_openai import ChatOpenAI from tabulate import tabulate from codergpt.commenter.commenter import CodeCommenter -from codergpt.constants import EXTENSION_MAP_FILE, GEMINI, GPT_4_TURBO, INSPECTION_HEADERS +from codergpt.constants import CLAUDE, EXTENSION_MAP_FILE, GEMINI, GPT_4_TURBO, INSPECTION_HEADERS from codergpt.documenter.documenter import CodeDocumenter from codergpt.explainer.explainer import CodeExplainer from codergpt.optimizer.optimizer import CodeOptimizer @@ -23,11 +24,16 @@ class CoderGPT: def __init__(self, model: str = GPT_4_TURBO): """Initialize the CoderGPT class.""" + temp = 0.3 if model is None or model.startswith("gpt-"): - self.llm = ChatOpenAI(openai_api_key=os.environ.get("OPENAI_API_KEY"), temperature=0.3, model=model) - # elif model == CLAUDE: - # self.llm = ChatAnthropic() - # print("Coming Soon!") + self.llm = ChatOpenAI(openai_api_key=os.environ.get("OPENAI_API_KEY"), temperature=temp, model=model) + elif model == CLAUDE: + self.llm = ChatAnthropicMessages( + model_name=model, + anthropic_api_key=os.environ.get("ANTHROPIC_API_KEY"), + temperature=temp, + max_tokens=2048, + ) elif model == GEMINI: self.llm = ChatGoogleGenerativeAI(model=model, convert_system_message_to_human=True) else: @@ -115,7 +121,7 @@ def explainer(self, path: Union[str, Path], function: str = None, classname=None """ code_explainer = CodeExplainer(self.chain) code, language = self.get_code(filename=path, function_name=function, class_name=classname) - code_explainer.explain(code, language) + code_explainer.explain(code=code, function=function, classname=classname, language=language) def commenter(self, path: Union[str, Path], overwrite: bool = False): """ diff --git a/tests/test_explainer.py b/tests/test_explainer.py index 3ebb19c..bb3f72e 100644 --- a/tests/test_explainer.py +++ b/tests/test_explainer.py @@ -33,7 +33,10 @@ def test_explain_function(self): self.code_explainer.explain(code=sample_code, function=sample_function_name, language="python") # Verify that invoke was called once with the correct parameters - expected_invoke_input = {"input": f"Explain the following python code: \n\n```\n{sample_code}\n```"} + expected_invoke_input = { + "input": f"Explain the function {sample_function_name} " + f"in the following python code: \n\n```\n{sample_code}\n```" + } self.mock_chain.invoke.assert_called_once_with(expected_invoke_input) # Check if the expected explanation message is in the captured output @@ -51,7 +54,9 @@ def test_explain_class(self): self.code_explainer.explain(code=sample_code, classname=sample_class_name, language="python") # Verify that invoke was called once with the correct parameters - expected_invoke_input = {"input": f"Explain the following python code: \n\n```\n{sample_code}\n```"} + expected_invoke_input = { + "input": f"Explain the class {sample_class_name} in the following python code: \n\n```\n{sample_code}\n```" + } self.mock_chain.invoke.assert_called_once_with(expected_invoke_input) # Check if the expected explanation message is in the captured output