Simple is best: do you really need your code to be fast?
Introduction
In my post from last week I tried to convince novice Julia users to focus on writing simple code, and only when they learn more to start writing fully generic code.
This week I continue my simple is best miniseries of posts. This time I want to argue that you do not need to try writing code that is maximally fast. Usually it is enough that your code is just fast (fortunately in Julia, if you do not do some serious mistake your code is not likely to be slow). Instead, I recommend that you try your code to be simple and follow common patterns. In the long run it will be much easier to read and update such code.
For the post I use DataFrames.jl as an example. It was tested under Julia 1.8.0, DataFrames.jl 1.3.4, ShiftedArrays.jl 1.0.0, and BenchmarkTools.jl 1.3.1.
An example task
Recently one of the users asked the following question. Given a data frame
with column col
add 5 columns to it with lags of column col
ranging from
1 to 5.
Here is an example how you can do this task:
julia> using DataFrames
julia> using ShiftedArrays
julia> using BenchmarkTools
julia> df = DataFrame(a=1:10)
10×1 DataFrame
Row │ a
│ Int64
─────┼───────
1 │ 1
2 │ 2
3 │ 3
4 │ 4
5 │ 5
6 │ 6
7 │ 7
8 │ 8
9 │ 9
10 │ 10
julia> function add_lags(df)
out_df = copy(df)
for i in 1:5
out_df[:, "a$i"] = lag(df.a, i)
end
return out_df
end
add_lags (generic function with 1 method)
julia> add_lags(df)
10×6 DataFrame
Row │ a a1 a2 a3 a4 a5
│ Int64 Int64? Int64? Int64? Int64? Int64?
─────┼────────────────────────────────────────────────────
1 │ 1 missing missing missing missing missing
2 │ 2 1 missing missing missing missing
3 │ 3 2 1 missing missing missing
4 │ 4 3 2 1 missing missing
5 │ 5 4 3 2 1 missing
6 │ 6 5 4 3 2 1
7 │ 7 6 5 4 3 2
8 │ 8 7 6 5 4 3
9 │ 9 8 7 6 5 4
10 │ 10 9 8 7 6 5
julia> @btime add_lags($df);
2.900 μs (50 allocations: 3.26 KiB)
I have additionally added timing of this operation for a reference.
The add_lags
function is meant to be an example of “fast code”
(it could still be made faster, or I could have chosen a more complex example
but I wanted to pick something that is easy to understand).
Adding lags using the high-level API function
The functions combine
, select[!]
, transform[!]
, and subset[!]
support
the operation specification syntax
and are called the high-level API in DataFrames.jl.
A natural code for adding lags using the transform
function is:
julia> transform(df, ["a" => (x -> lag(x, i)) => "a$i" for i in 1:5])
10×6 DataFrame
Row │ a a1 a2 a3 a4 a5
│ Int64 Int64? Int64? Int64? Int64? Int64?
─────┼────────────────────────────────────────────────────
1 │ 1 missing missing missing missing missing
2 │ 2 1 missing missing missing missing
3 │ 3 2 1 missing missing missing
4 │ 4 3 2 1 missing missing
5 │ 5 4 3 2 1 missing
6 │ 6 5 4 3 2 1
7 │ 7 6 5 4 3 2
8 │ 8 7 6 5 4 3
9 │ 9 8 7 6 5 4
10 │ 10 9 8 7 6 5
The problem with this solution that users often raise is that it is slower than the previous one:
julia> @btime transform($df, ["a" => (x -> lag(x, i)) => "a$i" for i in 1:5]);
61.600 μs (523 allocations: 27.46 KiB)
There are two questions that one naturally asks. The first is what is the reason of this timing difference and the second is do we care.
So first let me explain the reason for the slowdown. The transform
function is
flexible and allows for many different operations. Therefore it does,
apart from doing core transformations, a lot of extra pre and post processing
that is needed to provide this flexibility.
Now let us try to answer the question if we care. First check a bigger data frame:
julia> df2 = DataFrame(a=1:10^8);
julia> @btime add_lags($df2);
1.354 s (62 allocations: 4.94 GiB)
julia> @btime transform($df2, ["a" => (x -> lag(x, i)) => "a$i" for i in 1:5]);
1.318 s (539 allocations: 4.94 GiB)
The timings are roughly the same.
As you can see - the pre and post processing time in transform
does not grow
with size of data. Therefore we get the following conclusions:
- if you have few small data frames it does not matter what you use; things will be fast;
- if you have a large data frame it does not matter what you use; things will have a comparable speed;
- if you have a lot of small data frames - then you should be careful as
pre and post processing time of
transform
might end up being a significant portion of total run time.
You might think that the code:
["a" => (x -> lag(x, i)) => "a$i" for i in 1:5]
is not so easy to read. This might be true for a newcomer, as this is something new you need to learn. However, my experience is that after some practice with DataFrames.jl operation specification language it becomes quite easy to write and read. You just need to embrace the pattern:
[input column] => [transformation function] => [output column name]
and the rest is a natural consequence (one of the benefits of this pattern
is that it is clear visually as =>
nicely separates its parts).
What is the benefit of using the high-level API?
The benefit of the high-level API is that most likely the lagging operation we
have considered needs to be done in groups. That is, you have a grouping
variable, call it id
, and you want to perform lagging per-id
value
(e.g. you have different stocks and their quotations for different days
that you want to lag).
With transform
this is a breeze. You just need to add groupby
to your code:
julia> df3 = DataFrame(id=repeat(1:2, inner=8), a=1:16)
16×2 DataFrame
Row │ id a
│ Int64 Int64
─────┼──────────────
1 │ 1 1
2 │ 1 2
3 │ 1 3
4 │ 1 4
5 │ 1 5
6 │ 1 6
7 │ 1 7
8 │ 1 8
9 │ 2 9
10 │ 2 10
11 │ 2 11
12 │ 2 12
13 │ 2 13
14 │ 2 14
15 │ 2 15
16 │ 2 16
julia> transform(groupby(df3, :id),
["a" => (x -> lag(x, i)) => "a$i" for i in 1:5])
16×7 DataFrame
Row │ id a a1 a2 a3 a4 a5
│ Int64 Int64 Int64? Int64? Int64? Int64? Int64?
─────┼───────────────────────────────────────────────────────────
1 │ 1 1 missing missing missing missing missing
2 │ 1 2 1 missing missing missing missing
3 │ 1 3 2 1 missing missing missing
4 │ 1 4 3 2 1 missing missing
5 │ 1 5 4 3 2 1 missing
6 │ 1 6 5 4 3 2 1
7 │ 1 7 6 5 4 3 2
8 │ 1 8 7 6 5 4 3
9 │ 2 9 missing missing missing missing missing
10 │ 2 10 9 missing missing missing missing
11 │ 2 11 10 9 missing missing missing
12 │ 2 12 11 10 9 missing missing
13 │ 2 13 12 11 10 9 missing
14 │ 2 14 13 12 11 10 9
15 │ 2 15 14 13 12 11 10
16 │ 2 16 15 14 13 12 11
The add_lags
function does not obviously allow you to do what you want.
After some thinking (and if you know DataFrames.jl well enough) you come to the
conclusion that the following will work:
julia> transform(groupby(df3, :id), add_lags)
16×7 DataFrame
Row │ id a a1 a2 a3 a4 a5
│ Int64 Int64 Int64? Int64? Int64? Int64? Int64?
─────┼───────────────────────────────────────────────────────────
1 │ 1 1 missing missing missing missing missing
2 │ 1 2 1 missing missing missing missing
3 │ 1 3 2 1 missing missing missing
4 │ 1 4 3 2 1 missing missing
5 │ 1 5 4 3 2 1 missing
6 │ 1 6 5 4 3 2 1
7 │ 1 7 6 5 4 3 2
8 │ 1 8 7 6 5 4 3
9 │ 2 9 missing missing missing missing missing
10 │ 2 10 9 missing missing missing missing
11 │ 2 11 10 9 missing missing missing
12 │ 2 12 11 10 9 missing missing
13 │ 2 13 12 11 10 9 missing
14 │ 2 14 13 12 11 10 9
15 │ 2 15 14 13 12 11 10
16 │ 2 16 15 14 13 12 11
so things are still easy to code (but need transform
).
Let us compare the performance of both operations on a larger data frame:
julia> df4 = DataFrame(id=repeat(1:10, inner=10^7), a=1:10^8);
julia> @btime transform(groupby($df4, :id),
["a" => (x -> lag(x, i)) => "a$i" for i in 1:5]);
10.124 s (1731 allocations: 19.35 GiB)
julia> @btime transform(groupby($df4, :id), add_lags);
11.747 s (1552 allocations: 24.59 GiB)
We see that using add_lags
in this case is slightly slower. Of course we could
write a custom add_lags
function that would work by groups, even without using
groupby
but this would be not so easy (I recommend you try it if you are not
convinced).
Here we see the benefit of general design of transform
: it handles grouped
data in the same way as it handles a data frame.
Conclusions
Let me summarize my thinking about code performance in Julia in general, and in DataFrames.jl in particular:
- If you have small data and small problem - it does not matter how fast your code is as it will be fast enough anyway, so you do not have to think about speed.
- If you have large data, then make sure that the solution you use is fast in
the core of the processing (in DataFrames.jl there are things like: grouping
data, sorting, joining, reshaping); the pre and post processing that some
functions (like
transform
in our case) do will be negligible anyway (this logic is something that you most likely already know as every Julia user learns that compilation time is negligible when your code runs for several hours); it is usually not worth to optimize every part of your code; optimize only the parts that are expensive when you scale your computations. - For DataFrames.jl there is, in my experience, only one case when you need to carefully think how to write your code and which functions to use. This case is when you have a lot of (like millions) of small data frames. The reason is that then the cost of pre and post processing in functions from the high-level API of DataFrames.jl indeed might be an important part of the total runtime of your code.