ryujinx/src/Ryujinx.Cpu/LightningJit/CodeGen/Arm64/RegisterSaveRestore.cs
gdkchan 427b7d06b5
Implement a new JIT for Arm devices (#6057)
* Implement a new JIT for Arm devices

* Auto-format

* Make a lot of Assembler members read-only

* More read-only

* Fix more warnings

* ObjectDisposedException.ThrowIf

* New JIT cache for platforms that enforce W^X, currently unused

* Remove unused using

* Fix assert

* Pass memory manager type around

* Safe memory manager mode support + other improvements

* Actual safe memory manager mode masking support

* PR feedback
2024-01-20 11:11:28 -03:00

253 lines
9.3 KiB
C#

using System.Numerics;
namespace Ryujinx.Cpu.LightningJit.CodeGen.Arm64
{
readonly struct RegisterSaveRestore
{
private const int FpRegister = 29;
private const int LrRegister = 30;
public const int Encodable9BitsOffsetLimit = 0x100;
private readonly uint _gprMask;
private readonly uint _fpSimdMask;
private readonly OperandType _fpSimdType;
private readonly int _reservedStackSize;
private readonly bool _hasCall;
public RegisterSaveRestore(
uint gprMask,
uint fpSimdMask = 0,
OperandType fpSimdType = OperandType.FP64,
bool hasCall = false,
int reservedStackSize = 0)
{
_gprMask = gprMask;
_fpSimdMask = fpSimdMask;
_fpSimdType = fpSimdType;
_reservedStackSize = reservedStackSize;
_hasCall = hasCall;
}
public int GetReservedStackOffset()
{
int gprCalleeSavedRegsCount = BitOperations.PopCount(_gprMask);
int fpSimdCalleeSavedRegsCount = BitOperations.PopCount(_fpSimdMask);
return (_hasCall ? 16 : 0) + Align16(gprCalleeSavedRegsCount * 8 + fpSimdCalleeSavedRegsCount * _fpSimdType.GetSizeInBytes());
}
public void WritePrologue(ref Assembler asm)
{
uint gprMask = _gprMask;
uint fpSimdMask = _fpSimdMask;
int gprCalleeSavedRegsCount = BitOperations.PopCount(gprMask);
int fpSimdCalleeSavedRegsCount = BitOperations.PopCount(fpSimdMask);
int reservedStackSize = Align16(_reservedStackSize);
int calleeSaveRegionSize = Align16(gprCalleeSavedRegsCount * 8 + fpSimdCalleeSavedRegsCount * _fpSimdType.GetSizeInBytes()) + reservedStackSize;
int offset = 0;
WritePrologueCalleeSavesPreIndexed(ref asm, ref gprMask, ref offset, calleeSaveRegionSize, OperandType.I64);
if (_fpSimdType == OperandType.V128 && (gprCalleeSavedRegsCount & 1) != 0)
{
offset += 8;
}
WritePrologueCalleeSavesPreIndexed(ref asm, ref fpSimdMask, ref offset, calleeSaveRegionSize, _fpSimdType);
if (_hasCall)
{
Operand rsp = Register(Assembler.SpRegister);
if (offset != 0 || calleeSaveRegionSize + 16 < Encodable9BitsOffsetLimit)
{
asm.StpRiPre(Register(FpRegister), Register(LrRegister), rsp, offset == 0 ? -(calleeSaveRegionSize + 16) : -16);
}
else
{
asm.Sub(rsp, rsp, new Operand(OperandKind.Constant, OperandType.I64, (ulong)calleeSaveRegionSize));
asm.StpRiPre(Register(FpRegister), Register(LrRegister), rsp, -16);
}
asm.MovSp(Register(FpRegister), rsp);
}
}
private static void WritePrologueCalleeSavesPreIndexed(
ref Assembler asm,
ref uint mask,
ref int offset,
int calleeSaveRegionSize,
OperandType type)
{
if ((BitOperations.PopCount(mask) & 1) != 0)
{
int reg = BitOperations.TrailingZeroCount(mask);
mask &= ~(1u << reg);
if (offset != 0)
{
asm.StrRiUn(Register(reg, type), Register(Assembler.SpRegister), offset);
}
else if (calleeSaveRegionSize < Encodable9BitsOffsetLimit)
{
asm.StrRiPre(Register(reg, type), Register(Assembler.SpRegister), -calleeSaveRegionSize);
}
else
{
asm.Sub(Register(Assembler.SpRegister), Register(Assembler.SpRegister), new Operand(OperandType.I64, (ulong)calleeSaveRegionSize));
asm.StrRiUn(Register(reg, type), Register(Assembler.SpRegister), 0);
}
offset += type.GetSizeInBytes();
}
while (mask != 0)
{
int reg = BitOperations.TrailingZeroCount(mask);
mask &= ~(1u << reg);
int reg2 = BitOperations.TrailingZeroCount(mask);
mask &= ~(1u << reg2);
if (offset != 0)
{
asm.StpRiUn(Register(reg, type), Register(reg2, type), Register(Assembler.SpRegister), offset);
}
else if (calleeSaveRegionSize < Encodable9BitsOffsetLimit)
{
asm.StpRiPre(Register(reg, type), Register(reg2, type), Register(Assembler.SpRegister), -calleeSaveRegionSize);
}
else
{
asm.Sub(Register(Assembler.SpRegister), Register(Assembler.SpRegister), new Operand(OperandType.I64, (ulong)calleeSaveRegionSize));
asm.StpRiUn(Register(reg, type), Register(reg2, type), Register(Assembler.SpRegister), 0);
}
offset += type.GetSizeInBytes() * 2;
}
}
public void WriteEpilogue(ref Assembler asm)
{
uint gprMask = _gprMask;
uint fpSimdMask = _fpSimdMask;
int gprCalleeSavedRegsCount = BitOperations.PopCount(gprMask);
int fpSimdCalleeSavedRegsCount = BitOperations.PopCount(fpSimdMask);
bool misalignedVector = _fpSimdType == OperandType.V128 && (gprCalleeSavedRegsCount & 1) != 0;
int offset = gprCalleeSavedRegsCount * 8 + fpSimdCalleeSavedRegsCount * _fpSimdType.GetSizeInBytes();
if (misalignedVector)
{
offset += 8;
}
int calleeSaveRegionSize = Align16(offset) + Align16(_reservedStackSize);
if (_hasCall)
{
Operand rsp = Register(Assembler.SpRegister);
if (offset != 0 || calleeSaveRegionSize + 16 < Encodable9BitsOffsetLimit)
{
asm.LdpRiPost(Register(FpRegister), Register(LrRegister), rsp, offset == 0 ? calleeSaveRegionSize + 16 : 16);
}
else
{
asm.LdpRiPost(Register(FpRegister), Register(LrRegister), rsp, 16);
asm.Add(rsp, rsp, new Operand(OperandKind.Constant, OperandType.I64, (ulong)calleeSaveRegionSize));
}
}
WriteEpilogueCalleeSavesPostIndexed(ref asm, ref fpSimdMask, ref offset, calleeSaveRegionSize, _fpSimdType);
if (misalignedVector)
{
offset -= 8;
}
WriteEpilogueCalleeSavesPostIndexed(ref asm, ref gprMask, ref offset, calleeSaveRegionSize, OperandType.I64);
}
private static void WriteEpilogueCalleeSavesPostIndexed(
ref Assembler asm,
ref uint mask,
ref int offset,
int calleeSaveRegionSize,
OperandType type)
{
while (mask != 0)
{
int reg = HighestBitSet(mask);
mask &= ~(1u << reg);
if (mask != 0)
{
int reg2 = HighestBitSet(mask);
mask &= ~(1u << reg2);
offset -= type.GetSizeInBytes() * 2;
if (offset != 0)
{
asm.LdpRiUn(Register(reg2, type), Register(reg, type), Register(Assembler.SpRegister), offset);
}
else if (calleeSaveRegionSize < Encodable9BitsOffsetLimit)
{
asm.LdpRiPost(Register(reg2, type), Register(reg, type), Register(Assembler.SpRegister), calleeSaveRegionSize);
}
else
{
asm.LdpRiUn(Register(reg2, type), Register(reg, type), Register(Assembler.SpRegister), 0);
asm.Add(Register(Assembler.SpRegister), Register(Assembler.SpRegister), new Operand(OperandType.I64, (ulong)calleeSaveRegionSize));
}
}
else
{
offset -= type.GetSizeInBytes();
if (offset != 0)
{
asm.LdrRiUn(Register(reg, type), Register(Assembler.SpRegister), offset);
}
else if (calleeSaveRegionSize < Encodable9BitsOffsetLimit)
{
asm.LdrRiPost(Register(reg, type), Register(Assembler.SpRegister), calleeSaveRegionSize);
}
else
{
asm.LdrRiUn(Register(reg, type), Register(Assembler.SpRegister), 0);
asm.Add(Register(Assembler.SpRegister), Register(Assembler.SpRegister), new Operand(OperandType.I64, (ulong)calleeSaveRegionSize));
}
}
}
}
private static int HighestBitSet(uint value)
{
return 31 - BitOperations.LeadingZeroCount(value);
}
private static Operand Register(int register, OperandType type = OperandType.I64)
{
return new Operand(register, RegisterType.Integer, type);
}
private static int Align16(int value)
{
return (value + 0xf) & ~0xf;
}
}
}