@@ -488,100 +488,62 @@ static bool isTrivialFiller(Expr *E) {
488488 return false ;
489489}
490490
491- static void EmitHLSLAggregateSplatCast (CodeGenFunction &CGF, Address DestVal,
492- QualType DestTy, llvm::Value *SrcVal,
493- QualType SrcTy, SourceLocation Loc) {
491+ // emit an elementwise cast where the RHS is a scalar or vector
492+ // or emit an aggregate splat cast
493+ static void EmitHLSLScalarElementwiseAndSplatCasts (CodeGenFunction &CGF,
494+ LValue DestVal,
495+ llvm::Value *SrcVal,
496+ QualType SrcTy,
497+ SourceLocation Loc) {
494498 // Flatten our destination
495- SmallVector<QualType> DestTypes; // Flattened type
496- SmallVector<std::pair<Address, llvm::Value *>, 16 > StoreGEPList;
497- // ^^ Flattened accesses to DestVal we want to store into
498- CGF.FlattenAccessAndType (DestVal, DestTy, StoreGEPList, DestTypes);
499-
500- assert (SrcTy->isScalarType () && " Invalid HLSL Aggregate splat cast." );
501- for (unsigned I = 0 , Size = StoreGEPList.size (); I < Size; ++I) {
502- llvm::Value *Cast =
503- CGF.EmitScalarConversion (SrcVal, SrcTy, DestTypes[I], Loc);
504-
505- // store back
506- llvm::Value *Idx = StoreGEPList[I].second ;
507- if (Idx) {
508- llvm::Value *V =
509- CGF.Builder .CreateLoad (StoreGEPList[I].first , " load.for.insert" );
510- Cast = CGF.Builder .CreateInsertElement (V, Cast, Idx);
511- }
512- CGF.Builder .CreateStore (Cast, StoreGEPList[I].first );
513- }
514- }
515-
516- // emit a flat cast where the RHS is a scalar, including vector
517- static void EmitHLSLScalarFlatCast (CodeGenFunction &CGF, Address DestVal,
518- QualType DestTy, llvm::Value *SrcVal,
519- QualType SrcTy, SourceLocation Loc) {
520- // Flatten our destination
521- SmallVector<QualType, 16 > DestTypes; // Flattened type
522- SmallVector<std::pair<Address, llvm::Value *>, 16 > StoreGEPList;
523- // ^^ Flattened accesses to DestVal we want to store into
524- CGF.FlattenAccessAndType (DestVal, DestTy, StoreGEPList, DestTypes);
525-
526- assert (SrcTy->isVectorType () && " HLSL Flat cast doesn't handle splatting." );
527- const VectorType *VT = SrcTy->getAs <VectorType>();
528- SrcTy = VT->getElementType ();
529- assert (StoreGEPList.size () <= VT->getNumElements () &&
530- " Cannot perform HLSL flat cast when vector source \
531- object has less elements than flattened destination \
532- object." );
533- for (unsigned I = 0 , Size = StoreGEPList.size (); I < Size; I++) {
534- llvm::Value *Load = CGF.Builder .CreateExtractElement (SrcVal, I, " vec.load" );
499+ SmallVector<LValue, 16 > StoreList;
500+ CGF.FlattenAccessAndTypeLValue (DestVal, StoreList);
501+
502+ bool isVector = false ;
503+ if (auto *VT = SrcTy->getAs <VectorType>()) {
504+ isVector = true ;
505+ SrcTy = VT->getElementType ();
506+ assert (StoreList.size () <= VT->getNumElements () &&
507+ " Cannot perform HLSL flat cast when vector source \
508+ object has less elements than flattened destination \
509+ object." );
510+ }
511+
512+ for (unsigned I = 0 , Size = StoreList.size (); I < Size; I++) {
513+ LValue DestLVal = StoreList[I];
514+ llvm::Value *Load =
515+ isVector ? CGF.Builder .CreateExtractElement (SrcVal, I, " vec.load" )
516+ : SrcVal;
535517 llvm::Value *Cast =
536- CGF.EmitScalarConversion (Load, SrcTy, DestTypes[I], Loc);
537-
538- // store back
539- llvm::Value *Idx = StoreGEPList[I].second ;
540- if (Idx) {
541- llvm::Value *V =
542- CGF.Builder .CreateLoad (StoreGEPList[I].first , " load.for.insert" );
543- Cast = CGF.Builder .CreateInsertElement (V, Cast, Idx);
544- }
545- CGF.Builder .CreateStore (Cast, StoreGEPList[I].first );
518+ CGF.EmitScalarConversion (Load, SrcTy, DestLVal.getType (), Loc);
519+ CGF.EmitStoreThroughLValue (RValue::get (Cast), DestLVal);
546520 }
547521}
548522
549523// emit a flat cast where the RHS is an aggregate
550- static void EmitHLSLElementwiseCast (CodeGenFunction &CGF, Address DestVal,
551- QualType DestTy, Address SrcVal,
552- QualType SrcTy, SourceLocation Loc) {
524+ static void EmitHLSLElementwiseCast (CodeGenFunction &CGF, LValue DestVal,
525+ LValue SrcVal, SourceLocation Loc) {
553526 // Flatten our destination
554- SmallVector<QualType, 16 > DestTypes; // Flattened type
555- SmallVector<std::pair<Address, llvm::Value *>, 16 > StoreGEPList;
556- // ^^ Flattened accesses to DestVal we want to store into
557- CGF.FlattenAccessAndType (DestVal, DestTy, StoreGEPList, DestTypes);
527+ SmallVector<LValue, 16 > StoreList;
528+ CGF.FlattenAccessAndTypeLValue (DestVal, StoreList);
558529 // Flatten our src
559- SmallVector<QualType, 16 > SrcTypes; // Flattened type
560- SmallVector<std::pair<Address, llvm::Value *>, 16 > LoadGEPList;
561- // ^^ Flattened accesses to SrcVal we want to load from
562- CGF.FlattenAccessAndType (SrcVal, SrcTy, LoadGEPList, SrcTypes);
530+ SmallVector<LValue, 16 > LoadList;
531+ CGF.FlattenAccessAndTypeLValue (SrcVal, LoadList);
563532
564- assert (StoreGEPList .size () <= LoadGEPList .size () &&
565- " Cannot perform HLSL flat cast when flattened source object \
533+ assert (StoreList .size () <= LoadList .size () &&
534+ " Cannot perform HLSL elementwise cast when flattened source object \
566535 has less elements than flattened destination object." );
567- // apply casts to what we load from LoadGEPList
536+ // apply casts to what we load from LoadList
568537 // and store result in Dest
569- for (unsigned I = 0 , E = StoreGEPList.size (); I < E; I++) {
570- llvm::Value *Idx = LoadGEPList[I].second ;
571- llvm::Value *Load = CGF.Builder .CreateLoad (LoadGEPList[I].first , " load" );
572- Load =
573- Idx ? CGF.Builder .CreateExtractElement (Load, Idx, " vec.extract" ) : Load;
574- llvm::Value *Cast =
575- CGF.EmitScalarConversion (Load, SrcTypes[I], DestTypes[I], Loc);
576-
577- // store back
578- Idx = StoreGEPList[I].second ;
579- if (Idx) {
580- llvm::Value *V =
581- CGF.Builder .CreateLoad (StoreGEPList[I].first , " load.for.insert" );
582- Cast = CGF.Builder .CreateInsertElement (V, Cast, Idx);
583- }
584- CGF.Builder .CreateStore (Cast, StoreGEPList[I].first );
538+ for (unsigned I = 0 , E = StoreList.size (); I < E; I++) {
539+ LValue DestLVal = StoreList[I];
540+ LValue SrcLVal = LoadList[I];
541+ RValue RVal = CGF.EmitLoadOfLValue (SrcLVal, Loc);
542+ assert (RVal.isScalar () && " All flattened source values should be scalars" );
543+ llvm::Value *Val = RVal.getScalarVal ();
544+ llvm::Value *Cast = CGF.EmitScalarConversion (Val, SrcLVal.getType (),
545+ DestLVal.getType (), Loc);
546+ CGF.EmitStoreThroughLValue (RValue::get (Cast), DestLVal);
585547 }
586548}
587549
@@ -988,31 +950,33 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
988950 Expr *Src = E->getSubExpr ();
989951 QualType SrcTy = Src->getType ();
990952 RValue RV = CGF.EmitAnyExpr (Src);
991- QualType DestTy = E->getType ();
992- Address DestVal = Dest.getAddress ();
953+ LValue DestLVal = CGF.MakeAddrLValue (Dest.getAddress (), E->getType ());
993954 SourceLocation Loc = E->getExprLoc ();
994955
995- assert (RV.isScalar () && " RHS of HLSL splat cast must be a scalar." );
956+ assert (RV.isScalar () && SrcTy->isScalarType () &&
957+ " RHS of HLSL splat cast must be a scalar." );
996958 llvm::Value *SrcVal = RV.getScalarVal ();
997- EmitHLSLAggregateSplatCast (CGF, DestVal, DestTy , SrcVal, SrcTy, Loc);
959+ EmitHLSLScalarElementwiseAndSplatCasts (CGF, DestLVal , SrcVal, SrcTy, Loc);
998960 break ;
999961 }
1000962 case CK_HLSLElementwiseCast: {
1001963 Expr *Src = E->getSubExpr ();
1002964 QualType SrcTy = Src->getType ();
1003965 RValue RV = CGF.EmitAnyExpr (Src);
1004- QualType DestTy = E->getType ();
1005- Address DestVal = Dest.getAddress ();
966+ LValue DestLVal = CGF.MakeAddrLValue (Dest.getAddress (), E->getType ());
1006967 SourceLocation Loc = E->getExprLoc ();
1007968
1008969 if (RV.isScalar ()) {
1009970 llvm::Value *SrcVal = RV.getScalarVal ();
1010- EmitHLSLScalarFlatCast (CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
971+ assert (SrcTy->isVectorType () &&
972+ " HLSL Elementwise cast doesn't handle splatting." );
973+ EmitHLSLScalarElementwiseAndSplatCasts (CGF, DestLVal, SrcVal, SrcTy, Loc);
1011974 } else {
1012975 assert (RV.isAggregate () &&
1013976 " Can't perform HLSL Aggregate cast on a complex type." );
1014977 Address SrcVal = RV.getAggregateAddress ();
1015- EmitHLSLElementwiseCast (CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
978+ EmitHLSLElementwiseCast (CGF, DestLVal, CGF.MakeAddrLValue (SrcVal, SrcTy),
979+ Loc);
1016980 }
1017981 break ;
1018982 }
0 commit comments