1
0
mirror of https://github.com/VCMP-SqMod/SqMod.git synced 2025-02-23 13:17:14 +01:00
Sandu Liviu Catalin 4a6bfc086c Major plugin refactor and cleanup.
Switched to POCO library for unified platform/library interface.
Deprecated the external module API. It was creating more problems than solving.
Removed most built-in libraries in favor of system libraries for easier maintenance.
Cleaned and secured code with help from static analyzers.
2021-01-30 08:51:39 +02:00

522 lines
13 KiB
C++

//
// Context.cpp
//
// Library: NetSSL_Win
// Package: SSLCore
// Module: Context
//
// Copyright (c) 2006-2014, Applied Informatics Software Engineering GmbH.
// and Contributors.
//
// SPDX-License-Identifier: BSL-1.0
//
#include "Poco/Net/Context.h"
#include "Poco/Net/SSLManager.h"
#include "Poco/Net/SSLException.h"
#include "Poco/Net/Utility.h"
#include "Poco/UnicodeConverter.h"
#include "Poco/Format.h"
#include "Poco/File.h"
#include "Poco/FileStream.h"
#include "Poco/MemoryStream.h"
#include "Poco/Base64Decoder.h"
#include "Poco/Buffer.h"
#include <algorithm>
#include <cctype>
namespace Poco {
namespace Net {
const std::string Context::CERT_STORE_MY("MY");
const std::string Context::CERT_STORE_ROOT("ROOT");
const std::string Context::CERT_STORE_TRUST("TRUST");
const std::string Context::CERT_STORE_CA("CA");
const std::string Context::CERT_STORE_USERDS("USERDS");
Context::Context(Usage usage,
const std::string& certNameOrPath,
VerificationMode verMode,
int options,
const std::string& certStore):
_usage(usage),
_mode(verMode),
_options(options),
_disabledProtocols(0),
_extendedCertificateVerification(true),
_certNameOrPath(certNameOrPath),
_certStoreName(certStore),
_hMemCertStore(0),
_hCollectionCertStore(0),
_hRootCertStore(0),
_hCertStore(0),
_pCert(0),
_securityFunctions(SSLManager::instance().securityFunctions())
{
init();
}
Context::~Context()
{
if (_pCert)
{
CertFreeCertificateContext(_pCert);
}
if (_hCertStore)
{
CertCloseStore(_hCertStore, 0);
}
CertCloseStore(_hCollectionCertStore, 0);
CertCloseStore(_hMemCertStore, 0);
if (_hRootCertStore)
{
CertCloseStore(_hRootCertStore, 0);
}
if (_hCreds.dwLower != 0 && _hCreds.dwUpper != 0)
{
_securityFunctions.FreeCredentialsHandle(&_hCreds);
}
}
void Context::init()
{
_hCreds.dwLower = 0;
_hCreds.dwUpper = 0;
_hMemCertStore = CertOpenStore(
CERT_STORE_PROV_MEMORY, // The memory provider type
0, // The encoding type is not needed
NULL, // Use the default provider
0, // Accept the default dwFlags
NULL); // pvPara is not used
if (!_hMemCertStore)
throw SSLException("Failed to create memory certificate store", GetLastError());
_hCollectionCertStore = CertOpenStore(
CERT_STORE_PROV_COLLECTION, // A collection store
0, // Encoding type; not used with a collection store
NULL, // Use the default provider
0, // No flags
NULL); // Not needed
if (!_hCollectionCertStore)
throw SSLException("Failed to create collection store", GetLastError());
if (!CertAddStoreToCollection(_hCollectionCertStore, _hMemCertStore, CERT_PHYSICAL_STORE_ADD_ENABLE_FLAG, 1))
throw SSLException("Failed to add memory certificate store to collection store", GetLastError());
if (_options & OPT_TRUST_ROOTS_WIN_CERT_STORE)
{
// add root certificates
std::wstring rootStore;
Poco::UnicodeConverter::convert(CERT_STORE_ROOT, rootStore);
_hRootCertStore = CertOpenSystemStoreW(0, rootStore.c_str());
if (!_hRootCertStore)
throw SSLException("Failed to open root certificate store", GetLastError());
if (!CertAddStoreToCollection(_hCollectionCertStore, _hRootCertStore, CERT_PHYSICAL_STORE_ADD_ENABLE_FLAG, 1))
throw SSLException("Failed to add root certificate store to collection store", GetLastError());
}
}
void Context::enableExtendedCertificateVerification(bool flag)
{
_extendedCertificateVerification = flag;
}
void Context::addTrustedCert(const Poco::Net::X509Certificate& cert)
{
Poco::FastMutex::ScopedLock lock(_mutex);
if (!CertAddCertificateContextToStore(_hMemCertStore, cert.system(), CERT_STORE_ADD_REPLACE_EXISTING, 0))
throw CertificateException("Failed to add certificate to store", GetLastError());
}
Poco::Net::X509Certificate Context::certificate()
{
if (_pCert)
return Poco::Net::X509Certificate(_pCert, true);
if (_certNameOrPath.empty())
throw NoCertificateException("Certificate requested, but no certificate name or path provided");
if (_options & OPT_LOAD_CERT_FROM_FILE)
{
importCertificate();
}
else
{
loadCertificate();
}
return Poco::Net::X509Certificate(_pCert, true);
}
void Context::loadCertificate()
{
std::wstring wcertStoreName;
Poco::UnicodeConverter::convert(_certStoreName, wcertStoreName);
if (!_hCertStore)
{
if (_options & OPT_USE_MACHINE_STORE)
_hCertStore = CertOpenStore(CERT_STORE_PROV_SYSTEM, 0, 0, CERT_SYSTEM_STORE_LOCAL_MACHINE, wcertStoreName.c_str());
else
_hCertStore = CertOpenSystemStoreW(0, wcertStoreName.c_str());
}
if (!_hCertStore) throw CertificateException("Failed to open certificate store", _certStoreName, GetLastError());
CERT_RDN_ATTR cert_rdn_attr;
cert_rdn_attr.pszObjId = szOID_COMMON_NAME;
cert_rdn_attr.dwValueType = CERT_RDN_ANY_TYPE;
cert_rdn_attr.Value.cbData = (DWORD) _certNameOrPath.size();
cert_rdn_attr.Value.pbData = (BYTE *) _certNameOrPath.c_str();
CERT_RDN cert_rdn;
cert_rdn.cRDNAttr = 1;
cert_rdn.rgRDNAttr = &cert_rdn_attr;
_pCert = CertFindCertificateInStore(_hCertStore, X509_ASN_ENCODING, 0, CERT_FIND_SUBJECT_ATTR, &cert_rdn, NULL);
if (!_pCert) throw NoCertificateException(Poco::format("Failed to find certificate %s in store %s", _certNameOrPath, _certStoreName));
}
void Context::importCertificate()
{
Poco::File certFile(_certNameOrPath);
if (!certFile.exists()) throw Poco::FileNotFoundException(_certNameOrPath);
Poco::File::FileSize size = certFile.getSize();
Poco::Buffer<char> buffer(static_cast<std::size_t>(size));
Poco::FileInputStream istr(_certNameOrPath);
istr.read(buffer.begin(), buffer.size());
if (istr.gcount() != size) throw Poco::IOException("error reading PKCS #12 certificate file");
importCertificate(buffer.begin(), buffer.size());
}
void Context::importCertificate(const char* pBuffer, std::size_t size)
{
std::string password;
SSLManager::instance().PrivateKeyPassphraseRequired.notify(&SSLManager::instance(), password);
std::wstring wpassword;
Poco::UnicodeConverter::toUTF16(password, wpassword);
// clear UTF-8 password
std::fill(const_cast<char*>(password.data()), const_cast<char*>(password.data() + password.size()), 'X');
CRYPT_DATA_BLOB blob;
blob.cbData = static_cast<DWORD>(size);
blob.pbData = reinterpret_cast<BYTE*>(const_cast<char*>(pBuffer));
HCERTSTORE hTempStore = PFXImportCertStore(&blob, wpassword.data(), PKCS12_ALLOW_OVERWRITE_KEY | PKCS12_INCLUDE_EXTENDED_PROPERTIES);
// clear UTF-16 password
std::fill(const_cast<wchar_t*>(wpassword.data()), const_cast<wchar_t*>(wpassword.data() + password.size()), L'X');
if (hTempStore)
{
PCCERT_CONTEXT pCert = 0;
pCert = CertEnumCertificatesInStore(hTempStore, pCert);
while (pCert)
{
PCCERT_CONTEXT pStoreCert = 0;
BOOL res = CertAddCertificateContextToStore(_hMemCertStore, pCert, CERT_STORE_ADD_REPLACE_EXISTING, &pStoreCert);
if (res)
{
if (!_pCert)
{
_pCert = pStoreCert;
}
else
{
CertFreeCertificateContext(pStoreCert);
pStoreCert = 0;
}
}
pCert = CertEnumCertificatesInStore(hTempStore, pCert);
}
CertCloseStore(hTempStore, 0);
}
else throw CertificateException("failed to import certificate", GetLastError());
}
CredHandle& Context::credentials()
{
Poco::FastMutex::ScopedLock lock(_mutex);
if (_hCreds.dwLower == 0 && _hCreds.dwUpper == 0)
{
acquireSchannelCredentials(_hCreds);
}
return _hCreds;
}
void Context::acquireSchannelCredentials(CredHandle& credHandle) const
{
SCHANNEL_CRED schannelCred;
ZeroMemory(&schannelCred, sizeof(schannelCred));
schannelCred.dwVersion = SCHANNEL_CRED_VERSION;
if (_pCert)
{
schannelCred.cCreds = 1; // how many cred are stored in &pCertContext
schannelCred.paCred = &const_cast<PCCERT_CONTEXT>(_pCert);
}
schannelCred.grbitEnabledProtocols = proto();
// Windows NT and Windows Me/98/95: revocation checking not supported via flags
if (_options & Context::OPT_PERFORM_REVOCATION_CHECK)
schannelCred.dwFlags |= SCH_CRED_REVOCATION_CHECK_CHAIN;
else
schannelCred.dwFlags |= SCH_CRED_IGNORE_NO_REVOCATION_CHECK | SCH_CRED_IGNORE_REVOCATION_OFFLINE;
if (isForServerUse())
{
if (_mode >= Context::VERIFY_STRICT)
schannelCred.dwFlags |= SCH_CRED_NO_SYSTEM_MAPPER;
if (_mode == Context::VERIFY_NONE)
schannelCred.dwFlags |= SCH_CRED_MANUAL_CRED_VALIDATION;
}
else
{
if (_mode >= Context::VERIFY_STRICT)
schannelCred.dwFlags |= SCH_CRED_NO_DEFAULT_CREDS;
else
schannelCred.dwFlags |= SCH_CRED_USE_DEFAULT_CREDS;
if (_mode == Context::VERIFY_NONE)
schannelCred.dwFlags |= SCH_CRED_MANUAL_CRED_VALIDATION | SCH_CRED_NO_SERVERNAME_CHECK;
if (!_extendedCertificateVerification)
schannelCred.dwFlags |= SCH_CRED_NO_SERVERNAME_CHECK;
}
#if defined(SCH_USE_STRONG_CRYPTO)
if (_options & Context::OPT_USE_STRONG_CRYPTO)
schannelCred.dwFlags |= SCH_USE_STRONG_CRYPTO;
#endif
schannelCred.hRootStore = schannelCred.hRootStore = isForServerUse() ? _hCollectionCertStore : NULL;
TimeStamp tsExpiry;
tsExpiry.LowPart = tsExpiry.HighPart = 0;
SECURITY_STATUS status = _securityFunctions.AcquireCredentialsHandleW(
NULL,
UNISP_NAME_W,
isForServerUse() ? SECPKG_CRED_INBOUND : SECPKG_CRED_OUTBOUND,
NULL,
&schannelCred,
NULL,
NULL,
&credHandle,
&tsExpiry);
if (status != SEC_E_OK)
{
throw SSLException("Failed to acquire Schannel credentials", Utility::formatError(status));
}
}
DWORD Context::proto() const
{
DWORD result = 0;
switch (_usage)
{
case Context::TLS_CLIENT_USE:
case Context::CLIENT_USE:
result = SP_PROT_SSL3_CLIENT
| SP_PROT_TLS1_CLIENT
#if defined(SP_PROT_TLS1_1)
| SP_PROT_TLS1_1_CLIENT
#endif
#if defined(SP_PROT_TLS1_2)
| SP_PROT_TLS1_2_CLIENT
#endif
#if defined(SP_PROT_TLS1_3)
| SP_PROT_TLS1_3_CLIENT
#endif
;
break;
case Context::TLS_SERVER_USE:
case Context::SERVER_USE:
result = SP_PROT_SSL3_SERVER
| SP_PROT_TLS1_SERVER
#if defined(SP_PROT_TLS1_1)
| SP_PROT_TLS1_1_SERVER
#endif
#if defined(SP_PROT_TLS1_2)
| SP_PROT_TLS1_2_SERVER
#endif
#if defined(SP_PROT_TLS1_3)
| SP_PROT_TLS1_3_SERVER
#endif
;
break;
case Context::TLSV1_CLIENT_USE:
result = SP_PROT_TLS1_CLIENT
#if defined(SP_PROT_TLS1_1)
| SP_PROT_TLS1_1_CLIENT
#endif
#if defined(SP_PROT_TLS1_2)
| SP_PROT_TLS1_2_CLIENT
#endif
#if defined(SP_PROT_TLS1_3)
| SP_PROT_TLS1_3_CLIENT
#endif
;
break;
case Context::TLSV1_SERVER_USE:
result = SP_PROT_TLS1_SERVER
#if defined(SP_PROT_TLS1_1)
| SP_PROT_TLS1_1_SERVER
#endif
#if defined(SP_PROT_TLS1_2)
| SP_PROT_TLS1_2_SERVER
#endif
#if defined(SP_PROT_TLS1_3)
| SP_PROT_TLS1_3_SERVER
#endif
;
break;
#if defined(SP_PROT_TLS1_1)
case Context::TLSV1_1_CLIENT_USE:
result = SP_PROT_TLS1_1_CLIENT
#if defined(SP_PROT_TLS1_2)
| SP_PROT_TLS1_2_CLIENT
#endif
#if defined(SP_PROT_TLS1_3)
| SP_PROT_TLS1_3_CLIENT
#endif
;
break;
case Context::TLSV1_1_SERVER_USE:
result = SP_PROT_TLS1_1_SERVER
#if defined(SP_PROT_TLS1_2)
| SP_PROT_TLS1_2_SERVER
#endif
#if defined(SP_PROT_TLS1_3)
| SP_PROT_TLS1_3_SERVER
#endif
;
break;
#endif
#if defined(SP_PROT_TLS1_2)
case Context::TLSV1_2_CLIENT_USE:
result = SP_PROT_TLS1_2_CLIENT
#if defined(SP_PROT_TLS1_3)
| SP_PROT_TLS1_3_CLIENT
#endif
;
break;
case Context::TLSV1_2_SERVER_USE:
result = SP_PROT_TLS1_2_SERVER
#if defined(SP_PROT_TLS1_3)
| SP_PROT_TLS1_3_SERVER
#endif
;
break;
#endif
#if defined(SP_PROT_TLS1_3)
case Context::TLSV1_3_CLIENT_USE:
result = SP_PROT_TLS1_3_CLIENT;
break;
case Context::TLSV1_3_SERVER_USE:
result = SP_PROT_TLS1_3_SERVER;
break;
#endif
default:
throw Poco::InvalidArgumentException("Unsupported SSL/TLS protocol version");
}
return result & enabledProtocols();
}
DWORD Context::enabledProtocols() const
{
DWORD result = 0;
if (!(_disabledProtocols & PROTO_SSLV3))
{
result |= SP_PROT_SSL3_CLIENT | SP_PROT_SSL3_SERVER;
}
if (!(_disabledProtocols & PROTO_TLSV1))
{
result |= SP_PROT_TLS1_CLIENT | SP_PROT_TLS1_SERVER;
}
if (!(_disabledProtocols & PROTO_TLSV1_1))
{
#ifdef SP_PROT_TLS1_1_CLIENT
result |= SP_PROT_TLS1_1_CLIENT | SP_PROT_TLS1_1_SERVER;
#endif
}
if (!(_disabledProtocols & PROTO_TLSV1_2))
{
#ifdef SP_PROT_TLS1_2_CLIENT
result |= SP_PROT_TLS1_2_CLIENT | SP_PROT_TLS1_2_SERVER;
#endif
}
if (!(_disabledProtocols & PROTO_TLSV1_3))
{
#ifdef SP_PROT_TLS1_3_CLIENT
result |= SP_PROT_TLS1_3_CLIENT | SP_PROT_TLS1_3_SERVER;
#endif
}
return result;
}
void Context::disableProtocols(int protocols)
{
_disabledProtocols = protocols;
}
void Context::requireMinimumProtocol(Protocols protocol)
{
_disabledProtocols = 0;
switch (protocol)
{
case PROTO_SSLV3:
_disabledProtocols = PROTO_SSLV2;
break;
case PROTO_TLSV1:
_disabledProtocols = PROTO_SSLV2 | PROTO_SSLV3;
break;
case PROTO_TLSV1_1:
_disabledProtocols = PROTO_SSLV2 | PROTO_SSLV3 | PROTO_TLSV1;
break;
case PROTO_TLSV1_2:
_disabledProtocols = PROTO_SSLV2 | PROTO_SSLV3 | PROTO_TLSV1 | PROTO_TLSV1_1;
break;
case PROTO_TLSV1_3:
_disabledProtocols = PROTO_SSLV2 | PROTO_SSLV3 | PROTO_TLSV1 | PROTO_TLSV1_1 | PROTO_TLSV1_2;
break;
}
}
} } // namespace Poco::Net