|
|
@@ -118,23 +118,39 @@ class OperatorNode extends TempNode {
|
|
|
|
|
|
} else {
|
|
|
|
|
|
- if ( typeA === 'float' && builder.isMatrix( typeB ) ) {
|
|
|
+ // Handle matrix operations
|
|
|
+ if ( builder.isMatrix( typeA ) ) {
|
|
|
|
|
|
- return typeB;
|
|
|
+ if ( typeB === 'float' ) {
|
|
|
+
|
|
|
+ return typeA; // matrix * scalar = matrix
|
|
|
+
|
|
|
+ } else if ( builder.isVector( typeB ) ) {
|
|
|
+
|
|
|
+ return builder.getVectorFromMatrix( typeA ); // matrix * vector
|
|
|
+
|
|
|
+ } else if ( builder.isMatrix( typeB ) ) {
|
|
|
|
|
|
- } else if ( builder.isMatrix( typeA ) && builder.isVector( typeB ) ) {
|
|
|
+ return typeA; // matrix * matrix
|
|
|
|
|
|
- // matrix x vector
|
|
|
+ }
|
|
|
+
|
|
|
+ } else if ( builder.isMatrix( typeB ) ) {
|
|
|
+
|
|
|
+ if ( typeA === 'float' ) {
|
|
|
|
|
|
- return builder.getVectorFromMatrix( typeA );
|
|
|
+ return typeB; // scalar * matrix = matrix
|
|
|
|
|
|
- } else if ( builder.isVector( typeA ) && builder.isMatrix( typeB ) ) {
|
|
|
+ } else if ( builder.isVector( typeA ) ) {
|
|
|
|
|
|
- // vector x matrix
|
|
|
+ return builder.getVectorFromMatrix( typeB ); // vector * matrix
|
|
|
+
|
|
|
+ }
|
|
|
|
|
|
- return builder.getVectorFromMatrix( typeB );
|
|
|
+ }
|
|
|
|
|
|
- } else if ( builder.getTypeLength( typeB ) > builder.getTypeLength( typeA ) ) {
|
|
|
+ // Handle non-matrix cases
|
|
|
+ if ( builder.getTypeLength( typeB ) > builder.getTypeLength( typeA ) ) {
|
|
|
|
|
|
// anytype x anytype: use the greater length vector
|
|
|
|
|
|
@@ -182,17 +198,43 @@ class OperatorNode extends TempNode {
|
|
|
typeA = type;
|
|
|
typeB = builder.changeComponentType( typeB, 'uint' );
|
|
|
|
|
|
- } else if ( builder.isMatrix( typeA ) && builder.isVector( typeB ) ) {
|
|
|
+ } else if ( builder.isMatrix( typeA ) ) {
|
|
|
+
|
|
|
+ if ( typeB === 'float' ) {
|
|
|
+
|
|
|
+ // Keep matrix type for typeA, but ensure typeB stays float
|
|
|
+ typeB = 'float';
|
|
|
+
|
|
|
+ } else if ( builder.isVector( typeB ) ) {
|
|
|
+
|
|
|
+ // matrix x vector
|
|
|
+ typeB = builder.getVectorFromMatrix( typeA );
|
|
|
+
|
|
|
+ } else if ( builder.isMatrix( typeB ) ) {
|
|
|
+ // matrix x matrix - keep both types
|
|
|
+ } else {
|
|
|
+
|
|
|
+ typeA = typeB = type;
|
|
|
+
|
|
|
+ }
|
|
|
|
|
|
- // matrix x vector
|
|
|
+ } else if ( builder.isMatrix( typeB ) ) {
|
|
|
|
|
|
- typeB = builder.getVectorFromMatrix( typeA );
|
|
|
+ if ( typeA === 'float' ) {
|
|
|
|
|
|
- } else if ( builder.isVector( typeA ) && builder.isMatrix( typeB ) ) {
|
|
|
+ // Keep matrix type for typeB, but ensure typeA stays float
|
|
|
+ typeA = 'float';
|
|
|
|
|
|
- // vector x matrix
|
|
|
+ } else if ( builder.isVector( typeA ) ) {
|
|
|
|
|
|
- typeA = builder.getVectorFromMatrix( typeB );
|
|
|
+ // vector x matrix
|
|
|
+ typeA = builder.getVectorFromMatrix( typeB );
|
|
|
+
|
|
|
+ } else {
|
|
|
+
|
|
|
+ typeA = typeB = type;
|
|
|
+
|
|
|
+ }
|
|
|
|
|
|
} else {
|
|
|
|
|
|
@@ -274,7 +316,20 @@ class OperatorNode extends TempNode {
|
|
|
|
|
|
} else {
|
|
|
|
|
|
- return builder.format( `( ${ a } ${ op } ${ b } )`, type, output );
|
|
|
+ // Handle matrix operations
|
|
|
+ if ( builder.isMatrix( typeA ) && typeB === 'float' ) {
|
|
|
+
|
|
|
+ return builder.format( `( ${ b } ${ op } ${ a } )`, type, output );
|
|
|
+
|
|
|
+ } else if ( typeA === 'float' && builder.isMatrix( typeB ) ) {
|
|
|
+
|
|
|
+ return builder.format( `${ a } ${ op } ${ b }`, type, output );
|
|
|
+
|
|
|
+ } else {
|
|
|
+
|
|
|
+ return builder.format( `( ${ a } ${ op } ${ b } )`, type, output );
|
|
|
+
|
|
|
+ }
|
|
|
|
|
|
}
|
|
|
|
|
|
@@ -286,7 +341,15 @@ class OperatorNode extends TempNode {
|
|
|
|
|
|
} else {
|
|
|
|
|
|
- return builder.format( `${ a } ${ op } ${ b }`, type, output );
|
|
|
+ if ( builder.isMatrix( typeA ) && typeB === 'float' ) {
|
|
|
+
|
|
|
+ return builder.format( `${ b } ${ op } ${ a }`, type, output );
|
|
|
+
|
|
|
+ } else {
|
|
|
+
|
|
|
+ return builder.format( `${ a } ${ op } ${ b }`, type, output );
|
|
|
+
|
|
|
+ }
|
|
|
|
|
|
}
|
|
|
|