Skip to content

Commit 0550db2

Browse files
committed
LLVMBuildUtils: Add runtime optimization for string comparison
1 parent 3be5c62 commit 0550db2

File tree

2 files changed

+161
-4
lines changed

2 files changed

+161
-4
lines changed

src/engine/internal/llvm/llvmbuildutils.cpp

Lines changed: 157 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1333,16 +1333,42 @@ llvm::Value *LLVMBuildUtils::createStringComparison(LLVMRegister *arg1, LLVMRegi
13331333
return m_builder.getInt1(result);
13341334
} else {
13351335
// Optimize number and string constant comparison
1336+
// If there's a non-numeric string constant and the other operand is a valid number, the result is false
13361337
// TODO: Optimize bool and string constant comparison (in compare() as well)
1337-
if ((type1 == Compiler::StaticType::Number && type2 == Compiler::StaticType::String && arg2->isConst() && !arg2->constValue().isValidNumber()) ||
1338-
(type1 == Compiler::StaticType::String && type2 == Compiler::StaticType::Number && arg1->isConst() && !arg1->constValue().isValidNumber()))
1339-
return m_builder.getInt1(false);
1338+
llvm::Value *optimize;
1339+
1340+
if (type1 == Compiler::StaticType::String && arg1->isConst() && !arg1->constValue().isValidNumber())
1341+
optimize = valueIsValidNumber(arg2);
1342+
else if (type2 == Compiler::StaticType::String && arg2->isConst() && !arg2->constValue().isValidNumber())
1343+
optimize = valueIsValidNumber(arg1);
1344+
else
1345+
optimize = m_builder.getInt1(false);
1346+
1347+
llvm::BasicBlock *optimizedBlock = llvm::BasicBlock::Create(m_llvmCtx, "stringComparison.optimized", m_function);
1348+
llvm::BasicBlock *compareBlock = llvm::BasicBlock::Create(m_llvmCtx, "stringComparison.compare", m_function);
1349+
llvm::BasicBlock *nextBlock = llvm::BasicBlock::Create(m_llvmCtx, "stringComparison.next", m_function);
1350+
m_builder.CreateCondBr(optimize, optimizedBlock, compareBlock);
1351+
1352+
m_builder.SetInsertPoint(optimizedBlock);
1353+
m_builder.CreateBr(nextBlock);
1354+
1355+
m_builder.SetInsertPoint(compareBlock);
13401356

13411357
// Explicitly cast to string
13421358
llvm::Value *string1 = castValue(arg1, Compiler::StaticType::String);
13431359
llvm::Value *string2 = castValue(arg2, Compiler::StaticType::String);
13441360
llvm::Value *cmp = m_builder.CreateCall(caseSensitive ? m_functions.resolve_string_compare_case_sensitive() : m_functions.resolve_string_compare_case_insensitive(), { string1, string2 });
1345-
return m_builder.CreateICmpEQ(cmp, m_builder.getInt32(0));
1361+
llvm::Value *result = m_builder.CreateICmpEQ(cmp, m_builder.getInt32(0));
1362+
m_builder.CreateBr(nextBlock);
1363+
1364+
llvm::BasicBlock *compareBlockNext = m_builder.GetInsertBlock();
1365+
m_builder.SetInsertPoint(nextBlock);
1366+
1367+
llvm::PHINode *phi = m_builder.CreatePHI(m_builder.getInt1Ty(), 2, "stringComparison.result");
1368+
phi->addIncoming(m_builder.getInt1(false), optimizedBlock);
1369+
phi->addIncoming(result, compareBlockNext);
1370+
1371+
return phi;
13461372
}
13471373
}
13481374

@@ -1581,6 +1607,133 @@ llvm::Constant *LLVMBuildUtils::castConstValue(const Value &value, Compiler::Sta
15811607
}
15821608
}
15831609

1610+
llvm::Value *LLVMBuildUtils::valueIsValidNumber(LLVMRegister *reg)
1611+
{
1612+
if (reg->isConst())
1613+
return m_builder.getInt1(reg->constValue().isValidNumber());
1614+
1615+
if (reg->isRawValue)
1616+
return rawValueIsValidNumber(reg);
1617+
1618+
assert(reg->type() != Compiler::StaticType::Void);
1619+
1620+
// Handle multiple type cases with runtime switch
1621+
llvm::Value *typePtr = getValueTypePtr(reg);
1622+
llvm::Value *loadedType = m_builder.CreateLoad(m_builder.getInt32Ty(), typePtr);
1623+
1624+
llvm::BasicBlock *mergeBlock = llvm::BasicBlock::Create(m_llvmCtx, "merge", m_function);
1625+
llvm::BasicBlock *defaultBlock = llvm::BasicBlock::Create(m_llvmCtx, "default", m_function);
1626+
1627+
llvm::SwitchInst *sw = m_builder.CreateSwitch(loadedType, defaultBlock, 4);
1628+
std::vector<std::pair<llvm::BasicBlock *, llvm::Value *>> results;
1629+
1630+
Compiler::StaticType type = reg->type();
1631+
1632+
// Number case
1633+
if ((type & Compiler::StaticType::Number) == Compiler::StaticType::Number) {
1634+
llvm::BasicBlock *numberBlock = llvm::BasicBlock::Create(m_llvmCtx, "isValidNumber.number", m_function);
1635+
sw->addCase(m_builder.getInt32(static_cast<uint32_t>(ValueType::Number)), numberBlock);
1636+
1637+
m_builder.SetInsertPoint(numberBlock);
1638+
llvm::Value *ptr = m_builder.CreateStructGEP(m_valueDataType, reg->value, 0);
1639+
llvm::Value *number = m_builder.CreateLoad(m_builder.getDoubleTy(), ptr);
1640+
llvm::Value *numberResult = m_builder.CreateOr(reg->isInt, m_builder.CreateNot(isNaN(number)));
1641+
m_builder.CreateBr(mergeBlock);
1642+
results.push_back({ numberBlock, numberResult });
1643+
}
1644+
1645+
// Bool case
1646+
if ((type & Compiler::StaticType::Bool) == Compiler::StaticType::Bool) {
1647+
llvm::BasicBlock *boolBlock = llvm::BasicBlock::Create(m_llvmCtx, "isValidNumber.bool", m_function);
1648+
sw->addCase(m_builder.getInt32(static_cast<uint32_t>(ValueType::Bool)), boolBlock);
1649+
1650+
m_builder.SetInsertPoint(boolBlock);
1651+
llvm::Value *boolResult = m_builder.getInt1(true);
1652+
m_builder.CreateBr(mergeBlock);
1653+
results.push_back({ boolBlock, boolResult });
1654+
}
1655+
1656+
// String case
1657+
if ((type & Compiler::StaticType::String) == Compiler::StaticType::String) {
1658+
llvm::BasicBlock *stringBlock = llvm::BasicBlock::Create(m_llvmCtx, "isValidNumber.string", m_function);
1659+
sw->addCase(m_builder.getInt32(static_cast<uint32_t>(ValueType::String)), stringBlock);
1660+
1661+
m_builder.SetInsertPoint(stringBlock);
1662+
llvm::Value *ptr = m_builder.CreateStructGEP(m_valueDataType, reg->value, 0);
1663+
llvm::Value *stringPtr = m_builder.CreateLoad(m_stringPtrType->getPointerTo(), ptr);
1664+
1665+
llvm::Value *stringResult = stringIsValidNumber(stringPtr);
1666+
1667+
m_builder.CreateBr(mergeBlock);
1668+
results.push_back({ m_builder.GetInsertBlock(), stringResult });
1669+
}
1670+
1671+
// Default case
1672+
m_builder.SetInsertPoint(defaultBlock);
1673+
1674+
// All possible types are covered, mark as unreachable
1675+
m_builder.CreateUnreachable();
1676+
1677+
// Create phi node to merge results
1678+
m_builder.SetInsertPoint(mergeBlock);
1679+
1680+
llvm::PHINode *result = m_builder.CreatePHI(m_builder.getInt1Ty(), results.size());
1681+
1682+
for (auto &pair : results)
1683+
result->addIncoming(pair.second, pair.first);
1684+
1685+
return result;
1686+
}
1687+
1688+
llvm::Value *LLVMBuildUtils::rawValueIsValidNumber(LLVMRegister *reg)
1689+
{
1690+
switch (reg->type()) {
1691+
case Compiler::StaticType::Number:
1692+
return m_builder.CreateOr(reg->isInt, m_builder.CreateNot(isNaN(reg->value)));
1693+
1694+
case Compiler::StaticType::Bool:
1695+
return m_builder.getInt1(true);
1696+
1697+
case Compiler::StaticType::String:
1698+
return stringIsValidNumber(reg->value);
1699+
1700+
default:
1701+
assert(false);
1702+
return nullptr;
1703+
}
1704+
}
1705+
1706+
llvm::Value *LLVMBuildUtils::stringIsValidNumber(llvm::Value *stringPtr)
1707+
{
1708+
llvm::Value *stringSizePtr = m_builder.CreateStructGEP(m_stringPtrType, stringPtr, 1);
1709+
llvm::Value *stringSize = m_builder.CreateLoad(m_builder.getInt64Ty(), stringSizePtr);
1710+
llvm::Value *empty = m_builder.CreateICmpEQ(stringSize, m_builder.getInt64(0));
1711+
1712+
// If the string is empty, return true, otherwise call value_stringToDoubleWithCheck()
1713+
llvm::BasicBlock *emptyBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
1714+
llvm::BasicBlock *castBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
1715+
llvm::BasicBlock *nextBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
1716+
m_builder.CreateCondBr(empty, emptyBlock, castBlock);
1717+
1718+
m_builder.SetInsertPoint(emptyBlock);
1719+
m_builder.CreateBr(nextBlock);
1720+
1721+
m_builder.SetInsertPoint(castBlock);
1722+
llvm::Value *okPtr = addAlloca(m_builder.getInt1Ty());
1723+
m_builder.CreateCall(m_functions.resolve_value_stringToDoubleWithCheck(), { stringPtr, okPtr });
1724+
1725+
llvm::Value *ok = m_builder.CreateLoad(m_builder.getInt1Ty(), okPtr);
1726+
m_builder.CreateBr(nextBlock);
1727+
1728+
m_builder.SetInsertPoint(nextBlock);
1729+
1730+
llvm::PHINode *phi = m_builder.CreatePHI(m_builder.getInt1Ty(), 2, "stringIsValidNumber");
1731+
phi->addIncoming(m_builder.getInt1(true), emptyBlock);
1732+
phi->addIncoming(ok, castBlock);
1733+
1734+
return phi;
1735+
}
1736+
15841737
void LLVMBuildUtils::createValueCopy(llvm::Value *source, llvm::Value *target)
15851738
{
15861739
// NOTE: This doesn't copy strings, but only the pointers

src/engine/internal/llvm/llvmbuildutils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,10 @@ class LLVMBuildUtils
129129
llvm::Value *castRawValue(LLVMRegister *reg, Compiler::StaticType targetType, NumberType targetNumType);
130130
llvm::Constant *castConstValue(const Value &value, Compiler::StaticType targetType, NumberType targetNumType);
131131

132+
llvm::Value *valueIsValidNumber(LLVMRegister *reg);
133+
llvm::Value *rawValueIsValidNumber(LLVMRegister *reg);
134+
llvm::Value *stringIsValidNumber(llvm::Value *stringPtr);
135+
132136
void createValueCopy(llvm::Value *source, llvm::Value *target);
133137
void copyStructField(llvm::Value *source, llvm::Value *target, int index, llvm::StructType *structType, llvm::Type *fieldType);
134138

0 commit comments

Comments
 (0)