From 77c7b4afafa7e1e4b488da64fa6588d2995a956f Mon Sep 17 00:00:00 2001
From: Uros Majstorovic <majstor@majstor.org>
Date: Fri, 3 May 2024 23:32:44 +0200
Subject: fixed conn new / key checker interface; fixed memory leak when conn
 create fails

---
 ecp/src/ecp/core.c | 33 +++++++++++++++++----------------
 ecp/src/ecp/core.h |  6 +++---
 2 files changed, 20 insertions(+), 19 deletions(-)

(limited to 'ecp')

diff --git a/ecp/src/ecp/core.c b/ecp/src/ecp/core.c
index a049bb7..7d61def 100644
--- a/ecp/src/ecp/core.c
+++ b/ecp/src/ecp/core.c
@@ -912,10 +912,10 @@ static int conn_shkey_set(ECPConnection *conn, unsigned char s_idx, unsigned cha
     return ECP_OK;
 }
 
-ECPConnection *ecp_conn_new_inb(ECPSocket *sock, unsigned char ctype) {
+ECPConnection *ecp_conn_new_inb(ECPSocket *sock, ECPConnection *parent, unsigned char ctype) {
     ECPContext *ctx = sock->ctx;
 
-    if (ctx->conn_new) return ctx->conn_new(sock, ctype);
+    if (ctx->conn_new) return ctx->conn_new(sock, parent, ctype);
     return NULL;
 }
 
@@ -1106,17 +1106,12 @@ int ecp_conn_create_outb(ECPConnection *conn, ECPConnection *parent, ECPNode *no
     conn->key[conn->key_curr] = key;
 
     rv = ecp_conn_create(conn, parent);
-    if (rv) return rv;
-
-    return ECP_OK;
+    return rv;
 }
 
 void ecp_conn_destroy(ECPConnection *conn) {
 #ifdef ECP_WITH_VCONN
-    if (conn->parent) {
-        ecp_conn_refcount_dec(conn->parent);
-        conn->parent = NULL;
-    }
+    if (conn->parent) ecp_conn_refcount_dec(conn->parent);
 #endif
 
     ecp_ext_conn_destroy(conn);
@@ -1292,7 +1287,10 @@ int _ecp_conn_open(ECPConnection *conn, ECPConnection *parent, ECPNode *node, in
     ssize_t _rv;
 
     rv = ecp_conn_create_outb(conn, parent, node);
-    if (rv) return rv;
+    if (rv) {
+        ecp_conn_free(conn);
+        return rv;
+    }
 
     rv = ecp_conn_insert(conn);
     if (rv) {
@@ -1894,18 +1892,21 @@ ssize_t ecp_handle_open_req(ECPSocket *sock, ECPConnection *parent, unsigned cha
         if (memcmp(vbox_buf, public_buf, ECP_SIZE_ECDH_PUB) != 0) return ECP_ERR_VBOX;
         if (memcmp(vbox_buf+ECP_SIZE_ECDH_PUB, &key_curr.public, ECP_SIZE_ECDH_PUB) != 0) return ECP_ERR_VBOX;
         rkey_perma.valid = 1;
+    }
 
-        if (sock->ctx->key_checker) {
-            _rv = sock->ctx->key_checker(sock, ctype, &rkey_perma.public);
-            if (_rv) return _rv;
-        }
+    if (sock->ctx->key_checker) {
+        _rv = sock->ctx->key_checker(sock, parent, ctype, rkey_perma.valid ? &rkey_perma.public : NULL);
+        if (!_rv) return ECP_ERR_VBOX;
     }
 
-    conn = ecp_conn_new_inb(sock, ctype);
+    conn = ecp_conn_new_inb(sock, parent, ctype);
     if (conn == NULL) return ECP_ERR_ALLOC;
 
     _rv = ecp_conn_create_inb(conn, parent, s_idx, c_idx, (ecp_ecdh_public_t *)public_buf, rkey_perma.valid ? &rkey_perma : NULL, shkey);
-    if (_rv) return _rv;
+    if (_rv) {
+        ecp_conn_free(conn);
+        return _rv;
+    }
 
     _rv = ecp_conn_insert(conn);
     if (_rv) {
diff --git a/ecp/src/ecp/core.h b/ecp/src/ecp/core.h
index 62c1de3..5d8cd47 100644
--- a/ecp/src/ecp/core.h
+++ b/ecp/src/ecp/core.h
@@ -227,9 +227,9 @@ struct ECPFragIter;
 typedef int (*ecp_conn_expired_t) (struct ECPConnection *conn, ecp_sts_t now);
 
 typedef void (*ecp_err_handler_t) (struct ECPConnection *conn, unsigned char mtype, int err);
-typedef struct ECPConnection * (*ecp_conn_new_t) (struct ECPSocket *sock, unsigned char type);
+typedef struct ECPConnection * (*ecp_conn_new_t) (struct ECPSocket *sock, struct ECPConnection *parent, unsigned char type);
 typedef void (*ecp_conn_free_t) (struct ECPConnection *conn);
-typedef int (*ecp_key_checker_t) (struct ECPSocket *sock, unsigned char ctype, ecp_ecdh_public_t *pub);
+typedef int (*ecp_key_checker_t) (struct ECPSocket *sock, struct ECPConnection *parent, unsigned char ctype, ecp_ecdh_public_t *pub);
 
 typedef ssize_t (*ecp_msg_handler_t) (struct ECPConnection *conn, ecp_seq_t seq, unsigned char mtype, unsigned char *msg, size_t msg_size, struct ECP2Buffer *b);
 typedef int (*ecp_open_handler_t) (struct ECPConnection *conn, struct ECP2Buffer *b);
@@ -401,7 +401,7 @@ void ecp_sock_expire(ECPSocket *sock, ecp_conn_expired_t conn_expired);
 void ecp_atag_gen(ECPSocket *sock, unsigned char *public_buf, unsigned char *atag, ecp_nonce_t *nonce);
 int ecp_cookie_verify(ECPSocket *sock, unsigned char *cookie, unsigned char *public_buf);
 
-ECPConnection *ecp_conn_new_inb(ECPSocket *sock, unsigned char ctype);
+ECPConnection *ecp_conn_new_inb(ECPSocket *sock, ECPConnection *parent, unsigned char ctype);
 void ecp_conn_init(ECPConnection *conn, ECPSocket *sock, unsigned char ctype);
 void ecp_conn_set_flags(ECPConnection *conn, unsigned char flags);
 void ecp_conn_clr_flags(ECPConnection *conn, unsigned char flags);
-- 
cgit v1.2.3