Move SessionManager and related classes to own file

This commit is contained in:
Camden Dixie O'Brien 2025-02-28 12:44:40 +00:00
parent 90a6eb3ab4
commit fe8b1ef977
3 changed files with 168 additions and 167 deletions

View File

@ -50,172 +50,5 @@ namespace StudySystemClient {
Thread.usleep(1000 * TASK_PERIOD_MS); Thread.usleep(1000 * TASK_PERIOD_MS);
} }
} }
}
private class SessionManager {
public delegate void ReceiveCallback(uint8[] msg);
private const uint INIT_RECONNECT_WAIT_MS = 500;
private const uint MAX_RECONNECT_WAIT_MS = 60000;
private const double RECONNECT_BACKOFF = 1.6;
private SessionFactory session_factory;
private ReceiveCallback receive_callback;
private Session? session;
private AsyncQueue<OutgoingMessage> queue;
private uint reconnect_wait_ms;
public SessionManager(SessionFactory session_factory,
owned ReceiveCallback receive_callback) {
this.session_factory = session_factory;
this.receive_callback = (owned) receive_callback;
this.session = null;
queue = new AsyncQueue<OutgoingMessage>();
reconnect_wait_ms = INIT_RECONNECT_WAIT_MS;
}
public void send(uint8[] msg) {
queue.push(new OutgoingMessage(msg));
}
public void task() {
if (session != null) {
var failed_msg = session.task(queue);
if (failed_msg != null)
handle_failed_msg(failed_msg);
} else {
try_start_session();
}
}
private void handle_failed_msg(OutgoingMessage msg) {
msg.has_failed();
if (msg.should_retry())
queue.push(msg);
session = null;
}
private void try_start_session() {
try {
session = session_factory.start_session();
session.received.connect(
(msg) => receive_callback(msg));
reconnect_wait_ms = INIT_RECONNECT_WAIT_MS;
} catch (Error _) {
Thread.usleep(1000 * reconnect_wait_ms);
update_reconnect_wait();
}
}
private void update_reconnect_wait() {
var new_wait = RECONNECT_BACKOFF * reconnect_wait_ms;
if (new_wait < MAX_RECONNECT_WAIT_MS)
reconnect_wait_ms = (uint)new_wait;
else
reconnect_wait_ms = MAX_RECONNECT_WAIT_MS;
}
}
private class SessionFactory {
private const string CA_FILENAME = "/ca.pem";
private const string CERT_FILENAME = "/client.pem";
private const uint TIMEOUT_S = 1;
private InetSocketAddress host;
private TlsCertificate cert;
private TlsDatabase ca_db;
public SessionFactory(InetAddress host_addr, uint16 host_port,
string cert_dir) throws Error {
host = new InetSocketAddress(host_addr, host_port);
var cert_path = cert_dir + CERT_FILENAME;
cert = new TlsCertificate.from_file(cert_path);
var ca_path = cert_dir + CA_FILENAME;
var db_type = TlsBackend.get_default().get_file_database_type();
ca_db = Object.new(db_type, "anchors", ca_path) as TlsDatabase;
}
public Session start_session() throws Error {
var plain_client = new SocketClient();
plain_client.set_timeout(TIMEOUT_S);
var plain_connection = plain_client.connect(host);
var connection = TlsClientConnection.new(plain_connection, host);
connection.set_database(ca_db);
connection.set_certificate(cert);
connection.handshake();
return new Session(connection);
}
}
private class Session {
public signal void received(uint8[] msg);
private const uint MAX_BATCH_SIZE = 10;
private const uint MAX_MSG_LEN = 1024;
private TlsClientConnection connection;
public Session(TlsClientConnection connection) {
this.connection = connection;
}
public OutgoingMessage? task(AsyncQueue<OutgoingMessage> queue) {
for (int i = 0; i < MAX_BATCH_SIZE; ++i) {
if (queue.length() == 0)
break;
var msg = queue.pop();
var success = true;
success &= send(msg);
success &= receive();
if (!success)
return msg;
}
return null;
}
private bool send(OutgoingMessage msg) {
try {
size_t written;
connection.output_stream.write_all(msg.content, out written);
return true;
} catch (IOError _) {
return false;
}
}
private bool receive() {
try {
var buffer = new uint8[MAX_MSG_LEN];
var len = connection.input_stream.read(buffer);
if (len <= 0)
return false;
received(buffer[0:len]);
return true;
} catch (IOError _) {
return false;
}
}
}
private class OutgoingMessage {
public uint8[] content { get; private set; }
private const uint MAX_FAIL_COUNT = 4;
private uint fail_count;
public OutgoingMessage(owned uint8[] content) {
this.content = (owned)content;
fail_count = 0;
}
public void has_failed() {
++fail_count;
}
public bool should_retry() {
return fail_count < MAX_FAIL_COUNT;
}
} }
} }

View File

@ -14,6 +14,7 @@ lib = library(
'connection.vala', 'connection.vala',
'der.vala', 'der.vala',
'main_window.vala', 'main_window.vala',
'session_manager.vala',
) + resources, ) + resources,
dependencies: [gtk_dep], dependencies: [gtk_dep],
vala_vapi: 'study-system-client.vapi', vala_vapi: 'study-system-client.vapi',

View File

@ -0,0 +1,167 @@
namespace StudySystemClient {
public class SessionManager {
public delegate void ReceiveCallback(uint8[] msg);
private const uint INIT_RECONNECT_WAIT_MS = 500;
private const uint MAX_RECONNECT_WAIT_MS = 60000;
private const double RECONNECT_BACKOFF = 1.6;
private SessionFactory session_factory;
private ReceiveCallback receive_callback;
private Session? session;
private AsyncQueue<OutgoingMessage> queue;
private uint reconnect_wait_ms;
public SessionManager(SessionFactory session_factory,
owned ReceiveCallback receive_callback) {
this.session_factory = session_factory;
this.receive_callback = (owned) receive_callback;
this.session = null;
queue = new AsyncQueue<OutgoingMessage>();
reconnect_wait_ms = INIT_RECONNECT_WAIT_MS;
}
public void send(uint8[] msg) {
queue.push(new OutgoingMessage(msg));
}
public void task() {
if (session != null) {
var failed_msg = session.task(queue);
if (failed_msg != null)
handle_failed_msg(failed_msg);
} else {
try_start_session();
}
}
private void handle_failed_msg(OutgoingMessage msg) {
msg.has_failed();
if (msg.should_retry())
queue.push(msg);
session = null;
}
private void try_start_session() {
try {
session = session_factory.start_session();
session.received.connect(
(msg) => receive_callback(msg));
reconnect_wait_ms = INIT_RECONNECT_WAIT_MS;
} catch (Error _) {
Thread.usleep(1000 * reconnect_wait_ms);
update_reconnect_wait();
}
}
private void update_reconnect_wait() {
var new_wait = RECONNECT_BACKOFF * reconnect_wait_ms;
if (new_wait < MAX_RECONNECT_WAIT_MS)
reconnect_wait_ms = (uint)new_wait;
else
reconnect_wait_ms = MAX_RECONNECT_WAIT_MS;
}
}
public class SessionFactory {
private const string CA_FILENAME = "/ca.pem";
private const string CERT_FILENAME = "/client.pem";
private const uint TIMEOUT_S = 1;
private InetSocketAddress host;
private TlsCertificate cert;
private TlsDatabase ca_db;
public SessionFactory(InetAddress host_addr, uint16 host_port,
string cert_dir) throws Error {
host = new InetSocketAddress(host_addr, host_port);
var cert_path = cert_dir + CERT_FILENAME;
cert = new TlsCertificate.from_file(cert_path);
var ca_path = cert_dir + CA_FILENAME;
var db_type = TlsBackend.get_default().get_file_database_type();
ca_db = Object.new(db_type, "anchors", ca_path) as TlsDatabase;
}
internal Session start_session() throws Error {
var plain_client = new SocketClient();
plain_client.set_timeout(TIMEOUT_S);
var plain_connection = plain_client.connect(host);
var connection = TlsClientConnection.new(plain_connection, host);
connection.set_database(ca_db);
connection.set_certificate(cert);
connection.handshake();
return new Session(connection);
}
}
private class Session {
public signal void received(uint8[] msg);
private const uint MAX_BATCH_SIZE = 10;
private const uint MAX_MSG_LEN = 1024;
private TlsClientConnection connection;
public Session(TlsClientConnection connection) {
this.connection = connection;
}
public OutgoingMessage? task(AsyncQueue<OutgoingMessage> queue) {
for (int i = 0; i < MAX_BATCH_SIZE; ++i) {
if (queue.length() == 0)
break;
var msg = queue.pop();
var success = true;
success &= send(msg);
success &= receive();
if (!success)
return msg;
}
return null;
}
private bool send(OutgoingMessage msg) {
try {
size_t written;
connection.output_stream.write_all(msg.content, out written);
return true;
} catch (IOError _) {
return false;
}
}
private bool receive() {
try {
var buffer = new uint8[MAX_MSG_LEN];
var len = connection.input_stream.read(buffer);
if (len <= 0)
return false;
received(buffer[0:len]);
return true;
} catch (IOError _) {
return false;
}
}
}
private class OutgoingMessage {
public uint8[] content { get; private set; }
private const uint MAX_FAIL_COUNT = 4;
private uint fail_count;
public OutgoingMessage(owned uint8[] content) {
this.content = (owned)content;
fail_count = 0;
}
public void has_failed() {
++fail_count;
}
public bool should_retry() {
return fail_count < MAX_FAIL_COUNT;
}
}
}