yuzu/src/shader_recompiler/frontend/maxwell/control_flow.cpp

532 lines
19 KiB
C++
Raw Normal View History

2021-01-08 22:30:07 -08:00
// Copyright 2021 yuzu Emulator Project
// Licensed under GPLv2 or any later version
// Refer to the license.txt file included.
#include <algorithm>
#include <array>
#include <optional>
#include <ranges>
#include <string>
#include <utility>
#include <fmt/format.h>
#include "shader_recompiler/exception.h"
#include "shader_recompiler/frontend/maxwell/control_flow.h"
#include "shader_recompiler/frontend/maxwell/decode.h"
#include "shader_recompiler/frontend/maxwell/location.h"
namespace Shader::Maxwell::Flow {
static u32 BranchOffset(Location pc, Instruction inst) {
return pc.Offset() + inst.branch.Offset() + 8;
}
static std::array<Block, 2> Split(Block&& block, Location pc, BlockId new_id) {
if (pc <= block.begin || pc >= block.end) {
throw InvalidArgument("Invalid address to split={}", pc);
}
return {
Block{
.begin{block.begin},
.end{pc},
.end_class{EndClass::Branch},
.id{block.id},
.stack{block.stack},
.cond{true},
.branch_true{new_id},
.branch_false{UNREACHABLE_BLOCK_ID},
},
Block{
.begin{pc},
.end{block.end},
.end_class{block.end_class},
.id{new_id},
.stack{std::move(block.stack)},
.cond{block.cond},
.branch_true{block.branch_true},
.branch_false{block.branch_false},
},
};
}
static Token OpcodeToken(Opcode opcode) {
switch (opcode) {
case Opcode::PBK:
case Opcode::BRK:
return Token::PBK;
case Opcode::PCNT:
case Opcode::CONT:
return Token::PBK;
case Opcode::PEXIT:
case Opcode::EXIT:
return Token::PEXIT;
case Opcode::PLONGJMP:
case Opcode::LONGJMP:
return Token::PLONGJMP;
case Opcode::PRET:
case Opcode::RET:
case Opcode::CAL:
return Token::PRET;
case Opcode::SSY:
case Opcode::SYNC:
return Token::SSY;
default:
throw InvalidArgument("{}", opcode);
}
}
static bool IsAbsoluteJump(Opcode opcode) {
switch (opcode) {
case Opcode::JCAL:
case Opcode::JMP:
case Opcode::JMX:
return true;
default:
return false;
}
}
static bool HasFlowTest(Opcode opcode) {
switch (opcode) {
case Opcode::BRA:
case Opcode::BRX:
case Opcode::EXIT:
case Opcode::JMP:
case Opcode::JMX:
case Opcode::BRK:
case Opcode::CONT:
case Opcode::LONGJMP:
case Opcode::RET:
case Opcode::SYNC:
return true;
case Opcode::CAL:
case Opcode::JCAL:
return false;
default:
throw InvalidArgument("Invalid branch {}", opcode);
}
}
static std::string Name(const Block& block) {
if (block.begin.IsVirtual()) {
return fmt::format("\"Virtual {}\"", block.id);
} else {
return fmt::format("\"{}\"", block.begin);
}
}
void Stack::Push(Token token, Location target) {
entries.push_back({
.token{token},
.target{target},
});
}
std::pair<Location, Stack> Stack::Pop(Token token) const {
const std::optional<Location> pc{Peek(token)};
if (!pc) {
throw LogicError("Token could not be found");
}
return {*pc, Remove(token)};
}
std::optional<Location> Stack::Peek(Token token) const {
const auto reverse_entries{entries | std::views::reverse};
const auto it{std::ranges::find(reverse_entries, token, &StackEntry::token)};
if (it == reverse_entries.end()) {
return std::nullopt;
}
return it->target;
}
Stack Stack::Remove(Token token) const {
const auto reverse_entries{entries | std::views::reverse};
const auto it{std::ranges::find(reverse_entries, token, &StackEntry::token)};
const auto pos{std::distance(reverse_entries.begin(), it)};
Stack result;
result.entries.insert(result.entries.end(), entries.begin(), entries.end() - pos - 1);
return result;
}
bool Block::Contains(Location pc) const noexcept {
return pc >= begin && pc < end;
}
Function::Function(Location start_address)
: entrypoint{start_address}, labels{Label{
.address{start_address},
.block_id{0},
.stack{},
}} {}
CFG::CFG(Environment& env_, Location start_address) : env{env_} {
functions.emplace_back(start_address);
for (FunctionId function_id = 0; function_id < functions.size(); ++function_id) {
while (!functions[function_id].labels.empty()) {
Function& function{functions[function_id]};
Label label{function.labels.back()};
function.labels.pop_back();
AnalyzeLabel(function_id, label);
}
}
}
void CFG::AnalyzeLabel(FunctionId function_id, Label& label) {
if (InspectVisitedBlocks(function_id, label)) {
// Label address has been visited
return;
}
// Try to find the next block
Function* function{&functions[function_id]};
Location pc{label.address};
const auto next{std::upper_bound(function->blocks.begin(), function->blocks.end(), pc,
[function](Location pc, u32 block_index) {
return pc < function->blocks_data[block_index].begin;
})};
const auto next_index{std::distance(function->blocks.begin(), next)};
const bool is_last{next == function->blocks.end()};
Location next_pc;
BlockId next_id{UNREACHABLE_BLOCK_ID};
if (!is_last) {
next_pc = function->blocks_data[*next].begin;
next_id = function->blocks_data[*next].id;
}
// Insert before the next block
Block block{
.begin{pc},
.end{pc},
.end_class{EndClass::Branch},
.id{label.block_id},
.stack{std::move(label.stack)},
.cond{true},
.branch_true{UNREACHABLE_BLOCK_ID},
.branch_false{UNREACHABLE_BLOCK_ID},
};
// Analyze instructions until it reaches an already visited block or there's a branch
bool is_branch{false};
while (is_last || pc < next_pc) {
is_branch = AnalyzeInst(block, function_id, pc) == AnalysisState::Branch;
if (is_branch) {
break;
}
++pc;
}
if (!is_branch) {
// If the block finished without a branch,
// it means that the next instruction is already visited, jump to it
block.end = pc;
block.cond = true;
block.branch_true = next_id;
block.branch_false = UNREACHABLE_BLOCK_ID;
}
// Function's pointer might be invalid, resolve it again
function = &functions[function_id];
const u32 new_block_index = static_cast<u32>(function->blocks_data.size());
function->blocks.insert(function->blocks.begin() + next_index, new_block_index);
function->blocks_data.push_back(std::move(block));
}
bool CFG::InspectVisitedBlocks(FunctionId function_id, const Label& label) {
const Location pc{label.address};
Function& function{functions[function_id]};
const auto it{std::ranges::find_if(function.blocks, [&function, pc](u32 block_index) {
return function.blocks_data[block_index].Contains(pc);
})};
if (it == function.blocks.end()) {
// Address has not been visited
return false;
}
Block& block{function.blocks_data[*it]};
if (block.begin == pc) {
throw LogicError("Dangling branch");
}
const u32 first_index{*it};
const u32 second_index{static_cast<u32>(function.blocks_data.size())};
const std::array new_indices{first_index, second_index};
std::array split_blocks{Split(std::move(block), pc, label.block_id)};
function.blocks_data[*it] = std::move(split_blocks[0]);
function.blocks_data.push_back(std::move(split_blocks[1]));
function.blocks.insert(function.blocks.erase(it), new_indices.begin(), new_indices.end());
return true;
}
CFG::AnalysisState CFG::AnalyzeInst(Block& block, FunctionId function_id, Location pc) {
const Instruction inst{env.ReadInstruction(pc.Offset())};
const Opcode opcode{Decode(inst.raw)};
switch (opcode) {
case Opcode::BRA:
case Opcode::BRX:
case Opcode::JMP:
case Opcode::JMX:
case Opcode::RET:
if (!AnalyzeBranch(block, function_id, pc, inst, opcode)) {
return AnalysisState::Continue;
}
switch (opcode) {
case Opcode::BRA:
case Opcode::JMP:
AnalyzeBRA(block, function_id, pc, inst, IsAbsoluteJump(opcode));
break;
case Opcode::BRX:
case Opcode::JMX:
AnalyzeBRX(block, pc, inst, IsAbsoluteJump(opcode));
break;
case Opcode::RET:
block.end_class = EndClass::Return;
break;
default:
break;
}
block.end = pc;
return AnalysisState::Branch;
case Opcode::BRK:
case Opcode::CONT:
case Opcode::LONGJMP:
case Opcode::SYNC: {
if (!AnalyzeBranch(block, function_id, pc, inst, opcode)) {
return AnalysisState::Continue;
}
const auto [stack_pc, new_stack]{block.stack.Pop(OpcodeToken(opcode))};
block.branch_true = AddLabel(block, new_stack, stack_pc, function_id);
block.end = pc;
return AnalysisState::Branch;
}
case Opcode::PBK:
case Opcode::PCNT:
case Opcode::PEXIT:
case Opcode::PLONGJMP:
case Opcode::SSY:
block.stack.Push(OpcodeToken(opcode), BranchOffset(pc, inst));
return AnalysisState::Continue;
case Opcode::EXIT:
return AnalyzeEXIT(block, function_id, pc, inst);
case Opcode::PRET:
throw NotImplementedException("PRET flow analysis");
case Opcode::CAL:
case Opcode::JCAL: {
const bool is_absolute{IsAbsoluteJump(opcode)};
const Location cal_pc{is_absolute ? inst.branch.Absolute() : BranchOffset(pc, inst)};
// Technically CAL pushes into PRET, but that's implicit in the function call for us
// Insert the function into the list if it doesn't exist
if (std::ranges::find(functions, cal_pc, &Function::entrypoint) == functions.end()) {
functions.push_back(cal_pc);
}
// Handle CAL like a regular instruction
break;
}
default:
break;
}
const Predicate pred{inst.Pred()};
if (pred == Predicate{true} || pred == Predicate{false}) {
return AnalysisState::Continue;
}
const IR::Condition cond{static_cast<IR::Pred>(pred.index), pred.negated};
AnalyzeCondInst(block, function_id, pc, EndClass::Branch, cond);
return AnalysisState::Branch;
}
void CFG::AnalyzeCondInst(Block& block, FunctionId function_id, Location pc,
EndClass insn_end_class, IR::Condition cond) {
if (block.begin != pc) {
// If the block doesn't start in the conditional instruction
// mark it as a label to visit it later
block.end = pc;
block.cond = true;
block.branch_true = AddLabel(block, block.stack, pc, function_id);
block.branch_false = UNREACHABLE_BLOCK_ID;
return;
}
// Impersonate the visited block with a virtual block
// Jump from this virtual to the real conditional instruction and the next instruction
Function& function{functions[function_id]};
const BlockId conditional_block_id{++function.current_block_id};
function.blocks.push_back(static_cast<u32>(function.blocks_data.size()));
Block& virtual_block{function.blocks_data.emplace_back(Block{
.begin{}, // Virtual block
.end{},
.end_class{EndClass::Branch},
.id{block.id}, // Impersonating
.stack{block.stack},
.cond{cond},
.branch_true{conditional_block_id},
.branch_false{UNREACHABLE_BLOCK_ID},
})};
// Set the end properties of the conditional instruction and give it a new identity
Block& conditional_block{block};
conditional_block.end = pc;
conditional_block.end_class = insn_end_class;
conditional_block.id = conditional_block_id;
// Add a label to the instruction after the conditional instruction
const BlockId endif_block_id{AddLabel(conditional_block, block.stack, pc + 1, function_id)};
// Branch to the next instruction from the virtual block
virtual_block.branch_false = endif_block_id;
// And branch to it from the conditional instruction if it is a branch
if (insn_end_class == EndClass::Branch) {
conditional_block.cond = true;
conditional_block.branch_true = endif_block_id;
conditional_block.branch_false = UNREACHABLE_BLOCK_ID;
}
}
bool CFG::AnalyzeBranch(Block& block, FunctionId function_id, Location pc, Instruction inst,
Opcode opcode) {
if (inst.branch.is_cbuf) {
throw NotImplementedException("Branch with constant buffer offset");
}
const Predicate pred{inst.Pred()};
if (pred == Predicate{false}) {
return false;
}
const bool has_flow_test{HasFlowTest(opcode)};
const IR::FlowTest flow_test{has_flow_test ? inst.branch.flow_test.Value() : IR::FlowTest::T};
if (pred != Predicate{true} || flow_test != IR::FlowTest::T) {
block.cond = IR::Condition(flow_test, static_cast<IR::Pred>(pred.index), pred.negated);
block.branch_false = AddLabel(block, block.stack, pc + 1, function_id);
} else {
block.cond = true;
}
return true;
}
void CFG::AnalyzeBRA(Block& block, FunctionId function_id, Location pc, Instruction inst,
bool is_absolute) {
const Location bra_pc{is_absolute ? inst.branch.Absolute() : BranchOffset(pc, inst)};
block.branch_true = AddLabel(block, block.stack, bra_pc, function_id);
}
void CFG::AnalyzeBRX(Block&, Location, Instruction, bool is_absolute) {
throw NotImplementedException("{}", is_absolute ? "JMX" : "BRX");
}
void CFG::AnalyzeCAL(Location pc, Instruction inst, bool is_absolute) {
const Location cal_pc{is_absolute ? inst.branch.Absolute() : BranchOffset(pc, inst)};
// Technically CAL pushes into PRET, but that's implicit in the function call for us
// Insert the function to the function list if it doesn't exist
const auto it{std::ranges::find(functions, cal_pc, &Function::entrypoint)};
if (it == functions.end()) {
functions.emplace_back(cal_pc);
}
}
CFG::AnalysisState CFG::AnalyzeEXIT(Block& block, FunctionId function_id, Location pc,
Instruction inst) {
const IR::FlowTest flow_test{inst.branch.flow_test};
const Predicate pred{inst.Pred()};
if (pred == Predicate{false} || flow_test == IR::FlowTest::F) {
// EXIT will never be taken
return AnalysisState::Continue;
}
if (pred != Predicate{true} || flow_test != IR::FlowTest::T) {
if (block.stack.Peek(Token::PEXIT).has_value()) {
throw NotImplementedException("Conditional EXIT with PEXIT token");
}
const IR::Condition cond{flow_test, static_cast<IR::Pred>(pred.index), pred.negated};
AnalyzeCondInst(block, function_id, pc, EndClass::Exit, cond);
return AnalysisState::Branch;
}
if (const std::optional<Location> exit_pc{block.stack.Peek(Token::PEXIT)}) {
const Stack popped_stack{block.stack.Remove(Token::PEXIT)};
block.cond = true;
block.branch_true = AddLabel(block, popped_stack, *exit_pc, function_id);
block.branch_false = UNREACHABLE_BLOCK_ID;
return AnalysisState::Branch;
}
block.end = pc;
block.end_class = EndClass::Exit;
return AnalysisState::Branch;
}
BlockId CFG::AddLabel(const Block& block, Stack stack, Location pc, FunctionId function_id) {
Function& function{functions[function_id]};
if (block.begin == pc) {
return block.id;
}
const auto target{std::ranges::find(function.blocks_data, pc, &Block::begin)};
if (target != function.blocks_data.end()) {
return target->id;
}
const BlockId block_id{++function.current_block_id};
function.labels.push_back(Label{
.address{pc},
.block_id{block_id},
.stack{std::move(stack)},
});
return block_id;
}
std::string CFG::Dot() const {
int node_uid{0};
std::string dot{"digraph shader {\n"};
for (const Function& function : functions) {
dot += fmt::format("\tsubgraph cluster_{} {{\n", function.entrypoint);
dot += fmt::format("\t\tnode [style=filled];\n");
for (const u32 block_index : function.blocks) {
const Block& block{function.blocks_data[block_index]};
const std::string name{Name(block)};
const auto add_branch = [&](BlockId branch_id, bool add_label) {
const auto it{std::ranges::find(function.blocks_data, branch_id, &Block::id)};
dot += fmt::format("\t\t{}->", name);
if (it == function.blocks_data.end()) {
dot += fmt::format("\"Unknown label {}\"", branch_id);
} else {
dot += Name(*it);
};
if (add_label && block.cond != true && block.cond != false) {
dot += fmt::format(" [label=\"{}\"]", block.cond);
}
dot += '\n';
};
dot += fmt::format("\t\t{};\n", name);
switch (block.end_class) {
case EndClass::Branch:
if (block.cond != false) {
add_branch(block.branch_true, true);
}
if (block.cond != true) {
add_branch(block.branch_false, false);
}
break;
case EndClass::Exit:
dot += fmt::format("\t\t{}->N{};\n", name, node_uid);
dot += fmt::format("\t\tN{} [label=\"Exit\"][shape=square][style=stripped];\n",
node_uid);
++node_uid;
break;
case EndClass::Return:
dot += fmt::format("\t\t{}->N{};\n", name, node_uid);
dot += fmt::format("\t\tN{} [label=\"Return\"][shape=square][style=stripped];\n",
node_uid);
++node_uid;
break;
case EndClass::Unreachable:
dot += fmt::format("\t\t{}->N{};\n", name, node_uid);
dot += fmt::format(
"\t\tN{} [label=\"Unreachable\"][shape=square][style=stripped];\n", node_uid);
++node_uid;
break;
}
}
if (function.entrypoint == 8) {
dot += fmt::format("\t\tlabel = \"main\";\n");
} else {
dot += fmt::format("\t\tlabel = \"Function {}\";\n", function.entrypoint);
}
dot += "\t}\n";
}
if (!functions.empty()) {
if (functions.front().blocks.empty()) {
dot += "Start;\n";
} else {
dot += fmt::format("\tStart -> {};\n", Name(functions.front().blocks_data.front()));
}
dot += fmt::format("\tStart [shape=diamond];\n");
}
dot += "}\n";
return dot;
}
} // namespace Shader::Maxwell::Flow