Use queue for tiles to probe

This allows updating and searching until no more inferences can be
made, resulting in a ~5% improvement (also better performance).
This commit is contained in:
Camden Dixie O'Brien 2025-03-22 13:33:38 +00:00
parent 5b20a3e97b
commit c5687d5b06

173
main.c
View File

@ -14,6 +14,7 @@
#define NRUNS 10000 #define NRUNS 10000
#define MAX_ADJ 8 #define MAX_ADJ 8
#define QUEUE_MAX 128
typedef enum { typedef enum {
FOUND_MINE, FOUND_MINE,
@ -22,31 +23,71 @@ typedef enum {
FOUND_CONTRADICTION, FOUND_CONTRADICTION,
} update_res_t; } update_res_t;
typedef struct {
int x, y;
} coord_t;
typedef struct {
coord_t buf[QUEUE_MAX], *start, *end;
} queue_t;
typedef struct { typedef struct {
int mines, unknown; int mines, unknown;
queue_t queue;
puzz_t field; puzz_t field;
} state_t; } state_t;
static bool empty(const queue_t *queue)
{
return queue->start == queue->end;
}
static void enqueue(queue_t *queue, int x, int y)
{
queue->end->x = x;
queue->end->y = y;
++queue->end;
if (queue->end == queue->buf + QUEUE_MAX)
queue->end = queue->buf;
assert(queue->end != queue->start);
}
static void dequeue(queue_t *queue, int *x_out, int *y_out)
{
assert(queue->start != queue->end);
*x_out = queue->start->x;
*y_out = queue->start->y;
++queue->start;
if (queue->start == queue->buf + QUEUE_MAX)
queue->start = queue->buf;
}
static void init(state_t *state)
{
state->mines = 0;
state->unknown = WIDTH * HEIGHT;
state->queue.start = state->queue.end = state->queue.buf;
memset(state->field, UNKNOWN, sizeof(puzz_t));
}
static void dup(state_t *orig, state_t *copy)
{
memcpy(copy, orig, sizeof(state_t));
const int startpos = orig->queue.start - orig->queue.buf;
const int endpos = orig->queue.end - orig->queue.buf;
copy->queue.start = copy->queue.buf + startpos;
copy->queue.end = copy->queue.buf + endpos;
}
static void setadj(puzz_t field, int x, int y, uint8_t from, uint8_t to) static void setadj(puzz_t field, int x, int y, uint8_t from, uint8_t to)
{ {
FORADJ(x, y, xi, yi) FORADJ(x, y, xi, yi)
field[xi][yi] = field[xi][yi] == from ? to : field[xi][yi]; field[xi][yi] = field[xi][yi] == from ? to : field[xi][yi];
} }
static void getadj(puzz_t field, int *x, int *y, uint8_t val) static update_res_t update(state_t *state)
{
FORADJ(*x, *y, xi, yi)
{
if (field[xi][yi] == val) {
*x = xi;
*y = yi;
return;
}
}
assert(false);
}
static update_res_t update(state_t *state, int *x_out, int *y_out)
{ {
state->mines = state->unknown = 0; state->mines = state->unknown = 0;
@ -83,9 +124,12 @@ static update_res_t update(state_t *state, int *x_out, int *y_out)
} }
if (mines == state->field[x][y]) { if (mines == state->field[x][y]) {
getadj(state->field, &x, &y, UNKNOWN); FORADJ(x, y, xi, yi)
*x_out = x; {
*y_out = y; if (state->field[xi][yi] == UNKNOWN)
enqueue(&state->queue, xi, yi);
}
setadj(state->field, x, y, UNKNOWN, SAFE);
return FOUND_SAFE; return FOUND_SAFE;
} }
} }
@ -94,63 +138,43 @@ static update_res_t update(state_t *state, int *x_out, int *y_out)
return FOUND_NOTHING; return FOUND_NOTHING;
} }
static update_res_t static update_res_t search_at(state_t *state, int x, int y)
search_at(state_t *state, int x, int y, int *x_out, int *y_out)
{ {
update_res_t res; update_res_t res;
state_t with_mine; state_t with_mine;
memcpy(&with_mine, state, sizeof(state_t)); dup(state, &with_mine);
with_mine.field[x][y] = MINE; with_mine.field[x][y] = MINE;
do { do {
int res_x, res_y; if ((res = update(&with_mine)) == FOUND_CONTRADICTION) {
switch (res = update(&with_mine, &res_x, &res_y)) { state->field[x][y] = SAFE;
case FOUND_MINE: enqueue(&state->queue, x, y);
break;
case FOUND_SAFE:
with_mine.field[res_x][res_y] = SAFE;
break;
case FOUND_NOTHING:
break;
case FOUND_CONTRADICTION:
*x_out = x;
*y_out = y;
return FOUND_SAFE; return FOUND_SAFE;
} }
} while (res == FOUND_MINE); } while (res != FOUND_NOTHING);
state_t with_safe; state_t with_safe;
memcpy(&with_safe, state, sizeof(state_t)); dup(state, &with_safe);
with_safe.field[x][y] = SAFE; with_safe.field[x][y] = SAFE;
do { do {
int res_x, res_y; if ((res = update(&with_safe)) == FOUND_CONTRADICTION) {
switch (res = update(&with_safe, &res_x, &res_y)) {
case FOUND_MINE:
break;
case FOUND_SAFE:
with_safe.field[res_x][res_y] = SAFE;
break;
case FOUND_NOTHING:
break;
case FOUND_CONTRADICTION:
state->field[x][y] = MINE; state->field[x][y] = MINE;
return FOUND_MINE; return FOUND_MINE;
} }
} while (res == FOUND_MINE); } while (res != FOUND_NOTHING);
return FOUND_NOTHING; return FOUND_NOTHING;
} }
static update_res_t search(state_t *state, int *x_out, int *y_out) static update_res_t search(state_t *state)
{ {
for (int y = 0; y < HEIGHT; ++y) { for (int y = 0; y < HEIGHT; ++y) {
for (int x = 0; x < WIDTH; ++x) { for (int x = 0; x < WIDTH; ++x) {
if (state->field[x][y] != UNKNOWN) if (state->field[x][y] != UNKNOWN)
continue; continue;
if (countadj(state->field, x, y, UNKNOWN) != MAX_ADJ) { if (countadj(state->field, x, y, UNKNOWN) != MAX_ADJ) {
update_res_t res = search_at(state, x, y, x_out, y_out); update_res_t res = search_at(state, x, y);
if (res == FOUND_MINE || res == FOUND_SAFE) if (res != FOUND_NOTHING)
return res; return res;
} }
} }
@ -158,44 +182,53 @@ static update_res_t search(state_t *state, int *x_out, int *y_out)
return FOUND_NOTHING; return FOUND_NOTHING;
} }
static status_t solve(int *turns_out) static status_t solve(int *probes_out)
{ {
state_t state = { .mines = 0, .unknown = WIDTH * HEIGHT }; state_t state;
memset(state.field, UNKNOWN, sizeof(puzz_t)); init(&state);
int x = rand() % WIDTH; enqueue(&state.queue, rand() % WIDTH, rand() % HEIGHT);
int y = rand() % HEIGHT;
status_t status; int probes = 0;
int turns = 0;
do { do {
++turns; while (!empty(&state.queue)) {
if (state.field[x][y] != MINE int x, y;
&& (status = probe(x, y, state.field)) == DEAD) dequeue(&state.queue, &x, &y);
break; if (state.field[x][y] != SAFE && state.field[x][y] != UNKNOWN)
continue;
++probes;
if (probe(x, y, state.field) == DEAD) {
*probes_out = probes;
return DEAD;
}
}
update_res_t res; update_res_t res;
do { do {
res = update(&state, &x, &y); res = update(&state);
if (res == FOUND_NOTHING) if (res == FOUND_NOTHING)
res = search(&state, &x, &y); res = search(&state);
} while (res == FOUND_MINE); } while (res != FOUND_NOTHING);
if (check(state.field) != OK) { if (check(state.field) != OK) {
printf("Incorrect inference! State:\n"); printf("Incorrect inference!\n");
print(state.field); print(state.field);
printsoln(); printsoln();
return INCORRECT; return INCORRECT;
} }
if (res == FOUND_NOTHING) { if (empty(&state.queue)) {
int x, y;
do {
x = rand() % WIDTH; x = rand() % WIDTH;
y = rand() % HEIGHT; y = rand() % HEIGHT;
} while (state.field[x][y] == MINE);
enqueue(&state.queue, x, y);
} }
} while (state.mines < NMINES || state.unknown > 0); } while (state.mines < NMINES || state.unknown > 0);
*turns_out = turns; *probes_out = probes;
return status; return OK;
} }
int main(void) int main(void)
@ -207,12 +240,12 @@ int main(void)
} }
srand(tv.tv_usec); srand(tv.tv_usec);
int nsolved = 0, nfirst = 0, nincorrect = 0, turns; int nsolved = 0, nfirst = 0, nincorrect = 0, probes;
for (int i = 0; i < NRUNS; ++i) { for (int i = 0; i < NRUNS; ++i) {
gen(); gen();
switch (solve(&turns)) { switch (solve(&probes)) {
case DEAD: case DEAD:
if (turns == 1) if (probes == 1)
++nfirst; ++nfirst;
break; break;
case OK: case OK: