From dd3cabf11bfb77357c6dc685be228effd1a70aa1 Mon Sep 17 00:00:00 2001 From: Camden Dixie O'Brien Date: Sun, 9 Mar 2025 10:41:43 +0000 Subject: [PATCH] Use queue for tiles to probe This allows updating and searching until no more inferences can be made, resulting in a ~5% improvement (also better performance). --- main.c | 179 ++++++++++++++++++++++++++++++++++----------------------- 1 file changed, 106 insertions(+), 73 deletions(-) diff --git a/main.c b/main.c index 9976c54..9930d1b 100644 --- a/main.c +++ b/main.c @@ -14,6 +14,7 @@ #define NRUNS 10000 #define MAX_ADJ 8 +#define QUEUE_MAX 128 typedef enum { FOUND_MINE, @@ -22,31 +23,71 @@ typedef enum { FOUND_CONTRADICTION, } update_res_t; +typedef struct { + int x, y; +} coord_t; + +typedef struct { + coord_t buf[QUEUE_MAX], *start, *end; +} queue_t; + typedef struct { int mines, unknown; + queue_t queue; puzz_t field; } state_t; +static bool empty(const queue_t *queue) +{ + return queue->start == queue->end; +} + +static void enqueue(queue_t *queue, int x, int y) +{ + queue->end->x = x; + queue->end->y = y; + + ++queue->end; + if (queue->end == queue->buf + QUEUE_MAX) + queue->end = queue->buf; + assert(queue->end != queue->start); +} + +static void dequeue(queue_t *queue, int *x_out, int *y_out) +{ + assert(queue->start != queue->end); + *x_out = queue->start->x; + *y_out = queue->start->y; + + ++queue->start; + if (queue->start == queue->buf + QUEUE_MAX) + queue->start = queue->buf; +} + +static void init(state_t *state) +{ + state->mines = 0; + state->unknown = WIDTH * HEIGHT; + state->queue.start = state->queue.end = state->queue.buf; + memset(state->field, UNKNOWN, sizeof(puzz_t)); +} + +static void dup(state_t *orig, state_t *copy) +{ + memcpy(copy, orig, sizeof(state_t)); + const int startpos = orig->queue.start - orig->queue.buf; + const int endpos = orig->queue.end - orig->queue.buf; + copy->queue.start = copy->queue.buf + startpos; + copy->queue.end = copy->queue.buf + endpos; +} + static void setadj(puzz_t field, int x, int y, uint8_t from, uint8_t to) { FORADJ(x, y, xi, yi) - field[xi][yi] = field[xi][yi] == from ? to : field[xi][yi]; + field[xi][yi] = field[xi][yi] == from ? to : field[xi][yi]; } -static void getadj(puzz_t field, int *x, int *y, uint8_t val) -{ - FORADJ(*x, *y, xi, yi) - { - if (field[xi][yi] == val) { - *x = xi; - *y = yi; - return; - } - } - assert(false); -} - -static update_res_t update(state_t *state, int *x_out, int *y_out) +static update_res_t update(state_t *state) { state->mines = state->unknown = 0; @@ -83,9 +124,12 @@ static update_res_t update(state_t *state, int *x_out, int *y_out) } if (mines == state->field[x][y]) { - getadj(state->field, &x, &y, UNKNOWN); - *x_out = x; - *y_out = y; + FORADJ(x, y, xi, yi) + { + if (state->field[xi][yi] == UNKNOWN) + enqueue(&state->queue, xi, yi); + } + setadj(state->field, x, y, UNKNOWN, SAFE); return FOUND_SAFE; } } @@ -94,63 +138,43 @@ static update_res_t update(state_t *state, int *x_out, int *y_out) return FOUND_NOTHING; } -static update_res_t -search_at(state_t *state, int x, int y, int *x_out, int *y_out) +static update_res_t search_at(state_t *state, int x, int y) { update_res_t res; state_t with_mine; - memcpy(&with_mine, state, sizeof(state_t)); + dup(state, &with_mine); with_mine.field[x][y] = MINE; do { - int res_x, res_y; - switch (res = update(&with_mine, &res_x, &res_y)) { - case FOUND_MINE: - break; - case FOUND_SAFE: - with_mine.field[res_x][res_y] = SAFE; - break; - case FOUND_NOTHING: - break; - case FOUND_CONTRADICTION: - *x_out = x; - *y_out = y; + if ((res = update(&with_mine)) == FOUND_CONTRADICTION) { + state->field[x][y] = SAFE; + enqueue(&state->queue, x, y); return FOUND_SAFE; } - } while (res == FOUND_MINE); + } while (res != FOUND_NOTHING); state_t with_safe; - memcpy(&with_safe, state, sizeof(state_t)); + dup(state, &with_safe); with_safe.field[x][y] = SAFE; do { - int res_x, res_y; - switch (res = update(&with_safe, &res_x, &res_y)) { - case FOUND_MINE: - break; - case FOUND_SAFE: - with_safe.field[res_x][res_y] = SAFE; - break; - case FOUND_NOTHING: - break; - case FOUND_CONTRADICTION: + if ((res = update(&with_safe)) == FOUND_CONTRADICTION) { state->field[x][y] = MINE; return FOUND_MINE; } - } while (res == FOUND_MINE); + } while (res != FOUND_NOTHING); return FOUND_NOTHING; } -static update_res_t search(state_t *state, int *x_out, int *y_out) +static update_res_t search(state_t *state) { for (int y = 0; y < HEIGHT; ++y) { for (int x = 0; x < WIDTH; ++x) { if (state->field[x][y] != UNKNOWN) continue; - if (countadj(state->field, x, y, UNKNOWN) != MAX_ADJ) { - update_res_t res = search_at(state, x, y, x_out, y_out); - if (res == FOUND_MINE || res == FOUND_SAFE) + update_res_t res = search_at(state, x, y); + if (res != FOUND_NOTHING) return res; } } @@ -158,44 +182,53 @@ static update_res_t search(state_t *state, int *x_out, int *y_out) return FOUND_NOTHING; } -static status_t solve(int *turns_out) +static status_t solve(int *probes_out) { - state_t state = { .mines = 0, .unknown = WIDTH * HEIGHT }; - memset(state.field, UNKNOWN, sizeof(puzz_t)); + state_t state; + init(&state); - int x = rand() % WIDTH; - int y = rand() % HEIGHT; + enqueue(&state.queue, rand() % WIDTH, rand() % HEIGHT); - status_t status; - int turns = 0; + int probes = 0; do { - ++turns; - if (state.field[x][y] != MINE - && (status = probe(x, y, state.field)) == DEAD) - break; + while (!empty(&state.queue)) { + int x, y; + dequeue(&state.queue, &x, &y); + if (state.field[x][y] != SAFE && state.field[x][y] != UNKNOWN) + continue; + ++probes; + if (probe(x, y, state.field) == DEAD) { + *probes_out = probes; + return DEAD; + } + } update_res_t res; do { - res = update(&state, &x, &y); + res = update(&state); if (res == FOUND_NOTHING) - res = search(&state, &x, &y); - } while (res == FOUND_MINE); + res = search(&state); + } while (res != FOUND_NOTHING); if (check(state.field) != OK) { - printf("Incorrect inference! State:\n"); + printf("Incorrect inference!\n"); print(state.field); printsoln(); return INCORRECT; } - if (res == FOUND_NOTHING) { - x = rand() % WIDTH; - y = rand() % HEIGHT; + if (empty(&state.queue)) { + int x, y; + do { + x = rand() % WIDTH; + y = rand() % HEIGHT; + } while (state.field[x][y] == MINE); + enqueue(&state.queue, x, y); } } while (state.mines < NMINES || state.unknown > 0); - *turns_out = turns; - return status; + *probes_out = probes; + return OK; } int main(void) @@ -207,12 +240,12 @@ int main(void) } srand(tv.tv_usec); - int nsolved = 0, nfirst = 0, nincorrect = 0, turns; + int nsolved = 0, nfirst = 0, nincorrect = 0, probes; for (int i = 0; i < NRUNS; ++i) { gen(); - switch (solve(&turns)) { + switch (solve(&probes)) { case DEAD: - if (turns == 1) + if (probes == 1) ++nfirst; break; case OK: