// Copyright (c) Microsoft Corporation. // Licensed under the MIT License. #pragma once #include "winadapter.h" // defined by winadapter.h and needed by some windows headers, but conflicts // with some libc++ implementation headers #ifdef __in #undef __in #endif #ifdef __out #undef __out #endif #include #include #include #include #include namespace Microsoft { namespace WRL { namespace Details { struct BoolStruct { int Member; }; typedef int BoolStruct::* BoolType; template // T should be the ComPtr or a derived type of it, not just the interface class ComPtrRefBase { public: typedef typename T::InterfaceType InterfaceType; operator IUnknown**() const throw() { static_assert(__is_base_of(IUnknown, InterfaceType), "Invalid cast: InterfaceType does not derive from IUnknown"); return reinterpret_cast(ptr_->ReleaseAndGetAddressOf()); } protected: T* ptr_; }; template class ComPtrRef : public Details::ComPtrRefBase // T should be the ComPtr or a derived type of it, not just the interface { using Super = Details::ComPtrRefBase; using InterfaceType = typename Super::InterfaceType; public: ComPtrRef(_In_opt_ T* ptr) throw() { this->ptr_ = ptr; } // Conversion operators operator void**() const throw() { return reinterpret_cast(this->ptr_->ReleaseAndGetAddressOf()); } // This is our operator ComPtr (or the latest derived class from ComPtr (e.g. WeakRef)) operator T*() throw() { *this->ptr_ = nullptr; return this->ptr_; } // We define operator InterfaceType**() here instead of on ComPtrRefBase, since // if InterfaceType is IUnknown or IInspectable, having it on the base will collide. operator InterfaceType**() throw() { return this->ptr_->ReleaseAndGetAddressOf(); } // This is used for IID_PPV_ARGS in order to do __uuidof(**(ppType)). // It does not need to clear ptr_ at this point, it is done at IID_PPV_ARGS_Helper(ComPtrRef&) later in this file. InterfaceType* operator *() throw() { return this->ptr_->Get(); } // Explicit functions InterfaceType* const * GetAddressOf() const throw() { return this->ptr_->GetAddressOf(); } InterfaceType** ReleaseAndGetAddressOf() throw() { return this->ptr_->ReleaseAndGetAddressOf(); } }; } template class ComPtr { public: typedef T InterfaceType; protected: InterfaceType *ptr_; template friend class ComPtr; void InternalAddRef() const throw() { if (ptr_ != nullptr) { ptr_->AddRef(); } } unsigned long InternalRelease() throw() { unsigned long ref = 0; T* temp = ptr_; if (temp != nullptr) { ptr_ = nullptr; ref = temp->Release(); } return ref; } public: ComPtr() throw() : ptr_(nullptr) { } ComPtr(decltype(nullptr)) throw() : ptr_(nullptr) { } template ComPtr(_In_opt_ U *other) throw() : ptr_(other) { InternalAddRef(); } ComPtr(const ComPtr& other) throw() : ptr_(other.ptr_) { InternalAddRef(); } // copy constructor that allows to instantiate class when U* is convertible to T* template ComPtr(const ComPtr &other, typename std::enable_if::value, void *>::type * = 0) throw() : ptr_(other.ptr_) { InternalAddRef(); } ComPtr(_Inout_ ComPtr &&other) throw() : ptr_(nullptr) { if (this != reinterpret_cast(&reinterpret_cast(other))) { Swap(other); } } // Move constructor that allows instantiation of a class when U* is convertible to T* template ComPtr(_Inout_ ComPtr&& other, typename std::enable_if::value, void *>::type * = 0) throw() : ptr_(other.ptr_) { other.ptr_ = nullptr; } ~ComPtr() throw() { InternalRelease(); } ComPtr& operator=(decltype(nullptr)) throw() { InternalRelease(); return *this; } ComPtr& operator=(_In_opt_ T *other) throw() { if (ptr_ != other) { ComPtr(other).Swap(*this); } return *this; } template ComPtr& operator=(_In_opt_ U *other) throw() { ComPtr(other).Swap(*this); return *this; } ComPtr& operator=(const ComPtr &other) throw() { if (ptr_ != other.ptr_) { ComPtr(other).Swap(*this); } return *this; } template ComPtr& operator=(const ComPtr& other) throw() { ComPtr(other).Swap(*this); return *this; } ComPtr& operator=(_Inout_ ComPtr &&other) throw() { ComPtr(static_cast(other)).Swap(*this); return *this; } template ComPtr& operator=(_Inout_ ComPtr&& other) throw() { ComPtr(static_cast&&>(other)).Swap(*this); return *this; } void Swap(_Inout_ ComPtr&& r) throw() { T* tmp = ptr_; ptr_ = r.ptr_; r.ptr_ = tmp; } void Swap(_Inout_ ComPtr& r) throw() { T* tmp = ptr_; ptr_ = r.ptr_; r.ptr_ = tmp; } operator Details::BoolType() const throw() { return Get() != nullptr ? &Details::BoolStruct::Member : nullptr; } T* Get() const throw() { return ptr_; } InterfaceType* operator->() const throw() { return ptr_; } Details::ComPtrRef> operator&() throw() { return Details::ComPtrRef>(this); } const Details::ComPtrRef> operator&() const throw() { return Details::ComPtrRef>(this); } T* const* GetAddressOf() const throw() { return &ptr_; } T** GetAddressOf() throw() { return &ptr_; } T** ReleaseAndGetAddressOf() throw() { InternalRelease(); return &ptr_; } T* Detach() throw() { T* ptr = ptr_; ptr_ = nullptr; return ptr; } void Attach(_In_opt_ InterfaceType* other) throw() { if (ptr_ != nullptr) { auto ref = ptr_->Release(); // DBG_UNREFERENCED_LOCAL_VARIABLE(ref); // Attaching to the same object only works if duplicate references are being coalesced. Otherwise // re-attaching will cause the pointer to be released and may cause a crash on a subsequent dereference. assert(ref != 0 || ptr_ != other); } ptr_ = other; } unsigned long Reset() { return InternalRelease(); } // Previously, unsafe behavior could be triggered when 'this' is ComPtr or ComPtr and CopyTo is used to copy to another type U. // The user will use operator& to convert the destination into a ComPtrRef, which can then implicit cast to IInspectable** and IUnknown**. // If this overload of CopyTo is not present, it will implicitly cast to IInspectable or IUnknown and match CopyTo(InterfaceType**) instead. // A valid polymoprhic downcast requires run-time type checking via QueryInterface, so CopyTo(InterfaceType**) will break type safety. // This overload matches ComPtrRef before the implicit cast takes place, preventing the unsafe downcast. template HRESULT CopyTo(Details::ComPtrRef> ptr, typename std::enable_if< (std::is_same::value) && !std::is_same::value, void *>::type * = 0) const throw() { return ptr_->QueryInterface(uuidof(), ptr); } HRESULT CopyTo(_Outptr_result_maybenull_ InterfaceType** ptr) const throw() { InternalAddRef(); *ptr = ptr_; return S_OK; } HRESULT CopyTo(REFIID riid, _Outptr_result_nullonfailure_ void** ptr) const throw() { return ptr_->QueryInterface(riid, ptr); } template HRESULT CopyTo(_Outptr_result_nullonfailure_ U** ptr) const throw() { return ptr_->QueryInterface(uuidof(), reinterpret_cast(ptr)); } // query for U interface template HRESULT As(_Inout_ Details::ComPtrRef> p) const throw() { return ptr_->QueryInterface(uuidof(), p); } // query for U interface template HRESULT As(_Out_ ComPtr* p) const throw() { return ptr_->QueryInterface(uuidof(), reinterpret_cast(p->ReleaseAndGetAddressOf())); } // query for riid interface and return as IUnknown HRESULT AsIID(REFIID riid, _Out_ ComPtr* p) const throw() { return ptr_->QueryInterface(riid, reinterpret_cast(p->ReleaseAndGetAddressOf())); } }; // ComPtr namespace Details { // Empty struct used as default template parameter class Nil { }; // Empty struct used for validating template parameter types in Implements struct ImplementsBase { }; class RuntimeClassBase { protected: template static HRESULT AsIID(_In_ T* implements, REFIID riid, _Outptr_result_nullonfailure_ void **ppvObject) noexcept { *ppvObject = nullptr; bool isRefDelegated = false; // Prefer InlineIsEqualGUID over other forms since it's better perf on 4-byte aligned data, which is almost always the case. if (InlineIsEqualGUID(riid, uuidof())) { *ppvObject = implements->CastToUnknown(); static_cast(*ppvObject)->AddRef(); return S_OK; } HRESULT hr = implements->CanCastTo(riid, ppvObject, &isRefDelegated); if (SUCCEEDED(hr) && !isRefDelegated) { static_cast(*ppvObject)->AddRef(); } #ifdef _MSC_VER #pragma warning(push) #pragma warning(disable: 6102) // '*ppvObject' is used but may not be initialized #endif _Analysis_assume_(SUCCEEDED(hr) || (*ppvObject == nullptr)); #ifdef _MSC_VER #pragma warning(pop) #endif return hr; } public: HRESULT RuntimeClassInitialize() noexcept { return S_OK; } }; // Interface traits provides casting and filling iids methods helpers template struct InterfaceTraits { typedef I0 Base; template static Base* CastToBase(_In_ T* ptr) noexcept { return static_cast(ptr); } template static IUnknown* CastToUnknown(_In_ T* ptr) noexcept { return static_cast(static_cast(ptr)); } template _Success_(return == true) static bool CanCastTo(_In_ T* ptr, REFIID riid, _Outptr_ void **ppv) noexcept { // Prefer InlineIsEqualGUID over other forms since it's better perf on 4-byte aligned data, which is almost always the case. if (InlineIsEqualGUID(riid, uuidof())) { *ppv = static_cast(ptr); return true; } return false; } }; // Specialization for Nil parameter template<> struct InterfaceTraits { typedef Nil Base; template _Success_(return == true) static bool CanCastTo(_In_ T*, REFIID, _Outptr_ void **) noexcept { return false; } }; // ChainInterfaces - template allows specifying a derived COM interface along with its class hierarchy to allow QI for the base interfaces template struct ChainInterfaces : I0 { protected: HRESULT CanCastTo(REFIID riid, _Outptr_ void **ppv) throw() { typename InterfaceTraits::Base* ptr = InterfaceTraits::CastToBase(this); return (InterfaceTraits::CanCastTo(this, riid, ppv) || InterfaceTraits::CanCastTo(ptr, riid, ppv) || InterfaceTraits::CanCastTo(ptr, riid, ppv) || InterfaceTraits::CanCastTo(ptr, riid, ppv) || InterfaceTraits::CanCastTo(ptr, riid, ppv) || InterfaceTraits::CanCastTo(ptr, riid, ppv) || InterfaceTraits::CanCastTo(ptr, riid, ppv) || InterfaceTraits::CanCastTo(ptr, riid, ppv) || InterfaceTraits::CanCastTo(ptr, riid, ppv) || InterfaceTraits::CanCastTo(ptr, riid, ppv)) ? S_OK : E_NOINTERFACE; } IUnknown* CastToUnknown() throw() { return InterfaceTraits::CastToUnknown(this); } }; // Helper template used by Implements. This template traverses a list of interfaces and adds them as base class and information // to enable QI. template struct ImplementsHelper; template struct ImplementsMarker {}; template struct MarkImplements; template struct MarkImplements { typedef I0 Type; }; template struct MarkImplements { typedef ImplementsMarker Type; }; // AdjustImplements pre-processes the type list for more efficient builds. template struct AdjustImplements; template struct AdjustImplements { typedef ImplementsHelper::value>::Type, Bases...> Type; }; // Use AdjustImplements to remove instances of "Nil" from the type list. template struct AdjustImplements { typedef typename AdjustImplements::Type Type; }; template <> struct AdjustImplements<> { typedef ImplementsHelper<> Type; }; // Specialization handles unadorned interfaces template struct ImplementsHelper : I0, AdjustImplements::Type { template friend struct ImplementsHelper; friend class RuntimeClassBase; protected: HRESULT CanCastTo(REFIID riid, _Outptr_ void **ppv, bool *pRefDelegated = nullptr) noexcept { // Prefer InlineIsEqualGUID over other forms since it's better perf on 4-byte aligned data, which is almost always the case. if (InlineIsEqualGUID(riid, uuidof())) { *ppv = reinterpret_cast(reinterpret_cast(this)); return S_OK; } return AdjustImplements::Type::CanCastTo(riid, ppv, pRefDelegated); } IUnknown* CastToUnknown() noexcept { return reinterpret_cast(reinterpret_cast(this)); } }; // Selector is used to "tag" base interfaces to be used in casting, since a runtime class may indirectly derive from // the same interface or Implements<> template multiple times template struct Selector : public base { }; // Specialization handles types that derive from ImplementsHelper (e.g. nested Implements). template struct ImplementsHelper, TInterfaces...> : Selector, TInterfaces...>>, Selector::Type, ImplementsHelper, TInterfaces...>> { template friend struct ImplementsHelper; friend class RuntimeClassBase; protected: typedef Selector, TInterfaces...>> CurrentType; typedef Selector::Type, ImplementsHelper, TInterfaces...>> BaseType; HRESULT CanCastTo(REFIID riid, _Outptr_ void **ppv, bool *pRefDelegated = nullptr) noexcept { HRESULT hr = CurrentType::CanCastTo(riid, ppv); if (hr == E_NOINTERFACE) { hr = BaseType::CanCastTo(riid, ppv, pRefDelegated); } return hr; } IUnknown* CastToUnknown() noexcept { // First in list wins. return CurrentType::CastToUnknown(); } }; // terminal case specialization. template <> struct ImplementsHelper<> { template friend struct ImplementsHelper; friend class RuntimeClassBase; protected: HRESULT CanCastTo(_In_ REFIID /*riid*/, _Outptr_ void ** /*ppv*/, bool * /*pRefDelegated*/ = nullptr) noexcept { return E_NOINTERFACE; } // IUnknown* CastToUnknown() noexcept; // not defined for terminal case. }; // Specialization handles chaining interfaces template struct ImplementsHelper, TInterfaces...> : ChainInterfaces, AdjustImplements::Type { template friend struct ImplementsHelper; friend class RuntimeClassBase; protected: typedef typename AdjustImplements::Type BaseType; HRESULT CanCastTo(REFIID riid, _Outptr_ void **ppv, bool *pRefDelegated = nullptr) noexcept { HRESULT hr = ChainInterfaces::CanCastTo(riid, ppv); if (FAILED(hr)) { hr = BaseType::CanCastTo(riid, ppv, pRefDelegated); } return hr; } IUnknown* CastToUnknown() noexcept { return ChainInterfaces::CastToUnknown(); } }; // Implements - template implementing QI using the information provided through its template parameters // Each template parameter has to be one of the following: // * COM Interface // * A class that implements one or more COM interfaces // * ChainInterfaces template template struct Implements : AdjustImplements::Type, ImplementsBase { public: typedef I0 FirstInterface; protected: typedef typename AdjustImplements::Type BaseType; template friend struct ImplementsHelper; friend class RuntimeClassBase; HRESULT CanCastTo(REFIID riid, _Outptr_ void **ppv) noexcept { return BaseType::CanCastTo(riid, ppv); } IUnknown* CastToUnknown() noexcept { return BaseType::CastToUnknown(); } }; // Used on RuntimeClass to protect it from being constructed with new class DontUseNewUseMake { private: void* operator new(size_t) noexcept { assert(false); return 0; } public: void* operator new(size_t, _In_ void* placement) noexcept { return placement; } }; template class RuntimeClassImpl : public AdjustImplements::Type, public RuntimeClassBase, public DontUseNewUseMake { public: STDMETHOD(QueryInterface)(REFIID riid, _Outptr_result_nullonfailure_ void **ppvObject) { return Super::AsIID(this, riid, ppvObject); } STDMETHOD_(ULONG, AddRef)() { return InternalAddRef(); } STDMETHOD_(ULONG, Release)() { ULONG ref = InternalRelease(); if (ref == 0) { delete this; } return ref; } protected: using Super = RuntimeClassBase; static const LONG c_lProtectDestruction = -(LONG_MAX / 2); RuntimeClassImpl() noexcept = default; virtual ~RuntimeClassImpl() noexcept { // Set refcount_ to -(LONG_MAX/2) to protect destruction and // also catch mismatched Release in debug builds refcount_ = static_cast(c_lProtectDestruction); } ULONG InternalAddRef() noexcept { return ++refcount_; } ULONG InternalRelease() noexcept { return --refcount_; } unsigned long GetRefCount() const noexcept { return refcount_; } std::atomic refcount_{1}; }; } template class Base : public Details::RuntimeClassImpl { Base(const Base&) = delete; Base& operator=(const Base&) = delete; protected: HRESULT CustomQueryInterface(REFIID /*riid*/, _Outptr_result_nullonfailure_ void** /*ppvObject*/, _Out_ bool *handled) { *handled = false; return S_OK; } public: Base() throw() = default; typedef Base RuntimeClassT; }; // Creates a Nano-COM object wrapped in a smart pointer. template ComPtr Make(TArgs&&... args) { ComPtr object; std::unique_ptr buffer(new unsigned char[sizeof(T)]); if (buffer) { T* ptr = new (buffer.get())T(std::forward(args)...); object.Attach(ptr); buffer.release(); } return object; } using Details::ChainInterfaces; } } // Overloaded global function to provide to IID_PPV_ARGS that support Details::ComPtrRef template void** IID_PPV_ARGS_Helper(Microsoft::WRL::Details::ComPtrRef pp) throw() { return pp; }