4
$\begingroup$

I have an array of time of arrivals and I want to convert it to count data using pytorch in a differentiable way.

Example arrival times:

arrival_times = [2.1, 2.9, 5.1]

and let's say the total range is 6 seconds. What I want to have is:

counts = [0, 0, 2, 2, 2, 3]

For this task, a non-differentiable way works perfect:

x = [1, 2, 3, 4,5,6] counts = torch.sum(torch.Tensor(arrival_times)[:, None] < torch.Tensor(x), dim=0) 

It turns out the < operation here is not differentiable. I need a differentiable approximation of this operation.

What I could think of is to subtract the x from arrival_times with broadcasting which leads the following array.

[ [1.1, 0.1, -0.9, -1.9, -2.9, -3.9] [1.9, 0.9, -0.1, -1.1, -2.1, -3.1] [4.1, 3.1, 2.1, 1.1, 0.1, -0.9] ] 

And then somehow count the number of negative (and also zero preferably) elements vertically which will give us the counts [0, 0, 2, 2, 2, 3].

Is there a way to do this or completely new idea for such approximation?

$\endgroup$
4
  • $\begingroup$ How many elements should the result contain? Your first example will return 3 values whereas in your second example you are showing a list containing 6 values. $\endgroup$ Commented Nov 8, 2021 at 16:12
  • $\begingroup$ My first example also shows 6 count values. The number of elements that is going to be returned is the length of array x. $\endgroup$ Commented Nov 8, 2021 at 16:20
  • 1
    $\begingroup$ When running the code from your first example I get back a tensor containing only three values (assuming that arrival_times is also a tensor): tensor([4, 4, 1]). $\endgroup$ Commented Nov 8, 2021 at 16:26
  • $\begingroup$ You are right, I added the question. The dim in Torch.sum was not correct. $\endgroup$ Commented Nov 9, 2021 at 9:51

1 Answer 1

1
$\begingroup$

I tried a hacky way to do this. Still open for suggestions.

 diffs = arrival_times[..., None] - torch.Tensor(x) zeros = torch.zeros_like(diffs) minimums = torch.minimum(diffs, zeros) eps, r_eps = 0.001, 1000 epsilons = torch.ones_like(diffs) * -eps maximums = torch.maximum(minimums, epsilons) * -r_eps counts = torch.sum(maximums, dim=1) 

The hackiness comes from multiplying by eps. In this way, any arrival time that is between 0 - 0.001 away from its closest second will be polluting the count. It's still differentiable but did not give good results in training for my case.

$\endgroup$

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.