diff --git a/drivers/net/wireless/ath/ath12k/mac.c b/drivers/net/wireless/ath/ath12k/mac.c index ee804d4a3fd8..ae8a253c466c 100644 --- a/drivers/net/wireless/ath/ath12k/mac.c +++ b/drivers/net/wireless/ath/ath12k/mac.c @@ -4657,6 +4657,7 @@ static int ath12k_mac_op_set_key(struct ieee80211_hw *hw, enum set_key_cmd cmd, if (sta) { ahsta = ath12k_sta_to_ahsta(sta); + /* For an ML STA Pairwise key is same for all associated link Stations, * hence do set key for all link STAs which are active. */ @@ -4679,41 +4680,47 @@ static int ath12k_mac_op_set_key(struct ieee80211_hw *hw, enum set_key_cmd cmd, if (ret) break; } - } else { - arsta = &ahsta->deflink; - arvif = arsta->arvif; - if (WARN_ON(!arvif)) { - ret = -EINVAL; - goto out; - } - ret = ath12k_mac_set_key(arvif->ar, cmd, arvif, arsta, key); - } - } else { - if (key->link_id >= 0 && key->link_id < IEEE80211_MLD_MAX_NUM_LINKS) { - link_id = key->link_id; - arvif = wiphy_dereference(hw->wiphy, ahvif->link[link_id]); - } else { - link_id = 0; - arvif = &ahvif->deflink; + return 0; } - if (!arvif || !arvif->is_created) { - cache = ath12k_ahvif_get_link_cache(ahvif, link_id); - if (!cache) - return -ENOSPC; - - ret = ath12k_mac_update_key_cache(cache, cmd, sta, key); + arsta = &ahsta->deflink; + arvif = arsta->arvif; + if (WARN_ON(!arvif)) + return -EINVAL; + ret = ath12k_mac_set_key(arvif->ar, cmd, arvif, arsta, key); + if (ret) return ret; - } - ret = ath12k_mac_set_key(arvif->ar, cmd, arvif, NULL, key); + return 0; } -out: + if (key->link_id >= 0 && key->link_id < IEEE80211_MLD_MAX_NUM_LINKS) { + link_id = key->link_id; + arvif = wiphy_dereference(hw->wiphy, ahvif->link[link_id]); + } else { + link_id = 0; + arvif = &ahvif->deflink; + } - return ret; + if (!arvif || !arvif->is_created) { + cache = ath12k_ahvif_get_link_cache(ahvif, link_id); + if (!cache) + return -ENOSPC; + + ret = ath12k_mac_update_key_cache(cache, cmd, sta, key); + if (ret) + return ret; + + return 0; + } + + ret = ath12k_mac_set_key(arvif->ar, cmd, arvif, NULL, key); + if (ret) + return ret; + + return 0; } static int