Skip to content

Commit

Permalink
add model tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
jochenkirstaetter committed Mar 18, 2024
1 parent 6e2d4f2 commit f230e02
Show file tree
Hide file tree
Showing 11 changed files with 1,064 additions and 747 deletions.
1,540 changes: 805 additions & 735 deletions src/Mscc.GenerativeAI/GenerativeModel.cs

Large diffs are not rendered by default.

10 changes: 6 additions & 4 deletions src/Mscc.GenerativeAI/GenerativeModelExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ public static class GenerativeModelExtensions
{
if (value == null) return value;

if (value.StartsWith("model", StringComparison.InvariantCultureIgnoreCase))
if (value.StartsWith("tuned", StringComparison.InvariantCultureIgnoreCase))
return value;

if (!value.StartsWith("model", StringComparison.InvariantCultureIgnoreCase))
{
var parts = value.Split(new char[] { '/' }, StringSplitOptions.RemoveEmptyEntries);
value = parts.Last();
return $"models/{value}";
}
return value.ToLower();
return value;
}

public static string? GetValue(this JsonElement element, string key)
Expand Down
9 changes: 9 additions & 0 deletions src/Mscc.GenerativeAI/Types/CreateTunedModelRequest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
namespace Mscc.GenerativeAI
{
public class CreateTunedModelRequest
{
public string DisplayName { get; set; }
public string BaseModel { get; set; }
public TuningTask TuningTask { get; set; }
}
}
18 changes: 18 additions & 0 deletions src/Mscc.GenerativeAI/Types/CreateTunedModelResponse.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
using System.Diagnostics;

namespace Mscc.GenerativeAI
{
public class CreateTunedModelResponse
{
public string Name { get; set; }
public CreateTunedModelMetadata Metadata { get; set; }
}

[DebuggerDisplay("{TunedModel})")]
public class CreateTunedModelMetadata
{
public string Type { get; set; }
public int TotalSteps { get; set; }
public string TunedModel { get; set; }
}
}
9 changes: 9 additions & 0 deletions src/Mscc.GenerativeAI/Types/HyperParameters.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
namespace Mscc.GenerativeAI
{
public class HyperParameters
{
public int BatchSize { get; set; }
public float LearningRate { get; set; }
public int EpochCount { get; set; }
}
}
9 changes: 9 additions & 0 deletions src/Mscc.GenerativeAI/Types/ModelResponse.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#if NET472_OR_GREATER || NETSTANDARD2_0
using System;
using System.Collections.Generic;
#endif
using System.Diagnostics;
Expand All @@ -23,5 +24,13 @@ public class ModelResponse
public float? Temperature { get; set; } = default;
public float? TopP { get; set; } = default;
public int? TopK { get; set; } = default;

// Properties related to tunedModels.
public string? BaseModel { get; set; }
public string? State { get; set; }
public DateTime? CreateTime { get; set; }
public DateTime? UpdateTime { get; set; }
public TuningTask? TuningTask { get; set; }

}
}
16 changes: 16 additions & 0 deletions src/Mscc.GenerativeAI/Types/Snapshot.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#if NET472_OR_GREATER || NETSTANDARD2_0
using System;
#endif
using System.Diagnostics;

namespace Mscc.GenerativeAI
{
[DebuggerDisplay("{Step}: ({Epoch,nq} - {ComputeTime,nq}")]
public class Snapshot
{
public int Step { get; set; }
public float? MeanLoss { get; set; }
public DateTime ComputeTime { get; set; }
public int? Epoch { get; set; }
}
}
24 changes: 24 additions & 0 deletions src/Mscc.GenerativeAI/Types/TrainingData.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#if NET472_OR_GREATER || NETSTANDARD2_0
using System.Collections.Generic;
#endif
using System.Diagnostics;

namespace Mscc.GenerativeAI
{
public class TrainingData
{
public TrainingDataExamples Examples { get; set; }
}

public class TrainingDataExamples
{
public List<TrainingDataExample> Examples { get; set; }
}

[DebuggerDisplay("Input: {TextInput,nq} - Output: {Output,nq}")]
public class TrainingDataExample
{
public string TextInput { get; set; }
public string Output { get; set; }
}
}
28 changes: 28 additions & 0 deletions src/Mscc.GenerativeAI/Types/TunedModelResponse.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#if NET472_OR_GREATER || NETSTANDARD2_0
using System;
using System.Collections.Generic;
#endif
using System.Diagnostics;

namespace Mscc.GenerativeAI
{
public class ListTunedModelResponse
{
public List<ModelResponse> TunedModels { get; set; }
}

[DebuggerDisplay("{DisplayName} ({Name})")]
public class TunedModelResponse
{
public string Name { get; set; }
public string BaseModel { get; set; }
public string DisplayName { get; set; }
public string State { get; set; }
public DateTime CreateTime { get; set; }
public DateTime UpdateTime { get; set; }
public TuningTask TuningTask { get; set; }
public float Temperature { get; set; }
public float TopP { get; set; }
public int TopK { get; set; }
}
}
16 changes: 16 additions & 0 deletions src/Mscc.GenerativeAI/Types/TuningTask.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#if NET472_OR_GREATER || NETSTANDARD2_0
using System;
using System.Collections.Generic;
#endif

namespace Mscc.GenerativeAI
{
public class TuningTask
{
public DateTime? StartTime { get; set; }
public DateTime? CompleteTime { get; set; }
public List<Snapshot>? Snapshots { get; set; }
public HyperParameters? Hyperparameters { get; set; }
public TrainingData? TrainingData { get; set; }
}
}
132 changes: 124 additions & 8 deletions tests/Mscc.GenerativeAI/GoogleAi_GeminiPro_Should.cs
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ public async void List_Models()
public async void List_Models_Using_OAuth()
{
// Arrange
var model = new GenerativeModel();
model.AccessToken = fixture.AccessToken;
var model = new GenerativeModel { AccessToken = fixture.AccessToken };

// Act
var sut = await model.ListModels();
Expand All @@ -104,11 +103,31 @@ public async void List_Models_Using_OAuth()
});
}

[Fact]
public async void List_Tuned_Models()
{
// Arrange
var model = new GenerativeModel { AccessToken = fixture.AccessToken };

// Act
var sut = await model.ListTunedModels();

// Assert
sut.Should().NotBeNull();
sut.Should().NotBeNull().And.HaveCountGreaterThanOrEqualTo(1);
sut.ForEach(x =>
{
output.WriteLine($"Model: {x.DisplayName} ({x.Name})");
x.TuningTask.Snapshots.ForEach(m => output.WriteLine($" Snapshot: {m}"));
});
}

[Theory]
[InlineData(Model.GeminiPro)]
[InlineData(Model.Gemini10Pro001)]
[InlineData(Model.GeminiProVision)]
[InlineData(Model.BisonText)]
[InlineData(Model.BisonChat)]
[InlineData("tunedModels/number-generator-model-psx3d3gljyko")]
public async void Get_Model_Information(string modelName)
{
// Arrange
Expand All @@ -125,24 +144,46 @@ public async void Get_Model_Information(string modelName)
}

[Theory]
[InlineData(Model.GeminiPro)]
[InlineData("tunedModels/number-generator-model-psx3d3gljyko")]
public async void Get_TunedModel_Information_Using_ApiKey(string modelName)
{
// Arrange
var model = new GenerativeModel(apiKey: fixture.ApiKey);


// Act & Assert
await Assert.ThrowsAsync<NotSupportedException>(() => model.GetModel(model: modelName));
}

[Theory]
[InlineData(Model.Gemini10Pro001)]
[InlineData(Model.GeminiProVision)]
[InlineData(Model.BisonText)]
[InlineData(Model.BisonChat)]
[InlineData("tunedModels/number-generator-model-psx3d3gljyko")]
public async void Get_Model_Information_Using_OAuth(string modelName)
{
// Arrange
var model = new GenerativeModel();
model.AccessToken = fixture.AccessToken;
var model = new GenerativeModel { AccessToken = fixture.AccessToken };
var expected = modelName;
if (!expected.Contains("/"))
expected = $"models/{expected}";

// Act
var sut = await model.GetModel(model: modelName);

// Assert
sut.Should().NotBeNull();
sut.Name.Should().Be($"models/{modelName}");
sut.Name.Should().Be(expected);
output.WriteLine($"Model: {sut.DisplayName} ({sut.Name})");
sut.SupportedGenerationMethods.ForEach(m => output.WriteLine($" Method: {m}"));
if (sut.State is null)
{
sut?.SupportedGenerationMethods?.ForEach(m => output.WriteLine($" Method: {m}"));
}
else
{
output.WriteLine($"State: {sut.State}");
}
}

[Fact]
Expand Down Expand Up @@ -924,5 +965,80 @@ public async void Function_Calling_ContentStream()
// output.WriteLine($"CandidatesTokenCount: {response.LastOrDefault().UsageMetadata.CandidatesTokenCount}");
// output.WriteLine($"TotalTokenCount: {response.LastOrDefault().UsageMetadata.TotalTokenCount}");
}

[Fact]
public async void Create_Tuned_Model()
{
// Arrange
var model = new GenerativeModel(apiKey: null, model: Model.Gemini10Pro001)
{
AccessToken = fixture.AccessToken,
ProjectId = fixture.ProjectId
};
var request = new CreateTunedModelRequest()
{
BaseModel = $"models/{Model.Gemini10Pro001}",
DisplayName = "Autogenerated Test model",
TuningTask = new()
{
Hyperparameters = new() { BatchSize = 2, LearningRate = 0.001f, EpochCount = 3 },
TrainingData = new()
{
Examples = new()
{
Examples = new()
{
new TrainingDataExample() { TextInput = "1", Output = "2" },
new TrainingDataExample() { TextInput = "3", Output = "4" },
new TrainingDataExample() { TextInput = "-3", Output = "-2" },
new TrainingDataExample() { TextInput = "twenty two", Output = "twenty three" },
new TrainingDataExample() { TextInput = "two hundred", Output = "two hundred one" },
new TrainingDataExample() { TextInput = "ninety nine", Output = "one hundred" },
new TrainingDataExample() { TextInput = "8", Output = "9" },
new TrainingDataExample() { TextInput = "-98", Output = "-97" },
new TrainingDataExample() { TextInput = "1,000", Output = "1,001" },
new TrainingDataExample() { TextInput = "thirteen", Output = "fourteen" },
new TrainingDataExample() { TextInput = "seven", Output = "eight" },
}
}
}
}
};

// Act
var response = await model.CreateTunedModel(request);

// Assert
response.Should().NotBeNull();
response.Name.Should().NotBeNull();
response.Metadata.Should().NotBeNull();
output.WriteLine($"Name: {response.Name}");
output.WriteLine($"Model: {response.Metadata.TunedModel} (Steps: {response.Metadata.TotalSteps})");
}

[Theory]
[InlineData("255", "256")]
[InlineData("41", "42")]
// [InlineData("five", "six")]
// [InlineData("Six hundred thirty nine", "Six hundred forty")]
public async void Generate_Content_TunedModel(string prompt, string expected)
{
// Arrange
var model = new GenerativeModel(apiKey: null, model: "tunedModels/autogenerated-test-model-48gob9c9v54p")
{
AccessToken = fixture.AccessToken,
ProjectId = fixture.ProjectId
};

// Act
var response = await model.GenerateContent(prompt);

// Assert
response.Should().NotBeNull();
response.Candidates.Should().NotBeNull().And.HaveCount(1);
response.Text.Should().NotBeEmpty();
output.WriteLine(response?.Text);
response?.Text.Should().Be(expected);
}
}
}

0 comments on commit f230e02

Please sign in to comment.