Skip to content
This repository has been archived by the owner on Feb 28, 2024. It is now read-only.

Allow loading of mods with unloadable types #83

Merged
merged 3 commits into from
Mar 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 66 additions & 8 deletions NeosModLoader/AssemblyHider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using HarmonyLib;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;

namespace NeosModLoader
Expand Down Expand Up @@ -39,6 +40,11 @@ internal static void PatchNeos(Harmony harmony, HashSet<Assembly> initialAssembl
MethodInfo getTypeTarget = AccessTools.DeclaredMethod(typeof(WorkerManager), nameof(WorkerManager.GetType));
MethodInfo getTypePatch = AccessTools.DeclaredMethod(typeof(AssemblyHider), nameof(FindTypePostfix));
harmony.Patch(getTypeTarget, postfix: new HarmonyMethod(getTypePatch));

// FrooxEngine likes to enumerate all types in all assemblies, which is prone to issues (such as crashing FrooxCode if a type isn't loadable)
MethodInfo getAssembliesTarget = AccessTools.DeclaredMethod(typeof(AppDomain), nameof(AppDomain.GetAssemblies));
MethodInfo getAssembliesPatch = AccessTools.DeclaredMethod(typeof(AssemblyHider), nameof(GetAssembliesPostfix));
harmony.Patch(getAssembliesTarget, postfix: new HarmonyMethod(getAssembliesPatch));
}
}

Expand All @@ -64,20 +70,32 @@ private static HashSet<Assembly> GetModAssemblies()
return assemblies;
}

private static bool IsModType(Type type)
/// <summary>
/// Check if an assembly belongs to a mod or not
/// </summary>
/// <param name="assembly">The assembly to check</param>
/// <param name="typeOrAssembly">Type of root check being performed. Should be "type" or "assembly". Used in logging.</param>
/// <param name="name">Name of the root check being performed. Used in logging.</param>
/// <param name="log">If `true`, this will emit logs. If `false`, this function will not log.</param>
/// <param name="forceShowLate">If `true`, then this function will always return `false` for late-loaded types</param>
/// <returns>`true` if this assembly belongs to a mod.</returns>
private static bool IsModAssembly(Assembly assembly, string typeOrAssembly, string name, bool log, bool forceShowLate)
{
if (neosAssemblies!.Contains(type.Assembly))
if (neosAssemblies!.Contains(assembly))
{
// the type belongs to a Neos assembly
return false; // don't hide the type
return false; // don't hide the thing
}
else
{
if (modAssemblies!.Contains(type.Assembly))
if (modAssemblies!.Contains(assembly))
{
// known type from a mod assembly
Logger.DebugInternal($"Hid type \"{type}\" from Neos");
return true; // hide the type
if (log)
ljoonal marked this conversation as resolved.
Show resolved Hide resolved
{
Logger.DebugInternal($"Hid {typeOrAssembly} \"{name}\" from Neos");
}
return true; // hide the thing
}
else
{
Expand All @@ -86,12 +104,39 @@ private static bool IsModType(Type type)
// this is super weird, and probably shouldn't ever happen... but if it does, I want to know about it.
// since this is an edge case users may want to handle in different ways, the HideLateTypes nml config option allows them to choose.
bool hideLate = ModLoaderConfiguration.Get().HideLateTypes;
Logger.WarnInternal($"The \"{type}\" type does not appear to part of Neos or a mod. It is unclear whether it should be hidden or not. Due to the HideLateTypes config option being {hideLate} it will be {(hideLate ? "Hidden" : "Shown")}");
return hideLate; // hide the type only if hideLate == true
if (log)
{
Logger.WarnInternal($"The \"{name}\" {typeOrAssembly} does not appear to part of Neos or a mod. It is unclear whether it should be hidden or not. Due to the HideLateTypes config option being {hideLate} it will be {(hideLate ? "Hidden" : "Shown")}");
}
// if forceShowLate == true, then this function will always return `false` for late-loaded types
// if forceShowLate == false, then this function will return `true` when hideLate == true
return hideLate && !forceShowLate;
}
}
}

/// <summary>
/// Check if an assembly belongs to a mod or not
/// </summary>
/// <param name="assembly">The assembly to check</param>
/// <param name="forceShowLate">If `true`, then this function will always return `false` for late-loaded types</param>
/// <returns>`true` if this assembly belongs to a mod.</returns>
private static bool IsModAssembly(Assembly assembly, bool forceShowLate = false)
{
// this generates a lot of logspam, as a single call to AppDomain.GetAssemblies() calls this many times
return IsModAssembly(assembly, "assembly", assembly.ToString(), log: false, forceShowLate);
}

/// <summary>
/// Check if a type belongs to a mod or not
/// </summary>
/// <param name="type">The type to check</param>
/// <returns>true` if this type belongs to a mod.</returns>
private static bool IsModType(Type type)
{
return IsModAssembly(type.Assembly, "type", type.ToString(), log: true, forceShowLate: false);
}

// postfix for a method that searches for a type, and returns a reference to it if found (TypeHelper.FindType and WorkerManager.GetType)
private static void FindTypePostfix(ref Type? __result)
{
Expand All @@ -117,5 +162,18 @@ private static void IsValidTypePostfix(ref bool __result, Type type)
}
}
}

private static void GetAssembliesPostfix(ref Assembly[] __result)
{
Assembly? callingAssembly = Util.GetCallingAssembly();
if (callingAssembly != null && neosAssemblies!.Contains(callingAssembly))
{
// if we're being called by Neos, then hide mod assemblies
Logger.DebugFuncInternal(() => $"Intercepting call to AppDomain.GetAssemblies() from {callingAssembly}");
__result = __result
.Where(assembly => !IsModAssembly(assembly, forceShowLate: true)) // it turns out Neos itself late-loads a bunch of stuff, so we force-show late-loaded assemblies here
.ToArray();
}
}
}
}
4 changes: 2 additions & 2 deletions NeosModLoader/ModLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace NeosModLoader
{
public class ModLoader
{
internal const string VERSION_CONSTANT = "1.12.5";
internal const string VERSION_CONSTANT = "1.12.6";
ljoonal marked this conversation as resolved.
Show resolved Hide resolved
/// <summary>
/// NeosModLoader's version
/// </summary>
Expand Down Expand Up @@ -174,7 +174,7 @@ private static string TypesForOwner(Patches patches, string owner)
return null;
}

Type[] modClasses = mod.Assembly.GetTypes().Where(t => t.IsClass && !t.IsAbstract && NEOS_MOD_TYPE.IsAssignableFrom(t)).ToArray();
Type[] modClasses = mod.Assembly.GetLoadableTypes(t => t.IsClass && !t.IsAbstract && NEOS_MOD_TYPE.IsAssignableFrom(t)).ToArray();
if (modClasses.Length == 0)
{
Logger.ErrorInternal($"no mods found in {mod.File}");
Expand Down
72 changes: 72 additions & 0 deletions NeosModLoader/Util.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Security.Cryptography;
using System.Threading;
Expand Down Expand Up @@ -34,6 +35,25 @@ internal static class Util
return null;
}

/// <summary>
/// Get the calling assembly by stack trace analysis. Always skips the first one frame, being this method and you, the caller.
/// </summary>
/// <param name="skipFrames">The number of extra frame skip in the stack</param>
/// <returns>The executing mod, or null if none found</returns>
internal static Assembly? GetCallingAssembly(int skipFrames = 0)
{
// same logic as ExecutingMod(), but simpler case
StackTrace stackTrace = new(2 + skipFrames);
for (int i = 0; i < stackTrace.FrameCount; i++)
{
Assembly? assembly = stackTrace.GetFrame(i)?.GetMethod()?.DeclaringType?.Assembly;
if (assembly != null)
{
return assembly;
}
}
return null;
}

/// <summary>
/// Used to debounce a method call. The underlying method will be called after there have been no additional calls
Expand Down Expand Up @@ -102,5 +122,57 @@ internal static bool CanBeNull(Type t)
return !CannotBeNull(t);
}

internal static IEnumerable<Type> GetLoadableTypes(this Assembly assembly, Predicate<Type> predicate)
{
try
{
return assembly.GetTypes();
}
catch (ReflectionTypeLoadException e)
{
return e.Types.Where(type => CheckType(type, predicate));
}
}

// check a potentially unloadable type to see if it is (A) loadable and (B) satsifies a predicate without throwing an exception
// this does a series of increasingly aggressive checks to see if the type is unsafe to touch
private static bool CheckType(Type type, Predicate<Type> predicate)
{
if (type == null)
{
return false;
}
try
{
string _name = type.Name;
}
catch (Exception e)
{
Logger.DebugFuncInternal(() => $"Could not read the name for a type: {e}");
return false;
}
try
{
if (type.TypeInitializer == null)
{
return false;
}
}
catch (Exception e)
{
Logger.DebugFuncInternal(() => $"Could not read TypeInitializer for type \"{type.Name}\": {e}");
return false;
}

try
{
return predicate(type);
}
catch (Exception e)
{
Logger.DebugFuncInternal(() => $"Could not load type \"{type.Name}\": {e}");
return false;
}
}
}
}