Part One
AAAA BBCD BBCC EEEC按照植物的类别建立围栏,求最后围栏的长度
+-+-+-+-+ |A A A A| +-+-+-+-+ +-+ |D| +-+-+ +-+ +-+ |B B| |C| + + + +-+ |B B| |C C| +-+-+ +-+ + |C| +-+-+-+ +-+ |E E E| +-+-+-+
邻居围栏拆除
最直观的一个解法,就是每个植物默认4个围栏,但是相邻的围栏直接拆除掉,所以每有一个邻居,就可以节省2个围栏,如下图所示
+-+ +-+ +--+
|A| |A| -> |AA|
+-+ +-+ +--+
因此,只要求出所有的邻居数量即可。在求邻居的时候,注意避免重复计算,让坐标偏小的在左即可。
那么公式就简化为植物数量 * 4 – 邻居数 * 2
with rows as (
SELECT row_number() over () as _row, line
FROM lance_input
), matrix AS (
SELECT _row :: INTEGER, x.idx :: INTEGER as _col, x.pos as plant
FROM rows, regexp_split_to_table(line, '') with ordinality as x(pos, idx)
), plant_cnt AS (
SELECT plant, count(*) as cnt
FROM matrix
GROUP BY plant
), plant_neighbour AS (
SELECT a.plant, count(*) as CNT
FROM matrix a, matrix b
WHERE a.plant = b.plant
AND ((a._row + 1 = b._row AND a._col = b._col) OR (a._col + 1 = b._col AND a._row = b._row))
GROUP BY a.plant
)
SELECT sum(a.cnt * (a.cnt * 4 - coalesce(b.cnt, 0) * 2))
FROM plant_cnt a
LEFT JOIN plant_neighbour b
ON a.plant = b.plant;
相同植物多个区域
但是答案却是错误的,因为没有考虑相同植物种植在多个区域的问题,每个区域的围栏数是需要单独计算的。这就给我们带来了一个比较大的挑战,因为事先并不知晓每种植物到底会分布在多少个区域中。
OOOOO
OXOXO
OOOOO
OXOXO
OOOOO
在上面的例子中,X被分布在4个区域中,但是计算规则是植物数量 * 围栏数量,正确的结果是1 * 4 * 4 = 16。而合并在一起计算,计算的结果是4 * 16 = 64,因此还是需要找出X到底分布在那几个区域。
植物区域计算
一个直观的算法,就是flood fill算法,从某一点出发,尽可能去找相邻,找到的都是在同一个region中。而未找到的植物则肯定在新的一个region中。
针对matrix中的每一处坐标,都需要计算出region_seq,默认从0开始。每发现一个新的region,则seq加一。
这里可以通过函数循环update表中的记录,也可以通过递归查询持续输出最新的计算结果,这里就用递归查询的方式来计算。
第一步:初始化最左上
SELECT _row, _col, plant,
rn = 1 as find,
0 as region_seq,
1 as round
FROM (
SELECT _row, _col, plant,
row_number() over (partition by plant order by _row, _col) as rn
FROM matrix
) t
第二步:持续找出邻居,并更新其region_seq。如果是邻居,则region_seq不变,如果不是,则seq+1
SELECT b._row, b._col, b.plant,
abs(a._row - b._row) + abs(a._col - b._col) = 1 as neighbour,
case when abs(a._row - b._row) + abs(a._col - b._col) = 1 then a.region_seq else a.region_seq + 1 end as region_seq
FROM _region a
JOIN _region b
ON a.plant = b.plant
AND a.find
AND NOT b.find
第三步:按照邻居->非邻居的优先级,从而得到最终的region_seq。如果是邻居,则order为1,否则为0,并带上坐标,避免一次发现多个新region。因此一次会发现多个邻居,因此这里用rank,而不是row_number。
SELECT DISTINCT _row, _col, plant, region_seq
FROM (
SELECT _row, _col, plant, region_seq,
rank() over (partition by plant order by case when neighbour then '1' else '0' || lpad(_row::varchar, 3, '0') || lpad(_col::varchar, 3, '0') end desc, region_seq desc) as rn
FROM (
SELECT b._row, b._col, b.plant,
abs(a._row - b._row) + abs(a._col - b._col) = 1 as neighbour,
case when abs(a._row - b._row) + abs(a._col - b._col) = 1 then a.region_seq else a.region_seq + 1 end as region_seq
FROM _region a
JOIN _region b
ON a.plant = b.plant
AND a.find
AND NOT b.find
) t
) t
WHERE rn = 1
终止条件:所有的坐标都已经处理过了,总的迭代次数应该是和matrix的边长成正比。
WHERE exists (select * From _region where not find limit 1)
完整的计算逻辑如下所示:
with recursive rows as (
SELECT row_number() over () as _row, line
FROM lance_input
), matrix AS (
SELECT _row :: INTEGER, x.idx :: INTEGER as _col, x.pos as plant
FROM rows, regexp_split_to_table(line, '') with ordinality as x(pos, idx)
), region AS (
SELECT _row, _col, plant,
rn = 1 as find,
0 as region_seq,
1 as round
FROM (
SELECT _row, _col, plant,
row_number() over (partition by plant order by _row, _col) as rn
FROM matrix
) t
UNION ALL
SELECT _row, _col, plant, find, region_seq, round
FROM (
WITH _region AS (SELECT * FROM region)
SELECT a._row, a._col, a.plant,
a.find OR b._row is not null as find,
coalesce(b.region_seq, a.region_seq) as region_seq,
a.round + 1 as round
FROM _region a
LEFT JOIN (
SELECT DISTINCT _row, _col, plant, region_seq
FROM (
SELECT _row, _col, plant, region_seq,
rank() over (partition by plant order by case when neighbour then '1' else '0' || lpad(_row::varchar, 3, '0') || lpad(_col::varchar, 3, '0') end desc, region_seq desc) as rn
FROM (
SELECT b._row, b._col, b.plant,
abs(a._row - b._row) + abs(a._col - b._col) = 1 as neighbour,
case when abs(a._row - b._row) + abs(a._col - b._col) = 1 then a.region_seq else a.region_seq + 1 end as region_seq
FROM _region a
JOIN _region b
ON a.plant = b.plant
AND a.find
AND NOT b.find
) t
) t
WHERE rn = 1
) b
ON a._row = b._row
AND a._col = b._col
WHERE exists (select * From _region where not find limit 1)
) t
)
性能优化
然而计算性能却非常慢,而且越到后面越慢。输出前30W行需要9s,输出60W行则变成了38秒。
postgres(# ) SELECT * from region limit 300000;
SELECT 300000
Time: 9061.534 ms (00:09.062)
postgres(# ) SELECT * from region limit 600000;
SELECT 600000
Time: 38549.538 ms (00:38.550)
究其原因,就是越到后面,find的数量越多,因此join的代价也就越大。
postgres=# select round, count(*) from lance_test where find group by round order by 1;
round | count
-------+-------
1 | 26
2 | 65
3 | 119
4 | 183
5 | 266
6 | 372
7 | 483
8 | 605
9 | 733
10 | 859
11 | 983
12 | 1105
13 | 1215
14 | 1309
15 | 1398
16 | 1482
17 | 1561
18 | 1655
19 | 1756
20 | 1858
21 | 1969
但是,每次寻找邻居的时候,应该让新加入的邻居再去找它们的邻居,已经找过邻居的节点完全没有必要再去计算了。所以,通过一个active flag来标识哪些节点是刚加入的,只过滤出这些节点即可。
WITH _region AS (SELECT * FROM region)
SELECT a._row, a._col, a.plant,
a.find OR b._row is not null as find,
b._row is not null as active,
coalesce(b.region_seq, a.region_seq) as region_seq,
a.round + 1 as round
FROM _region a
LEFT JOIN (
SELECT DISTINCT _row, _col, plant, region_seq
FROM (
SELECT _row, _col, plant, region_seq,
rank() over (partition by plant order by case when neighbour then '1' else '0' || lpad(_row::varchar, 3, '0') || lpad(_col::varchar, 3, '0') end desc, region_seq desc) as rn
FROM (
SELECT b._row, b._col, b.plant,
abs(a._row - b._row) + abs(a._col - b._col) = 1 as neighbour,
case when abs(a._row - b._row) + abs(a._col - b._col) = 1 then a.region_seq else a.region_seq + 1 end as region_seq
FROM _region a
JOIN _region b
ON a.plant = b.plant
AND a.active
AND NOT b.find
) t
) t
WHERE rn = 1
) b
ON a._row = b._row
AND a._col = b._col
WHERE exists (select * From _region where not find limit 1)
Part Two
如果多个围栏在同一直线上,则可以合并为一个,求新规则下的围栏数量。
+-+-+-+-+ +- - - -+ |A A A A| -> |A A A A| +-+-+-+-+ +- - - -+
首先,构想一下,哪些场景是围栏可以合并的。
比如左右相邻的节点,只要都没有向上的邻居,则上面的围栏可以合并,如下图所示:
+-+-+ +- -+
|A A| -> |A A|
+ +-+ + +-+
|A| |A|
同理,没有向下邻居可以合并下面的围栏,还有向左向右等等。
我们要求出每个节点是否有四个方向上的邻居
plant_stats AS (
SELECT a._row, a._col, a.plant, a.region_seq,
max(case when b._row = a._row - 1 then 1 else 0 end) as has_up,
max(case when b._row = a._row + 1 then 1 else 0 end) as has_down,
max(case when b._col = a._col - 1 then 1 else 0 end) as has_left,
max(case when b._col = a._col + 1 then 1 else 0 end) as has_right
FROM final_round a
LEFT JOIN final_round b
ON a.plant = b.plant
AND abs(a._row - b._row) + abs(a._col - b._col) = 1
GROUP BY a._row, a._col, a.plant, a.region_seq
)
那么可以合并的围栏数,计算逻辑就比较简单了
plant_straight AS (
SELECT a.plant, a.region_seq,
sum(case when a._col + 1 = b._col AND a.has_up = 0 AND b.has_up = 0 then 1 else 0 end) as up_trim,
sum(case when a._col + 1 = b._col AND a.has_down = 0 AND b.has_down = 0 then 1 else 0 end) as down_trim,
sum(case when a._row + 1 = b._row AND a.has_left = 0 AND b.has_left = 0 then 1 else 0 end) as left_trim,
sum(case when a._row + 1 = b._row AND a.has_right = 0 AND b.has_right = 0 then 1 else 0 end) as right_trim
FROM plant_stats a, plant_stats b
WHERE a.plant = b.plant
AND ((a._row + 1 = b._row AND a._col = b._col) OR (a._col + 1 = b._col AND a._row = b._row))
GROUP BY a.plant, a.region_seq
)
最后的算法就是植物数 * 4 – 邻居数 * 2 – 向上合并数 – 向下合并数 – 向左合并数 – 向右合并数。
完整的SQL如下:
with recursive rows as (
SELECT row_number() over () as _row, line
FROM lance_input
), matrix AS (
SELECT _row :: INTEGER, x.idx :: INTEGER as _col, x.pos as plant
FROM rows, regexp_split_to_table(line, '') with ordinality as x(pos, idx)
), region AS (
SELECT _row, _col, plant,
rn = 1 as find,
rn = 1 as active,
0 as region_seq,
1 as round
FROM (
SELECT _row, _col, plant,
row_number() over (partition by plant order by _row, _col) as rn
FROM matrix
) t
UNION ALL
SELECT _row, _col, plant, find, active, region_seq, round
FROM (
WITH _region AS (SELECT * FROM region)
SELECT a._row, a._col, a.plant,
a.find OR b._row is not null as find,
b._row is not null as active,
coalesce(b.region_seq, a.region_seq) as region_seq,
a.round + 1 as round
FROM _region a
LEFT JOIN (
SELECT DISTINCT _row, _col, plant, region_seq
FROM (
SELECT _row, _col, plant, region_seq,
rank() over (partition by plant order by case when neighbour then '1' else '0' || lpad(_row::varchar, 3, '0') || lpad(_col::varchar, 3, '0') end desc, region_seq desc) as rn
FROM (
SELECT b._row, b._col, b.plant,
abs(a._row - b._row) + abs(a._col - b._col) = 1 as neighbour,
case when abs(a._row - b._row) + abs(a._col - b._col) = 1 then a.region_seq else a.region_seq + 1 end as region_seq
FROM _region a
JOIN _region b
ON a.plant = b.plant
AND a.active
AND NOT b.find
) t
) t
WHERE rn = 1
) b
ON a._row = b._row
AND a._col = b._col
WHERE exists (select * From _region where not find limit 1)
) t
), final_round AS (
SELECT _row, _col, plant, find, active, region_seq, round
FROM region
WHERE round = (select max(round) from region)
), plant_cnt AS (
SELECT plant, region_seq, count(*) as cnt
FROM final_round
GROUP BY plant, region_seq
), plant_stats AS (
SELECT a._row, a._col, a.plant, a.region_seq,
max(case when b._row = a._row - 1 then 1 else 0 end) as has_up,
max(case when b._row = a._row + 1 then 1 else 0 end) as has_down,
max(case when b._col = a._col - 1 then 1 else 0 end) as has_left,
max(case when b._col = a._col + 1 then 1 else 0 end) as has_right
FROM final_round a
LEFT JOIN final_round b
ON a.plant = b.plant
AND abs(a._row - b._row) + abs(a._col - b._col) = 1
GROUP BY a._row, a._col, a.plant, a.region_seq
), plant_neighbour AS (
SELECT a.plant, a.region_seq, count(*) as cnt
FROM final_round a, final_round b
WHERE a.plant = b.plant
AND ((a._row + 1 = b._row AND a._col = b._col) OR (a._col + 1 = b._col AND a._row = b._row))
GROUP BY a.plant, a.region_seq
), plant_straight AS (
SELECT a.plant, a.region_seq,
sum(case when a._col + 1 = b._col AND a.has_up = 0 AND b.has_up = 0 then 1 else 0 end) as up_trim,
sum(case when a._col + 1 = b._col AND a.has_down = 0 AND b.has_down = 0 then 1 else 0 end) as down_trim,
sum(case when a._row + 1 = b._row AND a.has_left = 0 AND b.has_left = 0 then 1 else 0 end) as left_trim,
sum(case when a._row + 1 = b._row AND a.has_right = 0 AND b.has_right = 0 then 1 else 0 end) as right_trim
FROM plant_stats a, plant_stats b
WHERE a.plant = b.plant
AND ((a._row + 1 = b._row AND a._col = b._col) OR (a._col + 1 = b._col AND a._row = b._row))
GROUP BY a.plant, a.region_seq
)
SELECT sum(a.cnt * (a.cnt * 4 - coalesce(b.cnt, 0) * 2 - coalesce(c.up_trim, 0) - coalesce(c.down_trim, 0) - coalesce(c.left_trim, 0) - coalesce(c.right_trim, 0)))
FROM plant_cnt a
LEFT JOIN plant_neighbour b
ON a.plant = b.plant
AND a.region_seq = b.region_seq
LEFT JOIN plant_straight c
ON a.plant = c.plant
AND a.region_seq = c.region_seq;