AdventOfCode 2024 Day 12

Part One

AAAA
BBCD
BBCC
EEEC

按照植物的类别建立围栏,求最后围栏的长度

+-+-+-+-+
|A A A A|
+-+-+-+-+     +-+
              |D|
+-+-+   +-+   +-+
|B B|   |C|
+   +   + +-+
|B B|   |C C|
+-+-+   +-+ +
          |C|
+-+-+-+   +-+
|E E E|
+-+-+-+

邻居围栏拆除

最直观的一个解法,就是每个植物默认4个围栏,但是相邻的围栏直接拆除掉,所以每有一个邻居,就可以节省2个围栏,如下图所示

+-+ +-+                       +--+ 
|A| |A|            ->         |AA|
+-+ +-+                       +--+

因此,只要求出所有的邻居数量即可。在求邻居的时候,注意避免重复计算,让坐标偏小的在左即可。

那么公式就简化为植物数量 * 4 – 邻居数 * 2

with rows as (
    SELECT row_number() over () as _row, line
    FROM lance_input
), matrix AS (
    SELECT _row :: INTEGER, x.idx :: INTEGER as _col, x.pos as plant
    FROM rows, regexp_split_to_table(line, '') with ordinality as x(pos, idx)
), plant_cnt AS (
    SELECT plant, count(*) as cnt
    FROM matrix
    GROUP BY plant
), plant_neighbour AS (
    SELECT a.plant, count(*) as CNT
    FROM matrix a, matrix b
    WHERE a.plant = b.plant
    AND ((a._row + 1 = b._row AND a._col = b._col) OR (a._col + 1 = b._col AND a._row = b._row))
    GROUP BY a.plant
)
SELECT sum(a.cnt * (a.cnt * 4 - coalesce(b.cnt, 0) * 2))
FROM plant_cnt a 
LEFT JOIN plant_neighbour b
ON a.plant = b.plant;

相同植物多个区域

但是答案却是错误的,因为没有考虑相同植物种植在多个区域的问题,每个区域的围栏数是需要单独计算的。这就给我们带来了一个比较大的挑战,因为事先并不知晓每种植物到底会分布在多少个区域中。

OOOOO
OXOXO
OOOOO
OXOXO
OOOOO

在上面的例子中,X被分布在4个区域中,但是计算规则是植物数量 * 围栏数量,正确的结果是1 * 4 * 4 = 16。而合并在一起计算,计算的结果是4 * 16 = 64,因此还是需要找出X到底分布在那几个区域。

植物区域计算

一个直观的算法,就是flood fill算法,从某一点出发,尽可能去找相邻,找到的都是在同一个region中。而未找到的植物则肯定在新的一个region中。

针对matrix中的每一处坐标,都需要计算出region_seq,默认从0开始。每发现一个新的region,则seq加一。

这里可以通过函数循环update表中的记录,也可以通过递归查询持续输出最新的计算结果,这里就用递归查询的方式来计算。

第一步:初始化最左上

    SELECT _row, _col, plant, 
           rn = 1 as find,
           0 as region_seq,
           1 as round
    FROM (
        SELECT _row, _col, plant,
               row_number() over (partition by plant order by _row, _col) as rn
        FROM matrix
    ) t

第二步:持续找出邻居,并更新其region_seq。如果是邻居,则region_seq不变,如果不是,则seq+1

SELECT b._row, b._col, b.plant,
       abs(a._row - b._row) + abs(a._col - b._col) = 1 as neighbour,
       case when abs(a._row - b._row) + abs(a._col - b._col) = 1 then a.region_seq else a.region_seq + 1 end as region_seq
FROM _region a
JOIN _region b
ON a.plant = b.plant
AND a.find
AND NOT b.find

第三步:按照邻居->非邻居的优先级,从而得到最终的region_seq。如果是邻居,则order为1,否则为0,并带上坐标,避免一次发现多个新region。因此一次会发现多个邻居,因此这里用rank,而不是row_number。

            SELECT DISTINCT _row, _col, plant, region_seq
            FROM (
                SELECT _row, _col, plant, region_seq,
                       rank() over (partition by plant order by case when neighbour then '1' else '0' || lpad(_row::varchar, 3, '0') || lpad(_col::varchar, 3, '0') end desc, region_seq desc) as rn
                FROM (
                    SELECT b._row, b._col, b.plant,
                           abs(a._row - b._row) + abs(a._col - b._col) = 1 as neighbour,
                           case when abs(a._row - b._row) + abs(a._col - b._col) = 1 then a.region_seq else a.region_seq + 1 end as region_seq
                    FROM _region a
                    JOIN _region b
                    ON a.plant = b.plant
                    AND a.find
                    AND NOT b.find
                ) t
            ) t
            WHERE rn = 1

终止条件:所有的坐标都已经处理过了,总的迭代次数应该是和matrix的边长成正比。

WHERE exists (select * From _region where not find limit 1)

完整的计算逻辑如下所示:

with recursive rows as (
    SELECT row_number() over () as _row, line
    FROM lance_input
), matrix AS (
    SELECT _row :: INTEGER, x.idx :: INTEGER as _col, x.pos as plant
    FROM rows, regexp_split_to_table(line, '') with ordinality as x(pos, idx)
), region AS (
    SELECT _row, _col, plant, 
           rn = 1 as find,
           0 as region_seq,
           1 as round
    FROM (
        SELECT _row, _col, plant,
               row_number() over (partition by plant order by _row, _col) as rn
        FROM matrix
    ) t

    UNION ALL

    SELECT _row, _col, plant, find, region_seq, round
    FROM (
        WITH _region AS (SELECT * FROM region)
        SELECT a._row, a._col, a.plant, 
               a.find OR b._row is not null as find,
               coalesce(b.region_seq, a.region_seq) as region_seq,
               a.round + 1 as round
        FROM _region a
        LEFT JOIN (
            SELECT DISTINCT _row, _col, plant, region_seq
            FROM (
                SELECT _row, _col, plant, region_seq,
                       rank() over (partition by plant order by case when neighbour then '1' else '0' || lpad(_row::varchar, 3, '0') || lpad(_col::varchar, 3, '0') end desc, region_seq desc) as rn
                FROM (
                    SELECT b._row, b._col, b.plant,
                           abs(a._row - b._row) + abs(a._col - b._col) = 1 as neighbour,
                           case when abs(a._row - b._row) + abs(a._col - b._col) = 1 then a.region_seq else a.region_seq + 1 end as region_seq
                    FROM _region a
                    JOIN _region b
                    ON a.plant = b.plant
                    AND a.find
                    AND NOT b.find
                ) t
            ) t
            WHERE rn = 1
        ) b
        ON a._row = b._row
        AND a._col = b._col
        WHERE exists (select * From _region where not find limit 1)
    ) t
)

性能优化

然而计算性能却非常慢,而且越到后面越慢。输出前30W行需要9s,输出60W行则变成了38秒。

postgres(# ) SELECT * from region limit 300000;
SELECT 300000
Time: 9061.534 ms (00:09.062)

postgres(# ) SELECT * from region limit 600000;
SELECT 600000
Time: 38549.538 ms (00:38.550)

究其原因,就是越到后面,find的数量越多,因此join的代价也就越大。

postgres=# select round, count(*) from lance_test where find group by round order by 1;
 round | count 
-------+-------
     1 |    26
     2 |    65
     3 |   119
     4 |   183
     5 |   266
     6 |   372
     7 |   483
     8 |   605
     9 |   733
    10 |   859
    11 |   983
    12 |  1105
    13 |  1215
    14 |  1309
    15 |  1398
    16 |  1482
    17 |  1561
    18 |  1655
    19 |  1756
    20 |  1858
    21 |  1969

但是,每次寻找邻居的时候,应该让新加入的邻居再去找它们的邻居,已经找过邻居的节点完全没有必要再去计算了。所以,通过一个active flag来标识哪些节点是刚加入的,只过滤出这些节点即可。

        WITH _region AS (SELECT * FROM region)
        SELECT a._row, a._col, a.plant, 
               a.find OR b._row is not null as find,
               b._row is not null as active,
               coalesce(b.region_seq, a.region_seq) as region_seq,
               a.round + 1 as round
        FROM _region a
        LEFT JOIN (
            SELECT DISTINCT _row, _col, plant, region_seq
            FROM (
                SELECT _row, _col, plant, region_seq,
                       rank() over (partition by plant order by case when neighbour then '1' else '0' || lpad(_row::varchar, 3, '0') || lpad(_col::varchar, 3, '0') end desc, region_seq desc) as rn
                FROM (
                    SELECT b._row, b._col, b.plant,
                           abs(a._row - b._row) + abs(a._col - b._col) = 1 as neighbour,
                           case when abs(a._row - b._row) + abs(a._col - b._col) = 1 then a.region_seq else a.region_seq + 1 end as region_seq
                    FROM _region a
                    JOIN _region b
                    ON a.plant = b.plant
                    AND a.active
                    AND NOT b.find
                ) t
            ) t
            WHERE rn = 1
        ) b
        ON a._row = b._row
        AND a._col = b._col
        WHERE exists (select * From _region where not find limit 1)

Part Two

如果多个围栏在同一直线上,则可以合并为一个,求新规则下的围栏数量。

+-+-+-+-+          +- - - -+
|A A A A|      ->  |A A A A| 
+-+-+-+-+          +- - - -+

首先,构想一下,哪些场景是围栏可以合并的。

比如左右相邻的节点,只要都没有向上的邻居,则上面的围栏可以合并,如下图所示:

+-+-+       +- -+
|A A|    -> |A A| 
+ +-+       + +-+
|A|         |A|

同理,没有向下邻居可以合并下面的围栏,还有向左向右等等。

我们要求出每个节点是否有四个方向上的邻居

plant_stats AS (
    SELECT a._row, a._col, a.plant, a.region_seq, 
           max(case when b._row = a._row - 1 then 1 else 0 end) as has_up,
           max(case when b._row = a._row + 1 then 1 else 0 end) as has_down,
           max(case when b._col = a._col - 1 then 1 else 0 end) as has_left,
           max(case when b._col = a._col + 1 then 1 else 0 end) as has_right
    FROM final_round a
    LEFT JOIN final_round b
    ON a.plant = b.plant
    AND abs(a._row - b._row) + abs(a._col - b._col) = 1
    GROUP BY a._row, a._col, a.plant, a.region_seq
)

那么可以合并的围栏数,计算逻辑就比较简单了

plant_straight AS (
    SELECT a.plant, a.region_seq,
           sum(case when a._col + 1 = b._col AND a.has_up = 0 AND b.has_up = 0 then 1 else 0 end) as up_trim,
           sum(case when a._col + 1 = b._col AND a.has_down = 0 AND b.has_down = 0 then 1 else 0 end) as down_trim,
           sum(case when a._row + 1 = b._row AND a.has_left = 0 AND b.has_left = 0 then 1 else 0 end) as left_trim,
           sum(case when a._row + 1 = b._row AND a.has_right = 0 AND b.has_right = 0 then 1 else 0 end) as right_trim
    FROM plant_stats a, plant_stats b
    WHERE a.plant = b.plant
    AND ((a._row + 1 = b._row AND a._col = b._col) OR (a._col + 1 = b._col AND a._row = b._row))
    GROUP BY a.plant, a.region_seq
)

最后的算法就是植物数 * 4 – 邻居数 * 2 – 向上合并数 – 向下合并数 – 向左合并数 – 向右合并数。

完整的SQL如下:

with recursive rows as (
    SELECT row_number() over () as _row, line
    FROM lance_input
), matrix AS (
    SELECT _row :: INTEGER, x.idx :: INTEGER as _col, x.pos as plant
    FROM rows, regexp_split_to_table(line, '') with ordinality as x(pos, idx)
), region AS (
    SELECT _row, _col, plant, 
           rn = 1 as find,
           rn = 1 as active,
           0 as region_seq,
           1 as round
    FROM (
        SELECT _row, _col, plant,
               row_number() over (partition by plant order by _row, _col) as rn
        FROM matrix
    ) t

    UNION ALL

    SELECT _row, _col, plant, find, active, region_seq, round
    FROM (
        WITH _region AS (SELECT * FROM region)
        SELECT a._row, a._col, a.plant, 
               a.find OR b._row is not null as find,
               b._row is not null as active,
               coalesce(b.region_seq, a.region_seq) as region_seq,
               a.round + 1 as round
        FROM _region a
        LEFT JOIN (
            SELECT DISTINCT _row, _col, plant, region_seq
            FROM (
                SELECT _row, _col, plant, region_seq,
                       rank() over (partition by plant order by case when neighbour then '1' else '0' || lpad(_row::varchar, 3, '0') || lpad(_col::varchar, 3, '0') end desc, region_seq desc) as rn
                FROM (
                    SELECT b._row, b._col, b.plant,
                           abs(a._row - b._row) + abs(a._col - b._col) = 1 as neighbour,
                           case when abs(a._row - b._row) + abs(a._col - b._col) = 1 then a.region_seq else a.region_seq + 1 end as region_seq
                    FROM _region a
                    JOIN _region b
                    ON a.plant = b.plant
                    AND a.active
                    AND NOT b.find
                ) t
            ) t
            WHERE rn = 1
        ) b
        ON a._row = b._row
        AND a._col = b._col
        WHERE exists (select * From _region where not find limit 1)
    ) t
), final_round AS (
    SELECT _row, _col, plant, find, active, region_seq, round
    FROM region
    WHERE round = (select max(round) from region)
), plant_cnt AS (
    SELECT plant, region_seq, count(*) as cnt
    FROM final_round
    GROUP BY plant, region_seq
), plant_stats AS (
    SELECT a._row, a._col, a.plant, a.region_seq, 
           max(case when b._row = a._row - 1 then 1 else 0 end) as has_up,
           max(case when b._row = a._row + 1 then 1 else 0 end) as has_down,
           max(case when b._col = a._col - 1 then 1 else 0 end) as has_left,
           max(case when b._col = a._col + 1 then 1 else 0 end) as has_right
    FROM final_round a
    LEFT JOIN final_round b
    ON a.plant = b.plant
    AND abs(a._row - b._row) + abs(a._col - b._col) = 1
    GROUP BY a._row, a._col, a.plant, a.region_seq
), plant_neighbour AS (
    SELECT a.plant, a.region_seq, count(*) as cnt
    FROM final_round a, final_round b
    WHERE a.plant = b.plant
    AND ((a._row + 1 = b._row AND a._col = b._col) OR (a._col + 1 = b._col AND a._row = b._row))
    GROUP BY a.plant, a.region_seq
), plant_straight AS (
    SELECT a.plant, a.region_seq,
           sum(case when a._col + 1 = b._col AND a.has_up = 0 AND b.has_up = 0 then 1 else 0 end) as up_trim,
           sum(case when a._col + 1 = b._col AND a.has_down = 0 AND b.has_down = 0 then 1 else 0 end) as down_trim,
           sum(case when a._row + 1 = b._row AND a.has_left = 0 AND b.has_left = 0 then 1 else 0 end) as left_trim,
           sum(case when a._row + 1 = b._row AND a.has_right = 0 AND b.has_right = 0 then 1 else 0 end) as right_trim
    FROM plant_stats a, plant_stats b
    WHERE a.plant = b.plant
    AND ((a._row + 1 = b._row AND a._col = b._col) OR (a._col + 1 = b._col AND a._row = b._row))
    GROUP BY a.plant, a.region_seq
)
SELECT sum(a.cnt * (a.cnt * 4 - coalesce(b.cnt, 0) * 2 - coalesce(c.up_trim, 0) - coalesce(c.down_trim, 0) - coalesce(c.left_trim, 0) - coalesce(c.right_trim, 0)))
FROM plant_cnt a 
LEFT JOIN plant_neighbour b
ON a.plant = b.plant
AND a.region_seq = b.region_seq
LEFT JOIN plant_straight c
ON a.plant = c.plant
AND a.region_seq = c.region_seq;

发表评论