Skip to content

fastfem.fields

Field dataclass

A class responsible for storing fields on elements as an NDArray of coefficients. There are 3 relevant shapes / axis sets to a field:

  • basis_shape - The shape of the basis. These axes represent the multi-index for the basis function.

  • stack_shape - The shape of the element stack. These axes represent the multi-index for the element.

  • point_shape - The shape of the field. These axes represent the pointwise, per-element tensor index.

The shape of coefficients will be some permutation of stack_shape + point_shape + basis_shape. The order is specified by shape_order, which is a 3-tuple (stack_location, field_location, basis_location), where each entry is an integer specifying the position relative to the other two shapes.

Source code in fastfem/fields/field.py
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
@dataclass(eq=False, frozen=True, unsafe_hash=False, init=False)
class Field:
    """
    A class responsible for storing fields on elements as an `NDArray` of coefficients.
    There are 3 relevant shapes / axis sets to a field:

    - `basis_shape` - The shape of the basis. These axes represent the multi-index for
            the basis function.

    - `stack_shape` - The shape of the element stack. These axes represent the
            multi-index for the element.

    - `point_shape` - The shape of the field. These axes represent the pointwise,
            per-element tensor index.

    The shape of `coefficients` will be some permutation of
    `stack_shape + point_shape + basis_shape`. The order is specified by `shape_order`,
    which is a 3-tuple `(stack_location, field_location, basis_location)`, where each
    entry is an integer specifying the position relative to the other two shapes.
    """

    basis_shape: tuple[int, ...]
    stack_shape: tuple[int, ...]
    point_shape: tuple[int, ...]
    coefficients: NDArray | jax.Array
    shape_order: tuple[int, int, int] = dataclasses.field(repr=False, init=False)
    shape_order_inverse: tuple[ShapeComponent, ShapeComponent, ShapeComponent] = (
        dataclasses.field(repr=False, init=False)
    )
    use_jax: bool = dataclasses.field(repr=False, init=False)

    def __init__(
        self,
        basis_shape: tuple[int, ...],
        point_shape: tuple[int, ...],
        coefficients: ArrayLike,
        shape_order: tuple[int, int, int] = (0, 1, 2),
        use_jax: bool | None = None,
    ):
        _verify_is_permutation(shape_order)
        if not isinstance(coefficients, np.ndarray) and not isinstance(
            coefficients, jax.Array
        ):
            coefficients = np.array(coefficients)
        if use_jax is None:
            use_jax = isinstance(coefficients, jax.Array)
        cshape_orig = np.shape(coefficients)
        if len(cshape_orig) < len(basis_shape) + len(point_shape):
            coefficients = coefficients[
                *(
                    (np.newaxis,)
                    * (len(basis_shape) + len(point_shape) - len(cshape_orig))
                ),
                ...,
            ]
            cshape = np.shape(coefficients)
        else:
            cshape = cshape_orig

        # here, coefficients is at least as large as basis and field shapes combined

        # we need to place two markers to index the separations between basis, field,
        # and stack shapes; start with basis_shape (if not in middle)
        stack_start = 0
        stack_end = len(cshape)

        def cshape_slice_positives(a, b):
            return cshape[a:b]

        def cshape_slice_negatives(a, b):
            return cshape[a : (b if b != 0 else None)] if a != 0 else ()

        if shape_order[ShapeComponent.BASIS] == 0:
            if not _is_broadcastable(
                basis_shape, cshape_slice_positives(0, len(basis_shape))
            ):
                raise FieldConstructionError(
                    basis_shape,
                    point_shape,
                    cshape_orig,
                    shape_order,
                    hint="basis_shape cannot be broadcasted at the beginning",
                )
            stack_start = len(basis_shape)
        elif shape_order[ShapeComponent.BASIS] == 2:
            if not _is_broadcastable(
                basis_shape, cshape_slice_negatives(-len(basis_shape), 0)
            ):
                raise FieldConstructionError(
                    basis_shape,
                    point_shape,
                    cshape_orig,
                    shape_order,
                    hint="basis_shape cannot be broadcasted at the end",
                )
            stack_end -= len(basis_shape)
        # then do point_shape
        if shape_order[ShapeComponent.POINT] == 0:
            if not _is_broadcastable(
                point_shape, cshape_slice_positives(0, len(point_shape))
            ):
                raise FieldConstructionError(
                    basis_shape,
                    point_shape,
                    cshape_orig,
                    shape_order,
                    hint="point_shape cannot be broadcasted at the beginning",
                )
            # if basis_shape was in center, we now have the right offset for it
            if shape_order[ShapeComponent.BASIS] == 1:
                if not _is_broadcastable(
                    basis_shape,
                    cshape_slice_positives(
                        len(point_shape), (len(basis_shape) + len(point_shape))
                    ),
                ):
                    raise FieldConstructionError(
                        basis_shape,
                        point_shape,
                        cshape_orig,
                        shape_order,
                        hint="basis_shape cannot be broadcasted in the center",
                    )
                stack_start = len(basis_shape) + len(point_shape)
            else:
                stack_start = len(point_shape)
        elif shape_order[ShapeComponent.POINT] == 2:
            if not _is_broadcastable(
                point_shape, cshape_slice_negatives(-len(point_shape), 0)
            ):
                raise FieldConstructionError(
                    basis_shape,
                    point_shape,
                    cshape_orig,
                    shape_order,
                    hint="point_shape cannot be broadcasted at the end",
                )
            # if basis_shape was in center, we now have the right offset for it
            if shape_order[ShapeComponent.BASIS] == 1:
                if not _is_broadcastable(
                    basis_shape,
                    cshape_slice_negatives(
                        -(len(basis_shape) + len(point_shape)), -len(point_shape)
                    ),
                ):
                    raise FieldConstructionError(
                        basis_shape,
                        point_shape,
                        cshape_orig,
                        shape_order,
                        hint="basis_shape cannot be broadcasted in the center",
                    )
                stack_end -= len(basis_shape) + len(point_shape)
            else:
                stack_end -= len(point_shape)
        elif shape_order[ShapeComponent.POINT] == 1:
            # cases by basis_location
            if shape_order[ShapeComponent.BASIS] == 0:
                if not _is_broadcastable(
                    point_shape,
                    cshape_slice_positives(
                        len(basis_shape), (len(basis_shape) + len(point_shape))
                    ),
                ):
                    raise FieldConstructionError(
                        basis_shape,
                        point_shape,
                        cshape_orig,
                        shape_order,
                        hint="point_shape cannot be broadcasted in the center",
                    )
                stack_start += len(point_shape)
            else:
                if not _is_broadcastable(
                    point_shape,
                    cshape_slice_negatives(
                        -(len(basis_shape) + len(point_shape)), -len(basis_shape)
                    ),
                ):
                    raise FieldConstructionError(
                        basis_shape,
                        point_shape,
                        cshape_orig,
                        shape_order,
                        hint="point_shape cannot be broadcasted in the center",
                    )
                stack_end -= len(point_shape)

        stack_shape = cshape[stack_start:stack_end]
        shapes: list[tuple[int, ...]] = [()] * 3
        shapes[shape_order[ShapeComponent.BASIS]] = basis_shape
        shapes[shape_order[ShapeComponent.STACK]] = stack_shape
        shapes[shape_order[ShapeComponent.POINT]] = point_shape
        object.__setattr__(
            self,
            "coefficients",
            np_or_jnp(use_jax).broadcast_to(
                coefficients, shapes[0] + shapes[1] + shapes[2]
            ),
        )
        object.__setattr__(self, "basis_shape", basis_shape)
        object.__setattr__(self, "point_shape", point_shape)
        object.__setattr__(self, "stack_shape", stack_shape)
        object.__setattr__(self, "shape_order", shape_order)
        object.__setattr__(
            self, "shape_order_inverse", _invert_permutation(shape_order)
        )
        object.__setattr__(self, "use_jax", use_jax)

    @typing.overload
    def __getattr__(
        self, name: Literal["shape"]
    ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ...
    @typing.overload
    def __getattr__(self, name: Literal["basis"]) -> FieldBasisAccessor: ...
    @typing.overload
    def __getattr__(self, name: Literal["stack"]) -> FieldStackAccessor: ...
    @typing.overload
    def __getattr__(self, name: Literal["point"]) -> FieldPointAccessor: ...

    def __getattr__(self, name):
        if name == "shape":
            return (self.stack_shape, self.basis_shape, self.point_shape)
        if name == "basis":
            return FieldBasisAccessor(self)
        if name == "stack":
            return FieldStackAccessor(self)
        if name == "point":
            return FieldPointAccessor(self)
        raise AttributeError

    def __add__(self, other: Union["Field", float]):
        if isinstance(other, Field):
            a, b = Field.broadcast_fields_full(self, other)
            return Field(
                a.basis_shape,
                a.point_shape,
                a.coefficients + b.coefficients,
                shape_order=a.shape_order,
                use_jax=a.use_jax or b.use_jax,
            )
        if isinstance(other, float):
            return Field(
                self.basis_shape,
                self.point_shape,
                self.coefficients + other,
                shape_order=self.shape_order,
                use_jax=self.use_jax,
            )
        raise NotImplementedError

    def __radd__(self, other: Union["Field", float]):
        return self.__add__(other)

    def __sub__(self, other: Union["Field", float]):
        if isinstance(other, Field):
            a, b = Field.broadcast_fields_full(self, other)
            return Field(
                a.basis_shape,
                a.point_shape,
                a.coefficients - b.coefficients,
                shape_order=a.shape_order,
                use_jax=a.use_jax or b.use_jax,
            )
        if isinstance(other, float):
            return Field(
                self.basis_shape,
                self.point_shape,
                self.coefficients - other,
                shape_order=self.shape_order,
                use_jax=self.use_jax,
            )
        raise NotImplementedError

    def __rsub__(self, other: Union["Field", float]):
        if isinstance(other, Field):
            a, b = Field.broadcast_fields_full(self, other)
            return Field(
                a.basis_shape,
                a.point_shape,
                b.coefficients - a.coefficients,
                shape_order=a.shape_order,
                use_jax=a.use_jax or b.use_jax,
            )
        if isinstance(other, float):
            return Field(
                self.basis_shape,
                self.point_shape,
                other - self.coefficients,
                shape_order=self.shape_order,
                use_jax=self.use_jax,
            )
        raise NotImplementedError

    def __mul__(self, other: Union["Field", float]):
        if isinstance(other, Field):
            a, b = Field.broadcast_fields_full(self, other)
            return Field(
                a.basis_shape,
                a.point_shape,
                a.coefficients * b.coefficients,
                shape_order=a.shape_order,
                use_jax=a.use_jax or b.use_jax,
            )
        if isinstance(other, float):
            return Field(
                self.basis_shape,
                self.point_shape,
                self.coefficients * other,
                shape_order=self.shape_order,
                use_jax=self.use_jax,
            )
        raise NotImplementedError

    def __rmul__(self, other: Union["Field", float]):
        return self.__mul__(other)

    def __neg__(self):
        return Field(
            self.basis_shape,
            self.point_shape,
            -self.coefficients,
            shape_order=self.shape_order,
            use_jax=self.use_jax,
        )

    def get_shape(self, component: ShapeComponent) -> tuple[int, ...]:
        """Recovers the shape of the specified component. This shape is in the same
        format as numpy.shape, that is a tuple.

        Args:
            component (ShapeComponent): The component to sample

        Returns:
            tuple[int,...]: The shape of the specified component.
        """
        return (
            self.basis_shape
            if component == ShapeComponent.BASIS
            else (
                self.stack_shape
                if component == ShapeComponent.STACK
                else self.point_shape
            )
        )

    def _component_offset(self, component: ShapeComponent) -> int:
        """Recovers the axis index of the start of a component.

        Args:
            component (ShapeComponent): The component to sample.

        Returns:
            int: the index of the first axis of the component
        """
        ind = 0
        prec = self.shape_order[component]
        if self.shape_order[ShapeComponent.BASIS] < prec:
            ind += len(self.basis_shape)
        if self.shape_order[ShapeComponent.STACK] < prec:
            ind += len(self.stack_shape)
        if self.shape_order[ShapeComponent.POINT] < prec:
            ind += len(self.point_shape)
        return ind

    def broadcast_to_shape(
        self,
        stack_shape: tuple[int, ...] | None,
        basis_shape: tuple[int, ...] | None,
        point_shape: tuple[int, ...] | None,
        complete_broadcast=True,
    ) -> "Field":
        """This function is related to the numpy broadcast_to function.
        https://numpy.org/doc/stable/reference/generated/numpy.broadcast_to.html#numpy.broadcast_to

        The shape of the desired array is given as separate tuples for each component.
        Instead of the `subok` optional argument, the returned field will always have
        the same coefficient array type. Additionally, the value of `use_jax` is inherited.

        Args:
            stack_shape (tuple[int, ...] | None): The shape for the stack shape to be
                broadcasted to, or None, if the shape should be kept as-is.
            basis_shape (tuple[int, ...] | None): The shape for the basis shape to be
                broadcasted to, or None, if the shape should be kept as-is.
            point_shape (tuple[int, ...] | None): The shape for the field shape to be
                broadcasted to, or None, if the shape should be kept as-is.
            complete_broadcast (bool, optional): If False, only dimensions of size 1
                are added. When true, the shape of the field precisely matches.

        Raises:
            FieldShapeError: if the broadcast cannot be done by standard numpy
                broadcasting rules in each component. Note that broadcasting the basis
                shape is permitted, beyond standard compatibility rules.

        Returns:
            Field: The broadcasted field.
        """
        if (
            (
                basis_shape is not None
                and not _is_broadcastable(basis_shape, self.basis_shape)
            )
            or (
                stack_shape is not None
                and not _is_broadcastable(stack_shape, self.stack_shape)
            )
            or (
                point_shape is not None
                and not _is_broadcastable(point_shape, self.point_shape)
            )
        ):
            message = (
                f"Cannot broadcast field of shape {self.shape} into"
                f" shape {(stack_shape,basis_shape,point_shape)}"
            )
            raise FieldShapeError(message)
        slices: list[typing.Any] = [None, None, None]
        shapes: list[typing.Any] = [None, None, None]
        slices[self.shape_order[ShapeComponent.BASIS]] = (
            itertools.chain(
                (np.newaxis for _ in range(len(basis_shape) - len(self.basis_shape))),
                (slice(None) for _ in range(len(self.basis_shape))),
            )
            if basis_shape is not None
            else (slice(None) for _ in range(len(self.basis_shape)))
        )
        slices[self.shape_order[ShapeComponent.STACK]] = (
            itertools.chain(
                (np.newaxis for _ in range(len(stack_shape) - len(self.stack_shape))),
                (slice(None) for _ in range(len(self.stack_shape))),
            )
            if stack_shape is not None
            else (slice(None) for _ in range(len(self.stack_shape)))
        )
        slices[self.shape_order[ShapeComponent.POINT]] = (
            itertools.chain(
                (np.newaxis for _ in range(len(point_shape) - len(self.point_shape))),
                (slice(None) for _ in range(len(self.point_shape))),
            )
            if point_shape is not None
            else (slice(None) for _ in range(len(self.point_shape)))
        )
        shapes[self.shape_order[ShapeComponent.BASIS]] = (
            self.basis_shape if basis_shape is None else basis_shape
        )
        shapes[self.shape_order[ShapeComponent.STACK]] = (
            self.stack_shape if stack_shape is None else stack_shape
        )
        shapes[self.shape_order[ShapeComponent.POINT]] = (
            self.point_shape if point_shape is None else point_shape
        )
        coefs = self.coefficients[*itertools.chain(*slices)]
        return Field(
            self.basis_shape if basis_shape is None else basis_shape,
            self.point_shape if point_shape is None else point_shape,
            (
                np_or_jnp(self).broadcast_to(coefs, shapes[0] + shapes[1] + shapes[2])
                if complete_broadcast
                else coefs
            ),
            shape_order=self.shape_order,
            use_jax=self.use_jax,
        )

    @staticmethod
    def are_broadcastable(*fields: "Field", strict_basis=True) -> bool:  # NOQA: ARG004
        """Two fields a and b are (fully) broadcastable if they are compatible and have
        broadcastable point shape. Since this relation is associative,
        more than two fields can be passed in.

        Args:
            fields (tuple[Field, ...]): The fields to broadcast.
            strict_basis (bool, optional): If true, the basis rule holds. Otherwise,
                only basis shape numpy-broadcastibility is checked.

        Returns:
            bool: True if the fields are broadcastable. False otherwise.
        """
        return Field.are_compatible(*fields) and _is_compatible(
            *(field.point_shape for field in fields)
        )

    @typing.overload
    @staticmethod
    def broadcast_fields_full(
        *fields: "Field", strict_basis=True, shapes_only: Literal[True]
    ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ...
    @typing.overload
    @staticmethod
    def broadcast_fields_full(
        *fields: "Field", strict_basis=True, shapes_only: Literal[False] = False
    ) -> tuple["Field", ...]: ...
    @staticmethod
    def broadcast_fields_full(
        *fields: "Field",
        strict_basis=True,  # NOQA: ARG004
        shapes_only: bool = False,
    ) -> tuple["Field", ...] | tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
        """Two fields a and b are (fully) broadcastable if they are compatible and have
        broadcastable point shape. Since this relation is associative,
        more than two fields can be passed in.

        Args:
            fields (tuple[Field, ...]): The fields to broadcast.
            strict_basis (bool, optional): If true, the basis rule holds. Otherwise,
                only basis shape numpy-broadcastibility is checked.
            shapes_only (bool, optional): If true, the target shape is returned and no
                array brouadcasting occurs.
        Raises:
            FieldShapeError: If the fields are not fully broadcastable together.

        Returns:
            tuple[Field, ...]: The broadcasted fields, in the order they were given.
        """
        if Field.are_broadcastable(*fields):
            basis_shape = np.broadcast_shapes(*[field.basis_shape for field in fields])
            stack_shape = np.broadcast_shapes(*[field.stack_shape for field in fields])
            point_shape = np.broadcast_shapes(*[field.point_shape for field in fields])
            if shapes_only:
                return (stack_shape, basis_shape, point_shape)
            return tuple(
                field.broadcast_to_shape(stack_shape, basis_shape, point_shape)
                for field in fields
            )

        message = "Cannot broadcast fields with incompatible shapes."
        raise FieldShapeError(message)

    @staticmethod
    def are_compatible(*fields: "Field", strict_basis=True) -> bool:
        """Two fields a and b are compatible if they have compatible bases
        (basis_shape equal or at least one of them is size 1 representing a constant)
        and they have broadcastable stack_shapes. This function checks them. Since
        this relation is associative, more than two fields can be passed in.

        Args:
            fields (tuple[Field, ...]): The fields to query compatibility.
            strict_basis (bool, optional): If true, the basis rule holds. Otherwise,
                only basis shape numpy-broadcastibility is checked.

        Returns:
            bool: True if the fields are compatible. False otherwise.
        """
        if strict_basis:
            return all(
                map(
                    lambda x: x[1],  # accumulator -> did nonempty tuple change?
                    itertools.accumulate(
                        (field.basis_shape for field in fields),
                        func=lambda a, b: (
                            a[0] if np.prod(b, dtype=int) == 1 else b,  # nonempty tuple
                            (np.prod(a[0], dtype=int) == 1)
                            or (np.prod(b, dtype=int) == 1)
                            or a[0] == b,  # if nonempty, did shape change?
                        ),
                        initial=((), True),
                    ),
                )
            ) and _is_compatible(*(field.stack_shape for field in fields))
        return _is_compatible(
            *(field.basis_shape for field in fields)
        ) and _is_compatible(*(field.stack_shape for field in fields))

    @staticmethod
    def broadcast_field_compatibility(
        *fields: "Field", strict_basis=True
    ) -> tuple["Field", ...]:
        """Two fields a and b are compatible if they have compatible bases
        (basis_shape equal or at least one of them is size 1 representing a constant)
        and they have broadcastable stack_shapes. Since
        this relation is associative, more than two fields can be passed in.

        This function broadcasts the fields to have the same stack and basis shapes if
        they are compatible, or raises an error if they are not.

        Args:
            fields (tuple[Field, ...]): The fields to broadcast.
            strict_basis (bool, optional): If true, the basis rule holds. Otherwise,
                only basis shape numpy-broadcastibility is checked.

        Raises:
            FieldShapeError: if the given fields are not all compatible.

        Returns:
            tuple[Field, ...]: The broadcasted fields, in the order they were given.
        """
        if Field.are_compatible(*fields, strict_basis=strict_basis):
            basis_shape = np.broadcast_shapes(*[field.basis_shape for field in fields])
            stack_shape = np.broadcast_shapes(*[field.stack_shape for field in fields])
            return tuple(
                field.broadcast_to_shape(stack_shape, basis_shape, field.point_shape)
                for field in fields
            )

        message = "Cannot broadcast fields with incompatible shapes."
        raise FieldShapeError(message)

    def _axis_field_to_numpy(
        self, index: FieldAxisIndexType, out_of_bounds_check: bool = True
    ) -> int:
        """Recovers the axis index (in terms of numpy) from the given field axis specifier.

        Args:
            index (FieldAxisIndexType): The field index to recover the axis of.
            out_of_bounds_check (bool, optional): Whether or not an out-of-bounds check
            occurs. As in python indexing, it is expected for
            -len(component_shape) <= index.index < len(component_shape), where
            component_shape is the shape of index.component. If out-of-bounds check is
            true, then an IndexError is thrown. Defaults to True.

        Raises:
            IndexError: if out_of_bounds_check is True, and the index is out of bounds.

        Returns:
            _type_: the integer index, in the numpy context.
        """
        comp: ShapeComponent = index[0]
        ind: int = index[1]
        shape = self.get_shape(comp)

        if ind < 0:
            if out_of_bounds_check and ind < -len(shape):
                message = (
                    f"Attempting to access axis {ind} ({-1-ind}) of shape {shape}."
                )
                raise IndexError(message)
            ind = len(shape) + ind
        elif out_of_bounds_check and ind >= len(shape):
            message = f"Attempting to access axis {ind} of shape {shape}."
            raise IndexError(message)

        return ind + self._component_offset(comp)

    def __eq__(self, other) -> bool:
        if not Field.are_broadcastable(self, other):
            return False

        return np_or_jnp(self, other).array_equiv(
            *(f.coefficients for f in Field.broadcast_fields_full(self, other))
        )

get_shape(component)

Recovers the shape of the specified component. This shape is in the same format as numpy.shape, that is a tuple.

Parameters:

  • component (ShapeComponent) –

    The component to sample

Returns:

  • tuple[int, ...]

    tuple[int,...]: The shape of the specified component.

Source code in fastfem/fields/field.py
def get_shape(self, component: ShapeComponent) -> tuple[int, ...]:
    """Recovers the shape of the specified component. This shape is in the same
    format as numpy.shape, that is a tuple.

    Args:
        component (ShapeComponent): The component to sample

    Returns:
        tuple[int,...]: The shape of the specified component.
    """
    return (
        self.basis_shape
        if component == ShapeComponent.BASIS
        else (
            self.stack_shape
            if component == ShapeComponent.STACK
            else self.point_shape
        )
    )

broadcast_to_shape(stack_shape, basis_shape, point_shape, complete_broadcast=True)

This function is related to the numpy broadcast_to function. https://numpy.org/doc/stable/reference/generated/numpy.broadcast_to.html#numpy.broadcast_to

The shape of the desired array is given as separate tuples for each component. Instead of the subok optional argument, the returned field will always have the same coefficient array type. Additionally, the value of use_jax is inherited.

Parameters:

  • stack_shape (tuple[int, ...] | None) –

    The shape for the stack shape to be broadcasted to, or None, if the shape should be kept as-is.

  • basis_shape (tuple[int, ...] | None) –

    The shape for the basis shape to be broadcasted to, or None, if the shape should be kept as-is.

  • point_shape (tuple[int, ...] | None) –

    The shape for the field shape to be broadcasted to, or None, if the shape should be kept as-is.

  • complete_broadcast (bool, default: True ) –

    If False, only dimensions of size 1 are added. When true, the shape of the field precisely matches.

Raises:

  • FieldShapeError

    if the broadcast cannot be done by standard numpy broadcasting rules in each component. Note that broadcasting the basis shape is permitted, beyond standard compatibility rules.

Returns:

  • Field ( Field ) –

    The broadcasted field.

Source code in fastfem/fields/field.py
def broadcast_to_shape(
    self,
    stack_shape: tuple[int, ...] | None,
    basis_shape: tuple[int, ...] | None,
    point_shape: tuple[int, ...] | None,
    complete_broadcast=True,
) -> "Field":
    """This function is related to the numpy broadcast_to function.
    https://numpy.org/doc/stable/reference/generated/numpy.broadcast_to.html#numpy.broadcast_to

    The shape of the desired array is given as separate tuples for each component.
    Instead of the `subok` optional argument, the returned field will always have
    the same coefficient array type. Additionally, the value of `use_jax` is inherited.

    Args:
        stack_shape (tuple[int, ...] | None): The shape for the stack shape to be
            broadcasted to, or None, if the shape should be kept as-is.
        basis_shape (tuple[int, ...] | None): The shape for the basis shape to be
            broadcasted to, or None, if the shape should be kept as-is.
        point_shape (tuple[int, ...] | None): The shape for the field shape to be
            broadcasted to, or None, if the shape should be kept as-is.
        complete_broadcast (bool, optional): If False, only dimensions of size 1
            are added. When true, the shape of the field precisely matches.

    Raises:
        FieldShapeError: if the broadcast cannot be done by standard numpy
            broadcasting rules in each component. Note that broadcasting the basis
            shape is permitted, beyond standard compatibility rules.

    Returns:
        Field: The broadcasted field.
    """
    if (
        (
            basis_shape is not None
            and not _is_broadcastable(basis_shape, self.basis_shape)
        )
        or (
            stack_shape is not None
            and not _is_broadcastable(stack_shape, self.stack_shape)
        )
        or (
            point_shape is not None
            and not _is_broadcastable(point_shape, self.point_shape)
        )
    ):
        message = (
            f"Cannot broadcast field of shape {self.shape} into"
            f" shape {(stack_shape,basis_shape,point_shape)}"
        )
        raise FieldShapeError(message)
    slices: list[typing.Any] = [None, None, None]
    shapes: list[typing.Any] = [None, None, None]
    slices[self.shape_order[ShapeComponent.BASIS]] = (
        itertools.chain(
            (np.newaxis for _ in range(len(basis_shape) - len(self.basis_shape))),
            (slice(None) for _ in range(len(self.basis_shape))),
        )
        if basis_shape is not None
        else (slice(None) for _ in range(len(self.basis_shape)))
    )
    slices[self.shape_order[ShapeComponent.STACK]] = (
        itertools.chain(
            (np.newaxis for _ in range(len(stack_shape) - len(self.stack_shape))),
            (slice(None) for _ in range(len(self.stack_shape))),
        )
        if stack_shape is not None
        else (slice(None) for _ in range(len(self.stack_shape)))
    )
    slices[self.shape_order[ShapeComponent.POINT]] = (
        itertools.chain(
            (np.newaxis for _ in range(len(point_shape) - len(self.point_shape))),
            (slice(None) for _ in range(len(self.point_shape))),
        )
        if point_shape is not None
        else (slice(None) for _ in range(len(self.point_shape)))
    )
    shapes[self.shape_order[ShapeComponent.BASIS]] = (
        self.basis_shape if basis_shape is None else basis_shape
    )
    shapes[self.shape_order[ShapeComponent.STACK]] = (
        self.stack_shape if stack_shape is None else stack_shape
    )
    shapes[self.shape_order[ShapeComponent.POINT]] = (
        self.point_shape if point_shape is None else point_shape
    )
    coefs = self.coefficients[*itertools.chain(*slices)]
    return Field(
        self.basis_shape if basis_shape is None else basis_shape,
        self.point_shape if point_shape is None else point_shape,
        (
            np_or_jnp(self).broadcast_to(coefs, shapes[0] + shapes[1] + shapes[2])
            if complete_broadcast
            else coefs
        ),
        shape_order=self.shape_order,
        use_jax=self.use_jax,
    )

are_broadcastable(*fields, strict_basis=True) staticmethod

Two fields a and b are (fully) broadcastable if they are compatible and have broadcastable point shape. Since this relation is associative, more than two fields can be passed in.

Parameters:

  • fields (tuple[Field, ...], default: () ) –

    The fields to broadcast.

  • strict_basis (bool, default: True ) –

    If true, the basis rule holds. Otherwise, only basis shape numpy-broadcastibility is checked.

Returns:

  • bool ( bool ) –

    True if the fields are broadcastable. False otherwise.

Source code in fastfem/fields/field.py
@staticmethod
def are_broadcastable(*fields: "Field", strict_basis=True) -> bool:  # NOQA: ARG004
    """Two fields a and b are (fully) broadcastable if they are compatible and have
    broadcastable point shape. Since this relation is associative,
    more than two fields can be passed in.

    Args:
        fields (tuple[Field, ...]): The fields to broadcast.
        strict_basis (bool, optional): If true, the basis rule holds. Otherwise,
            only basis shape numpy-broadcastibility is checked.

    Returns:
        bool: True if the fields are broadcastable. False otherwise.
    """
    return Field.are_compatible(*fields) and _is_compatible(
        *(field.point_shape for field in fields)
    )

broadcast_fields_full(*fields, strict_basis=True, shapes_only=False) staticmethod

Two fields a and b are (fully) broadcastable if they are compatible and have broadcastable point shape. Since this relation is associative, more than two fields can be passed in.

Parameters:

  • fields (tuple[Field, ...], default: () ) –

    The fields to broadcast.

  • strict_basis (bool, default: True ) –

    If true, the basis rule holds. Otherwise, only basis shape numpy-broadcastibility is checked.

  • shapes_only (bool, default: False ) –

    If true, the target shape is returned and no array brouadcasting occurs.

Raises: FieldShapeError: If the fields are not fully broadcastable together.

Returns:

  • tuple[Field, ...] | tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]

    tuple[Field, ...]: The broadcasted fields, in the order they were given.

Source code in fastfem/fields/field.py
@staticmethod
def broadcast_fields_full(
    *fields: "Field",
    strict_basis=True,  # NOQA: ARG004
    shapes_only: bool = False,
) -> tuple["Field", ...] | tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
    """Two fields a and b are (fully) broadcastable if they are compatible and have
    broadcastable point shape. Since this relation is associative,
    more than two fields can be passed in.

    Args:
        fields (tuple[Field, ...]): The fields to broadcast.
        strict_basis (bool, optional): If true, the basis rule holds. Otherwise,
            only basis shape numpy-broadcastibility is checked.
        shapes_only (bool, optional): If true, the target shape is returned and no
            array brouadcasting occurs.
    Raises:
        FieldShapeError: If the fields are not fully broadcastable together.

    Returns:
        tuple[Field, ...]: The broadcasted fields, in the order they were given.
    """
    if Field.are_broadcastable(*fields):
        basis_shape = np.broadcast_shapes(*[field.basis_shape for field in fields])
        stack_shape = np.broadcast_shapes(*[field.stack_shape for field in fields])
        point_shape = np.broadcast_shapes(*[field.point_shape for field in fields])
        if shapes_only:
            return (stack_shape, basis_shape, point_shape)
        return tuple(
            field.broadcast_to_shape(stack_shape, basis_shape, point_shape)
            for field in fields
        )

    message = "Cannot broadcast fields with incompatible shapes."
    raise FieldShapeError(message)

are_compatible(*fields, strict_basis=True) staticmethod

Two fields a and b are compatible if they have compatible bases (basis_shape equal or at least one of them is size 1 representing a constant) and they have broadcastable stack_shapes. This function checks them. Since this relation is associative, more than two fields can be passed in.

Parameters:

  • fields (tuple[Field, ...], default: () ) –

    The fields to query compatibility.

  • strict_basis (bool, default: True ) –

    If true, the basis rule holds. Otherwise, only basis shape numpy-broadcastibility is checked.

Returns:

  • bool ( bool ) –

    True if the fields are compatible. False otherwise.

Source code in fastfem/fields/field.py
@staticmethod
def are_compatible(*fields: "Field", strict_basis=True) -> bool:
    """Two fields a and b are compatible if they have compatible bases
    (basis_shape equal or at least one of them is size 1 representing a constant)
    and they have broadcastable stack_shapes. This function checks them. Since
    this relation is associative, more than two fields can be passed in.

    Args:
        fields (tuple[Field, ...]): The fields to query compatibility.
        strict_basis (bool, optional): If true, the basis rule holds. Otherwise,
            only basis shape numpy-broadcastibility is checked.

    Returns:
        bool: True if the fields are compatible. False otherwise.
    """
    if strict_basis:
        return all(
            map(
                lambda x: x[1],  # accumulator -> did nonempty tuple change?
                itertools.accumulate(
                    (field.basis_shape for field in fields),
                    func=lambda a, b: (
                        a[0] if np.prod(b, dtype=int) == 1 else b,  # nonempty tuple
                        (np.prod(a[0], dtype=int) == 1)
                        or (np.prod(b, dtype=int) == 1)
                        or a[0] == b,  # if nonempty, did shape change?
                    ),
                    initial=((), True),
                ),
            )
        ) and _is_compatible(*(field.stack_shape for field in fields))
    return _is_compatible(
        *(field.basis_shape for field in fields)
    ) and _is_compatible(*(field.stack_shape for field in fields))

broadcast_field_compatibility(*fields, strict_basis=True) staticmethod

Two fields a and b are compatible if they have compatible bases (basis_shape equal or at least one of them is size 1 representing a constant) and they have broadcastable stack_shapes. Since this relation is associative, more than two fields can be passed in.

This function broadcasts the fields to have the same stack and basis shapes if they are compatible, or raises an error if they are not.

Parameters:

  • fields (tuple[Field, ...], default: () ) –

    The fields to broadcast.

  • strict_basis (bool, default: True ) –

    If true, the basis rule holds. Otherwise, only basis shape numpy-broadcastibility is checked.

Raises:

  • FieldShapeError

    if the given fields are not all compatible.

Returns:

  • tuple[Field, ...]

    tuple[Field, ...]: The broadcasted fields, in the order they were given.

Source code in fastfem/fields/field.py
@staticmethod
def broadcast_field_compatibility(
    *fields: "Field", strict_basis=True
) -> tuple["Field", ...]:
    """Two fields a and b are compatible if they have compatible bases
    (basis_shape equal or at least one of them is size 1 representing a constant)
    and they have broadcastable stack_shapes. Since
    this relation is associative, more than two fields can be passed in.

    This function broadcasts the fields to have the same stack and basis shapes if
    they are compatible, or raises an error if they are not.

    Args:
        fields (tuple[Field, ...]): The fields to broadcast.
        strict_basis (bool, optional): If true, the basis rule holds. Otherwise,
            only basis shape numpy-broadcastibility is checked.

    Raises:
        FieldShapeError: if the given fields are not all compatible.

    Returns:
        tuple[Field, ...]: The broadcasted fields, in the order they were given.
    """
    if Field.are_compatible(*fields, strict_basis=strict_basis):
        basis_shape = np.broadcast_shapes(*[field.basis_shape for field in fields])
        stack_shape = np.broadcast_shapes(*[field.stack_shape for field in fields])
        return tuple(
            field.broadcast_to_shape(stack_shape, basis_shape, field.point_shape)
            for field in fields
        )

    message = "Cannot broadcast fields with incompatible shapes."
    raise FieldShapeError(message)

FieldConstructionError

Bases: FieldShapeError

Called when constructing a field fails.

Source code in fastfem/fields/field.py
class FieldConstructionError(FieldShapeError):
    """Called when constructing a field fails."""

    def __init__(self, basis_shape, point_shape, coeff_shape, shape_order, hint=None):
        errmsg = (
            f"Cannot construct Field object with basis_shape {basis_shape}, point_shape"
            f" {point_shape} given the coefficient shape {coeff_shape} and shape order"
            f" {shape_order}."
        )
        if hint is not None:
            errmsg += f" ({hint})"
        super().__init__(errmsg)

abs(field)

absolute value of field. TODO link to numpy

Parameters:

  • field (Field) –

    description

Returns:

  • Field ( Field ) –

    description

Source code in fastfem/fields/numpy_similes.py
def abs(field: Field) -> Field:
    """absolute value of field. TODO link to numpy

    Args:
        field (Field): _description_

    Returns:
        Field: _description_
    """
    return Field(
        field.basis_shape,
        field.point_shape,
        np(field).abs(field.coefficients),
        shape_order=field.shape_order,
        use_jax=field.use_jax,
    )

moveaxis(field, source, destination)

This attempts to replicate the numpy moveaxis function. Currently, multiple axes at the same time are not supported. https://numpy.org/doc/stable/reference/generated/numpy.moveaxis.html

Parameters:

  • field (Field) –

    The field whose axes should be reordered

  • source (FieldAxisIndexType | Sequence[FieldAxisIndexType]) –

    Original positions of the axes to move. These must be unique.

  • destination (FieldAxisIndexType | Sequence[FieldAxisIndexType]) –

    Destination positions of the axes to move. These must also be unique.

Source code in fastfem/fields/numpy_similes.py
def moveaxis(
    field: Field, source: FieldAxisIndexType, destination: FieldAxisIndexType
) -> Field:
    """This attempts to replicate the numpy `moveaxis` function. Currently, multiple axes at the same time are not supported.
    https://numpy.org/doc/stable/reference/generated/numpy.moveaxis.html

    Args:
        field (Field): The field whose axes should be reordered
        source (FieldAxisIndexType | typing.Sequence[FieldAxisIndexType]): Original positions of the axes to move. These must be unique.
        destination (FieldAxisIndexType | typing.Sequence[FieldAxisIndexType]): Destination positions of the axes to move. These must also be unique.
    """

    shapes = {
        ShapeComponent.BASIS: field.basis_shape,
        ShapeComponent.STACK: field.stack_shape,
        ShapeComponent.POINT: field.point_shape,
    }

    src_pos_np = field._axis_field_to_numpy(source)
    srcshape = shapes[source[0]]

    src_pt = source[1] if source[1] >= 0 else len(srcshape) + source[1]
    rem_axis = srcshape[source[1]]  # size of the removed axis
    # when counting from the right, be sure to add one for the new length
    dest_pt = (
        destination[1]
        if destination[1] >= 0
        else len(shapes[destination[0]])
        + (0 if source[0] == destination[0] else 1)
        + destination[1]
    )
    dest_pos_np = field._axis_field_to_numpy((destination[0], dest_pt), False)
    coefs = np(field.use_jax).moveaxis(
        field.coefficients,
        src_pos_np,
        # dest_pos_np - (1 if dest_pos_np > src_pos_np else 0), #account for removals shifting indices
        dest_pos_np
        - (
            1 if field.shape_order[source[0]] < field.shape_order[destination[0]] else 0
        ),
    )
    shapes[source[0]] = srcshape[:src_pt] + srcshape[(src_pt + 1) :]
    destshape = shapes[destination[0]]
    shapes[destination[0]] = destshape[:dest_pt] + (rem_axis,) + destshape[dest_pt:]
    return Field(
        shapes[ShapeComponent.BASIS],
        shapes[ShapeComponent.POINT],
        coefs,
        shape_order=field.shape_order,
    )

reshape(field, component_selector, shape, order='C', copy=None)

This attempts to replicate the numpy "reshape" function. https://numpy.org/doc/stable/reference/generated/numpy.reshape.html

reshape() applies numpy "reshape" to a given component. This function is also called when using the field accessor reshape methods. That is, reshape(field,BASIS,s) is the same as field.basis.reshape(s) Args: field (Field): The field to reshape component_selector (ShapeComponent): Which component to reshape. shape (int | tuple[int]): The target shape of the component. For any integer i, i is equivalent to (i,) order ({'C','F','A'}, optional): See the numpy documentation. Defaults to 'C'. copy (bool | None, optional): See the numpy documentation. Defaults to None.

Raises:

  • ValueError

    when a copy operation is required, but copy is False.

Returns: Field: The reshaped field. This is always a new object, but if data is copied, the underlying array is a view of the original.

Source code in fastfem/fields/numpy_similes.py
def reshape(
    field: Field,
    component_selector: ShapeComponent,
    shape: int | tuple[int],
    order: Literal["C", "F", "A"] = "C",
    copy: bool | None = None,
) -> Field:
    """This attempts to replicate the numpy "reshape" function.
    https://numpy.org/doc/stable/reference/generated/numpy.reshape.html

    reshape() applies numpy "reshape" to a given component. This function is also called
    when using the field accessor reshape methods. That is, `reshape(field,BASIS,s)` is
    the same as `field.basis.reshape(s)`
    Args:
        field (Field): The field to reshape
        component_selector (ShapeComponent): Which component to reshape.
        shape (int | tuple[int]): The target shape of the component. For any integer i,
            i is equivalent to (i,)
        order ({'C','F','A'}, optional): See the numpy documentation. Defaults to 'C'.
        copy (bool | None, optional): See the numpy documentation. Defaults to None.

    Raises:
        ValueError: when a copy operation is required, but `copy` is False.
    Returns:
        Field: The reshaped field. This is always a new object, but if data is copied,
        the underlying array is a view of the original.
    """
    return _reshape(field, component_selector, shape, order, copy)

sum(field, axes)

This attempts to replicate the numpy "sum" function. In https://numpy.org/doc/stable/reference/generated/numpy.sum.html

Parameters:

  • field (Field) –

    description

  • axes (FieldAxisIndex | tuple[FieldAxisIndex, ...] | ShapeComponent | None) –

    description

Returns:

  • Field ( Field ) –

    description

Source code in fastfem/fields/numpy_similes.py
def sum(
    field: Field,
    axes: FieldAxisIndexType | tuple[FieldAxisIndexType, ...] | ShapeComponent | None,
) -> Field:
    """This attempts to replicate the numpy "sum" function. In
    https://numpy.org/doc/stable/reference/generated/numpy.sum.html

    Args:
        field (Field): _description_
        axes (FieldAxisIndex | tuple[FieldAxisIndex,...] | ShapeComponent | None): _description_

    Returns:
        Field: _description_
    """
    if axes is None:
        return Field((), (), np(field.use_jax).sum(field.coefficients))
    axes_: tuple[FieldAxisIndexType, ...] = ()
    if isinstance(axes, ShapeComponent):
        axes_ = tuple(
            FieldAxisIndex(axes, i) for i in range(len(field.get_shape(axes)))
        )
    elif isinstance(axes, FieldAxisIndex):
        axes_ = (axes,)
    elif isinstance(axes[0], tuple | FieldAxisIndex):
        # ^ force tuple[tuple]              ^ force tuple[FieldAxisIndex]

        # for some reason pyright doesn't approve
        axes_ = cast(tuple[FieldAxisIndexType, ...], axes)
    else:
        axes_ = cast(tuple[FieldAxisIndexType, ...], (axes,))

    coefs = np(field.use_jax).sum(
        field.coefficients, tuple(map(field._axis_field_to_numpy, axes_))
    )
    shapes = {
        ShapeComponent.BASIS: field.basis_shape,
        ShapeComponent.STACK: field.stack_shape,
        ShapeComponent.POINT: field.point_shape,
    }
    for ax in axes_:
        shape = shapes[ax[0]]
        shapes[ax[0]] = shape[: ax[1]] + shape[(ax[1] + 1) :]
    return Field(shapes[ShapeComponent.BASIS], shapes[ShapeComponent.POINT], coefs)