From e6a7723f2f4b62279cd4f6d4b48eb02a9b60ffb6 Mon Sep 17 00:00:00 2001
From: Subv <subv2112@gmail.com>
Date: Sun, 1 Jan 2017 16:53:22 -0500
Subject: [PATCH] Kernel: Object ShouldWait and Acquire calls now take a thread
 as a parameter.

This will be useful when implementing mutex priority inheritance.
---
 src/core/hle/kernel/event.cpp          |  6 +++---
 src/core/hle/kernel/event.h            |  4 ++--
 src/core/hle/kernel/kernel.cpp         | 14 ++++++--------
 src/core/hle/kernel/kernel.h           |  9 +++++----
 src/core/hle/kernel/mutex.cpp          | 24 ++++++------------------
 src/core/hle/kernel/mutex.h            |  5 +++--
 src/core/hle/kernel/semaphore.cpp      |  6 +++---
 src/core/hle/kernel/semaphore.h        |  4 ++--
 src/core/hle/kernel/server_port.cpp    |  6 +++---
 src/core/hle/kernel/server_port.h      |  4 ++--
 src/core/hle/kernel/server_session.cpp |  6 +++---
 src/core/hle/kernel/server_session.h   |  4 ++--
 src/core/hle/kernel/thread.cpp         |  6 +++---
 src/core/hle/kernel/thread.h           |  4 ++--
 src/core/hle/kernel/timer.cpp          |  6 +++---
 src/core/hle/kernel/timer.h            |  4 ++--
 src/core/hle/svc.cpp                   | 12 ++++++------
 17 files changed, 56 insertions(+), 68 deletions(-)

diff --git a/src/core/hle/kernel/event.cpp b/src/core/hle/kernel/event.cpp
index 3e116e3dfc..e1f42af051 100644
--- a/src/core/hle/kernel/event.cpp
+++ b/src/core/hle/kernel/event.cpp
@@ -30,12 +30,12 @@ SharedPtr<Event> Event::Create(ResetType reset_type, std::string name) {
     return evt;
 }
 
-bool Event::ShouldWait() {
+bool Event::ShouldWait(Thread* thread) const {
     return !signaled;
 }
 
-void Event::Acquire() {
-    ASSERT_MSG(!ShouldWait(), "object unavailable!");
+void Event::Acquire(Thread* thread) {
+    ASSERT_MSG(!ShouldWait(thread), "object unavailable!");
 
     // Release the event if it's not sticky...
     if (reset_type != ResetType::Sticky)
diff --git a/src/core/hle/kernel/event.h b/src/core/hle/kernel/event.h
index 8dcd23edb8..39452bf338 100644
--- a/src/core/hle/kernel/event.h
+++ b/src/core/hle/kernel/event.h
@@ -35,8 +35,8 @@ public:
     bool signaled;    ///< Whether the event has already been signaled
     std::string name; ///< Name of event (optional)
 
-    bool ShouldWait() override;
-    void Acquire() override;
+    bool ShouldWait(Thread* thread) const override;
+    void Acquire(Thread* thread) override;
 
     void Signal();
     void Clear();
diff --git a/src/core/hle/kernel/kernel.cpp b/src/core/hle/kernel/kernel.cpp
index 1db8e102f0..ef9dbafa56 100644
--- a/src/core/hle/kernel/kernel.cpp
+++ b/src/core/hle/kernel/kernel.cpp
@@ -39,11 +39,6 @@ SharedPtr<Thread> WaitObject::GetHighestPriorityReadyThread() {
                thread->status == THREADSTATUS_DEAD;
     });
 
-    // TODO(Subv): This call should be performed inside the loop below to check if an object can be
-    // acquired by a particular thread. This is useful for things like recursive locking of Mutexes.
-    if (ShouldWait())
-        return nullptr;
-
     Thread* candidate = nullptr;
     s32 candidate_priority = THREADPRIO_LOWEST + 1;
 
@@ -51,9 +46,12 @@ SharedPtr<Thread> WaitObject::GetHighestPriorityReadyThread() {
         if (thread->current_priority >= candidate_priority)
             continue;
 
+        if (ShouldWait(thread.get()))
+            continue;
+
         bool ready_to_run =
             std::none_of(thread->wait_objects.begin(), thread->wait_objects.end(),
-                         [](const SharedPtr<WaitObject>& object) { return object->ShouldWait(); });
+                         [&thread](const SharedPtr<WaitObject>& object) { return object->ShouldWait(thread.get()); });
         if (ready_to_run) {
             candidate = thread.get();
             candidate_priority = thread->current_priority;
@@ -66,7 +64,7 @@ SharedPtr<Thread> WaitObject::GetHighestPriorityReadyThread() {
 void WaitObject::WakeupAllWaitingThreads() {
     while (auto thread = GetHighestPriorityReadyThread()) {
         if (!thread->IsSleepingOnWaitAll()) {
-            Acquire();
+            Acquire(thread.get());
             // Set the output index of the WaitSynchronizationN call to the index of this object.
             if (thread->wait_set_output) {
                 thread->SetWaitSynchronizationOutput(thread->GetWaitObjectIndex(this));
@@ -74,7 +72,7 @@ void WaitObject::WakeupAllWaitingThreads() {
             }
         } else {
             for (auto& object : thread->wait_objects) {
-                object->Acquire();
+                object->Acquire(thread.get());
                 object->RemoveWaitingThread(thread.get());
             }
             // Note: This case doesn't update the output index of WaitSynchronizationN.
diff --git a/src/core/hle/kernel/kernel.h b/src/core/hle/kernel/kernel.h
index 9503e7d044..67eae93f27 100644
--- a/src/core/hle/kernel/kernel.h
+++ b/src/core/hle/kernel/kernel.h
@@ -132,13 +132,14 @@ using SharedPtr = boost::intrusive_ptr<T>;
 class WaitObject : public Object {
 public:
     /**
-     * Check if the current thread should wait until the object is available
+     * Check if the specified thread should wait until the object is available
+     * @param thread The thread about which we're deciding.
      * @return True if the current thread should wait due to this object being unavailable
      */
-    virtual bool ShouldWait() = 0;
+    virtual bool ShouldWait(Thread* thread) const = 0;
 
-    /// Acquire/lock the object if it is available
-    virtual void Acquire() = 0;
+    /// Acquire/lock the object for the specified thread if it is available
+    virtual void Acquire(Thread* thread) = 0;
 
     /**
      * Add a thread to wait on this object
diff --git a/src/core/hle/kernel/mutex.cpp b/src/core/hle/kernel/mutex.cpp
index 8d92a9b8ed..072e4e7c18 100644
--- a/src/core/hle/kernel/mutex.cpp
+++ b/src/core/hle/kernel/mutex.cpp
@@ -40,31 +40,19 @@ SharedPtr<Mutex> Mutex::Create(bool initial_locked, std::string name) {
     mutex->name = std::move(name);
     mutex->holding_thread = nullptr;
 
-    // Acquire mutex with current thread if initialized as locked...
+    // Acquire mutex with current thread if initialized as locked
     if (initial_locked)
-        mutex->Acquire();
+        mutex->Acquire(GetCurrentThread());
 
     return mutex;
 }
 
-bool Mutex::ShouldWait() {
-    auto thread = GetCurrentThread();
-    bool wait = lock_count > 0 && holding_thread != thread;
-
-    // If the holding thread of the mutex is lower priority than this thread, that thread should
-    // temporarily inherit this thread's priority
-    if (wait && thread->current_priority < holding_thread->current_priority)
-        holding_thread->BoostPriority(thread->current_priority);
-
-    return wait;
-}
-
-void Mutex::Acquire() {
-    Acquire(GetCurrentThread());
+bool Mutex::ShouldWait(Thread* thread) const {
+    return lock_count > 0 && thread != holding_thread;
 }
 
-void Mutex::Acquire(SharedPtr<Thread> thread) {
-    ASSERT_MSG(!ShouldWait(), "object unavailable!");
+void Mutex::Acquire(Thread* thread) {
+    ASSERT_MSG(!ShouldWait(thread), "object unavailable!");
 
     // Actually "acquire" the mutex only if we don't already have it...
     if (lock_count == 0) {
diff --git a/src/core/hle/kernel/mutex.h b/src/core/hle/kernel/mutex.h
index 53c3dc1f1f..98b3d40b53 100644
--- a/src/core/hle/kernel/mutex.h
+++ b/src/core/hle/kernel/mutex.h
@@ -38,8 +38,9 @@ public:
     std::string name;                 ///< Name of mutex (optional)
     SharedPtr<Thread> holding_thread; ///< Thread that has acquired the mutex
 
-    bool ShouldWait() override;
-    void Acquire() override;
+    bool ShouldWait(Thread* thread) const override;
+    void Acquire(Thread* thread) override;
+
 
     /**
      * Acquires the specified mutex for the specified thread
diff --git a/src/core/hle/kernel/semaphore.cpp b/src/core/hle/kernel/semaphore.cpp
index bf76007805..5e6139265d 100644
--- a/src/core/hle/kernel/semaphore.cpp
+++ b/src/core/hle/kernel/semaphore.cpp
@@ -30,12 +30,12 @@ ResultVal<SharedPtr<Semaphore>> Semaphore::Create(s32 initial_count, s32 max_cou
     return MakeResult<SharedPtr<Semaphore>>(std::move(semaphore));
 }
 
-bool Semaphore::ShouldWait() {
+bool Semaphore::ShouldWait(Thread* thread) const {
     return available_count <= 0;
 }
 
-void Semaphore::Acquire() {
-    ASSERT_MSG(!ShouldWait(), "object unavailable!");
+void Semaphore::Acquire(Thread* thread) {
+    ASSERT_MSG(!ShouldWait(thread), "object unavailable!");
     --available_count;
 }
 
diff --git a/src/core/hle/kernel/semaphore.h b/src/core/hle/kernel/semaphore.h
index e01908a257..cde94f7ccd 100644
--- a/src/core/hle/kernel/semaphore.h
+++ b/src/core/hle/kernel/semaphore.h
@@ -39,8 +39,8 @@ public:
     s32 available_count; ///< Number of free slots left in the semaphore
     std::string name;    ///< Name of semaphore (optional)
 
-    bool ShouldWait() override;
-    void Acquire() override;
+    bool ShouldWait(Thread* thread) const override;
+    void Acquire(Thread* thread) override;
 
     /**
      * Releases a certain number of slots from a semaphore.
diff --git a/src/core/hle/kernel/server_port.cpp b/src/core/hle/kernel/server_port.cpp
index 6c19aa7c09..fd3bbbcad3 100644
--- a/src/core/hle/kernel/server_port.cpp
+++ b/src/core/hle/kernel/server_port.cpp
@@ -14,13 +14,13 @@ namespace Kernel {
 ServerPort::ServerPort() {}
 ServerPort::~ServerPort() {}
 
-bool ServerPort::ShouldWait() {
+bool ServerPort::ShouldWait(Thread* thread) const {
     // If there are no pending sessions, we wait until a new one is added.
     return pending_sessions.size() == 0;
 }
 
-void ServerPort::Acquire() {
-    ASSERT_MSG(!ShouldWait(), "object unavailable!");
+void ServerPort::Acquire(Thread* thread) {
+    ASSERT_MSG(!ShouldWait(thread), "object unavailable!");
 }
 
 std::tuple<SharedPtr<ServerPort>, SharedPtr<ClientPort>> ServerPort::CreatePortPair(
diff --git a/src/core/hle/kernel/server_port.h b/src/core/hle/kernel/server_port.h
index b0f8df62c5..6f8bdb6a90 100644
--- a/src/core/hle/kernel/server_port.h
+++ b/src/core/hle/kernel/server_port.h
@@ -53,8 +53,8 @@ public:
     /// ServerSessions created from this port inherit a reference to this handler.
     std::shared_ptr<Service::SessionRequestHandler> hle_handler;
 
-    bool ShouldWait() override;
-    void Acquire() override;
+    bool ShouldWait(Thread* thread) const override;
+    void Acquire(Thread* thread) override;
 
 private:
     ServerPort();
diff --git a/src/core/hle/kernel/server_session.cpp b/src/core/hle/kernel/server_session.cpp
index 146458c1ca..9447ff236a 100644
--- a/src/core/hle/kernel/server_session.cpp
+++ b/src/core/hle/kernel/server_session.cpp
@@ -29,12 +29,12 @@ ResultVal<SharedPtr<ServerSession>> ServerSession::Create(
     return MakeResult<SharedPtr<ServerSession>>(std::move(server_session));
 }
 
-bool ServerSession::ShouldWait() {
+bool ServerSession::ShouldWait(Thread* thread) const {
     return !signaled;
 }
 
-void ServerSession::Acquire() {
-    ASSERT_MSG(!ShouldWait(), "object unavailable!");
+void ServerSession::Acquire(Thread* thread) {
+    ASSERT_MSG(!ShouldWait(thread), "object unavailable!");
     signaled = false;
 }
 
diff --git a/src/core/hle/kernel/server_session.h b/src/core/hle/kernel/server_session.h
index 458284a5db..c088b9a199 100644
--- a/src/core/hle/kernel/server_session.h
+++ b/src/core/hle/kernel/server_session.h
@@ -57,9 +57,9 @@ public:
      */
     ResultCode HandleSyncRequest();
 
-    bool ShouldWait() override;
+    bool ShouldWait(Thread* thread) const override;
 
-    void Acquire() override;
+    void Acquire(Thread* thread) override;
 
     std::string name; ///< The name of this session (optional)
     bool signaled;    ///< Whether there's new data available to this ServerSession
diff --git a/src/core/hle/kernel/thread.cpp b/src/core/hle/kernel/thread.cpp
index 5fb95dadaa..7d03a2cf72 100644
--- a/src/core/hle/kernel/thread.cpp
+++ b/src/core/hle/kernel/thread.cpp
@@ -27,12 +27,12 @@ namespace Kernel {
 /// Event type for the thread wake up event
 static int ThreadWakeupEventType;
 
-bool Thread::ShouldWait() {
+bool Thread::ShouldWait(Thread* thread) const {
     return status != THREADSTATUS_DEAD;
 }
 
-void Thread::Acquire() {
-    ASSERT_MSG(!ShouldWait(), "object unavailable!");
+void Thread::Acquire(Thread* thread) {
+    ASSERT_MSG(!ShouldWait(thread), "object unavailable!");
 }
 
 // TODO(yuriks): This can be removed if Thread objects are explicitly pooled in the future, allowing
diff --git a/src/core/hle/kernel/thread.h b/src/core/hle/kernel/thread.h
index c77ac644d5..f2bc1ec9c5 100644
--- a/src/core/hle/kernel/thread.h
+++ b/src/core/hle/kernel/thread.h
@@ -72,8 +72,8 @@ public:
         return HANDLE_TYPE;
     }
 
-    bool ShouldWait() override;
-    void Acquire() override;
+    bool ShouldWait(Thread* thread) const override;
+    void Acquire(Thread* thread) override;
 
     /**
      * Gets the thread's current priority
diff --git a/src/core/hle/kernel/timer.cpp b/src/core/hle/kernel/timer.cpp
index b50cf520df..8f2bc4c7f1 100644
--- a/src/core/hle/kernel/timer.cpp
+++ b/src/core/hle/kernel/timer.cpp
@@ -39,12 +39,12 @@ SharedPtr<Timer> Timer::Create(ResetType reset_type, std::string name) {
     return timer;
 }
 
-bool Timer::ShouldWait() {
+bool Timer::ShouldWait(Thread* thread) const {
     return !signaled;
 }
 
-void Timer::Acquire() {
-    ASSERT_MSG(!ShouldWait(), "object unavailable!");
+void Timer::Acquire(Thread* thread) {
+    ASSERT_MSG(!ShouldWait(thread), "object unavailable!");
 
     if (reset_type == ResetType::OneShot)
         signaled = false;
diff --git a/src/core/hle/kernel/timer.h b/src/core/hle/kernel/timer.h
index 18ea0236b7..2e3b31b23d 100644
--- a/src/core/hle/kernel/timer.h
+++ b/src/core/hle/kernel/timer.h
@@ -39,8 +39,8 @@ public:
     u64 initial_delay;  ///< The delay until the timer fires for the first time
     u64 interval_delay; ///< The delay until the timer fires after the first time
 
-    bool ShouldWait() override;
-    void Acquire() override;
+    bool ShouldWait(Thread* thread) const override;
+    void Acquire(Thread* thread) override;
 
     /**
      * Starts the timer, with the specified initial delay and interval.
diff --git a/src/core/hle/svc.cpp b/src/core/hle/svc.cpp
index 5b538be22f..159ac0bf63 100644
--- a/src/core/hle/svc.cpp
+++ b/src/core/hle/svc.cpp
@@ -272,7 +272,7 @@ static ResultCode WaitSynchronization1(Kernel::Handle handle, s64 nano_seconds)
     LOG_TRACE(Kernel_SVC, "called handle=0x%08X(%s:%s), nanoseconds=%lld", handle,
               object->GetTypeName().c_str(), object->GetName().c_str(), nano_seconds);
 
-    if (object->ShouldWait()) {
+    if (object->ShouldWait(thread)) {
 
         if (nano_seconds == 0)
             return ERR_SYNC_TIMEOUT;
@@ -294,7 +294,7 @@ static ResultCode WaitSynchronization1(Kernel::Handle handle, s64 nano_seconds)
         return ERR_SYNC_TIMEOUT;
     }
 
-    object->Acquire();
+    object->Acquire(thread);
 
     return RESULT_SUCCESS;
 }
@@ -336,11 +336,11 @@ static ResultCode WaitSynchronizationN(s32* out, Kernel::Handle* handles, s32 ha
     if (wait_all) {
         bool all_available =
             std::all_of(objects.begin(), objects.end(),
-                        [](const ObjectPtr& object) { return !object->ShouldWait(); });
+                        [thread](const ObjectPtr& object) { return !object->ShouldWait(thread); });
         if (all_available) {
             // We can acquire all objects right now, do so.
             for (auto& object : objects)
-                object->Acquire();
+                object->Acquire(thread);
             // Note: In this case, the `out` parameter is not set,
             // and retains whatever value it had before.
             return RESULT_SUCCESS;
@@ -380,12 +380,12 @@ static ResultCode WaitSynchronizationN(s32* out, Kernel::Handle* handles, s32 ha
     } else {
         // Find the first object that is acquirable in the provided list of objects
         auto itr = std::find_if(objects.begin(), objects.end(),
-                                [](const ObjectPtr& object) { return !object->ShouldWait(); });
+                                [thread](const ObjectPtr& object) { return !object->ShouldWait(thread); });
 
         if (itr != objects.end()) {
             // We found a ready object, acquire it and set the result value
             Kernel::WaitObject* object = itr->get();
-            object->Acquire();
+            object->Acquire(thread);
             *out = std::distance(objects.begin(), itr);
             return RESULT_SUCCESS;
         }
-- 
GitLab