Просмотр исходного кода

TSL: Add matrix operations support (#30370)

* TSL: Add matrix operations support

* fix lint

* fix lint

* fix CI
Renaud Rohlinger 1 год назад
Родитель
Сommit
7bc5b17826
1 измененных файлов с 80 добавлено и 17 удалено
  1. 80 17
      src/nodes/math/OperatorNode.js

+ 80 - 17
src/nodes/math/OperatorNode.js

@@ -118,23 +118,39 @@ class OperatorNode extends TempNode {
 
 		} else {
 
-			if ( typeA === 'float' && builder.isMatrix( typeB ) ) {
+			// Handle matrix operations
+			if ( builder.isMatrix( typeA ) ) {
 
-				return typeB;
+				if ( typeB === 'float' ) {
+
+					return typeA; // matrix * scalar = matrix
+
+				} else if ( builder.isVector( typeB ) ) {
+
+					return builder.getVectorFromMatrix( typeA ); // matrix * vector
+
+				} else if ( builder.isMatrix( typeB ) ) {
 
-			} else if ( builder.isMatrix( typeA ) && builder.isVector( typeB ) ) {
+					return typeA; // matrix * matrix
 
-				// matrix x vector
+				}
+
+			} else if ( builder.isMatrix( typeB ) ) {
+
+				if ( typeA === 'float' ) {
 
-				return builder.getVectorFromMatrix( typeA );
+					return typeB; // scalar * matrix = matrix
 
-			} else if ( builder.isVector( typeA ) && builder.isMatrix( typeB ) ) {
+				} else if ( builder.isVector( typeA ) ) {
 
-				// vector x matrix
+					return builder.getVectorFromMatrix( typeB ); // vector * matrix
+
+				}
 
-				return builder.getVectorFromMatrix( typeB );
+			}
 
-			} else if ( builder.getTypeLength( typeB ) > builder.getTypeLength( typeA ) ) {
+			// Handle non-matrix cases
+			if ( builder.getTypeLength( typeB ) > builder.getTypeLength( typeA ) ) {
 
 				// anytype x anytype: use the greater length vector
 
@@ -182,17 +198,43 @@ class OperatorNode extends TempNode {
 				typeA = type;
 				typeB = builder.changeComponentType( typeB, 'uint' );
 
-			} else if ( builder.isMatrix( typeA ) && builder.isVector( typeB ) ) {
+			} else if ( builder.isMatrix( typeA ) ) {
+
+				if ( typeB === 'float' ) {
+
+					// Keep matrix type for typeA, but ensure typeB stays float
+					typeB = 'float';
+
+				} else if ( builder.isVector( typeB ) ) {
+
+					// matrix x vector
+					typeB = builder.getVectorFromMatrix( typeA );
+
+				} else if ( builder.isMatrix( typeB ) ) {
+					// matrix x matrix - keep both types
+				} else {
+
+					typeA = typeB = type;
+
+				}
 
-				// matrix x vector
+			} else if ( builder.isMatrix( typeB ) ) {
 
-				typeB = builder.getVectorFromMatrix( typeA );
+				if ( typeA === 'float' ) {
 
-			} else if ( builder.isVector( typeA ) && builder.isMatrix( typeB ) ) {
+					// Keep matrix type for typeB, but ensure typeA stays float
+					typeA = 'float';
 
-				// vector x matrix
+				} else if ( builder.isVector( typeA ) ) {
 
-				typeA = builder.getVectorFromMatrix( typeB );
+					// vector x matrix
+					typeA = builder.getVectorFromMatrix( typeB );
+
+				} else {
+
+					typeA = typeB = type;
+
+				}
 
 			} else {
 
@@ -274,7 +316,20 @@ class OperatorNode extends TempNode {
 
 			} else {
 
-				return builder.format( `( ${ a } ${ op } ${ b } )`, type, output );
+				// Handle matrix operations
+				if ( builder.isMatrix( typeA ) && typeB === 'float' ) {
+
+					return builder.format( `( ${ b } ${ op } ${ a } )`, type, output );
+
+				} else if ( typeA === 'float' && builder.isMatrix( typeB ) ) {
+
+					return builder.format( `${ a } ${ op } ${ b }`, type, output );
+
+				} else {
+
+					return builder.format( `( ${ a } ${ op } ${ b } )`, type, output );
+
+				}
 
 			}
 
@@ -286,7 +341,15 @@ class OperatorNode extends TempNode {
 
 			} else {
 
-				return builder.format( `${ a } ${ op } ${ b }`, type, output );
+				if ( builder.isMatrix( typeA ) && typeB === 'float' ) {
+
+					return builder.format( `${ b } ${ op } ${ a }`, type, output );
+
+				} else {
+
+					return builder.format( `${ a } ${ op } ${ b }`, type, output );
+
+				}
 
 			}
 

粤ICP备19079148号