Atmosphere/exosphere/program/source/boot/secmon_boot_rsa.cpp
2020-06-14 22:07:45 -07:00

159 lines
6.3 KiB
C++

/*
* Copyright (c) 2018-2020 Atmosphère-NX
*
* This program is free software; you can redistribute it and/or modify it
* under the terms and conditions of the GNU General Public License,
* version 2, as published by the Free Software Foundation.
*
* This program is distributed in the hope it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include <exosphere.hpp>
#include "secmon_boot.hpp"
namespace ams::secmon::boot {
namespace {
constinit const u8 RsaPublicKeyExponent[] = {
0x00, 0x01, 0x00, 0x01,
};
constexpr inline u8 TailMagic = 0xBC;
bool VerifyRsaPssSha256(const u8 *sig, const void *msg, size_t msg_size) {
/* Define constants. */
constexpr int EmBits = 2047;
constexpr int EmLen = util::DivideUp(EmBits, BITSIZEOF(u8));
constexpr int SaltLen = 0x20;
constexpr int HashLen = se::Sha256HashSize;
/* Define a work buffer. */
u8 work[EmLen];
ON_SCOPE_EXIT { util::ClearMemory(work, sizeof(work)); };
/* Calculate the message hash, first flushing cache to ensure SE sees correct data. */
se::Sha256Hash msg_hash;
hw::FlushDataCache(msg, msg_size);
hw::DataSynchronizationBarrierInnerShareable();
se::CalculateSha256(std::addressof(msg_hash), msg, msg_size);
/* Verify the tail magic. */
bool is_valid = sig[EmLen - 1] == TailMagic;
/* Determine extents of masked db and h. */
const u8 *masked_db = std::addressof(sig[0]);
const u8 *h = std::addressof(sig[EmLen - HashLen - 1]);
/* Verify the extra bits are zero. */
is_valid &= (masked_db[0] >> (BITSIZEOF(u8) - (BITSIZEOF(u8) * EmLen - EmBits))) == 0;
/* Calculate the db mask. */
{
constexpr int MaskLen = EmLen - HashLen - 1;
constexpr int HashIters = util::DivideUp(MaskLen, HashLen);
u8 mgf1_buf[sizeof(u32) + HashLen];
std::memcpy(std::addressof(mgf1_buf[0]), h, HashLen);
std::memset(std::addressof(mgf1_buf[HashLen]), 0, sizeof(u32));
for (int i = 0; i < HashIters; ++i) {
/* Set the counter for this iteration. */
mgf1_buf[sizeof(mgf1_buf) - 1] = i;
/* Calculate the sha256 to the appropriate place in the work buffer. */
auto *mgf1_dst = reinterpret_cast<se::Sha256Hash *>(std::addressof(work[HashLen * i]));
hw::FlushDataCache(mgf1_buf, sizeof(mgf1_buf));
hw::DataSynchronizationBarrierInnerShareable();
se::CalculateSha256(mgf1_dst, mgf1_buf, sizeof(mgf1_buf));
}
}
/* Decrypt masked db using the mask we just generated. */
for (int i = 0; i < EmLen - HashLen - 1; ++i) {
work[i] ^= masked_db[i];
}
/* Mask out the top bits. */
u8 *db = work;
db[0] &= 0xFF >> (BITSIZEOF(u8) * EmLen - EmBits);
/* Verify that DB is of the form 0000...0001 */
constexpr int DbLen = EmLen - HashLen - 1;
int salt_ofs = 0;
{
int looking_for_one = 1;
int invalid_db_padding = 0;
int is_zero;
int is_one;
for (size_t i = 0; i < DbLen; /* ... */) {
is_zero = (db[i] == 0);
is_one = (db[i] == 1);
salt_ofs += (looking_for_one & is_one) * (static_cast<s32>(++i));
looking_for_one &= ~is_one;
invalid_db_padding |= (looking_for_one & ~is_zero);
}
is_valid &= (invalid_db_padding == 0);
}
/* Verify salt. */
is_valid &= (DbLen - salt_ofs) == SaltLen;
/* Setup the message to verify. */
const u8 *salt = std::addressof(db[DbLen - SaltLen]);
u8 verif_msg[8 + HashLen + SaltLen];
ON_SCOPE_EXIT { util::ClearMemory(verif_msg, sizeof(verif_msg)); };
util::ClearMemory(std::addressof(verif_msg[0]), 8);
std::memcpy(std::addressof(verif_msg[8]), std::addressof(msg_hash), HashLen);
std::memcpy(std::addressof(verif_msg[8 + HashLen]), salt, SaltLen);
/* Verify the final hash. */
return VerifyHash(h, reinterpret_cast<uintptr_t>(std::addressof(verif_msg[0])), sizeof(verif_msg));
}
bool VerifyRsaPssSha256(int slot, void *sig, size_t sig_size, const void *msg, size_t msg_size) {
/* Exponentiate the signature, using the signature as the destination buffer. */
se::ModularExponentiate(sig, sig_size, slot, sig, sig_size);
/* Verify the pss padding. */
return VerifyRsaPssSha256(static_cast<const u8 *>(sig), msg, msg_size);
}
}
bool VerifySignature(void *sig, size_t sig_size, const void *mod, size_t mod_size, const void *msg, size_t msg_size) {
/* Load the public key into a temporary keyslot. */
const int slot = pkg1::RsaKeySlot_Temporary;
se::SetRsaKey(slot, mod, mod_size, RsaPublicKeyExponent, util::size(RsaPublicKeyExponent));
return VerifyRsaPssSha256(slot, sig, sig_size, msg, msg_size);
}
bool VerifyHash(const void *hash, uintptr_t msg, size_t msg_size) {
/* Zero-sized messages are always valid. */
if (msg_size == 0) {
return true;
}
/* Ensure that the SE sees correct data for the message. */
hw::FlushDataCache(reinterpret_cast<void *>(msg), msg_size);
hw::DataSynchronizationBarrierInnerShareable();
/* Calculate the hash. */
se::Sha256Hash calc_hash;
se::CalculateSha256(std::addressof(calc_hash), reinterpret_cast<void *>(msg), msg_size);
/* Verify the result. */
return crypto::IsSameBytes(std::addressof(calc_hash), hash, sizeof(calc_hash));
}
}