From 8ce40ee74c6b2b2c606383bf04e2afc912c69cce Mon Sep 17 00:00:00 2001 From: Sandu Liviu Catalin Date: Thu, 28 Jul 2016 00:41:43 +0300 Subject: [PATCH] Prevent server crash by accessing row data from MySQL result-set when there is no valid row available. Throw an error instead. Should close #25 --- modules/mysql/Common.hpp | 5 ++- modules/mysql/Field.cpp | 80 ++++++++++++++++++++++++++++++------- modules/mysql/Field.hpp | 18 +++++++++ modules/mysql/ResultSet.cpp | 50 +++++++++++++++++++++++ modules/mysql/ResultSet.hpp | 54 ++++++++++++++++--------- 5 files changed, 173 insertions(+), 34 deletions(-) diff --git a/modules/mysql/Common.hpp b/modules/mysql/Common.hpp index 40f1f638..e9a35001 100644 --- a/modules/mysql/Common.hpp +++ b/modules/mysql/Common.hpp @@ -29,17 +29,20 @@ namespace SqMod { #define SQMOD_VALIDATE_CREATED(x) (x).ValidateCreated(__FILE__, __LINE__) #define SQMOD_VALIDATE_PARAM(x, i) (x).ValidateParam((i), __FILE__, __LINE__) #define SQMOD_VALIDATE_FIELD(x, i) (x).ValidateField((i), __FILE__, __LINE__) + #define SQMOD_VALIDATE_STEPPED(x) (x).ValidateStepped(__FILE__, __LINE__) #define SQMOD_GET_VALID(x) (x).GetValid(__FILE__, __LINE__) #define SQMOD_GET_CREATED(x) (x).GetCreated(__FILE__, __LINE__) + #define SQMOD_GET_STEPPED(x) (x).GetStepped(__FILE__, __LINE__) #else #define SQMOD_THROW_CURRENT(x, a) (x).ThrowCurrent(a) #define SQMOD_VALIDATE(x) (x).Validate() #define SQMOD_VALIDATE_CREATED(x) (x).ValidateCreated() #define SQMOD_VALIDATE_PARAM(x, i) (x).ValidateParam((i)) #define SQMOD_VALIDATE_FIELD(x, i) (x).ValidateField((i)) - #define SQMOD_VALIDATE_ROW(x) (x).ValidateRow() + #define SQMOD_VALIDATE_STEPPED(x) (x).ValidateStepped() #define SQMOD_GET_VALID(x) (x).GetValid() #define SQMOD_GET_CREATED(x) (x).GetCreated() + #define SQMOD_GET_STEPPED(x) (x).GetStepped() #endif // _DEBUG /* ------------------------------------------------------------------------------------------------ diff --git a/modules/mysql/Field.cpp b/modules/mysql/Field.cpp index b9a77454..13fc447d 100644 --- a/modules/mysql/Field.cpp +++ b/modules/mysql/Field.cpp @@ -88,6 +88,41 @@ void Field::ValidateCreated() const } #endif // _DEBUG +// ------------------------------------------------------------------------------------------------ +#if defined(_DEBUG) || defined(SQMOD_EXCEPTLOC) +void Field::ValidateStepped(CCStr file, Int32 line) const +{ + // Do we have a valid result-set handle? + if (!m_Handle) + { + SqThrowF("Invalid MySQL result-set reference =>[%s:%d]", file, line); + } + // Do we have a valid row available? + else if (m_Handle->mRow == nullptr) + { + SqThrowF("No row available in MySQL result-set =>[%s:%d]", file, line); + } + // Are we pointing to a valid index? + m_Handle->ValidateField(m_Index, file, line); +} +#else +void Field::ValidateStepped() const +{ + // Do we have a valid result-set handle? + if (!m_Handle) + { + SqThrowF("Invalid MySQL result-set reference"); + } + // Do we have a valid row available? + else if (m_Handle->mRow == nullptr) + { + SqThrowF("No row available in MySQL result-set"); + } + // Are we pointing to a valid index? + m_Handle->ValidateField(m_Index); +} +#endif // _DEBUG + // ------------------------------------------------------------------------------------------------ #if defined(_DEBUG) || defined(SQMOD_EXCEPTLOC) const ResRef & Field::GetValid(CCStr file, Int32 line) const @@ -118,6 +153,21 @@ const ResRef & Field::GetCreated() const } #endif // _DEBUG +// ------------------------------------------------------------------------------------------------ +#if defined(_DEBUG) || defined(SQMOD_EXCEPTLOC) +const ResRef & Field::GetStepped(CCStr file, Int32 line) const +{ + ValidateStepped(file, line); + return m_Handle; +} +#else +const ResRef & Field::GetStepped() const +{ + ValidateStepped(); + return m_Handle; +} +#endif // _DEBUG + // ------------------------------------------------------------------------------------------------ #if defined(_DEBUG) || defined(SQMOD_EXCEPTLOC) void Field::ValidateField(Uint32 idx, CCStr file, Int32 line) const @@ -232,7 +282,7 @@ Object Field::GetConnection() const // ------------------------------------------------------------------------------------------------ bool Field::GetBoolean() const { - SQMOD_VALIDATE_CREATED(*this); + SQMOD_VALIDATE_STEPPED(*this); // Should we retrieve the value from the bind wrapper? if (m_Handle->mStatement) { @@ -247,7 +297,7 @@ bool Field::GetBoolean() const // ------------------------------------------------------------------------------------------------ SQChar Field::GetChar() const { - SQMOD_VALIDATE_CREATED(*this); + SQMOD_VALIDATE_STEPPED(*this); // Should we retrieve the value from the bind wrapper? if (m_Handle->mStatement) { @@ -262,7 +312,7 @@ SQChar Field::GetChar() const // ------------------------------------------------------------------------------------------------ SQInteger Field::GetInteger() const { - SQMOD_VALIDATE_CREATED(*this); + SQMOD_VALIDATE_STEPPED(*this); // Should we retrieve the value from the bind wrapper? if (m_Handle->mStatement) { @@ -281,7 +331,7 @@ SQInteger Field::GetInteger() const // ------------------------------------------------------------------------------------------------ SQFloat Field::GetFloat() const { - SQMOD_VALIDATE_CREATED(*this); + SQMOD_VALIDATE_STEPPED(*this); // Should we retrieve the value from the bind wrapper? if (m_Handle->mStatement) { @@ -300,7 +350,7 @@ SQFloat Field::GetFloat() const // ------------------------------------------------------------------------------------------------ SQInteger Field::GetInt8() const { - SQMOD_VALIDATE_CREATED(*this); + SQMOD_VALIDATE_STEPPED(*this); // Should we retrieve the value from the bind wrapper? if (m_Handle->mStatement) { @@ -315,7 +365,7 @@ SQInteger Field::GetInt8() const // ------------------------------------------------------------------------------------------------ SQInteger Field::GetUint8() const { - SQMOD_VALIDATE_CREATED(*this); + SQMOD_VALIDATE_STEPPED(*this); // Should we retrieve the value from the bind wrapper? if (m_Handle->mStatement) { @@ -330,7 +380,7 @@ SQInteger Field::GetUint8() const // ------------------------------------------------------------------------------------------------ SQInteger Field::GetInt16() const { - SQMOD_VALIDATE_CREATED(*this); + SQMOD_VALIDATE_STEPPED(*this); // Should we retrieve the value from the bind wrapper? if (m_Handle->mStatement) { @@ -345,7 +395,7 @@ SQInteger Field::GetInt16() const // ------------------------------------------------------------------------------------------------ SQInteger Field::GetUint16() const { - SQMOD_VALIDATE_CREATED(*this); + SQMOD_VALIDATE_STEPPED(*this); // Should we retrieve the value from the bind wrapper? if (m_Handle->mStatement) { @@ -360,7 +410,7 @@ SQInteger Field::GetUint16() const // ------------------------------------------------------------------------------------------------ SQInteger Field::GetInt32() const { - SQMOD_VALIDATE_CREATED(*this); + SQMOD_VALIDATE_STEPPED(*this); // Should we retrieve the value from the bind wrapper? if (m_Handle->mStatement) { @@ -375,7 +425,7 @@ SQInteger Field::GetInt32() const // ------------------------------------------------------------------------------------------------ SQInteger Field::GetUint32() const { - SQMOD_VALIDATE_CREATED(*this); + SQMOD_VALIDATE_STEPPED(*this); // Should we retrieve the value from the bind wrapper? if (m_Handle->mStatement) { @@ -390,7 +440,7 @@ SQInteger Field::GetUint32() const // ------------------------------------------------------------------------------------------------ Object Field::GetInt64() const { - SQMOD_VALIDATE_CREATED(*this); + SQMOD_VALIDATE_STEPPED(*this); // Obtain the initial stack size const StackGuard sg; // Should we retrieve the value from the bind wrapper? @@ -414,7 +464,7 @@ Object Field::GetInt64() const // ------------------------------------------------------------------------------------------------ Object Field::GetUint64() const { - SQMOD_VALIDATE_CREATED(*this); + SQMOD_VALIDATE_STEPPED(*this); // Obtain the initial stack size const StackGuard sg; // Should we retrieve the value from the bind wrapper? @@ -438,7 +488,7 @@ Object Field::GetUint64() const // ------------------------------------------------------------------------------------------------ SQFloat Field::GetFloat32() const { - SQMOD_VALIDATE_CREATED(*this); + SQMOD_VALIDATE_STEPPED(*this); // Should we retrieve the value from the bind wrapper? if (m_Handle->mStatement) { @@ -453,7 +503,7 @@ SQFloat Field::GetFloat32() const // ------------------------------------------------------------------------------------------------ SQFloat Field::GetFloat64() const { - SQMOD_VALIDATE_CREATED(*this); + SQMOD_VALIDATE_STEPPED(*this); // Should we retrieve the value from the bind wrapper? if (m_Handle->mStatement) { @@ -468,7 +518,7 @@ SQFloat Field::GetFloat64() const // ------------------------------------------------------------------------------------------------ Object Field::GetString() const { - SQMOD_VALIDATE_CREATED(*this); + SQMOD_VALIDATE_STEPPED(*this); // Obtain the initial stack size const StackGuard sg; // Retrieve the value directly from the row and push it on the stack diff --git a/modules/mysql/Field.hpp b/modules/mysql/Field.hpp index 1167444d..6ffa3fa4 100644 --- a/modules/mysql/Field.hpp +++ b/modules/mysql/Field.hpp @@ -41,6 +41,15 @@ protected: void ValidateCreated() const; #endif // _DEBUG + /* -------------------------------------------------------------------------------------------- + * Validate the associated result-set handle, field index and row, and throw an error if invalid. + */ +#if defined(_DEBUG) || defined(SQMOD_EXCEPTLOC) + void ValidateStepped(CCStr file, Int32 line) const; +#else + void ValidateStepped() const; +#endif // _DEBUG + /* -------------------------------------------------------------------------------------------- * Validate the associated result-set handle and field index, and throw an error if invalid. */ @@ -59,6 +68,15 @@ protected: const ResRef & GetCreated() const; #endif // _DEBUG + /* -------------------------------------------------------------------------------------------- + * Validate the associated result-set handle field index and row, and throw an error if invalid. + */ +#if defined(_DEBUG) || defined(SQMOD_EXCEPTLOC) + const ResRef & GetStepped(CCStr file, Int32 line) const; +#else + const ResRef & GetStepped() const; +#endif // _DEBUG + /* -------------------------------------------------------------------------------------------- * Validate the associated result-set handle and field index, and throw an error if invalid. */ diff --git a/modules/mysql/ResultSet.cpp b/modules/mysql/ResultSet.cpp index 797a35aa..c5e0ae1a 100644 --- a/modules/mysql/ResultSet.cpp +++ b/modules/mysql/ResultSet.cpp @@ -19,6 +19,7 @@ SQInteger ResultSet::Typename(HSQUIRRELVM vm) #if defined(_DEBUG) || defined(SQMOD_EXCEPTLOC) void ResultSet::Validate(CCStr file, Int32 line) const { + // Do we have a valid result-set handle? if (!m_Handle) { SqThrowF("Invalid MySQL result-set reference =>[%s:%d]", file, line); @@ -27,6 +28,7 @@ void ResultSet::Validate(CCStr file, Int32 line) const #else void ResultSet::Validate() const { + // Do we have a valid result-set handle? if (!m_Handle) { SqThrowF("Invalid MySQL result-set reference"); @@ -38,6 +40,7 @@ void ResultSet::Validate() const #if defined(_DEBUG) || defined(SQMOD_EXCEPTLOC) void ResultSet::ValidateCreated(CCStr file, Int32 line) const { + // Do we have a valid result-set handle? if (!m_Handle) { SqThrowF("Invalid MySQL result-set reference =>[%s:%d]", file, line); @@ -50,6 +53,7 @@ void ResultSet::ValidateCreated(CCStr file, Int32 line) const #else void ResultSet::ValidateCreated() const { + // Do we have a valid result-set handle? if (!m_Handle) { SqThrowF("Invalid MySQL result-set reference"); @@ -61,6 +65,37 @@ void ResultSet::ValidateCreated() const } #endif // _DEBUG +// ------------------------------------------------------------------------------------------------ +#if defined(_DEBUG) || defined(SQMOD_EXCEPTLOC) +void ResultSet::ValidateStepped(CCStr file, Int32 line) const +{ + // Do we have a valid result-set handle? + if (!m_Handle) + { + SqThrowF("Invalid MySQL result-set reference =>[%s:%d]", file, line); + } + // Do we have a valid row available? + else if (m_Handle->mRow == nullptr) + { + SqThrowF("No row available in MySQL result-set =>[%s:%d]", file, line); + } +} +#else +void ResultSet::ValidateStepped() const +{ + // Do we have a valid result-set handle? + if (!m_Handle) + { + SqThrowF("Invalid MySQL result-set reference"); + } + // Do we have a valid row available? + else if (m_Handle->mRow == nullptr) + { + SqThrowF("No row available in MySQL result-set"); + } +} +#endif // _DEBUG + // ------------------------------------------------------------------------------------------------ #if defined(_DEBUG) || defined(SQMOD_EXCEPTLOC) const ResRef & ResultSet::GetValid(CCStr file, Int32 line) const @@ -91,6 +126,21 @@ const ResRef & ResultSet::GetCreated() const } #endif // _DEBUG +// ------------------------------------------------------------------------------------------------ +#if defined(_DEBUG) || defined(SQMOD_EXCEPTLOC) +const ResRef & ResultSet::GetStepped(CCStr file, Int32 line) const +{ + ValidateStepped(file, line); + return m_Handle; +} +#else +const ResRef & ResultSet::GetStepped() const +{ + ValidateStepped(); + return m_Handle; +} +#endif // _DEBUG + // ------------------------------------------------------------------------------------------------ #if defined(_DEBUG) || defined(SQMOD_EXCEPTLOC) void ResultSet::ValidateField(Int32 idx, CCStr file, Int32 line) const diff --git a/modules/mysql/ResultSet.hpp b/modules/mysql/ResultSet.hpp index be6c7b64..060c23eb 100644 --- a/modules/mysql/ResultSet.hpp +++ b/modules/mysql/ResultSet.hpp @@ -38,6 +38,15 @@ protected: void ValidateCreated() const; #endif // _DEBUG + /* -------------------------------------------------------------------------------------------- + * Validate the managed statement handle and row, and throw an error if invalid. + */ +#if defined(_DEBUG) || defined(SQMOD_EXCEPTLOC) + void ValidateStepped(CCStr file, Int32 line) const; +#else + void ValidateStepped() const; +#endif // _DEBUG + /* -------------------------------------------------------------------------------------------- * Validate the managed statement handle and throw an error if invalid. */ @@ -56,6 +65,15 @@ protected: const ResRef & GetCreated() const; #endif // _DEBUG + /* -------------------------------------------------------------------------------------------- + * Validate the managed statement handle and row, and throw an error if invalid. + */ +#if defined(_DEBUG) || defined(SQMOD_EXCEPTLOC) + const ResRef & GetStepped(CCStr file, Int32 line) const; +#else + const ResRef & GetStepped() const; +#endif // _DEBUG + /* -------------------------------------------------------------------------------------------- * Validate the statement reference and field index, and throw an error if they're invalid. */ @@ -219,7 +237,7 @@ public: */ Field GetField(const Object & field) const { - return Field(SQMOD_GET_CREATED(*this), field); + return Field(SQMOD_GET_STEPPED(*this), field); } /* -------------------------------------------------------------------------------------------- @@ -227,7 +245,7 @@ public: */ bool GetBoolean(const Object & field) const { - return Field(SQMOD_GET_CREATED(*this), field).GetBoolean(); + return Field(SQMOD_GET_STEPPED(*this), field).GetBoolean(); } /* -------------------------------------------------------------------------------------------- @@ -235,7 +253,7 @@ public: */ SQChar GetChar(const Object & field) const { - return Field(SQMOD_GET_CREATED(*this), field).GetChar(); + return Field(SQMOD_GET_STEPPED(*this), field).GetChar(); } /* -------------------------------------------------------------------------------------------- @@ -243,7 +261,7 @@ public: */ SQInteger GetInteger(const Object & field) const { - return Field(SQMOD_GET_CREATED(*this), field).GetInteger(); + return Field(SQMOD_GET_STEPPED(*this), field).GetInteger(); } /* -------------------------------------------------------------------------------------------- @@ -251,7 +269,7 @@ public: */ SQFloat GetFloat(const Object & field) const { - return Field(SQMOD_GET_CREATED(*this), field).GetFloat(); + return Field(SQMOD_GET_STEPPED(*this), field).GetFloat(); } /* -------------------------------------------------------------------------------------------- @@ -259,7 +277,7 @@ public: */ SQInteger GetInt8(const Object & field) const { - return Field(SQMOD_GET_CREATED(*this), field).GetInt8(); + return Field(SQMOD_GET_STEPPED(*this), field).GetInt8(); } /* -------------------------------------------------------------------------------------------- @@ -267,7 +285,7 @@ public: */ SQInteger GetUint8(const Object & field) const { - return Field(SQMOD_GET_CREATED(*this), field).GetUint8(); + return Field(SQMOD_GET_STEPPED(*this), field).GetUint8(); } /* -------------------------------------------------------------------------------------------- @@ -275,7 +293,7 @@ public: */ SQInteger GetInt16(const Object & field) const { - return Field(SQMOD_GET_CREATED(*this), field).GetInt16(); + return Field(SQMOD_GET_STEPPED(*this), field).GetInt16(); } /* -------------------------------------------------------------------------------------------- @@ -283,7 +301,7 @@ public: */ SQInteger GetUint16(const Object & field) const { - return Field(SQMOD_GET_CREATED(*this), field).GetUint16(); + return Field(SQMOD_GET_STEPPED(*this), field).GetUint16(); } /* -------------------------------------------------------------------------------------------- @@ -291,7 +309,7 @@ public: */ SQInteger GetInt32(const Object & field) const { - return Field(SQMOD_GET_CREATED(*this), field).GetInt32(); + return Field(SQMOD_GET_STEPPED(*this), field).GetInt32(); } /* -------------------------------------------------------------------------------------------- @@ -299,7 +317,7 @@ public: */ SQInteger GetUint32(const Object & field) const { - return Field(SQMOD_GET_CREATED(*this), field).GetUint32(); + return Field(SQMOD_GET_STEPPED(*this), field).GetUint32(); } /* -------------------------------------------------------------------------------------------- @@ -307,7 +325,7 @@ public: */ Object GetInt64(const Object & field) const { - return Field(SQMOD_GET_CREATED(*this), field).GetInt64(); + return Field(SQMOD_GET_STEPPED(*this), field).GetInt64(); } /* -------------------------------------------------------------------------------------------- @@ -315,7 +333,7 @@ public: */ Object GetUint64(const Object & field) const { - return Field(SQMOD_GET_CREATED(*this), field).GetUint64(); + return Field(SQMOD_GET_STEPPED(*this), field).GetUint64(); } /* -------------------------------------------------------------------------------------------- @@ -323,7 +341,7 @@ public: */ SQFloat GetFloat32(const Object & field) const { - return Field(SQMOD_GET_CREATED(*this), field).GetFloat32(); + return Field(SQMOD_GET_STEPPED(*this), field).GetFloat32(); } /* -------------------------------------------------------------------------------------------- @@ -331,7 +349,7 @@ public: */ SQFloat GetFloat64(const Object & field) const { - return Field(SQMOD_GET_CREATED(*this), field).GetFloat64(); + return Field(SQMOD_GET_STEPPED(*this), field).GetFloat64(); } /* -------------------------------------------------------------------------------------------- @@ -339,7 +357,7 @@ public: */ Object GetString(const Object & field) const { - return Field(SQMOD_GET_CREATED(*this), field).GetString(); + return Field(SQMOD_GET_STEPPED(*this), field).GetString(); } /* -------------------------------------------------------------------------------------------- @@ -347,7 +365,7 @@ public: */ Object GetBuffer(const Object & field) const { - return Field(SQMOD_GET_CREATED(*this), field).GetBuffer(); + return Field(SQMOD_GET_STEPPED(*this), field).GetBuffer(); } /* -------------------------------------------------------------------------------------------- @@ -355,7 +373,7 @@ public: */ Object GetBlob(const Object & field) const { - return Field(SQMOD_GET_CREATED(*this), field).GetBlob(); + return Field(SQMOD_GET_STEPPED(*this), field).GetBlob(); } };