Clean Clojure: Small Functions

This is part 2 in a series on clean Clojure. Previously: meaningful names.

In Clean Code, Uncle Bob proposes two rules for good functions: “The first rule of functions is that they should be small. The second rule of functions is that they should be smaller than that.” Useful rules, but Clojure requires one more: they’re still not small enough.

Functional languages and immutable data make reasoning easy by making functions simpler. Functions take input, transform it, and return new output. Data passes through functions, flowing rather than mutating. Complicated functions make simple hard, and they can be dangerously easy to write.

Hyperbole aside, there really are two simple rules for functions: they should be small, and do one thing. In his presentation on functions, Uncle Bob describes a simple algorithm for cleaning up crufty functions:

  1. Pick a function.
  2. Extract functions until it does one thing.
  3. Recur on extracted functions.

Instead of a contrived wombat example, I’ll use one of my own disgusting old 4Clojure solutions as an example of atrocious code. (But cut me some slack! I was young and naive.)

Here’s the problem:

“Write a function which takes a collection of integers as an argument. Return the count of how many elements are smaller than the sum of their squared component digits. For example: 10 is larger than 1 squared plus 0 squared; whereas 15 is smaller than 1 squared plus 5 squared.”

And here’s my answer (hide the children):

1
2
3
4
5
6
7
8
9
10
(fn [elements]
  (let [lt-sqd-components
        (fn [n]
          (let [digits  (map #(- (int %) 48) (seq (str n)))
                squares (map #(* % %) digits)
                sqsum   (reduce + squares)]
            (if (< n sqsum)
              true
              false)))]
    (count (filter true? (map lt-sqd-components elements)))))

Like nested blocks in other languages, code that sprawls rightward indicates a problem—and it can happen fast in Clojure. To start, we’ll extract lt-sqd-components from the let binding. (This is a common, awful 4Clojure hack for defining a named function inside an anonymous one, though the discerning 4Clojurist uses letfn).

1
2
3
4
5
6
7
8
9
10
11
(def lt-sqd-components
     (fn [n]
       (let [digits  (map #(- (int %) 48) (seq (str n)))
             squares (map #(* % %) digits)
             sqsum   (reduce + squares)]
         (if (< n sqsum)
              true
              false)))

(fn [elements]
    (count (filter true? (map lt-sqd-components elements)))))

The original function is almost readable, but we can do better. It looks like I didn’t understand filter when I wrote this: the extra map is redundant since lt-sqd-components is already a predicate function that returns true or false.

1
2
3
4
5
6
7
8
9
10
11
(def lt-sqd-components
     (fn [n]
       (let [digits  (map #(- (int %) 48) (seq (str n)))
             squares (map #(* % %) digits)
             sqsum   (reduce + squares)]
         (if (< n sqsum)
              true
              false)))

(fn [elements]
  (count (filter lt-sqd-components elements)))

This does one thing, so let’s clean it up and move on. It needs a name, and the function we’re filtering against needs a question mark.

1
2
3
4
5
6
7
8
9
10
11
12
(def lt-sqd-components?
     (fn [n]
       (let [digits  (map #(- (int %) 48) (seq (str n)))
             squares (map #(* % %) digits)
             sqsum   (reduce + squares)]
         (if (< n sqsum)
              true
              false)))

(defn count-less-than-sum-of-squares
  [coll]
  (count (filter lt-sqd-components? coll)))

And now the recursive step. Let’s look at the terribly-named lt-sqd-components. Each line in its let binding does something different. One splits a number into a sequence of its digits:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
(defn split-digits [number]
  (map #(- (int %) 48) (seq (str n))))

(def lt-sqd-components
     (fn [n]
       (let [digits  (split-digits n)))
             squares (map #(* % %) digits)
             sqsum   (reduce + squares)]
         (if (< n sqsum)
              true
              false))

(defn count-less-than-sum-of-squares
  [coll]
  (count (filter lt-sqd-components? coll)))

One squares every element in a sequence:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
(defn split-digits [number]
  (map #(- (int %) 48) (seq (str n))))

(defn square-all [digits]
  (map #(* % %) digits))

(def lt-sqd-components?
     (fn [n]
       (let [digits  (split-digits n)))
             squares (square-all digits)
             sqsum   (reduce + squares)]
         (if (< n sqsum)
              true
              false))

(defn count-less-than-sum-of-squares
  [coll]
  (count (filter lt-sqd-components? coll)))

And one takes the sum of the collection.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
(defn split-digits [number]
  (map #(- (int %) 48) (seq (str n))))

(defn square-all [digits]
  (map #(* % %) digits))

(defn sum-of-squares [squares]
  (reduce + squares))

(def lt-sqd-components?
     (fn [n]
       (let [digits  (split-digits n)))
             squares (square-all digits)
             sqsum   (sum-of-squares squares)]
         (if (< n sqsum)
              true
              false))

(defn count-less-than-sum-of-squares
  [coll]
  (count (filter lt-sqd-components? coll)))

One more function to extract: the let binding should be its own function. One might argue that this function does one thing—all it does is check whether a number is less than the sum of its squared components! But it’s operating on several different levels of abstraction: digits, a sequence of digits, and their sum. A helpful guideline is limiting functions to one level of abstraction. In this case, the function should only know about the sum.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
(defn split-digits [number]
  (map #(- (int %) 48) (seq (str n))))

(defn square-all [digits]
  (map #(* % %) digits))

(defn sum-of-squares [squares]
  (reduce + squares))

(defn sum-of-squared-components [n]
  (let [digits  (split-digits n)
        squares (square-all digits)
        sqsum   (sum-of-squares squares)]
    sqsum))

(def lt-sqd-components?
     (fn [n]
         (< n (sum-of-squared-components n)))

(defn count-less-than-sum-of-squares
  [coll]
  (count (filter lt-sqd-components? coll)))

Despite its dumb name, lt-sqd-components? is doing one thing. Let’s clean it up. I prefer “digits” to “components”, and it should use defn.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
(defn split-digits [number]
  (map #(- (int %) 48) (seq (str n))))

(defn square-all [digits]
  (map #(* % %) digits))

(defn sum-of-squares [squares]
  (reduce + squares))

(defn sum-of-squared-digits [n]
  (let [digits  (split-digits n)
        squares (square-all digits)
        sqsum   (sum-of-squares squares)]
    sqsum))

(defn less-than-sum-of-digits-squared?
  [n]
  (< n (sum-of-squared-digits n)))

(defn count-less-than-sum-of-squares
  [coll]
  (count (filter less-than-sum-of-digits-squared? coll)))

On to sum-of-squared-digits. We can transform the let binding into a function using the threading macro (as suggested in the comments on my last post).

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
(defn split-digits [number]
  (map #(- (int %) 48) (seq (str n))))

(defn square-all [digits]
  (map #(* % %) digits))

(defn sum-of-squares [squares]
  (reduce + squares))

(defn sum-of-squared-components [n]
  (-> n
      split-digits
      square-all
      sum-of-squares))

(defn less-than-sum-of-digits-squared?
  [n]
  (< n (sum-of-squared-digits n)))

(defn count-less-than-sum-of-squares
  [coll]
  (count (filter lt-sqd-components? coll)))

We can do better. I don’t like the intermediate square-all step, which should be hidden in sum-of-squares:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
(defn split-digits [number]
  (map #(- (int %) 48) (seq (str n))))

(defn square-all [digits]
  (map #(* % %) digits))

(defn sum-of-squares [digits]
  (reduce + (square-all digits)))

(defn sum-of-squared-components [n]
  (-> n
      split-digits
      sum-of-squares))

(defn less-than-sum-of-digits-squared?
  [n] (< n (sum-of-squared-digits n)))

(defn count-less-than-sum-of-squares
  [coll]
  (count (filter lt-sqd-components? coll)))

Extract the function literal in square-all. I’ve got a great name for it:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
(defn split-digits [number]
  (map #(- (int %) 48) (seq (str n))))

(defn square [n] (* n n))

(defn square-all [digits]
  (map square digits))

(defn sum-of-squares [digits]
  (reduce + (square-all digits)))

(defn sum-of-squared-components [n]
  (-> n
      split-digits
      sum-of-squares))

(defn less-than-sum-of-digits-squared?
  [n]
  (< n (sum-of-squared-digits n)))

(defn count-less-than-sum-of-squares
  [coll]
  (count (filter lt-sqd-components? coll)))

And there’s only one function left: splitting a number into a sequence of digits. Let’s extract and name the function literal:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
(defn char->num [character]
  (- (int character) 48))

(defn split-digits [number]
  (map char->num (seq (str number))))

(defn square [n] (* n n))

(defn square-all [digits]
  (map square digits))

(defn sum-of-squares [digits]
  (reduce + (square-all digits)))

(defn sum-of-squared-components [n]
  (-> n
      split-digits
      sum-of-squares))

(defn less-than-sum-of-digits-squared?
  [n]
  (< n (sum-of-squared-digits n)))

(defn count-less-than-sum-of-squares
  [coll]
  (count (filter lt-sqd-components? coll)))

And finally, clean it up by using Integer/parseInt instead of hacky subtraction:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
(defn char->num [character]
  (Integer/parseInt (str character)))

(defn split-digits [number]
  (map char->num (str number)))

(defn square [n] (* n n))

(defn square-all [digits]
  (map square digits))

(defn sum-of-squares [digits]
  (reduce + (square-all digits)))

(defn sum-of-squared-digits [n]
  (-> n
      split-digits
      sum-of-squares))

(defn less-than-sum-of-digits-squared?
  [number]
  (< number
     (sum-of-squared-digits number)))

(defn count-less-than-sum-of-squares
  [coll]
  (count (filter less-than-sum-of-digits-squared? coll)))

And there it is—clean, readable functions at all levels of abstraction, minimal nesting, and nothing longer than three lines. Starting from the top, low-level functions build into bigger abstractions through combination and composition. Each step is easy to read and comprehend.

As Uncle Bob puts it in Clean Code:

Master programmers think of systems as stories to be told rather than programs to be written. They use the facilities of their chosen programming language to construct a much richer and more expressive language that can be used to tell that story. Part of that domain-specific language is the hierarchy of functions that describe all the actions that take place within that system. In an artful act of recursion, those actions are written to use the very domain-specific language they define to tell their own small part of the story.

Extract. Simplify. Recur. Take the time to consider each line, and clean code comes naturally.

Comments