Part One
.##. (1,3) (2) (2,3) (0,2) (0,1) {3,5,4,7}初始按钮状态都为关闭,括号中的为按钮的组合,求出达到目标状态的最小次数
通过位运算,来实现不同按钮的组合。并确保这些组合最后的结果,与目标一致。
with origin as (
SELECT row_number() over () as rn,
substring(split_part(line, '] ', 1), 2) as target,
split_part(split_part(line, '] ', 2), ' {', 1) as schematics
FROM lance_input
), split_schematics as (
SELECT rn, button, idx
FROM origin, regexp_split_to_table(schematics, ' ') with ordinality as x(button, idx)
), button_cnt as (
SELECT rn, count(*) as total
FROM split_schematics
GROUP BY rn
), power_seq as (
SELECT generate_series(1, power(2, max(total))::integer) - 1 as seq
FROM button_cnt
), cross_join as (
SELECT a.rn, substring(a.button, 2, length(a.button) - 2) as buttons, b.seq
FROM split_schematics a
JOIN button_cnt c
ON a.rn = c.rn
JOIN power_seq b
ON power(2, a.idx - 1)::integer & b.seq = power(2, a.idx - 1)::integer
AND power(2, c.total)::integer > b.seq
)
举例来说,例子中的第一条记录中,当seq为19时,包含1、2、5这三个button
select * From cross_join where rn = 1 and seq = 19;
rn | buttons | seq
----+---------+-----
1 | 3 | 19
1 | 1,3 | 19
1 | 0,2 | 19
(3 rows)
将这些button拆分后,进行聚合,如果为奇数则代表打开,偶数则为关闭。
button_agg as (
SELECT rn, seq, string_agg(button, '' order by button::integer) as buttons
FROM (
SELECT rn, seq, button
FROM cross_join a, regexp_split_to_table(buttons, ',') as x(button)
GROUP BY rn, seq, button
HAVING count(*) % 2 = 1
) t
GROUP BY rn, seq
)
select * From button_agg where seq = 19 and rn = 1;
rn | seq | buttons
----+-----+---------
1 | 19 | 012
(1 row)
最终找到与目标匹配的最小长度即可
origin_result as (
SELECT rn, string_agg((idx - 1)::text, '' order by idx) as buttons
FROM origin, regexp_split_to_table(target, '') with ordinality as x(button, idx)
WHERE x.button = '#'
GROUP BY rn
)
SELECT sum(len)
FROM(
SELECT b.rn, min(length(replace(b.seq::bit(32)::text, '0', ''))::integer) as len
FROM origin_result a
JOIN button_agg b
ON a.rn = b.rn
AND a.buttons = b.buttons
GROUP BY b.rn
) t;
Part Two
第二部分比较复杂,不再是判断开关状态,而是需要达到目标的次数。直接穷举法的话,次数太多了,单个上限差不多为300,按钮数最多到10左右,300^10是个惊人的数字。
如果将每个按钮的次数当作变量,则所有的结果可以当作这些变量的方程组,最终求解出来的结果,部分变量为确定值,部分自由变量则无解。求解方程组,求解方程组,可以使用高斯消元法。
高斯消元法(Gaussian elimination)
python中可以使用scipy等组件来求解方程,而这里通过大模型生成了高斯消元法的函数。函数的返回值,就是自由变量以及被求解变量的列表。
DROP TYPE IF EXISTS variable_solution CASCADE;
CREATE TYPE variable_solution AS (
var_index INTEGER, -- 变量编号 (x1, x2, ...)
constant DOUBLE PRECISION, -- 常数项
free_var_indices INTEGER[], -- 自由变量编号列表
free_var_coeffs DOUBLE PRECISION[] -- 对应的系数(解中要减去的)
);
-- 完整的方程组解
DROP TYPE IF EXISTS linear_system_solution CASCADE;
CREATE TYPE linear_system_solution AS (
status TEXT, -- 'unique' | 'infinite' | 'inconsistent'
rank INTEGER, -- 矩阵的秩
num_variables INTEGER, -- 变量总数
free_variables INTEGER[], -- 自由变量列表
solved_variables variable_solution[] -- 被求解变量及其解
);
-- ========== 主函数 ==========
CREATE OR REPLACE FUNCTION solve_linear_system(
matrix_a DOUBLE PRECISION[], -- 系数矩阵 (m×n)
vector_b DOUBLE PRECISION[], -- 常数向量 (m)
m INTEGER, -- 方程数
n INTEGER -- 变量数
)
RETURNS linear_system_solution
LANGUAGE plpgsql
AS $$
DECLARE
augmented DOUBLE PRECISION[];
cols INTEGER;
i INTEGER;
j INTEGER;
k INTEGER;
current_row INTEGER;
pivot_col INTEGER;
max_row INTEGER;
max_val DOUBLE PRECISION;
temp DOUBLE PRECISION;
factor DOUBLE PRECISION;
pivot DOUBLE PRECISION;
eps DOUBLE PRECISION := 1e-10;
pivot_cols INTEGER[] := ARRAY[]::INTEGER[];
free_cols INTEGER[] := ARRAY[]::INTEGER[];
result linear_system_solution;
var_sol variable_solution;
solutions variable_solution[] := ARRAY[]::variable_solution[];
row_idx INTEGER;
coeff DOUBLE PRECISION;
free_indices INTEGER[];
free_coeffs DOUBLE PRECISION[];
BEGIN
cols := n + 1;
-- 构建增广矩阵 [A|b]
augmented := ARRAY[]::DOUBLE PRECISION[];
FOR i IN 1..m LOOP
FOR j IN 1..n LOOP
augmented := array_append(augmented, matrix_a[(i-1)*n + j]);
END LOOP;
augmented := array_append(augmented, vector_b[i]);
END LOOP;
-- ========== 高斯-约旦消元 (转为 RREF) ==========
current_row := 1;
pivot_col := 1;
WHILE current_row <= m AND pivot_col <= n LOOP
-- 找最大主元
max_row := current_row;
max_val := ABS(augmented[(current_row-1)*cols + pivot_col]);
FOR i IN (current_row+1)..m LOOP
IF ABS(augmented[(i-1)*cols + pivot_col]) > max_val THEN
max_val := ABS(augmented[(i-1)*cols + pivot_col]);
max_row := i;
END IF;
END LOOP;
IF max_val < eps THEN
pivot_col := pivot_col + 1;
CONTINUE;
END IF;
pivot_cols := array_append(pivot_cols, pivot_col);
-- 交换行
IF max_row != current_row THEN
FOR j IN 1..cols LOOP
temp := augmented[(current_row-1)*cols + j];
augmented[(current_row-1)*cols + j] := augmented[(max_row-1)*cols + j];
augmented[(max_row-1)*cols + j] := temp;
END LOOP;
END IF;
-- 归一化主元行
pivot := augmented[(current_row-1)*cols + pivot_col];
FOR j IN pivot_col..cols LOOP
augmented[(current_row-1)*cols + j] := augmented[(current_row-1)*cols + j] / pivot;
END LOOP;
-- 消元(上下都消,得到 RREF)
FOR i IN 1..m LOOP
IF i != current_row THEN
factor := augmented[(i-1)*cols + pivot_col];
IF ABS(factor) > eps THEN
FOR j IN pivot_col..cols LOOP
augmented[(i-1)*cols + j] :=
augmented[(i-1)*cols + j] - factor * augmented[(current_row-1)*cols + j];
END LOOP;
END IF;
END IF;
END LOOP;
current_row := current_row + 1;
pivot_col := pivot_col + 1;
END LOOP;
-- 确定自由变量
FOR j IN 1..n LOOP
IF NOT (j = ANY(pivot_cols)) THEN
free_cols := array_append(free_cols, j);
END IF;
END LOOP;
-- 计算秩
result.rank := COALESCE(array_length(pivot_cols, 1), 0);
result.num_variables := n;
result.free_variables := free_cols;
-- ========== 检查是否无解 ==========
FOR i IN (result.rank + 1)..m LOOP
IF ABS(augmented[(i-1)*cols + cols]) > eps THEN
result.status := 'inconsistent';
result.solved_variables := ARRAY[]::variable_solution[];
RETURN result;
END IF;
END LOOP;
-- ========== 构建解 ==========
IF array_length(free_cols, 1) IS NULL OR array_length(free_cols, 1) = 0 THEN
result.status := 'unique';
ELSE
result.status := 'infinite';
END IF;
-- 为每个主元变量构建解
FOR i IN 1..result.rank LOOP
row_idx := i;
var_sol.var_index := pivot_cols[i];
var_sol.constant := augmented[(row_idx-1)*cols + cols];
-- 收集自由变量的系数
free_indices := ARRAY[]::INTEGER[];
free_coeffs := ARRAY[]::DOUBLE PRECISION[];
IF array_length(free_cols, 1) > 0 THEN
FOR j IN 1..array_length(free_cols, 1) LOOP
coeff := augmented[(row_idx-1)*cols + free_cols[j]];
IF ABS(coeff) > eps THEN
free_indices := array_append(free_indices, free_cols[j]);
free_coeffs := array_append(free_coeffs, coeff);
END IF;
END LOOP;
END IF;
var_sol.free_var_indices := free_indices;
var_sol.free_var_coeffs := free_coeffs;
solutions := array_append(solutions, var_sol);
END LOOP;
result.solved_variables := solutions;
RETURN result;
END;
$$;
求解方程(Solve the equation)
经过高斯消元后,自由变量的个数基本控制在3个以内,这样后续迭代的成本就会小很多。
INFO: free:1
INFO: free:<NULL>
INFO: free:<NULL>
INFO: free:1
INFO: free:2
INFO: free:2
INFO: free:3
因此通过对自由变量的穷举,再计算出被求解变量的值,最终可以找到最小的组合次数。
CREATE OR REPLACE FUNCTION calculate_joltage(buttons text[], targets integer[])
RETURNS INTEGER AS $$
DECLARE
ROW_LEN integer;
COL_LEN integer;
BUTTON integer[];
BUTTON_CURR integer[];
BUTTON_ALL integer[][];
MATRIX integer[];
LINEAR_RESULT linear_system_solution;
SOLVED variable_solution;
MAX_COEFF integer;
CURR_COEFF DOUBLE PRECISION;
MIN_VAL integer := 0;
VAR_SOLUTION integer[];
eps DOUBLE PRECISION := 1e-3;
BEGIN
ROW_LEN := array_length(targets, 1);
COL_LEN := array_length(buttons, 1);
MAX_COEFF := (SELECT max(x) FROM unnest(targets) as x);
FOR i IN 1..ROW_LEN LOOP
FOR j IN 1..COL_LEN LOOP
button := replace(replace(buttons[j], '(', '{'), ')', '}') :: INTEGER[];
IF i - 1 = ANY(button) THEN
matrix := matrix || 1;
ELSE
matrix := matrix || 0;
END IF;
END LOOP;
END LOOP;
FOR i IN 1..COL_LEN LOOP
BUTTON_CURR := '{}' :: integer[];
button := replace(replace(buttons[i], '(', '{'), ')', '}') :: INTEGER[];
FOR j IN 1..ROW_LEN LOOP
IF j - 1 = ANY(button) THEN
BUTTON_CURR := BUTTON_CURR || 1;
ELSE
BUTTON_CURR := BUTTON_CURR || 0;
END IF;
END LOOP;
BUTTON_ALL := BUTTON_ALL || array[BUTTON_CURR];
END LOOP;
LINEAR_RESULT := solve_linear_system(matrix, targets, ROW_LEN, COL_LEN);
raise info 'free:%', array_length(LINEAR_RESULT.free_variables, 1);
IF LINEAR_RESULT.status = 'unique' THEN
FOR i IN 1..array_length(LINEAR_RESULT.solved_variables, 1) LOOP
SOLVED := LINEAR_RESULT.solved_variables[i];
min_val := min_val + SOLVED.constant;
END LOOP;
ELSE
-- 最小解,初始化为最大系数 * 变量数
MIN_VAL := MAX_COEFF * COL_LEN;
<<outer_loop>>
FOR i IN 1..power(MAX_COEFF, array_length(LINEAR_RESULT.free_variables, 1)) LOOP
-- 初始化变量解
VAR_SOLUTION := array_fill(0, array[COL_LEN]);
FOR j IN 1..array_length(LINEAR_RESULT.free_variables, 1) LOOP
CURR_COEFF := floor((i % power(MAX_COEFF, j) :: BIGINT) / power(MAX_COEFF, j - 1));
VAR_SOLUTION[LINEAR_RESULT.free_variables[j]] := CURR_COEFF;
END LOOP;
FOR k IN 1..array_length(LINEAR_RESULT.solved_variables, 1) LOOP
SOLVED := LINEAR_RESULT.solved_variables[k];
CURR_COEFF := SOLVED.constant;
IF array_length(SOLVED.free_var_indices, 1) > 0 THEN
FOR m IN 1..array_length(SOLVED.free_var_indices, 1) LOOP
CURR_COEFF := CURR_COEFF - VAR_SOLUTION[SOLVED.free_var_indices[m]] * SOLVED.free_var_coeffs[m];
END LOOP;
VAR_SOLUTION[SOLVED.var_index] := round(CURR_COEFF);
END IF;
-- 不能为小数
IF abs(CURR_COEFF - round(CURR_COEFF)) > eps THEN
continue outer_loop;
END IF;
IF round(CURR_COEFF) >= 0 THEN
VAR_SOLUTION[SOLVED.var_index] := round(CURR_COEFF);
ELSE
continue outer_loop;
END IF;
END LOOP;
IF (SELECT sum(x) FROM unnest(VAR_SOLUTION) as x) > 0 THEN
MIN_VAL := least(MIN_VAL, (SELECT sum(x) FROM unnest(VAR_SOLUTION) as x));
END IF;
END LOOP;
END IF;
return MIN_VAL;
END;
$$ LANGUAGE plpgsql;
最终通过该函数求解出每一条记录的最小次数,最终汇总即可
with origin as (
SELECT row_number() over () as rn,
'{' || split_part(line, ' {', 2) as targets,
split_part(split_part(line, '] ', 2), ' {', 1) as schematics
FROM lance_input
), origin_input as (
SELECT rn, string_to_array(schematics, ' ') as buttons, targets :: integer[]
FROM origin
)
select sum(calculate_joltage(buttons, targets)) from origin_input;
方程计算过程中,需要注意的几点包括
- 变量解必须大于等于0
- 变量解不能带小数
- 虽然变量的系数可能为小数,但是最终的解会是整数