diff --git a/src/Mscc.GenerativeAI/GenerativeModel.cs b/src/Mscc.GenerativeAI/GenerativeModel.cs index 24f65b5..9835b6d 100644 --- a/src/Mscc.GenerativeAI/GenerativeModel.cs +++ b/src/Mscc.GenerativeAI/GenerativeModel.cs @@ -455,13 +455,18 @@ public async Task TransferOwnership(string model, string emailAddress) /// Required. The resource name of the model. This name should match a model name returned by the models.list method. Format: models/model-id or tunedModels/my-model-id /// /// - public async Task GetModel(string model = GenerativeAI.Model.GeminiPro) + public async Task GetModel(string model = null) { if (_useVertexAi) { throw new NotSupportedException(); } + if (model is null) + { + model = _model; + } + model = model.SanitizeModelName(); if (!string.IsNullOrEmpty(_apiKey) && model.StartsWith("tunedModel", StringComparison.InvariantCultureIgnoreCase)) { diff --git a/tests/Mscc.GenerativeAI/GenerativeAI_Should.cs b/tests/Mscc.GenerativeAI/GenerativeAI_Should.cs index 8f1b1dc..93f8bc7 100644 --- a/tests/Mscc.GenerativeAI/GenerativeAI_Should.cs +++ b/tests/Mscc.GenerativeAI/GenerativeAI_Should.cs @@ -2,6 +2,7 @@ #endif using FluentAssertions; using Mscc.GenerativeAI; +using System.Threading.Tasks; using Xunit; using Xunit.Abstractions; @@ -41,6 +42,22 @@ public void Initialize_Interface_GoogleAI() model.Name.Should().Be($"{expected}"); } + [Fact] + public async Task GetModel() + { + // Arrange + IGenerativeAI genAi; + genAi = new GoogleAI(apiKey: fixture.ApiKey); + var expected = Model.Embedding.SanitizeModelName(); + + // Act + var model = genAi.GenerativeModel(model: Model.Embedding); + var get_model = await model.GetModel(); + + // Assert + get_model.Name.SanitizeModelName().Should().Be(expected); + } + [Fact] public void Initialize_Interface_VertexAI() {