From 261859373e45a77062da17c566c3d75130a90e6b Mon Sep 17 00:00:00 2001
From: Zach Allaun <zach.allaun@gmail.com>
Date: Tue, 23 Jul 2024 16:47:34 -0400
Subject: [PATCH] Strip top-level calls from .exs files before compilation

---
 .../build/document/compilers/elixir.ex        |  5 ++-
 .../build/document/compilers/quoted.ex        | 42 +++++++++++++++++++
 .../lexical/remote_control/build_test.exs     |  2 +-
 .../lexical_shared/lib/lexical/document.ex    |  9 +++-
 4 files changed, 53 insertions(+), 5 deletions(-)

diff --git a/apps/remote_control/lib/lexical/remote_control/build/document/compilers/elixir.ex b/apps/remote_control/lib/lexical/remote_control/build/document/compilers/elixir.ex
index 36a62d074..e53ab498c 100644
--- a/apps/remote_control/lib/lexical/remote_control/build/document/compilers/elixir.ex
+++ b/apps/remote_control/lib/lexical/remote_control/build/document/compilers/elixir.ex
@@ -11,8 +11,9 @@ defmodule Lexical.RemoteControl.Build.Document.Compilers.Elixir do
   @behaviour Build.Document.Compiler
 
   @impl true
-  def recognizes?(%Document{language_id: "elixir"}), do: true
-  def recognizes?(_), do: false
+  def recognizes?(%Document{} = doc) do
+    doc.language_id in ["elixir", "elixir-script"]
+  end
 
   @impl true
   def enabled?, do: true
diff --git a/apps/remote_control/lib/lexical/remote_control/build/document/compilers/quoted.ex b/apps/remote_control/lib/lexical/remote_control/build/document/compilers/quoted.ex
index ae416c681..7e7a63529 100644
--- a/apps/remote_control/lib/lexical/remote_control/build/document/compilers/quoted.ex
+++ b/apps/remote_control/lib/lexical/remote_control/build/document/compilers/quoted.ex
@@ -9,6 +9,13 @@ defmodule Lexical.RemoteControl.Build.Document.Compilers.Quoted do
   def compile(%Document{} = document, quoted_ast, compiler_name) do
     prepare_compile(document.path)
 
+    quoted_ast =
+      if document.language_id == "elixir-script" do
+        wrap_top_level_forms(quoted_ast)
+      else
+        quoted_ast
+      end
+
     {status, diagnostics} =
       if Features.with_diagnostics?() do
         do_compile(quoted_ast, document)
@@ -130,4 +137,39 @@ defmodule Lexical.RemoteControl.Build.Document.Compilers.Quoted do
   defp replace_source(result, source) do
     Map.put(result, :source, source)
   end
+
+  defp wrap_top_level_forms({:__block__, meta, nodes}) do
+    chunks =
+      nodes
+      |> Enum.chunk_by(&should_wrap?/1)
+      |> Enum.with_index(fn [node | _] = nodes, i ->
+        if should_wrap?(node) do
+          wrap_nodes(nodes, i)
+        else
+          nodes
+        end
+      end)
+
+    {:__block__, meta, chunks}
+  end
+
+  defp wrap_top_level_forms(ast) do
+    wrap_top_level_forms({:__block__, [], [ast]})
+  end
+
+  defp wrap_nodes(nodes, i) do
+    module_name = :"lexical_wrapper_#{i}"
+
+    quote do
+      defmodule unquote(module_name) do
+        def __lexical_wrapper__ do
+          (unquote_splicing(nodes))
+        end
+      end
+    end
+  end
+
+  @allowed_top_level [:defmodule, :alias, :import, :require, :use]
+  defp should_wrap?({allowed, _, _}) when allowed in @allowed_top_level, do: false
+  defp should_wrap?(_), do: true
 end
diff --git a/apps/remote_control/test/lexical/remote_control/build_test.exs b/apps/remote_control/test/lexical/remote_control/build_test.exs
index bdca695b5..98c01fc51 100644
--- a/apps/remote_control/test/lexical/remote_control/build_test.exs
+++ b/apps/remote_control/test/lexical/remote_control/build_test.exs
@@ -24,7 +24,7 @@ defmodule Lexical.BuildTest do
         project
         |> Project.root_path()
         |> Path.join(to_string(sequence))
-        |> Path.join("file.exs")
+        |> Path.join("file.ex")
         |> Document.Path.to_uri()
       end
 
diff --git a/projects/lexical_shared/lib/lexical/document.ex b/projects/lexical_shared/lib/lexical/document.ex
index c3afc64ab..3d5abe94c 100644
--- a/projects/lexical_shared/lib/lexical/document.ex
+++ b/projects/lexical_shared/lib/lexical/document.ex
@@ -48,7 +48,12 @@ defmodule Lexical.Document do
     uri = DocumentPath.ensure_uri(maybe_uri)
     path = DocumentPath.from_uri(uri)
 
-    language_id = language_id || language_id_from_path(path)
+    language_id =
+      if String.ends_with?(path, ".exs") do
+        "elixir-script"
+      else
+        language_id || language_id_from_path(path)
+      end
 
     %__MODULE__{
       uri: uri,
@@ -233,7 +238,7 @@ defmodule Lexical.Document do
         "elixir"
 
       ".exs" ->
-        "elixir"
+        "elixir-script"
 
       ".eex" ->
         "eex"