this post was submitted on 22 May 2025
11 points (100.0% liked)

dynomight internet forum

73 readers
2 users here now

dynomight internet forum

founded 10 months ago
MODERATORS
you are viewing a single comment's thread
view the rest of the comments
[–] YetiBeets@lemmy.world 2 points 6 days ago (9 children)

Along the same vein, could we just move entirely to Einstein summation? It seems like your solution is 90% there.

I assume there is a good reason why you didn't

[–] dynomight@lemmy.world 2 points 6 days ago (2 children)

Well, Einstein summation is good, but it only does multiplication and sums. (Or, more generally, some scalar operation and some scalar reduction.) I want a notation that works for ANY type of operation, including non-scalar ones, and that's what DumPy does. So I'd argue it moves further than Einstein summation.

[–] ferflo@lemmy.world 1 points 3 days ago (1 children)

There's einx which allows expressing most tensor operations using einsum-like notation: https://github.com/fferflo/einx (Disclaimer: I'm the author). Dumpy and einx actually seem similar to me in that they both use axis names to represent for-loops/vectorization over some simpler, underlying operation.

[–] dynomight@lemmy.world 2 points 3 days ago (1 children)

Hey, thanks for pointing this out! I quite like the bracket notation for indicating axes that operations should be applied "to" vs. "over".

One question I have—is it possible for me as a user to define my own function and then apply it with einx-type notation?

[–] ferflo@lemmy.world 1 points 3 days ago (1 children)

Thanks! You can use einx.vmap for custom operations:

def my_dot(x, y):
    return jnp.sum(x * y)

z = einx.vmap("a [c], b [c] -> b a", x, y, op=my_dot)

Or like so:

def my_dot(x, y):
    return jnp.sum(x * y)
my_dot = partial(einx.vmap, op=my_dot)

z = my_dot("a [c], b [c] -> b a", x, y)
[–] dynomight@lemmy.world 1 points 2 days ago* (last edited 2 days ago) (1 children)

OK, I gave it a shot on the initial example in my post:

import einx
from jax import numpy as jnp
import numpy as onp
import jax

X = jnp.array(onp.random.randn(20,5))
Y = jnp.array(onp.random.randn(30,5))
A = jnp.array(onp.random.randn(20,30,5,5))

def my_op(x,y,a):
    print(x.shape)
    return y @ jnp.linalg.solve(a,x)

Z = einx.vmap("i [m], j [n], i j [m n]->i j", X, Y, A, op=my_op)

Aaaaand, it seemed to work the first time! Well done!

I am a little confused though, because if I use "i [a], j [b], i j [c d]->i j" it still seems to work, so maybe I don't actually 100% understand that bracket notation after all...

Two more thoughts:

  1. I added a link.
  2. You gotta add def wrap(fun): partial(vmap, op=fun) for easy wrapping. :)
[–] ferflo@lemmy.world 1 points 2 days ago (1 children)

Thanks for the mention!

Regarding the naming of axes: einx.vmap doesn't know anything about my_op, other than that it has the signature "m, n, m n -> " in the first case and "a, b, c d -> " in the second case. Both are valid if you pass the right inputs shapes. You get different behavior for incorrect input shapes though: In the first case, einx will raise an exception before calling my_op due to failing the shape resolution (e.g. due to multiple different values for m). In the second case, einx will assume the shapes to be correct (and it can't know they aren't correct before calling my_op), so the error will be raised somewhere in my_op.

The decorator for einx.vmap is a good point. I did only realize when typing the above comment that wrapping is a nice way of writing the operation in the first place. :D

[–] dynomight@lemmy.world 1 points 2 days ago* (last edited 2 days ago) (1 children)

Ah, I see, very nice. I wonder if it might make sense to declare the dimensions that are supposed to match once and for all when you wrap the function?

E.g. perhaps you could write:

@new_wrap('m, n, m n->')
def my_op(x,y,a):
    return y @ jnp.linalg.solve(a,x)

to declare the matching dimensions of the wrapped function and then call it with something like

Z = my_op('i [:], j [:], i j [: :]->i j', X, Y, A)

It's a small thing but it seems like the matching declaration should be done "once and for all"?

(On the other hand, I guess there might be cases where the way things match depend on the arguments...)

Edit: Or perhaps if you declare the matching shapes when you wrap the function you wouldn't actually need to use brackets at all, and could just call it as:

Z = my_op('i :, j :, i j : :->i j', X, Y, A)

?

[–] ferflo@lemmy.world 1 points 2 days ago

I did at some point consider adding a new symbol (like ":") that would act as a new axis with a unique name, but have been hesitant so far, since it adds to the complexity of the notation. There are a bunch of ideas for improving quality-of-life in einx, but so far I've tried erring on the side of less complexity (and there's probably some cases where I should've adhered to this more); to keep a low barrier of entry and also not end up with a mess of many different rules that classical tensor notation is in (you made good points about that here...). There are indeed cases where the operation depends on names in the brackets, e.g. einx.dot("a [b], [b c], [c] d", ...), so the ":" would be an additional variant rather than a simplification.

What I like better about using actual names is also that they convey semantics

einx.softmax("b h q [k]", x) # Ok, softmax along key dimension
einx.softmax("b h q [:]", x) # Softmax along... last dimension?

and indicate corresponding axes in consecutive operations (although this is not enforced strictly).

Defining the shapes of a vmapped operation in the decorator sounds like a good idea. It would probably require a kind of pattern matching to align the inner and outer expression (e.g. to also allow for n-dimensional inputs or variadic arguments to the custom, vmapped function).

[–] YetiBeets@lemmy.world 2 points 5 days ago

I knew there was a reason lol

load more comments (6 replies)