diff --git a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
index 77fc58f25a997f2bd1d7028b5751fa203cbbd33c..505e4957026ba240dc0c54df58b2daf67d265352 100644
--- a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
+++ b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
@@ -88,6 +88,9 @@ bool IsPrecise(Operation operand) {
 
 } // namespace
 
+class ASTDecompiler;
+class ExprDecompiler;
+
 class SPIRVDecompiler : public Sirit::Module {
 public:
     explicit SPIRVDecompiler(const VKDevice& device, const ShaderIR& ir, ShaderStage stage)
@@ -97,27 +100,7 @@ public:
         AddExtension("SPV_KHR_variable_pointers");
     }
 
-    void Decompile() {
-        AllocateBindings();
-        AllocateLabels();
-
-        DeclareVertex();
-        DeclareGeometry();
-        DeclareFragment();
-        DeclareRegisters();
-        DeclarePredicates();
-        DeclareLocalMemory();
-        DeclareInternalFlags();
-        DeclareInputAttributes();
-        DeclareOutputAttributes();
-        DeclareConstantBuffers();
-        DeclareGlobalBuffers();
-        DeclareSamplers();
-
-        execute_function =
-            Emit(OpFunction(t_void, spv::FunctionControlMask::Inline, TypeFunction(t_void)));
-        Emit(OpLabel());
-
+    void DecompileBranchMode() {
         const u32 first_address = ir.GetBasicBlocks().begin()->first;
         const Id loop_label = OpLabel("loop");
         const Id merge_label = OpLabel("merge");
@@ -174,6 +157,43 @@ public:
         Emit(continue_label);
         Emit(OpBranch(loop_label));
         Emit(merge_label);
+    }
+
+    void DecompileAST();
+
+    void Decompile() {
+        const bool is_fully_decompiled = ir.IsDecompiled();
+        AllocateBindings();
+        if (!is_fully_decompiled) {
+            AllocateLabels();
+        }
+
+        DeclareVertex();
+        DeclareGeometry();
+        DeclareFragment();
+        DeclareRegisters();
+        DeclarePredicates();
+        if (is_fully_decompiled) {
+            DeclareFlowVariables();
+        }
+        DeclareLocalMemory();
+        DeclareInternalFlags();
+        DeclareInputAttributes();
+        DeclareOutputAttributes();
+        DeclareConstantBuffers();
+        DeclareGlobalBuffers();
+        DeclareSamplers();
+
+        execute_function =
+            Emit(OpFunction(t_void, spv::FunctionControlMask::Inline, TypeFunction(t_void)));
+        Emit(OpLabel());
+
+        if (is_fully_decompiled) {
+            DecompileAST();
+        } else {
+            DecompileBranchMode();
+        }
+
         Emit(OpReturn());
         Emit(OpFunctionEnd());
     }
@@ -206,6 +226,9 @@ public:
     }
 
 private:
+    friend class ASTDecompiler;
+    friend class ExprDecompiler;
+
     static constexpr auto INTERNAL_FLAGS_COUNT = static_cast<std::size_t>(InternalFlag::Amount);
 
     void AllocateBindings() {
@@ -294,6 +317,14 @@ private:
         }
     }
 
+    void DeclareFlowVariables() {
+        for (u32 i = 0; i < ir.GetASTNumVariables(); i++) {
+            const Id id = OpVariable(t_prv_bool, spv::StorageClass::Private, v_false);
+            Name(id, fmt::format("flow_var_{}", static_cast<u32>(i)));
+            flow_variables.emplace(i, AddGlobalVariable(id));
+        }
+    }
+
     void DeclareLocalMemory() {
         if (const u64 local_memory_size = header.GetLocalMemorySize(); local_memory_size > 0) {
             const auto element_count = static_cast<u32>(Common::AlignUp(local_memory_size, 4) / 4);
@@ -1019,7 +1050,7 @@ private:
         return {};
     }
 
-    Id Exit(Operation operation) {
+    Id PreExit() {
         switch (stage) {
         case ShaderStage::Vertex: {
             // TODO(Rodrigo): We should use VK_EXT_depth_range_unrestricted instead, but it doesn't
@@ -1067,6 +1098,11 @@ private:
         }
         }
 
+        return {};
+    }
+
+    Id Exit(Operation operation) {
+        PreExit();
         BranchingOp([&]() { Emit(OpReturn()); });
         return {};
     }
@@ -1545,6 +1581,7 @@ private:
     Id per_vertex{};
     std::map<u32, Id> registers;
     std::map<Tegra::Shader::Pred, Id> predicates;
+    std::map<u32, Id> flow_variables;
     Id local_memory{};
     std::array<Id, INTERNAL_FLAGS_COUNT> internal_flags{};
     std::map<Attribute::Index, Id> input_attributes;
@@ -1580,6 +1617,223 @@ private:
     std::map<u32, Id> labels;
 };
 
+class ExprDecompiler {
+public:
+    ExprDecompiler(SPIRVDecompiler& decomp) : decomp{decomp} {}
+
+    void operator()(VideoCommon::Shader::ExprAnd& expr) {
+        const Id type_def = decomp.GetTypeDefinition(Type::Bool);
+        const Id op1 = Visit(expr.operand1);
+        const Id op2 = Visit(expr.operand2);
+        current_id = decomp.Emit(decomp.OpLogicalAnd(type_def, op1, op2));
+    }
+
+    void operator()(VideoCommon::Shader::ExprOr& expr) {
+        const Id type_def = decomp.GetTypeDefinition(Type::Bool);
+        const Id op1 = Visit(expr.operand1);
+        const Id op2 = Visit(expr.operand2);
+        current_id = decomp.Emit(decomp.OpLogicalOr(type_def, op1, op2));
+    }
+
+    void operator()(VideoCommon::Shader::ExprNot& expr) {
+        const Id type_def = decomp.GetTypeDefinition(Type::Bool);
+        const Id op1 = Visit(expr.operand1);
+        current_id = decomp.Emit(decomp.OpLogicalNot(type_def, op1));
+    }
+
+    void operator()(VideoCommon::Shader::ExprPredicate& expr) {
+        auto pred = static_cast<Tegra::Shader::Pred>(expr.predicate);
+        current_id = decomp.Emit(decomp.OpLoad(decomp.t_bool, decomp.predicates.at(pred)));
+    }
+
+    void operator()(VideoCommon::Shader::ExprCondCode& expr) {
+        Node cc = decomp.ir.GetConditionCode(expr.cc);
+        Id target;
+
+        if (const auto pred = std::get_if<PredicateNode>(&*cc)) {
+            const auto index = pred->GetIndex();
+            switch (index) {
+            case Tegra::Shader::Pred::NeverExecute:
+                target = decomp.v_false;
+            case Tegra::Shader::Pred::UnusedIndex:
+                target = decomp.v_true;
+            default:
+                target = decomp.predicates.at(index);
+            }
+        } else if (const auto flag = std::get_if<InternalFlagNode>(&*cc)) {
+            target = decomp.internal_flags.at(static_cast<u32>(flag->GetFlag()));
+        }
+        current_id = decomp.Emit(decomp.OpLoad(decomp.t_bool, target));
+    }
+
+    void operator()(VideoCommon::Shader::ExprVar& expr) {
+        current_id = decomp.Emit(decomp.OpLoad(decomp.t_bool, decomp.flow_variables.at(expr.var_index)));
+    }
+
+    void operator()(VideoCommon::Shader::ExprBoolean& expr) {
+        current_id = expr.value ? decomp.v_true : decomp.v_false;
+    }
+
+    Id GetResult() {
+        return current_id;
+    }
+
+    Id Visit(VideoCommon::Shader::Expr& node) {
+        std::visit(*this, *node);
+        return current_id;
+    }
+
+private:
+    Id current_id;
+    SPIRVDecompiler& decomp;
+};
+
+class ASTDecompiler {
+public:
+    ASTDecompiler(SPIRVDecompiler& decomp) : decomp{decomp} {}
+
+    void operator()(VideoCommon::Shader::ASTProgram& ast) {
+        ASTNode current = ast.nodes.GetFirst();
+        while (current) {
+            Visit(current);
+            current = current->GetNext();
+        }
+    }
+
+    void operator()(VideoCommon::Shader::ASTIfThen& ast) {
+        ExprDecompiler expr_parser{decomp};
+        const Id condition = expr_parser.Visit(ast.condition);
+        const Id then_label = decomp.OpLabel();
+        const Id endif_label = decomp.OpLabel();
+        decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone));
+        decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label));
+        decomp.Emit(then_label);
+        ASTNode current = ast.nodes.GetFirst();
+        while (current) {
+            Visit(current);
+            current = current->GetNext();
+        }
+        decomp.Emit(endif_label);
+    }
+
+    void operator()(VideoCommon::Shader::ASTIfElse& ast) {
+        UNREACHABLE();
+    }
+
+    void operator()(VideoCommon::Shader::ASTBlockEncoded& ast) {
+        UNREACHABLE();
+    }
+
+    void operator()(VideoCommon::Shader::ASTBlockDecoded& ast) {
+        decomp.VisitBasicBlock(ast.nodes);
+    }
+
+    void operator()(VideoCommon::Shader::ASTVarSet& ast) {
+        ExprDecompiler expr_parser{decomp};
+        const Id condition = expr_parser.Visit(ast.condition);
+        decomp.Emit(decomp.OpStore(decomp.flow_variables.at(ast.index), condition));
+    }
+
+    void operator()(VideoCommon::Shader::ASTLabel& ast) {
+        // Do nothing
+    }
+
+    void operator()(VideoCommon::Shader::ASTGoto& ast) {
+        UNREACHABLE();
+    }
+
+    void operator()(VideoCommon::Shader::ASTDoWhile& ast) {
+        const Id loop_label = decomp.OpLabel();
+        const Id endloop_label = decomp.OpLabel();
+        const Id loop_start_block = decomp.OpLabel();
+        const Id loop_end_block = decomp.OpLabel();
+        current_loop_exit = endloop_label;
+        decomp.Emit(loop_label);
+        decomp.Emit(decomp.OpLoopMerge(endloop_label, loop_end_block, spv::LoopControlMask::MaskNone));
+        decomp.Emit(decomp.OpBranch(loop_start_block));
+        decomp.Emit(loop_start_block);
+        ASTNode current = ast.nodes.GetFirst();
+        while (current) {
+            Visit(current);
+            current = current->GetNext();
+        }
+        decomp.Emit(decomp.OpBranch(loop_end_block));
+        decomp.Emit(loop_end_block);
+        ExprDecompiler expr_parser{decomp};
+        const Id condition = expr_parser.Visit(ast.condition);
+        decomp.Emit(decomp.OpBranchConditional(condition, loop_label, endloop_label));
+        decomp.Emit(endloop_label);
+    }
+
+    void operator()(VideoCommon::Shader::ASTReturn& ast) {
+        bool is_true = VideoCommon::Shader::ExprIsTrue(ast.condition);
+        if (!is_true) {
+            ExprDecompiler expr_parser{decomp};
+            const Id condition = expr_parser.Visit(ast.condition);
+            const Id then_label = decomp.OpLabel();
+            const Id endif_label = decomp.OpLabel();
+            decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone));
+            decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label));
+            decomp.Emit(then_label);
+            if (ast.kills) {
+                decomp.Emit(decomp.OpKill());
+            } else {
+                decomp.PreExit();
+                decomp.Emit(decomp.OpReturn());
+            }
+            decomp.Emit(endif_label);
+        } else {
+            decomp.Emit(decomp.OpLabel());
+            if (ast.kills) {
+                decomp.Emit(decomp.OpKill());
+            } else {
+                decomp.PreExit();
+                decomp.Emit(decomp.OpReturn());
+            }
+            decomp.Emit(decomp.OpLabel());
+        }
+    }
+
+    void operator()(VideoCommon::Shader::ASTBreak& ast) {
+        bool is_true = VideoCommon::Shader::ExprIsTrue(ast.condition);
+        if (!is_true) {
+            ExprDecompiler expr_parser{decomp};
+            const Id condition = expr_parser.Visit(ast.condition);
+            const Id then_label = decomp.OpLabel();
+            const Id endif_label = decomp.OpLabel();
+            decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone));
+            decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label));
+            decomp.Emit(then_label);
+            decomp.Emit(decomp.OpBranch(current_loop_exit));
+            decomp.Emit(endif_label);
+        } else {
+            decomp.Emit(decomp.OpLabel());
+            decomp.Emit(decomp.OpBranch(current_loop_exit));
+            decomp.Emit(decomp.OpLabel());
+        }
+    }
+
+    void Visit(VideoCommon::Shader::ASTNode& node) {
+        std::visit(*this, *node->GetInnerData());
+    }
+
+private:
+    SPIRVDecompiler& decomp;
+    Id current_loop_exit;
+};
+
+void SPIRVDecompiler::DecompileAST() {
+    u32 num_flow_variables = ir.GetASTNumVariables();
+    for (u32 i = 0; i < num_flow_variables; i++) {
+        const Id id = OpVariable(t_prv_bool, spv::StorageClass::Private, v_false);
+        Name(id, fmt::format("flow_var_{}", i));
+        flow_variables.emplace(i, AddGlobalVariable(id));
+    }
+    ASTDecompiler decompiler{*this};
+    VideoCommon::Shader::ASTNode program = ir.GetASTProgram();
+    decompiler.Visit(program);
+}
+
 DecompilerResult Decompile(const VKDevice& device, const VideoCommon::Shader::ShaderIR& ir,
                            Maxwell::ShaderStage stage) {
     auto decompiler = std::make_unique<SPIRVDecompiler>(device, ir, stage);
diff --git a/src/video_core/shader/ast.h b/src/video_core/shader/ast.h
index 07deb58e4ddc1cf253c3d5a630cec140644b382a..12db336df257871545e6a68581e7139a593be22e 100644
--- a/src/video_core/shader/ast.h
+++ b/src/video_core/shader/ast.h
@@ -205,13 +205,29 @@ public:
         return nullptr;
     }
 
-    void MarkLabelUnused() const {
+    void MarkLabelUnused() {
         auto inner = std::get_if<ASTLabel>(&data);
         if (inner) {
             inner->unused = true;
         }
     }
 
+    bool IsLabelUnused() const {
+        auto inner = std::get_if<ASTLabel>(&data);
+        if (inner) {
+            return inner->unused;
+        }
+        return true;
+    }
+
+    u32 GetLabelIndex() const {
+        auto inner = std::get_if<ASTLabel>(&data);
+        if (inner) {
+            return inner->index;
+        }
+        return -1;
+    }
+
     Expr GetIfCondition() const {
         auto inner = std::get_if<ASTIfThen>(&data);
         if (inner) {
@@ -336,6 +352,10 @@ public:
         return variables;
     }
 
+    const std::vector<ASTNode>& GetLabels() const {
+        return labels;
+    }
+
 private:
     bool IsBackwardsJump(ASTNode goto_node, ASTNode label_node) const;
 
diff --git a/src/video_core/shader/shader_ir.h b/src/video_core/shader/shader_ir.h
index 7a91c9bb67db07e9ce144e437cc08baacb523ba6..105981d67a92f308ae6919cf0d5338d212df7cb7 100644
--- a/src/video_core/shader/shader_ir.h
+++ b/src/video_core/shader/shader_ir.h
@@ -151,6 +151,10 @@ public:
         return decompiled;
     }
 
+    const ASTManager& GetASTManager() const {
+        return program_manager;
+    }
+
     ASTNode GetASTProgram() const {
         return program_manager.GetProgram();
     }