diff --git a/src/asm/asm.s b/src/asm/asm.s index 6e3ec8446..6eff11c31 100644 --- a/src/asm/asm.s +++ b/src/asm/asm.s @@ -94,6 +94,12 @@ _x86_64_asm_invlpg: invlpg (%rdi) retq +.global _x86_64_asm_invpcid +.p2align 4 +_x86_64_asm_invpcid: + invpcid (%rsi), %rdi + retq + .global _x86_64_asm_ltr .p2align 4 _x86_64_asm_ltr: diff --git a/src/asm/mod.rs b/src/asm/mod.rs index 1bc7cf9a7..f67a17e16 100644 --- a/src/asm/mod.rs +++ b/src/asm/mod.rs @@ -144,6 +144,12 @@ extern "C" { )] pub(crate) fn x86_64_asm_invlpg(addr: u64); + #[cfg_attr( + any(target_env = "gnu", target_env = "musl"), + link_name = "_x86_64_asm_invpcid" + )] + pub(crate) fn x86_64_asm_invpcid(kind: u64, desc: u64); + #[cfg_attr( any(target_env = "gnu", target_env = "musl"), link_name = "_x86_64_asm_read_cr0" diff --git a/src/instructions/tlb.rs b/src/instructions/tlb.rs index 8972271f0..0a97e7540 100644 --- a/src/instructions/tlb.rs +++ b/src/instructions/tlb.rs @@ -23,3 +23,88 @@ pub fn flush_all() { let (frame, flags) = Cr3::read(); unsafe { Cr3::write(frame, flags) } } + +/// The Invalidate PCID Command to execute. +#[derive(Debug)] +pub enum InvPicdCommand { + /// The logical processor invalidates mappings—except global translations—for the linear address and PCID specified. + Address(VirtAddr, Pcid), + + /// The logical processor invalidates all mappings—except global translations—associated with the PCID. + Single(Pcid), + + /// The logical processor invalidates all mappings—including global translations—associated with any PCID. + All, + + /// The logical processor invalidates all mappings—except global translations—associated with any PCID. + AllExceptGlobal, +} + +/// The INVPCID descriptor comprises 128 bits and consists of a PCID and a linear address. +/// For INVPCID type 0, the processor uses the full 64 bits of the linear address even outside 64-bit mode; the linear address is not used for other INVPCID types. +#[repr(C)] +#[derive(Debug)] +struct InvpcidDescriptor { + address: u64, + pcid: u64, +} + +/// Structure of a PCID. A PCID has to be <= 4096 for x86_64. +#[repr(transparent)] +#[derive(Debug)] +pub struct Pcid(u16); + +impl Pcid { + /// Create a new PCID. Will result in a failure if the value of + /// PCID is out of expected bounds. + pub const fn new(pcid: u16) -> Result { + if pcid >= 4096 { + Err("PCID should be < 4096.") + } else { + Ok(Pcid(pcid)) + } + } + + /// Get the value of the current PCID. + pub const fn value(&self) -> u16 { + self.0 + } +} + +/// Invalidate the given address in the TLB using the `invpcid` instruction. +/// +/// ## Safety +/// This function is unsafe as it requires CPUID.(EAX=07H, ECX=0H):EBX.INVPCID to be 1. +#[inline] +pub unsafe fn flush_pcid(command: InvPicdCommand) { + let mut desc = InvpcidDescriptor { + address: 0, + pcid: 0, + }; + + let kind: u64; + match command { + InvPicdCommand::Address(addr, pcid) => { + kind = 0; + desc.pcid = pcid.value().into(); + desc.address = addr.as_u64() + } + InvPicdCommand::Single(pcid) => { + kind = 1; + desc.pcid = pcid.0.into() + } + InvPicdCommand::All => kind = 2, + InvPicdCommand::AllExceptGlobal => kind = 3, + } + + #[cfg(feature = "inline_asm")] + { + let desc_value = &desc as *const InvpcidDescriptor as u64; + asm!("invpcid {1}, [{0}]", in(reg) desc_value, in(reg) kind); + }; + + #[cfg(not(feature = "inline_asm"))] + { + crate::asm::x86_64_asm_invpcid(kind, &desc as *const InvpcidDescriptor as u64) + }; +} diff --git a/src/registers/control.rs b/src/registers/control.rs index bf62fa218..97abb027e 100644 --- a/src/registers/control.rs +++ b/src/registers/control.rs @@ -128,8 +128,7 @@ bitflags! { #[cfg(feature = "instructions")] mod x86_64 { use super::*; - use crate::structures::paging::PhysFrame; - use crate::{PhysAddr, VirtAddr}; + use crate::{instructions::tlb::Pcid, structures::paging::PhysFrame, PhysAddr, VirtAddr}; impl Cr0 { /// Read the current set of CR0 flags. @@ -233,6 +232,14 @@ mod x86_64 { /// Read the current P4 table address from the CR3 register. #[inline] pub fn read() -> (PhysFrame, Cr3Flags) { + let (frame, value) = Cr3::read_raw(); + let flags = Cr3Flags::from_bits_truncate(value.into()); + (frame, flags) + } + + /// Read the current P4 table address from the CR3 register + #[inline] + pub fn read_raw() -> (PhysFrame, u16) { let value: u64; #[cfg(feature = "inline_asm")] @@ -245,10 +252,18 @@ mod x86_64 { value = crate::asm::x86_64_asm_read_cr3(); } - let flags = Cr3Flags::from_bits_truncate(value); let addr = PhysAddr::new(value & 0x_000f_ffff_ffff_f000); let frame = PhysFrame::containing_address(addr); - (frame, flags) + (frame, (value & 0xFFF) as u16) + } + + /// Read the current P4 table address from the CR3 register along with PCID. + /// The correct functioning of this requires CR4.PCIDE = 1. + /// See [`Cr4Flags::PCID`] + #[inline] + pub fn read_pcid() -> (PhysFrame, Pcid) { + let (frame, value) = Cr3::read_raw(); + (frame, Pcid::new(value as u16).unwrap()) } /// Write a new P4 table address into the CR3 register. @@ -258,8 +273,29 @@ mod x86_64 { /// changing the page mapping. #[inline] pub unsafe fn write(frame: PhysFrame, flags: Cr3Flags) { + Cr3::write_raw(frame, flags.bits() as u16); + } + + /// Write a new P4 table address into the CR3 register. + /// + /// ## Safety + /// Changing the level 4 page table is unsafe, because it's possible to violate memory safety by + /// changing the page mapping. + /// [`Cr4Flags::PCID`] must be set before calling this method. + #[inline] + pub unsafe fn write_pcid(frame: PhysFrame, pcid: Pcid) { + Cr3::write_raw(frame, pcid.value()); + } + + /// Write a new P4 table address into the CR3 register. + /// + /// ## Safety + /// Changing the level 4 page table is unsafe, because it's possible to violate memory safety by + /// changing the page mapping. + #[inline] + unsafe fn write_raw(frame: PhysFrame, val: u16) { let addr = frame.start_address(); - let value = addr.as_u64() | flags.bits(); + let value = addr.as_u64() | val as u64; #[cfg(feature = "inline_asm")] asm!("mov cr3, {}", in(reg) value, options(nostack));