[PATCH 08/15] virnetsshsession: Pass in username via virNetSSHSessionNew rather than auth functions

Peter Krempa pkrempa at redhat.com
Tue Jan 17 16:20:33 UTC 2023


We only ever allow one username so there's no point passing it to each
authentication registration function. Additionally the only caller
(virNetClientNewLibSSH2) always passes a username so all the checks were
pointless.

Signed-off-by: Peter Krempa <pkrempa at redhat.com>
---
 src/rpc/virnetsocket.c     | 14 +++----
 src/rpc/virnetsshsession.c | 84 ++++++++++----------------------------
 src/rpc/virnetsshsession.h | 10 ++---
 3 files changed, 29 insertions(+), 79 deletions(-)

diff --git a/src/rpc/virnetsocket.c b/src/rpc/virnetsocket.c
index b9b7328f87..b248ce24dc 100644
--- a/src/rpc/virnetsocket.c
+++ b/src/rpc/virnetsocket.c
@@ -909,7 +909,7 @@ virNetSocketNewConnectLibSSH2(const char *host,
     }

     /* create ssh session context */
-    if (!(sess = virNetSSHSessionNew()))
+    if (!(sess = virNetSSHSessionNew(username)))
         goto error;

     /* set ssh session parameters */
@@ -946,17 +946,13 @@ virNetSocketNewConnectLibSSH2(const char *host,
         const char *authMethod = *authMethodNext;

         if (STRCASEEQ(authMethod, "keyboard-interactive")) {
-            ret = virNetSSHSessionAuthAddKeyboardAuth(sess, username, -1);
+            ret = virNetSSHSessionAuthAddKeyboardAuth(sess, -1);
         } else if (STRCASEEQ(authMethod, "password")) {
-            ret = virNetSSHSessionAuthAddPasswordAuth(sess,
-                                                      uri,
-                                                      username);
+            ret = virNetSSHSessionAuthAddPasswordAuth(sess, uri);
         } else if (STRCASEEQ(authMethod, "privkey")) {
-            ret = virNetSSHSessionAuthAddPrivKeyAuth(sess,
-                                                     username,
-                                                     privkey);
+            ret = virNetSSHSessionAuthAddPrivKeyAuth(sess, privkey);
         } else if (STRCASEEQ(authMethod, "agent")) {
-            ret = virNetSSHSessionAuthAddAgentAuth(sess, username);
+            ret = virNetSSHSessionAuthAddAgentAuth(sess);
         } else {
             virReportError(VIR_ERR_INVALID_ARG,
                            _("Invalid authentication method: '%s'"),
diff --git a/src/rpc/virnetsshsession.c b/src/rpc/virnetsshsession.c
index 0454deec16..8f59906b4a 100644
--- a/src/rpc/virnetsshsession.c
+++ b/src/rpc/virnetsshsession.c
@@ -70,7 +70,6 @@ typedef struct _virNetSSHAuthMethod virNetSSHAuthMethod;

 struct _virNetSSHAuthMethod {
     virNetSSHAuthMethods method;
-    char *username;
     char *filename;

     int tries;
@@ -93,6 +92,7 @@ struct _virNetSSHSession {
     int port;

     /* authentication stuff */
+    char *username;
     virConnectAuthPtr cred;
     char *authPath;
     virNetSSHAuthCallbackError authCbErr;
@@ -115,7 +115,6 @@ virNetSSHSessionAuthMethodsClear(virNetSSHSession *sess)
     size_t i;

     for (i = 0; i < sess->nauths; i++) {
-        VIR_FREE(sess->auths[i]->username);
         VIR_FREE(sess->auths[i]->filename);
         VIR_FREE(sess->auths[i]);
     }
@@ -151,6 +150,7 @@ virNetSSHSessionDispose(void *obj)
     g_free(sess->hostname);
     g_free(sess->knownHostsFile);
     g_free(sess->authPath);
+    g_free(sess->username);
 }

 static virClass *virNetSSHSessionClass;
@@ -488,8 +488,7 @@ virNetSSHCheckHostKey(virNetSSHSession *sess)
  *         -1 on error
  */
 static int
-virNetSSHAuthenticateAgent(virNetSSHSession *sess,
-                           virNetSSHAuthMethod *priv)
+virNetSSHAuthenticateAgent(virNetSSHSession *sess)
 {
     struct libssh2_agent_publickey *agent_identity = NULL;
     bool no_identity = true;
@@ -515,7 +514,7 @@ virNetSSHAuthenticateAgent(virNetSSHSession *sess,
                                               agent_identity))) {
         no_identity = false;
         if (!(ret = libssh2_agent_userauth(sess->agent,
-                                           priv->username,
+                                           sess->username,
                                            agent_identity)))
             return 0; /* key accepted */

@@ -575,7 +574,7 @@ virNetSSHAuthenticatePrivkey(virNetSSHSession *sess,

     /* try open the key with no password */
     if ((ret = libssh2_userauth_publickey_fromfile(sess->session,
-                                                   priv->username,
+                                                   sess->username,
                                                    NULL,
                                                    priv->filename,
                                                    NULL)) == 0)
@@ -634,7 +633,7 @@ virNetSSHAuthenticatePrivkey(virNetSSHSession *sess,
     VIR_FREE(tmp);

     ret = libssh2_userauth_publickey_fromfile(sess->session,
-                                              priv->username,
+                                              sess->username,
                                               NULL,
                                               priv->filename,
                                               retr_passphrase.result);
@@ -668,8 +667,7 @@ virNetSSHAuthenticatePrivkey(virNetSSHSession *sess,
  *         -1 on error
  */
 static int
-virNetSSHAuthenticatePassword(virNetSSHSession *sess,
-                              virNetSSHAuthMethod *priv)
+virNetSSHAuthenticatePassword(virNetSSHSession *sess)
 {
     char *password = NULL;
     char *errmsg;
@@ -690,13 +688,13 @@ virNetSSHAuthenticatePassword(virNetSSHSession *sess,
      * connection if maximum number of bad auth tries is exceeded */
     while (true) {
         if (!(password = virAuthGetPasswordPath(sess->authPath, sess->cred,
-                                                "ssh", priv->username,
+                                                "ssh", sess->username,
                                                 sess->hostname)))
             goto cleanup;

         /* tunnelled password authentication */
         if ((rc = libssh2_userauth_password(sess->session,
-                                            priv->username,
+                                            sess->username,
                                             password)) == 0) {
             ret = 0;
             goto cleanup;
@@ -751,7 +749,7 @@ virNetSSHAuthenticateKeyboardInteractive(virNetSSHSession *sess,
      * connection if maximum number of bad auth tries is exceeded */
     while (priv->tries < 0 || priv->tries-- > 0) {
         ret = libssh2_userauth_keyboard_interactive(sess->session,
-                                                    priv->username,
+                                                    sess->username,
                                                     virNetSSHKbIntCb);

         /* check for errors while calling the callback */
@@ -817,9 +815,8 @@ virNetSSHAuthenticate(virNetSSHSession *sess)
     }

     /* obtain list of supported auth methods */
-    auth_list = libssh2_userauth_list(sess->session,
-                                      sess->auths[0]->username,
-                                      strlen(sess->auths[0]->username));
+    auth_list = libssh2_userauth_list(sess->session, sess->username,
+                                      strlen(sess->username));
     if (!auth_list) {
         /* unlikely event, authentication succeeded with NONE as method */
         if (libssh2_userauth_authenticated(sess->session) == 1)
@@ -845,7 +842,7 @@ virNetSSHAuthenticate(virNetSSHSession *sess)
             break;
         case VIR_NET_SSH_AUTH_AGENT:
             if (strstr(auth_list, "publickey"))
-                ret = virNetSSHAuthenticateAgent(sess, auth);
+                ret = virNetSSHAuthenticateAgent(sess);
             break;
         case VIR_NET_SSH_AUTH_PRIVKEY:
             if (strstr(auth_list, "publickey"))
@@ -853,7 +850,7 @@ virNetSSHAuthenticate(virNetSSHSession *sess)
             break;
         case VIR_NET_SSH_AUTH_PASSWORD:
             if (strstr(auth_list, "password"))
-                ret = virNetSSHAuthenticatePassword(sess, auth);
+                ret = virNetSSHAuthenticatePassword(sess);
             break;
         }

@@ -969,11 +966,9 @@ virNetSSHSessionAuthReset(virNetSSHSession *sess)

 int
 virNetSSHSessionAuthAddPasswordAuth(virNetSSHSession *sess,
-                                    virURI *uri,
-                                    const char *username)
+                                    virURI *uri)
 {
     virNetSSHAuthMethod *auth;
-    char *user = NULL;

     if (uri) {
         VIR_FREE(sess->authPath);
@@ -982,75 +977,50 @@ virNetSSHSessionAuthAddPasswordAuth(virNetSSHSession *sess,
             goto error;
     }

-    if (!username) {
-        if (!(user = virAuthGetUsernamePath(sess->authPath, sess->cred,
-                                            "ssh", NULL, sess->hostname)))
-            goto error;
-    } else {
-        user = g_strdup(username);
-    }
-
     virObjectLock(sess);

     if (!(auth = virNetSSHSessionAuthMethodNew(sess)))
         goto error;

-    auth->username = user;
     auth->method = VIR_NET_SSH_AUTH_PASSWORD;

     virObjectUnlock(sess);
     return 0;

  error:
-    VIR_FREE(user);
     virObjectUnlock(sess);
     return -1;
 }

 int
-virNetSSHSessionAuthAddAgentAuth(virNetSSHSession *sess,
-                                 const char *username)
+virNetSSHSessionAuthAddAgentAuth(virNetSSHSession *sess)
 {
     virNetSSHAuthMethod *auth;
-    char *user = NULL;
-
-    if (!username) {
-        virReportError(VIR_ERR_SSH, "%s",
-                       _("Username must be provided "
-                         "for ssh agent authentication"));
-        return -1;
-    }

     virObjectLock(sess);

-    user = g_strdup(username);
-
     if (!(auth = virNetSSHSessionAuthMethodNew(sess)))
         goto error;

-    auth->username = user;
     auth->method = VIR_NET_SSH_AUTH_AGENT;

     virObjectUnlock(sess);
     return 0;

  error:
-    VIR_FREE(user);
     virObjectUnlock(sess);
     return -1;
 }

 int
 virNetSSHSessionAuthAddPrivKeyAuth(virNetSSHSession *sess,
-                                   const char *username,
                                    const char *keyfile)
 {
     virNetSSHAuthMethod *auth;

-    if (!username || !keyfile) {
+    if (!keyfile) {
         virReportError(VIR_ERR_SSH, "%s",
-                       _("Username and key file path must be provided "
-                         "for private key authentication"));
+                       _("Key file path must be provided for private key authentication"));
         return -1;
     }

@@ -1059,7 +1029,6 @@ virNetSSHSessionAuthAddPrivKeyAuth(virNetSSHSession *sess,
     if (!(auth = virNetSSHSessionAuthMethodNew(sess)))
         return -1;

-    auth->username = g_strdup(username);
     auth->filename = g_strdup(keyfile);
     auth->method = VIR_NET_SSH_AUTH_PRIVKEY;

@@ -1069,27 +1038,15 @@ virNetSSHSessionAuthAddPrivKeyAuth(virNetSSHSession *sess,

 int
 virNetSSHSessionAuthAddKeyboardAuth(virNetSSHSession *sess,
-                                    const char *username,
                                     int tries)
 {
     virNetSSHAuthMethod *auth;
-    char *user = NULL;
-
-    if (!username) {
-        virReportError(VIR_ERR_SSH, "%s",
-                       _("Username must be provided "
-                         "for ssh agent authentication"));
-        return -1;
-    }

     virObjectLock(sess);

-    user = g_strdup(username);
-
     if (!(auth = virNetSSHSessionAuthMethodNew(sess)))
         goto error;

-    auth->username = user;
     auth->tries = tries;
     auth->method = VIR_NET_SSH_AUTH_KEYBOARD_INTERACTIVE;

@@ -1097,7 +1054,6 @@ virNetSSHSessionAuthAddKeyboardAuth(virNetSSHSession *sess,
     return 0;

  error:
-    VIR_FREE(user);
     virObjectUnlock(sess);
     return -1;

@@ -1170,7 +1126,7 @@ virNetSSHSessionSetHostKeyVerification(virNetSSHSession *sess,
 }

 /* allocate and initialize a ssh session object */
-virNetSSHSession *virNetSSHSessionNew(void)
+virNetSSHSession *virNetSSHSessionNew(const char *username)
 {
     virNetSSHSession *sess = NULL;

@@ -1180,6 +1136,8 @@ virNetSSHSession *virNetSSHSessionNew(void)
     if (!(sess = virObjectLockableNew(virNetSSHSessionClass)))
         goto error;

+    sess->username = g_strdup(username);
+
     /* initialize session data, use the internal data for callbacks
      * and stick to default memory management functions */
     if (!(sess->session = libssh2_session_init_ex(NULL,
diff --git a/src/rpc/virnetsshsession.h b/src/rpc/virnetsshsession.h
index 8d6c99c547..8187346000 100644
--- a/src/rpc/virnetsshsession.h
+++ b/src/rpc/virnetsshsession.h
@@ -25,7 +25,7 @@

 typedef struct _virNetSSHSession virNetSSHSession;

-virNetSSHSession *virNetSSHSessionNew(void);
+virNetSSHSession *virNetSSHSessionNew(const char *username);
 void virNetSSHSessionFree(virNetSSHSession *sess);

 typedef enum {
@@ -48,18 +48,14 @@ int virNetSSHSessionAuthSetCallback(virNetSSHSession *sess,
                                     virConnectAuthPtr auth);

 int virNetSSHSessionAuthAddPasswordAuth(virNetSSHSession *sess,
-                                        virURI *uri,
-                                        const char *username);
+                                        virURI *uri);

-int virNetSSHSessionAuthAddAgentAuth(virNetSSHSession *sess,
-                                     const char *username);
+int virNetSSHSessionAuthAddAgentAuth(virNetSSHSession *sess);

 int virNetSSHSessionAuthAddPrivKeyAuth(virNetSSHSession *sess,
-                                       const char *username,
                                        const char *keyfile);

 int virNetSSHSessionAuthAddKeyboardAuth(virNetSSHSession *sess,
-                                        const char *username,
                                         int tries);

 int virNetSSHSessionSetHostKeyVerification(virNetSSHSession *sess,
-- 
2.38.1



More information about the libvir-list mailing list