Fix pylint warnings

This commit is contained in:
Yue Zhang
2015-07-20 14:54:00 +08:00
parent cb3eb72589
commit 52e3cb60e5
83 changed files with 3552 additions and 3228 deletions
+63
View File
@@ -0,0 +1,63 @@
###############################################################################
# Set default behavior to automatically normalize line endings.
###############################################################################
* text=auto
###############################################################################
# Set default behavior for command prompt diff.
#
# This is need for earlier builds of msysgit that does not have it on by
# default for csharp files.
# Note: This is only used by command line
###############################################################################
#*.cs diff=csharp
###############################################################################
# Set the merge driver for project and solution files
#
# Merging from the command prompt will add diff markers to the files if there
# are conflicts (Merging from VS is not affected by the settings below, in VS
# the diff markers are never inserted). Diff markers may cause the following
# file extensions to fail to load in VS. An alternative would be to treat
# these files as binary and thus will always conflict and require user
# intervention with every merge. To do so, just uncomment the entries below
###############################################################################
#*.sln merge=binary
#*.csproj merge=binary
#*.vbproj merge=binary
#*.vcxproj merge=binary
#*.vcproj merge=binary
#*.dbproj merge=binary
#*.fsproj merge=binary
#*.lsproj merge=binary
#*.wixproj merge=binary
#*.modelproj merge=binary
#*.sqlproj merge=binary
#*.wwaproj merge=binary
###############################################################################
# behavior for image files
#
# image files are treated as binary by default.
###############################################################################
#*.jpg binary
#*.png binary
#*.gif binary
###############################################################################
# diff behavior for common document formats
#
# Convert binary document formats to text before diffing them. This feature
# is only available from the command line. Turn it on by uncommenting the
# entries below.
###############################################################################
#*.doc diff=astextplain
#*.DOC diff=astextplain
#*.docx diff=astextplain
#*.DOCX diff=astextplain
#*.dot diff=astextplain
#*.DOT diff=astextplain
#*.pdf diff=astextplain
#*.PDF diff=astextplain
#*.rtf diff=astextplain
#*.RTF diff=astextplain
+3
View File
@@ -54,3 +54,6 @@ docs/_build/
target/
waagentc
*.pyproj
*.sln
*.suo
+5 -10
View File
@@ -1,4 +1,3 @@
#!/bin/bash
# Copyright 2014 Microsoft Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,14 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Implements parts of RFC 2131, 1541, 1497 and
# http://msdn.microsoft.com/en-us/library/cc227282%28PROT.10%29.aspx
# http://msdn.microsoft.com/en-us/library/cc227259%28PROT.13%29.aspx
# Requires Python 2.4+ and Openssl 1.0+
#
read Line
import azurelinuxagent.agent as agent
if [ $Line == "y" ] ; then
exit 0
else
exit 1
fi
print "hehe"
agent.main()
+72 -45
View File
@@ -17,38 +17,48 @@
# Requires Python 2.4+ and Openssl 1.0+
#
"""
Module agent
"""
import os
import sys
import re
import shutil
import time
import traceback
import threading
import subprocess
import azurelinuxagent.logger as logger
from azurelinuxagent.metadata import GuestAgentName, GuestAgentLongVersion, \
DistroName, DistroVersion, DistroFullName
from azurelinuxagent.utils.osutil import OSUtil
from azurelinuxagent.handler import Handlers
import azurelinuxagent.utils.shellutil as shellutil
import azurelinuxagent.utils.fileutil as fileutil
from azurelinuxagent.metadata import AGENT_NAME, agent_long_version, \
DISTRO_NAME, DISTRO_VERSION
from azurelinuxagent.utils.osutil import OSUTIL
from azurelinuxagent.handler import HANDLERS
def Init(verbose):
Handlers.initHandler.init(verbose)
def Run():
Handlers.runHandler.run()
def Deprovision(force=False, deluser=False):
Handlers.deprovisionHandler.deprovision(force=force, deluser=deluser)
def ParseArgs(sysArgv):
def init(verbose):
"""
Initialize agent running environment.
"""
HANDLERS.init_handler.init(verbose)
def run():
"""
Run agent daemon
"""
HANDLERS.main_handler.run()
def deprovision(force=False, deluser=False):
"""
Run deprovision command
"""
HANDLERS.deprovision_handler.deprovision(force=force, deluser=deluser)
def parse_args(sys_args):
"""
Parse command line arguments
"""
cmd = "help"
force = False
verbose = False
for a in sysArgv:
if re.match("^([-/]*)deprovision\+user", a):
for a in sys_args:
if re.match("^([-/]*)deprovision\\+user", a):
cmd = "deprovision+user"
elif re.match("^([-/]*)deprovision", a):
cmd = "deprovision"
@@ -64,48 +74,65 @@ def ParseArgs(sysArgv):
verbose = True
elif re.match("^([-/]*)force", a):
force = True
elif re.match("^([-/]*)(help|usage|\?)", a):
elif re.match("^([-/]*)(help|usage|\\?)", a):
cmd = "help"
else:
cmd = "help"
break
return cmd, force, verbose
def Version():
print("{0} running on {1} {2}".format(GuestAgentLongVersion, DistroName,
DistroVersion))
def Usage():
def version():
"""
Show agent version
"""
print("{0} running on {1} {2}".format(agent_long_version, DISTRO_NAME,
DISTRO_VERSION))
def usage():
"""
Show agent usage
"""
print("")
print(("usage: {0} [-verbose] [-force] [-help]"
"-deprovision[+user]|-register-service|-version|-daemon|-start]"
"").format(sys.argv[0]))
print("")
def Start():
def start():
"""
Start agent daemon in a background process and set stdout/stderr to
/dev/null
"""
devnull = open(os.devnull, 'w')
subprocess.Popen([sys.argv[0], '-daemon'], stdout=devnull, stderr=devnull)
def RegisterService():
print "Register {0} service".format(GuestAgentName)
OSUtil.RegisterAgentService()
print "Start {0} service".format(GuestAgentName)
OSUtil.StartAgentService()
def register_service():
"""
Register agent as a service
"""
print "Register {0} service".format(AGENT_NAME)
OSUTIL.register_agent_service()
print "Start {0} service".format(AGENT_NAME)
OSUTIL.start_agent_service()
def Main():
command, force, verbose = ParseArgs(sys.argv[1:])
def main():
"""
Parse command line arguments, exit with usage() on error.
Invoke different methods according to different command
"""
command, force, verbose = parse_args(sys.argv[1:])
if command == "version":
Version()
version()
elif command == "help":
Usage()
else:
Init(verbose)
usage()
else:
init(verbose)
if command == "deprovision+user":
Deprovision(force, deluser=True)
deprovision(force, deluser=True)
elif command == "deprovision":
Deprovision(force, deluser=False)
deprovision(force, deluser=False)
elif command == "start":
Start()
start()
elif command == "register-service":
RegisterService()
register_service()
elif command == "daemon":
Run()
run()
+45 -27
View File
@@ -17,9 +17,12 @@
# Requires Python 2.4+ and Openssl 1.0+
#
"""
Module conf loads and parses configuration file
"""
import os
import azurelinuxagent.utils.fileutil as fileutil
from azurelinuxagent.exception import *
from azurelinuxagent.exception import AgentConfigError
class ConfigurationProvider(object):
"""
@@ -40,52 +43,67 @@ class ConfigurationProvider(object):
else:
self.values[parts[0]] = None
def get(self, key, defaultValue=None):
def get(self, key, default_val=None):
val = self.values.get(key)
return val if val is not None else defaultValue
return val if val is not None else default_val
def getSwitch(self, key, defaultValue=False):
def get_switch(self, key, default_val=False):
val = self.values.get(key)
if val is not None and val.lower() == 'y':
return True
elif val is not None and val.lower() == 'n':
return False
return defaultValue
return default_val
def getInt(self, key, defaultValue=-1):
def get_int(self, key, default_val=-1):
try:
return int(self.values.get(key))
except:
return defaultValue
except TypeError:
return default_val
except ValueError:
return default_val
__Config__ = ConfigurationProvider()
__config__ = ConfigurationProvider()
def LoadConfiguration(confFilePath, conf=__Config__):
if os.path.isfile(confFilePath) == False:
raise AgentConfigError("Missing configuration in {0}".format(confFilePath))
def load_conf(conf_file_path, conf=__config__):
"""
Load conf file from: conf_file_path
"""
if os.path.isfile(conf_file_path) == False:
raise AgentConfigError(("Missing configuration in {0}"
"").format(conf_file_path))
try:
content = fileutil.GetFileContents(confFilePath)
content = fileutil.read_file(conf_file_path)
conf.load(content)
except IOError, e:
raise AgentConfigError("Failed to load conf file:{0}".format(confFilePath))
except IOError as err:
raise AgentConfigError(("Failed to load conf file:{0}, {1}"
"").format(conf_file_path, err))
def Get(key, defaultValue=None, conf=__Config__):
def get(key, default_val=None, conf=__config__):
"""
Get option value by key, return default_val if not found
"""
if conf is not None:
return conf.get(key, defaultValue)
return conf.get(key, default_val)
else:
return defaultValue
def GetSwitch(key, defaultValue=None, conf=__Config__):
if conf is not None:
return conf.getSwitch(key, defaultValue)
else:
return defaultValue
return default_val
def GetInt(key, defaultValue=None, conf=__Config__):
def get_switch(key, default_val=None, conf=__config__):
"""
Get bool option value by key, return default_val if not found
"""
if conf is not None:
return conf.getInt(key, defaultValue)
return conf.get_switch(key, default_val)
else:
return defaultValue
return default_val
def get_int(key, default_val=None, conf=__config__):
"""
Get int option value by key, return default_val if not found
"""
if conf is not None:
return conf.get_int(key, default_val)
else:
return default_val
+4 -4
View File
@@ -17,9 +17,9 @@
# Requires Python 2.4+ and Openssl 1.0+
#
from azurelinuxagent.metadata import DistroName, DistroVersion
from azurelinuxagent.metadata import DISTRO_NAME, DISTRO_VERSION
import azurelinuxagent.distro.redhat.loader as redhat
def GetOSUtil():
return redhat.GetOSUtil()
def get_osutil():
return redhat.get_osutil()
+4 -4
View File
@@ -21,10 +21,10 @@ import azurelinuxagent.utils.fileutil as fileutil
from azurelinuxagent.distro.default.deprovision import DeprovisionHandler, DeprovisionAction
class CoreOSDeprovisionHandler(DeprovisionHandler):
def setUp(self, deluser):
warnings, actions = super(CoreOSDeprovisionHandler, self).setUp(deluser)
def setup(self, deluser):
warnings, actions = super(CoreOSDeprovisionHandler, self).setup(deluser)
warnings.append("WARNING! /etc/machine-id will be removed.")
filesToDel = ['/etc/machine-id']
actions.append(DeprovisionAction(fileutil.RemoveFiles, filesToDel))
files_to_del = ['/etc/machine-id']
actions.append(DeprovisionAction(fileutil.rm_files, files_to_del))
return warnings, actions
@@ -23,5 +23,5 @@ from azurelinuxagent.distro.default.handlerFactory import DefaultHandlerFactory
class CoreOSHandlerFactory(DefaultHandlerFactory):
def __init__(self):
super(CoreOSHandlerFactory, self).__init__()
self.deprovisionHandler = CoreOSDeprovisionHandler()
self.deprovision_handler = CoreOSDeprovisionHandler()
+2 -2
View File
@@ -18,11 +18,11 @@
#
def GetOSUtil():
def get_osutil():
from azurelinuxagent.distro.coreos.osutil import CoreOSUtil
return CoreOSUtil()
def GetHandlers():
def get_handlers():
from azurelinuxagent.distro.coreos.handlerFactory import CoreOSHandlerFactory
return CoreOSHandlerFactory()
+26 -26
View File
@@ -37,7 +37,7 @@ class CoreOSUtil(DefaultOSUtil):
super(CoreOSUtil, self).__init__()
self.waagent_path='/usr/share/oem/bin/waagent'
self.python_path='/usr/share/oem/python/bin'
self.configPath = '/usr/share/oem/waagent.conf'
self.conf_path = '/usr/share/oem/waagent.conf'
if 'PATH' in os.environ:
path = "{0}:{1}".format(os.environ['PATH'], self.python_path)
else:
@@ -51,40 +51,40 @@ class CoreOSUtil(DefaultOSUtil):
py_path = self.waagent_path
os.environ['PYTHONPATH'] = py_path
def IsSysUser(self, userName):
def is_sys_user(self, username):
#User 'core' is not a sysuser
if userName == 'core':
if username == 'core':
return False
return super(CoreOSUtil, self).IsSysUser(userName)
return super(CoreOSUtil, self).IsSysUser(username)
def IsDhcpEnabled(self):
def is_dhcp_enabled(self):
return True
def StartNetwork(self) :
return shellutil.Run("systemctl start systemd-networkd", chk_err=False)
def RestartInterface(self, iface):
shellutil.Run("systemctl restart systemd-networkd")
def RestartSshService(self):
return shellutil.Run("systemctl restart sshd", chk_err=False)
def start_network(self) :
return shellutil.run("systemctl start systemd-networkd", chk_err=False)
def StopDhcpService(self):
return shellutil.Run("systemctl stop systemd-networkd", chk_err=False)
def restart_if(self, iface):
shellutil.run("systemctl restart systemd-networkd")
def StartDhcpService(self):
return shellutil.Run("systemctl start systemd-networkd", chk_err=False)
def restart_ssh_service(self):
return shellutil.run("systemctl restart sshd", chk_err=False)
def StartAgentService(self):
return shellutil.Run("systemctl start wagent", chk_err=False)
def stop_dhcp_service(self):
return shellutil.run("systemctl stop systemd-networkd", chk_err=False)
def StopAgentService(self):
return shellutil.Run("systemctl stop wagent", chk_err=False)
def GetDhcpProcessId(self):
ret= shellutil.RunGetOutput("pidof systemd-networkd")
def start_dhcp_service(self):
return shellutil.run("systemctl start systemd-networkd", chk_err=False)
def start_agent_service(self):
return shellutil.run("systemctl start wagent", chk_err=False)
def stop_agent_service(self):
return shellutil.run("systemctl stop wagent", chk_err=False)
def get_dhcp_pid(self):
ret= shellutil.run_get_output("pidof systemd-networkd")
return ret[1] if ret[0] == 0 else None
def TranslateCustomData(self, data):
def decode_customdata(self, data):
return base64.b64decode(data)
+1 -1
View File
@@ -18,7 +18,7 @@
#
def GetOSUtil():
def get_osutil():
from azurelinuxagent.distro.debian.osutil import DebianOSUtil
return DebianOSUtil()
+8 -8
View File
@@ -30,18 +30,18 @@ import azurelinuxagent.logger as logger
import azurelinuxagent.utils.fileutil as fileutil
import azurelinuxagent.utils.shellutil as shellutil
import azurelinuxagent.utils.textutil as textutil
from azurelinuxagent.distro.default.osutil import OSUtil
from azurelinuxagent.distro.default.osutil import DefaultOSUtil
class DebianOSUtil(OSUtil):
class DebianOSUtil(DefaultOSUtil):
def __init__(self):
super(DebianOSUtil, self).__init__()
def RestartSshService(self):
return shellutil.Run("service sshd restart", chk_err=False)
def restart_ssh_service(self):
return shellutil.run("service sshd restart", chk_err=False)
def StopAgentService(self):
return shellutil.Run("service azurelinuxagent stop", chk_err=False)
def stop_agent_service(self):
return shellutil.run("service azurelinuxagent stop", chk_err=False)
def StartAgentService(self):
return shellutil.Run("service azurelinuxagent start", chk_err=False)
def start_agent_service(self):
return shellutil.run("service azurelinuxagent start", chk_err=False)
+42 -42
View File
@@ -18,7 +18,7 @@
#
import azurelinuxagent.conf as conf
from azurelinuxagent.utils.osutil import OSUtil
from azurelinuxagent.utils.osutil import OSUTIL
import azurelinuxagent.protocol as prot
import azurelinuxagent.protocol.ovfenv as ovf
import azurelinuxagent.utils.fileutil as fileutil
@@ -35,74 +35,74 @@ class DeprovisionAction(object):
class DeprovisionHandler(object):
def deleteRootPassword(self, warnings, actions):
def del_root_password(self, warnings, actions):
warnings.append("WARNING! root password will be disabled. "
"You will not be able to login as root.")
actions.append(DeprovisionAction(OSUtil.DeleteRootPassword))
def deleteUser(self, warnings, actions):
actions.append(DeprovisionAction(OSUTIL.del_root_password))
def del_user(self, warnings, actions):
try:
ovfenv = ovf.GetOvfEnv()
ovfenv = ovf.get_ovf_env()
except prot.ProtocolError:
warnings.append("WARNING! ovf-env.xml is not found.")
warnings.append("WARNING! Skip delete user.")
return
userName = ovfenv.getUserName()
username = ovfenv.get_username()
warnings.append(("WARNING! {0} account and entire home directory "
"will be deleted.").format(userName))
actions.append(DeprovisionAction(OSUtil.DeleteAccount, [userName]))
"will be deleted.").format(username))
actions.append(DeprovisionAction(OSUTIL.del_account, [username]))
def regenerateHostKeyPair(self, warnings, actions):
def regen_ssh_host_key(self, warnings, actions):
warnings.append("WARNING! All SSH host key pairs will be deleted.")
actions.append(DeprovisionAction(OSUtil.SetHostname,
actions.append(DeprovisionAction(OSUTIL.set_hostname,
['localhost.localdomain']))
actions.append(DeprovisionAction(shellutil.Run,
actions.append(DeprovisionAction(shellutil.run,
['rm -f /etc/ssh/ssh_host_*key*']))
def stopAgentService(self, warnings, actions):
def stop_agent_service(self, warnings, actions):
warnings.append("WARNING! The waagent service will be stopped.")
actions.append(DeprovisionAction(OSUtil.StopAgentService))
actions.append(DeprovisionAction(OSUTIL.stop_agent_service))
def deleteFiles(self, warnings, actions):
filesToDel = ['/root/.bash_history', '/var/log/waagent.log']
actions.append(DeprovisionAction(fileutil.RemoveFiles, filesToDel))
def del_files(self, warnings, actions):
files_to_del = ['/root/.bash_history', '/var/log/waagent.log']
actions.append(DeprovisionAction(fileutil.rm_files, files_to_del))
def deleteDhcpLease(self, warnings, actions):
def del_dhcp_lease(self, warnings, actions):
warnings.append("WARNING! Cached DHCP leases will be deleted.")
dirsToDel = ["/var/lib/dhclient", "/var/lib/dhcpcd", "/var/lib/dhcp"]
actions.append(DeprovisionAction(fileutil.CleanupDirs, dirsToDel))
dirs_to_del = ["/var/lib/dhclient", "/var/lib/dhcpcd", "/var/lib/dhcp"]
actions.append(DeprovisionAction(fileutil.rm_dirs, dirs_to_del))
def deleteLibDir(self, warnings, actions):
dirsToDel = [OSUtil.GetLibDir()]
actions.append(DeprovisionAction(fileutil.CleanupDirs, dirsToDel))
def del_lib_dir(self, warnings, actions):
dirs_to_del = [OSUTIL.get_lib_dir()]
actions.append(DeprovisionAction(fileutil.rm_dirs, dirs_to_del))
def setUp(self, deluser):
def setup(self, deluser):
warnings = []
actions = []
self.stopAgentService(warnings, actions)
if conf.GetSwitch("Provisioning.RegenerateSshHostkey", False):
self.regenerateHostKeyPair(warnings, actions)
self.deleteDhcpLease(warnings, actions)
self.stop_agent_service(warnings, actions)
if conf.get_switch("Provisioning.RegenerateSshHostkey", False):
self.regen_ssh_host_key(warnings, actions)
if conf.GetSwitch("Provisioning.DeleteRootPassword", False):
self.deleteRootPassword(warnings, actions)
self.del_dhcp_lease(warnings, actions)
self.deleteLibDir(warnings, actions)
self.deleteFiles(warnings, actions)
if conf.get_switch("Provisioning.DeleteRootPassword", False):
self.del_root_password(warnings, actions)
self.del_lib_dir(warnings, actions)
self.del_files(warnings, actions)
if deluser:
self.deleteUser(warnings, actions)
self.del_user(warnings, actions)
return warnings, actions
def deprovision(self, force=False, deluser=False):
warnings, actions = self.setUp(deluser)
warnings, actions = self.setup(deluser)
for warning in warnings:
print warning
@@ -110,8 +110,8 @@ class DeprovisionHandler(object):
confirm = raw_input("Do you want to proceed (y/n)")
if not confirm.lower().startswith('y'):
return
for action in actions:
action.invoke()
+121 -120
View File
@@ -20,171 +20,172 @@ import socket
import array
import time
import azurelinuxagent.logger as logger
from azurelinuxagent.utils.osutil import OSUtil
from azurelinuxagent.utils.osutil import OSUTIL
from azurelinuxagent.exception import AgentNetworkError
import azurelinuxagent.utils.fileutil as fileutil
import azurelinuxagent.utils.shellutil as shellutil
from azurelinuxagent.utils.textutil import *
WireServerAddrFile="WireServer"
WIRE_SERVER_ADDR_FILE_NAME="WireServer"
class DhcpHandler(object):
def __init__(self):
self.endpoint = None
self.gateway = None
self.routes = None
def waitForNetwork(self):
ipv4 = OSUtil.GetIpv4Address()
def wait_for_network(self):
ipv4 = OSUTIL.get_ip4_addr()
while ipv4 == '' or ipv4 == '0.0.0.0':
logger.Info("Waiting for network.")
logger.info("Waiting for network.")
time.sleep(10)
OSUtil.StartNetwork()
ipv4 = OSUtil.GetIpv4Address()
OSUTIL.start_network()
ipv4 = OSUTIL.get_ip4_addr()
def probe(self):
logger.Info("Send dhcp request")
self.waitForNetwork()
macAddress = OSUtil.GetMacAddress()
req = BuildDhcpRequest(macAddress)
resp = SendDhcpRequest(req)
endpoint, gateway, routes = ParseDhcpResponse(resp)
logger.info("Send dhcp request")
self.wait_for_network()
mac_addr = OSUTIL.get_mac_addr()
req = build_dhcp_request(mac_addr)
resp = send_dhcp_request(req)
endpoint, gateway, routes = parse_dhcp_resp(resp)
self.endpoint = endpoint
logger.Info("Wire server endpoint:{0}", endpoint)
logger.Info("Gateway:{0}", gateway)
logger.Info("Routes:{0}", routes)
logger.info("Wire server endpoint:{0}", endpoint)
logger.info("Gateway:{0}", gateway)
logger.info("Routes:{0}", routes)
if endpoint is not None:
path = os.path.join(OSUtil.GetLibDir(), WireServerAddrFile)
fileutil.SetFileContents(path, endpoint)
path = os.path.join(OSUTIL.get_lib_dir(), WIRE_SERVER_ADDR_FILE_NAME)
fileutil.write_file(path, endpoint)
self.gateway = gateway
self.routes = routes
self.configRoutes()
self.conf_routes()
def getEndpoint(self):
def get_endpoint(self):
return self.endpoint
def configRoutes(self):
logger.Info("Configure routes")
def conf_routes(self):
logger.info("Configure routes")
#Add default gateway
if self.gateway is not None:
OSUtil.RouteAdd(0 , 0, self.gateway)
OSUTIL.route_add(0 , 0, self.gateway)
if self.routes is not None:
for route in self.routes:
OSUtil.RouteAdd(route[0], route[1], route[2])
OSUTIL.route_add(route[0], route[1], route[2])
def ValidateDhcpResponse(request, response):
bytesReceived = len(response)
if bytesReceived < 0xF6:
logger.Error("HandleDhcpResponse: Too few bytes received:{0}",
str(bytesReceived))
def validate_dhcp_resp(request, response):
bytes_recv = len(response)
if bytes_recv < 0xF6:
logger.error("HandleDhcpResponse: Too few bytes received:{0}",
str(bytes_recv))
return False
logger.Verbose("BytesReceived:{0}", hex(bytesReceived))
logger.Verbose("DHCP response:{0}", HexDump(response, bytesReceived))
logger.verb("BytesReceived:{0}", hex(bytes_recv))
logger.verb("DHCP response:{0}", hex_dump(response, bytes_recv))
# check transactionId, cookie, MAC address cookie should never mismatch
# transactionId and MAC address may mismatch if we see a response
# transactionId and MAC address may mismatch if we see a response
# meant from another machine
if not CompareBytes(request, response, 0xEC, 4):
logger.Verbose("Cookie not match:\nsend={0},\nreceive={1}",
HexDump3(request, 0xEC, 4),
HexDump3(response, 0xEC, 4))
if not compare_bytes(request, response, 0xEC, 4):
logger.verb("Cookie not match:\nsend={0},\nreceive={1}",
hex_dump3(request, 0xEC, 4),
hex_dump3(response, 0xEC, 4))
raise AgentNetworkError("Cookie in dhcp respones "
"doesn't match the request")
if not CompareBytes(request, response, 4, 4):
logger.Verbose("TransactionID not match:\nsend={0},\nreceive={1}",
HexDump3(request, 4, 4),
HexDump3(response, 4, 4))
if not compare_bytes(request, response, 4, 4):
logger.verb("TransactionID not match:\nsend={0},\nreceive={1}",
hex_dump3(request, 4, 4),
hex_dump3(response, 4, 4))
raise AgentNetworkError("TransactionID in dhcp respones "
"doesn't match the request")
if not CompareBytes(request, response, 0x1C, 6):
logger.Verbose("Mac Address not match:\nsend={0},\nreceive={1}",
HexDump3(request, 0x1C, 6),
HexDump3(response, 0x1C, 6))
if not compare_bytes(request, response, 0x1C, 6):
logger.verb("Mac Address not match:\nsend={0},\nreceive={1}",
hex_dump3(request, 0x1C, 6),
hex_dump3(response, 0x1C, 6))
raise AgentNetworkError("Mac Addr in dhcp respones "
"doesn't match the request")
def ParseRoute(response, option, i, length, bytesReceived):
def parse_route(response, option, i, length, bytes_recv):
# http://msdn.microsoft.com/en-us/library/cc227282%28PROT.10%29.aspx
logger.Verbose("Routes at offset: {0} with length:{1}",
hex(i),
logger.verb("Routes at offset: {0} with length:{1}",
hex(i),
hex(length))
routes = []
if length < 5:
logger.Error("Data too small for option:{0}", str(option))
logger.error("Data too small for option:{0}", str(option))
j = i + 2
while j < (i + length + 2):
maskLengthBits = Ord(response[j])
maskLengthBytes = (((maskLengthBits + 7) & ~7) >> 3)
mask = 0xFFFFFFFF & (0xFFFFFFFF << (32 - maskLengthBits))
mask_len_bits = str_to_ord(response[j])
mask_len_bytes = (((mask_len_bits + 7) & ~7) >> 3)
mask = 0xFFFFFFFF & (0xFFFFFFFF << (32 - mask_len_bits))
j += 1
net = UnpackBigEndian(response, j, maskLengthBytes)
net <<= (32 - maskLengthBytes * 8)
net = unpack_big_endian(response, j, mask_len_bytes)
net <<= (32 - mask_len_bytes * 8)
net &= mask
j += maskLengthBytes
gateway = UnpackBigEndian(response, j, 4)
j += mask_len_bytes
gateway = unpack_big_endian(response, j, 4)
j += 4
routes.append((net, mask, gateway))
if j != (i + length + 2):
logger.Error("Unable to parse routes")
logger.error("Unable to parse routes")
return routes
def ParseIpAddress(response, option, i, length, bytesReceived):
if i + 5 < bytesReceived:
def parse_ip_addr(response, option, i, length, bytes_recv):
if i + 5 < bytes_recv:
if length != 4:
logger.Error("Endpoint or Default Gateway not 4 bytes")
logger.error("Endpoint or Default Gateway not 4 bytes")
return None
addr = UnpackBigEndian(response, i + 2, 4)
IpAddress = IntegerToIpAddressV4String(addr)
return IpAddress
addr = unpack_big_endian(response, i + 2, 4)
ip_addr = int_to_ip4_addr(addr)
return ip_addr
else:
logger.Error("Data too small for option:{0}", str(option))
logger.error("Data too small for option:{0}", str(option))
return None
def ParseDhcpResponse(response):
def parse_dhcp_resp(response):
"""
Parse DHCP response:
Returns endpoint server or None on error.
"""
logger.Verbose("parse Dhcp Response")
bytesReceived = len(response)
logger.verb("parse Dhcp Response")
bytes_recv = len(response)
endpoint = None
gateway = None
routes = None
# Walk all the returned options, parsing out what we need, ignoring the
# Walk all the returned options, parsing out what we need, ignoring the
# others. We need the custom option 245 to find the the endpoint we talk to,
# as well as, to handle some Linux DHCP client incompatibilities,
# options 3 for default gateway and 249 for routes. And 255 is end.
i = 0xF0 # offset to first option
while i < bytesReceived:
option = Ord(response[i])
while i < bytes_recv:
option = str_to_ord(response[i])
length = 0
if (i + 1) < bytesReceived:
length = Ord(response[i + 1])
logger.Verbose("DHCP option {0} at offset:{1} with length:{2}",
hex(option),
hex(i),
if (i + 1) < bytes_recv:
length = str_to_ord(response[i + 1])
logger.verb("DHCP option {0} at offset:{1} with length:{2}",
hex(option),
hex(i),
hex(length))
if option == 255:
logger.Verbose("DHCP packet ended at offset:{0}", hex(i))
logger.verb("DHCP packet ended at offset:{0}", hex(i))
break
elif option == 249:
routes = ParseRoute(response, option, i, length, bytesReceived)
routes = parse_route(response, option, i, length, bytes_recv)
elif option == 3:
gateway = ParseIpAddress(response, option, i, length, bytesReceived)
logger.Verbose("Default gateway:{0}, at {1}",
gateway,
gateway = parse_ip_addr(response, option, i, length, bytes_recv)
logger.verb("Default gateway:{0}, at {1}",
gateway,
hex(i))
elif option == 245:
endpoint = ParseIpAddress(response, option, i, length, bytesReceived)
logger.Verbose("Azure wire protocol endpoint:{0}, at {1}",
gateway,
endpoint = parse_ip_addr(response, option, i, length, bytes_recv)
logger.verb("Azure wire protocol endpoint:{0}, at {1}",
gateway,
hex(i))
else:
logger.Verbose("Skipping DHCP option:{0} at {1} with length {2}",
logger.verb("Skipping DHCP option:{0} at {1} with length {2}",
hex(option),
hex(i),
hex(length))
@@ -192,71 +193,71 @@ def ParseDhcpResponse(response):
return endpoint, gateway, routes
def AllowBroadcastForDhcp(func):
def allow_dhcp_broadcast(func):
"""
Temporary allow broadcase for dhcp. Remove the route when done.
"""
def Wrapper(*args, **kwargs):
missingDefaultRoute = OSUtil.IsMissingDefaultRoute()
ifname = OSUtil.GetInterfaceName()
if missingDefaultRoute:
OSUtil.SetBroadcastRouteForDhcp(ifname)
def wrapper(*args, **kwargs):
missing_default_route = OSUTIL.is_missing_default_route()
ifname = OSUTIL.get_if_name()
if missing_default_route:
OSUTIL.set_route_for_dhcp_broadcast(ifname)
result = func(*args, **kwargs)
if missingDefaultRoute:
OSUtil.RemoveBroadcastRouteForDhcp(ifname)
if missing_default_route:
OSUTIL.remove_route_for_dhcp_broadcast(ifname)
return result
return Wrapper
return wrapper
def DisableDhcpServiceIfNeeded(func):
def disable_dhcp_service(func):
"""
In some distros, dhcp service needs to be shutdown before agent probe
endpoint through dhcp.
"""
def Wrapper(*args, **kwargs):
if OSUtil.IsDhcpEnabled():
OSUtil.StopDhcpService()
def wrapper(*args, **kwargs):
if OSUTIL.is_dhcp_enabled():
OSUTIL.stop_dhcp_service()
result = func(*args, **kwargs)
OSUtil.StartDhcpService()
OSUTIL.start_dhcp_service()
return result
else:
return func(*args, **kwargs)
return Wrapper
return wrapper
@AllowBroadcastForDhcp
@DisableDhcpServiceIfNeeded
def SendDhcpRequest(request):
__SleepDuration = [0, 10, 30, 60, 60]
@allow_dhcp_broadcast
@disable_dhcp_service
def send_dhcp_request(request):
__waiting_duration__ = [0, 10, 30, 60, 60]
sock = None
for duration in __SleepDuration:
for duration in __waiting_duration__:
try:
OSUtil.OpenPortForDhcp()
response = SocketSend(request)
ValidateDhcpResponse(request, response)
OSUTIL.allow_dhcp_broadcast()
response = socket_send(request)
validate_dhcp_resp(request, response)
return response
except AgentNetworkError as e:
logger.Error("Failed to send DHCP request: {0}", e)
logger.error("Failed to send DHCP request: {0}", e)
return None
finally:
if sock:
sock.close()
time.sleep(duration)
def SocketSend(request):
sock = socket.socket(socket.AF_INET,
socket.SOCK_DGRAM,
def socket_send(request):
sock = socket.socket(socket.AF_INET,
socket.SOCK_DGRAM,
socket.IPPROTO_UDP)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("0.0.0.0", 68))
sock.bind(("0.0.0.0", 68))
sock.sendto(request, ("<broadcast>", 67))
sock.settimeout(10)
logger.Verbose("Send DHCP request: Setting socket.timeout=10, "
logger.verb("Send DHCP request: Setting socket.timeout=10, "
"entering recv")
response = sock.recv(1024)
return response
def BuildDhcpRequest(macAddress):
def build_dhcp_request(mac_addr):
"""
Build DHCP request string.
"""
@@ -291,7 +292,7 @@ def BuildDhcpRequest(macAddress):
# (struct.pack_into would be good here, but requires Python 2.5)
request = [0] * 244
transactionID = GenTransactionId()
trans_id = gen_trans_id()
# Opcode = 1
# HardwareAddressType = 1 (ethernet/MAC)
@@ -301,15 +302,15 @@ def BuildDhcpRequest(macAddress):
# fill in transaction id (random number to ensure response matches request)
for a in range(0, 4):
request[4 + a] = Ord(transactionID[a])
request[4 + a] = str_to_ord(trans_id[a])
logger.Verbose("BuildDhcpRequest: transactionId:%s,%04X" % (
HexDump2(transactionID),
UnpackBigEndian(request, 4, 4)))
logger.verb("BuildDhcpRequest: transactionId:%s,%04X" % (
hex_dump2(trans_id),
unpack_big_endian(request, 4, 4)))
# fill in ClientHardwareAddress
for a in range(0, 6):
request[0x1C + a] = Ord(macAddress[a])
request[0x1C + a] = str_to_ord(mac_addr[a])
# DHCP Magic Cookie: 99, 130, 83, 99
# MessageTypeCode = 53 DHCP Message Type
@@ -320,5 +321,5 @@ def BuildDhcpRequest(macAddress):
request[0xEC + a] = [99, 130, 83, 99, 53, 1, 1, 255][a]
return array.array("B", request)
def GenTransactionId():
def gen_trans_id():
return os.urandom(4)
+33 -33
View File
@@ -23,7 +23,7 @@ import threading
import time
import azurelinuxagent.logger as logger
import azurelinuxagent.conf as conf
from azurelinuxagent.utils.osutil import OSUtil
from azurelinuxagent.utils.osutil import OSUTIL
class EnvHandler(object):
"""
@@ -31,35 +31,35 @@ class EnvHandler(object):
If dhcp clinet process re-start has occurred, reset routes, dhcp with fabric.
Monitor scsi disk.
If new scsi disk found, set
If new scsi disk found, set
"""
def __init__(self, handlers):
self.monitor = EnvMonitor(handlers.dhcpHandler)
self.monitor = EnvMonitor(handlers.dhcp_handler)
def startMonitor(self):
def start(self):
self.monitor.start()
def stopMonitor(self):
def stop(self):
self.monitor.stop()
class EnvMonitor(object):
def __init__(self, dhcpHandler):
self.dhcpHandler = dhcpHandler
def __init__(self, dhcp_handler):
self.dhcp_handler = dhcp_handler
self.stopped = True
self.hostname = None
self.dhcpid = None
self.server_thread=None
def start(self):
if not self.stopped:
logger.Info("Stop existing env monitor service.")
logger.info("Stop existing env monitor service.")
self.stop()
self.stopped = False
logger.Info("Start env monitor service.")
logger.info("Start env monitor service.")
self.hostname = socket.gethostname()
self.dhcpid = OSUtil.GetDhcpProcessId()
self.dhcpid = OSUTIL.get_dhcp_pid()
self.server_thread = threading.Thread(target = self.monitor)
self.server_thread.setDaemon(True)
self.server_thread.start()
@@ -70,39 +70,39 @@ class EnvMonitor(object):
If dhcp clinet process re-start has occurred, reset routes.
"""
while not self.stopped:
OSUtil.RemoveRulesFiles()
timeout = conf.Get("OS.RootDeviceScsiTimeout", None)
OSUTIL.remove_rules_files()
timeout = conf.get("OS.RootDeviceScsiTimeout", None)
if timeout is not None:
OSUtil.SetScsiDiskTimeout(timeout)
if conf.GetSwitch("Provisioning.MonitorHostName", False):
self.handleHostnameUpdate()
self.handleDhcpClientRestart()
OSUTIL.set_scsi_disks_timeout(timeout)
if conf.get_switch("Provisioning.MonitorHostName", False):
self.handle_hostname_update()
self.handle_dhclient_restart()
time.sleep(5)
def handleHostnameUpdate(self):
currHostname = socket.gethostname()
if currHostname != self.hostname:
logger.Info("EnvMonitor: Detected host name change: {0} -> {1}",
self.hostname, currHostname)
OSUtil.SetHostname(currHostname)
OSUtil.PublishHostname(currHostname)
self.hostname = currHostname
def handle_hostname_update(self):
curr_hostname = socket.gethostname()
if curr_hostname != self.hostname:
logger.info("EnvMonitor: Detected host name change: {0} -> {1}",
self.hostname, curr_hostname)
OSUTIL.set_hostname(curr_hostname)
OSUTIL.publish_hostname(curr_hostname)
self.hostname = curr_hostname
def handleDhcpClientRestart(self):
def handle_dhclient_restart(self):
if self.dhcpid is None:
logger.Warn("Dhcp client is not running. ")
self.dhcpid = OSUtil.GetDhcpProcessId()
logger.warn("Dhcp client is not running. ")
self.dhcpid = OSUTIL.get_dhcp_pid()
return
#The dhcp process hasn't changed since last check
if os.path.isdir(os.path.join('/proc', self.dhcpid.strip())):
return
newpid = OSUtil.GetDhcpProcessId()
newpid = OSUTIL.get_dhcp_pid()
if newpid is not None and newpid != self.dhcpid:
logger.Info("EnvMonitor: Detected dhcp client restart. "
logger.info("EnvMonitor: Detected dhcp client restart. "
"Restoring routing table.")
self.dhcpHandler.configRoutes()
self.dhcp_handler.conf_routes()
self.dhcpid = newpid
def stop(self):
+307 -307
View File
@@ -22,22 +22,22 @@ import time
import json
import subprocess
import azurelinuxagent.logger as logger
from azurelinuxagent.utils.osutil import OSUtil
from azurelinuxagent.utils.osutil import OSUTIL
import azurelinuxagent.protocol as prot
from azurelinuxagent.event import AddExtensionEvent, WALAEventOperation
from azurelinuxagent.event import add_event, WALAEventOperation
from azurelinuxagent.exception import ExtensionError
import azurelinuxagent.utils.fileutil as fileutil
import azurelinuxagent.utils.restutil as restutil
import azurelinuxagent.utils.shellutil as shellutil
ValidExtensionStatus = ['transitioning', 'error', 'success', 'warning']
VALID_EXTENSION_STATUS = ['transitioning', 'error', 'success', 'warning']
def validate_has_key(obj, key, fullName):
def validate_has_key(obj, key, fullname):
if key not in obj:
raise ExtensionError("Missing: {0}".format(fullName))
raise ExtensionError("Missing: {0}".format(fullname))
def validate_in_range(val, validRange, name):
if val not in validRange:
def validate_in_range(val, valid_range, name):
if val not in valid_range:
raise ExtensionError("Invalid {0}: {1}".format(name, val))
def try_get(dictionary, key, default=None):
@@ -52,12 +52,12 @@ def extension_sub_status_to_v2(substatus):
validate_has_key(substatus, 'status', 'substatus/status')
validate_has_key(substatus, 'code', 'substatus/code')
validate_has_key(substatus, 'formattedMessage', 'substatus/formattedMessage')
validate_has_key(substatus['formattedMessage'], 'lang',
validate_has_key(substatus['formattedMessage'], 'lang',
'substatus/formattedMessage/lang')
validate_has_key(substatus['formattedMessage'], 'message',
validate_has_key(substatus['formattedMessage'], 'message',
'substatus/formattedMessage/message')
validate_in_range(substatus['status'], ValidExtensionStatus,
validate_in_range(substatus['status'], VALID_EXTENSION_STATUS,
'substatus/status')
status = prot.ExtensionSubStatus()
status.name = try_get(substatus, 'name')
@@ -66,212 +66,212 @@ def extension_sub_status_to_v2(substatus):
status.message = try_get(substatus['formattedMessage'], 'message')
return status
def extension_status_to_v2(extStatus, seqNo):
def ext_status_to_v2(ext_status, seq_no):
#Check extension status format
validate_has_key(extStatus, 'status', 'status')
validate_has_key(extStatus['status'], 'status', 'status/status')
validate_has_key(extStatus['status'], 'operation', 'status/operation')
validate_has_key(extStatus['status'], 'code', 'status/code')
validate_has_key(extStatus['status'], 'name', 'status/name')
validate_has_key(extStatus['status'], 'formattedMessage',
validate_has_key(ext_status, 'status', 'status')
validate_has_key(ext_status['status'], 'status', 'status/status')
validate_has_key(ext_status['status'], 'operation', 'status/operation')
validate_has_key(ext_status['status'], 'code', 'status/code')
validate_has_key(ext_status['status'], 'name', 'status/name')
validate_has_key(ext_status['status'], 'formattedMessage',
'status/formattedMessage')
validate_has_key(extStatus['status']['formattedMessage'], 'lang',
validate_has_key(ext_status['status']['formattedMessage'], 'lang',
'status/formattedMessage/lang')
validate_has_key(extStatus['status']['formattedMessage'], 'message',
validate_has_key(ext_status['status']['formattedMessage'], 'message',
'status/formattedMessage/message')
validate_in_range(extStatus['status']['status'], ValidExtensionStatus,
validate_in_range(ext_status['status']['status'], VALID_EXTENSION_STATUS,
'status/status')
status = prot.ExtensionStatus()
status.name = try_get(extStatus['status'], 'name')
status.configurationAppliedTime = try_get(extStatus['status'],
'configurationAppliedTime')
status.operation = try_get(extStatus['status'], 'operation')
status.status = try_get(extStatus['status'], 'status')
status.code = try_get(extStatus['status'], 'code')
status.message = try_get(extStatus['status']['formattedMessage'], 'message')
status.sequenceNumber = seqNo
substatusList = try_get(extStatus['status'], 'substatus', [])
for substatus in substatusList:
status = prot.ExtensionStatus()
status.name = try_get(ext_status['status'], 'name')
status.configurationAppliedTime = try_get(ext_status['status'],
'configurationAppliedTime')
status.operation = try_get(ext_status['status'], 'operation')
status.status = try_get(ext_status['status'], 'status')
status.code = try_get(ext_status['status'], 'code')
status.message = try_get(ext_status['status']['formattedMessage'], 'message')
status.sequenceNumber = seq_no
substatus_list = try_get(ext_status['status'], 'substatus', [])
for substatus in substatus_list:
status.substatusList.extend(extension_sub_status_to_v2(substatus))
return status
class ExtensionsHandler(object):
def process(self):
protocol = prot.Factory.getDefaultProtocol()
extList = protocol.getExtensions()
handlerStatusList = []
for extension in extList.extensions:
#TODO handle extension in parallel
packageList = protocol.getExtensionPackages(extension)
handlerStatus = self.processExtension(extension, packageList)
handlerStatusList.append(handlerStatus)
protocol = prot.Factory.get_default_protocol()
ext_list = protocol.get_extensions()
return handlerStatusList
def processExtension(self, extension, packageList):
installedVersion = GetInstalledExtensionVersion(extension.name)
if installedVersion is not None:
ext = ExtensionInstance(extension, packageList,
installedVersion, installed=True)
h_status_list = []
for extension in ext_list.extensions:
#TODO handle extension in parallel
pkg_list = protocol.get_extension_pkgs(extension)
h_status = self.process_extension(extension, pkg_list)
h_status_list.append(h_status)
return h_status_list
def process_extension(self, extension, pkg_list):
installed_version = get_installed_version(extension.name)
if installed_version is not None:
ext = ExtensionInstance(extension, pkg_list,
installed_version, installed=True)
else:
ext = ExtensionInstance(extension, packageList,
ext = ExtensionInstance(extension, pkg_list,
extension.properties.version)
try:
ext.initLog()
ext.init_logger()
ext.handle()
status = ext.collectHandlerStatus()
status = ext.collect_handler_status()
except ExtensionError as e:
logger.Error("Failed to handle extension: {0}-{1}\n {2}",
ext.getName(), ext.getVersion(), e)
AddExtensionEvent(name=ext.getName(), isSuccess=False,
op=ext.getCurrOperation(), message = str(e))
extStatus = prot.ExtensionStatus(status='error', code='-1',
operation = ext.getCurrOperation(),
logger.error("Failed to handle extension: {0}-{1}\n {2}",
ext.get_name(), ext.get_version(), e)
add_event(name=ext.get_name(), is_success=False,
op=ext.get_curr_op(), message = str(e))
ext_status = prot.ExtensionStatus(status='error', code='-1',
operation = ext.get_curr_op(),
message = str(e),
sequenceNumber = ext.getSeqNo())
status = ext.createHandlerStatus(extStatus)
seq_no = ext.get_seq_no())
status = ext.create_handler_status(ext_status)
status.status = "Ready"
return status
def ParseExtensionDirName(dirName):
def parse_extension_dirname(dirname):
"""
Parse installed extension dir name. Sample: ExtensionName-Version/
"""
seprator = dirName.rfind('-')
seprator = dirname.rfind('-')
if seprator < 0:
raise ExtensionError("Invalid extenation dir name")
return dirName[0:seprator], dirName[seprator + 1:]
return dirname[0:seprator], dirname[seprator + 1:]
def GetInstalledExtensionVersion(targetName):
def get_installed_version(target_name):
"""
Return the highest version instance with the same name
"""
installedVersion = None
libDir = OSUtil.GetLibDir()
for dirName in os.listdir(libDir):
path = os.path.join(libDir, dirName)
if os.path.isdir(path) and dirName.startswith(targetName):
name, version = ParseExtensionDirName(dirName)
installed_version = None
lib_dir = OSUTIL.get_lib_dir()
for dir_name in os.listdir(lib_dir):
path = os.path.join(lib_dir, dir_name)
if os.path.isdir(path) and dir_name.startswith(target_name):
name, version = parse_extension_dirname(dir_name)
#Here we need to ensure names are exactly the same.
if name == targetName:
if installedVersion is None or installedVersion < version:
installedVersion = version
return installedVersion
if name == target_name:
if installed_version is None or installed_version < version:
installed_version = version
return installed_version
class ExtensionInstance(object):
def __init__(self, extension, packageList, currVersion, installed=False):
def __init__(self, extension, pkg_list, curr_version, installed=False):
self.extension = extension
self.packageList = packageList
self.currVersion = currVersion
self.libDir = OSUtil.GetLibDir()
self.pkg_list = pkg_list
self.curr_version = curr_version
self.lib_dir = OSUTIL.get_lib_dir()
self.installed = installed
self.settings = None
#Extension will have no more than 1 settings instance
if len(extension.properties.extensions) > 0:
self.settings = extension.properties.extensions[0]
self.enabled = False
self.currOperation = None
self.curr_op = None
prefix = "[{0}]".format(self.getFullName())
self.logger = logger.Logger(logger.DefaultLogger, prefix)
def initLog(self):
prefix = "[{0}]".format(self.get_full_name())
self.logger = logger.Logger(logger.default_logger, prefix)
def init_logger(self):
#Init logger appender for extension
fileutil.CreateDir(self.getLogDir(), mode=0700)
logFile = os.path.join(self.getLogDir(), "CommandExecution.log")
self.logger.addLoggerAppender(logger.AppenderType.FILE,
logger.LogLevel.INFO, logFile)
fileutil.mkdir(self.get_log_dir(), mode=0700)
log_file = os.path.join(self.get_log_dir(), "CommandExecution.log")
self.logger.add_appender(logger.AppenderType.FILE,
logger.LogLevel.INFO, log_file)
def handle(self):
self.logger.info("Process extension settings:")
self.logger.info(" Name: {0}", self.getName())
self.logger.info(" Version: {0}", self.getVersion())
self.logger.info(" Name: {0}", self.get_name())
self.logger.info(" Version: {0}", self.get_version())
if self.installed:
self.logger.info("Installed version:{0}", self.currVersion)
handlerStatus = self.getHandlerStatus()
self.enabled = (handlerStatus == "Ready")
state = self.getState()
self.logger.info("Installed version:{0}", self.curr_version)
h_status = self.get_handler_status()
self.enabled = (h_status == "Ready")
state = self.get_state()
if state == 'enabled':
self.handleEnable()
self.handle_enable()
elif state == 'disabled':
self.handleDisable()
self.handle_disable()
elif state == 'uninstall':
self.handleDisable()
self.handleUninstall()
self.handle_disable()
self.handle_uninstall()
else:
raise ExtensionError("Unknown extension state:{0}".format(state))
def handleEnable(self):
targetVersion = self.getTargetVersion()
def handle_enable(self):
target_version = self.get_target_version()
if self.installed:
if targetVersion > self.currVersion:
self.upgrade(targetVersion)
elif targetVersion == self.currVersion:
if target_version > self.curr_version:
self.upgrade(target_version)
elif target_version == self.curr_version:
self.enable()
else:
raise ExtensionError("A newer version has already been installed")
else:
if targetVersion > self.getVersion():
if target_version > self.get_version():
#This will happen when auto upgrade policy is enabled
self.logger.info("Auto upgrade to new version:{0}",
targetVersion)
self.currVersion = targetVersion
self.logger.info("Auto upgrade to new version:{0}",
target_version)
self.curr_version = target_version
self.download()
self.initExtensionDir()
self.init_dir()
self.install()
self.enable()
def handleDisable(self):
def handle_disable(self):
if not self.installed or not self.enabled:
return
self.disable()
def handleUninstall(self):
def handle_uninstall(self):
if not self.installed:
return
self.uninstall()
def upgrade(self, targetVersion):
self.logger.info("Upgrade from: {0} to {1}", self.currVersion,
targetVersion)
self.currOperation=WALAEventOperation.Upgrade
def upgrade(self, target_version):
self.logger.info("Upgrade from: {0} to {1}", self.curr_version,
target_version)
self.curr_op=WALAEventOperation.Upgrade
old = self
new = ExtensionInstance(self.extension, self.packageList, targetVersion)
new = ExtensionInstance(self.extension, self.pkg_list, target_version)
self.logger.info("Download new extension package")
new.initLog()
new.init_logger()
new.download()
self.logger.info("Initialize new extension directory")
new.initExtensionDir()
new.init_dir()
old.disable()
self.logger.info("Update new extension")
new.update()
old.uninstall()
man = new.loadManifest()
if man.isUpdateWithInstall():
man = new.load_manifest()
if man.is_update_with_install():
self.logger.info("Install new extension")
new.install()
self.logger.info("Enable new extension")
new.enable()
AddExtensionEvent(name=self.getName(), isSuccess=True,
op=self.currOperation, message="")
add_event(name=self.get_name(), is_success=True,
op=self.curr_op, message="")
def download(self):
self.logger.info("Download extension package")
self.currOperation=WALAEventOperation.Download
uris = self.getPackageUris()
self.curr_op=WALAEventOperation.Download
uris = self.get_package_uris()
package = None
for uri in uris:
try:
resp = restutil.HttpGet(uri.uri, chkProxy=True)
resp = restutil.http_get(uri.uri, chk_proxy=True)
package = resp.read()
break
except restutil.HttpError as e:
@@ -279,167 +279,167 @@ class ExtensionInstance(object):
if package is None:
raise ExtensionError("Download extension failed")
self.logger.info("Unpack extension package")
pkgFile = os.path.join(self.libDir, os.path.basename(uri.uri) + ".zip")
fileutil.SetFileContents(pkgFile, bytearray(package))
zipfile.ZipFile(pkgFile).extractall(self.getBaseDir())
chmod = "find {0} -type f | xargs chmod u+x".format(self.getBaseDir())
shellutil.Run(chmod)
AddExtensionEvent(name=self.getName(), isSuccess=True,
op=self.currOperation, message="")
def initExtensionDir(self):
pkg_file = os.path.join(self.lib_dir, os.path.basename(uri.uri) + ".zip")
fileutil.write_file(pkg_file, bytearray(package))
zipfile.ZipFile(pkg_file).extractall(self.get_base_dir())
chmod = "find {0} -type f | xargs chmod u+x".format(self.get_base_dir())
shellutil.run(chmod)
add_event(name=self.get_name(), is_success=True,
op=self.curr_op, message="")
def init_dir(self):
self.logger.info("Initialize extension directory")
#Save HandlerManifest.json
manFile = fileutil.SearchForFile(self.getBaseDir(),
man_file = fileutil.search_file(self.get_base_dir(),
'HandlerManifest.json')
man = fileutil.GetFileContents(manFile, removeBom=True)
fileutil.SetFileContents(self.getManifestFile(), man)
man = fileutil.read_file(man_file, remove_bom=True)
fileutil.write_file(self.get_manifest_file(), man)
#Create status and config dir
statusDir = self.getStatusDir()
fileutil.CreateDir(statusDir, mode=0700)
configDir = self.getConfigDir()
fileutil.CreateDir(configDir, mode=0700)
status_dir = self.get_status_dir()
fileutil.mkdir(status_dir, mode=0700)
conf_dir = self.get_conf_dir()
fileutil.mkdir(conf_dir, mode=0700)
#Init handler state to uninstall
self.setHandlerStatus("NotReady")
self.set_handler_status("NotReady")
#Save HandlerEnvironment.json
self.createHandlerEnvironment()
self.create_handler_env()
def enable(self):
self.logger.info("Enable extension.")
self.currOperation=WALAEventOperation.Enable
man = self.loadManifest()
self.launchCommand(man.getEnableCommand())
self.setHandlerStatus("Ready")
AddExtensionEvent(name=self.getName(), isSuccess=True,
op=self.currOperation, message="")
self.curr_op=WALAEventOperation.Enable
man = self.load_manifest()
self.launch_command(man.get_enable_command())
self.set_handler_status("Ready")
add_event(name=self.get_name(), is_success=True,
op=self.curr_op, message="")
def disable(self):
self.logger.info("Disable extension.")
self.currOperation=WALAEventOperation.Disable
man = self.loadManifest()
self.launchCommand(man.getDisableCommand(), timeout=900)
self.setHandlerStatus("Ready")
AddExtensionEvent(name=self.getName(), isSuccess=True,
op=self.currOperation, message="")
self.curr_op=WALAEventOperation.Disable
man = self.load_manifest()
self.launch_command(man.get_disable_command(), timeout=900)
self.set_handler_status("Ready")
add_event(name=self.get_name(), is_success=True,
op=self.curr_op, message="")
def install(self):
self.logger.info("Install extension.")
self.currOperation=WALAEventOperation.Install
man = self.loadManifest()
self.setHandlerStatus("Installing")
self.launchCommand(man.getInstallCommand(), timeout=900)
self.setHandlerStatus("Ready")
AddExtensionEvent(name=self.getName(), isSuccess=True,
op=self.currOperation, message="")
self.curr_op=WALAEventOperation.Install
man = self.load_manifest()
self.set_handler_status("Installing")
self.launch_command(man.get_install_command(), timeout=900)
self.set_handler_status("Ready")
add_event(name=self.get_name(), is_success=True,
op=self.curr_op, message="")
def uninstall(self):
self.logger.info("Uninstall extension.")
self.currOperation=WALAEventOperation.UnInstall
man = self.loadManifest()
self.launchCommand(man.getUninstallCommand())
self.setHandlerStatus("NotReady")
AddExtensionEvent(name=self.getName(), isSuccess=True,
op=self.currOperation, message="")
self.curr_op=WALAEventOperation.UnInstall
man = self.load_manifest()
self.launch_command(man.get_uninstall_command())
self.set_handler_status("NotReady")
add_event(name=self.get_name(), is_success=True,
op=self.curr_op, message="")
def update(self):
self.logger.info("Update extension.")
self.currOperation=WALAEventOperation.Update
man = self.loadManifest()
self.launchCommand(man.getUpdateCommand(), timeout=900)
AddExtensionEvent(name=self.getName(), isSuccess=True,
op=self.currOperation, message="")
def createHandlerStatus(self, extStatus, heartbeat=None):
status = prot.ExtensionHandlerStatus()
status.handlerName = self.getName()
status.handlerVersion = self.getVersion()
status.status = self.getHandlerStatus()
status.extensionStatusList.append(extStatus)
return status
self.curr_op=WALAEventOperation.Update
man = self.load_manifest()
self.launch_command(man.get_update_command(), timeout=900)
add_event(name=self.get_name(), is_success=True,
op=self.curr_op, message="")
def collectHandlerStatus(self):
man = self.loadManifest()
heartbeat=None
if man.isReportHeartbeat():
heartbeat = self.getHeartbeat()
extStatus = self.getExtensionStatus()
status= self.createHandlerStatus(extStatus, heartbeat)
status.status = self.getHandlerStatus()
if heartbeat is not None:
status.status = heartbeat['status']
status.extensionStatusList.append(extStatus)
def create_handler_status(self, ext_status, heartbeat=None):
status = prot.ExtensionHandlerStatus()
status.handlerName = self.get_name()
status.handlerVersion = self.get_version()
status.status = self.get_handler_status()
status.extensionStatusList.append(ext_status)
return status
def getExtensionStatus(self):
extStatusFile = self.getStatusFile()
def collect_handler_status(self):
man = self.load_manifest()
heartbeat=None
if man.is_report_heartbeat():
heartbeat = self.collect_heartbeat()
ext_status = self.collect_extension_status()
status= self.create_handler_status(ext_status, heartbeat)
status.status = self.get_handler_status()
if heartbeat is not None:
status.status = heartbeat['status']
status.extensionStatusList.append(ext_status)
return status
def collect_extension_status(self):
ext_status_file = self.get_status_file()
try:
extStatusJson = fileutil.GetFileContents(extStatusFile)
extStatus = json.loads(extStatusJson)
ext_status_str = fileutil.read_file(ext_status_file)
ext_status = json.loads(ext_status_str)
except IOError as e:
raise ExtensionError("Failed to get status file: {0}".format(e))
except ValueError as e:
raise ExtensionError("Malformed status file: {0}".format(e))
return extension_status_to_v2(extStatus[0],
return ext_status_to_v2(ext_status[0],
self.settings.sequenceNumber)
def getHandlerStatus(self):
handlerStatus = "uninstalled"
handlerStatusFile = self.getHandlerStateFile()
def get_handler_status(self):
h_status = "uninstalled"
h_status_file = self.get_handler_state_file()
try:
handlerStatus = fileutil.GetFileContents(handlerStatusFile)
return handlerStatus
h_status = fileutil.read_file(h_status_file)
return h_status
except IOError as e:
raise ExtensionError("Failed to get handler status: {0}".format(e))
def setHandlerStatus(self, status):
handlerStatusFile = self.getHandlerStateFile()
def set_handler_status(self, status):
h_status_file = self.get_handler_state_file()
try:
fileutil.SetFileContents(handlerStatusFile, status)
fileutil.write_file(h_status_file, status)
except IOError as e:
raise ExtensionError("Failed to set handler status: {0}".format(e))
def getHeartbeat(self):
def collect_heartbeat(self):
self.logger.info("Collect heart beat")
heartbeatFile = os.path.join(OSUtil.GetLibDir(),
self.getHeartbeatFile())
if not os.path.isfile(heartbeatFile):
heartbeat_file = os.path.join(OSUTIL.get_lib_dir(),
self.get_heartbeat_file())
if not os.path.isfile(heartbeat_file):
raise ExtensionError("Failed to get heart beat file")
if not self.isResponsive(heartbeatFile):
if not self.is_responsive(heartbeat_file):
return {
"status": "Unresponsive",
"code": -1,
"message": "Extension heartbeat is not responsive"
}
}
try:
heartbeatJson = fileutil.GetFileContents(heartbeatFile)
heartbeat = json.loads(heartbeatJson)[0]['heartbeat']
heartbeat_json = fileutil.read_file(heartbeat_file)
heartbeat = json.loads(heartbeat_json)[0]['heartbeat']
except IOError as e:
raise ExtensionError("Failed to get heartbeat file:{0}".format(e))
except ValueError as e:
raise ExtensionError("Malformed heartbeat file: {0}".format(e))
return heartbeat
def isResponsive(self, heartbeatFile):
lastUpdate=int(time.time()-os.stat(heartbeatFile).st_mtime)
return lastUpdate > 600 # not updated for more than 10 min
def is_responsive(self, heartbeat_file):
last_update=int(time.time()-os.stat(heartbeat_file).st_mtime)
return last_update > 600 # not updated for more than 10 min
def launchCommand(self, cmd, timeout=300):
def launch_command(self, cmd, timeout=300):
self.logger.info("Launch command:{0}", cmd)
baseDir = self.getBaseDir()
self.updateSettings()
base_dir = self.get_base_dir()
self.update_settings()
try:
devnull = open(os.devnull, 'w')
child = subprocess.Popen(baseDir + "/" + cmd, shell=True,
cwd=baseDir, stdout=devnull)
child = subprocess.Popen(base_dir + "/" + cmd, shell=True,
cwd=base_dir, stdout=devnull)
except Exception as e:
#TODO do not catch all exception
raise ExtensionError("Failed to launch: {0}, {1}".format(cmd, e))
retry = timeout / 5
while retry > 0 and child.poll == None:
time.sleep(5)
@@ -451,11 +451,11 @@ class ExtensionInstance(object):
ret = child.wait()
if ret == None or ret != 0:
raise ExtensionError("Non-zero exit code: {0}, {1}".format(ret, cmd))
def loadManifest(self):
manFile = self.getManifestFile()
def load_manifest(self):
man_file = self.get_manifest_file()
try:
data = json.loads(fileutil.GetFileContents(manFile))
data = json.loads(fileutil.read_file(man_file))
except IOError as e:
raise ExtensionError('Failed to load manifest file.')
except ValueError as e:
@@ -464,141 +464,141 @@ class ExtensionInstance(object):
return HandlerManifest(data[0])
def updateSettings(self):
def update_settings(self):
if self.settings is None:
self.logger.verbose("Extension has no settings")
return
handlerSettings = {
settings = {
'publicSettings': self.settings.publicSettings,
'protectedSettings': self.settings.privateSettings,
'protectedSettingsCertThumbprint': self.settings.certificateThumbprint
}
extSettings = {
ext_settings = {
"runtimeSettings":[{
"handlerSettings": handlerSettings
"handlerSettings": settings
}]
}
fileutil.SetFileContents(self.getSettingsFile(), json.dumps(extSettings))
fileutil.write_file(self.get_settings_file(), json.dumps(ext_settings))
latest = os.path.join(self.getConfigDir(), "latest")
fileutil.SetFileContents(latest, self.settings.sequenceNumber)
latest = os.path.join(self.get_conf_dir(), "latest")
fileutil.write_file(latest, self.settings.sequenceNumber)
def createHandlerEnvironment(self):
def create_handler_env(self):
env = [{
"name": self.getName(),
"version" : self.getVersion(),
"name": self.get_name(),
"version" : self.get_version(),
"handlerEnvironment" : {
"logFolder" : self.getLogDir(),
"configFolder" : self.getConfigDir(),
"statusFolder" : self.getStatusDir(),
"heartbeatFile" : self.getHeartbeatFile()
"logFolder" : self.get_log_dir(),
"configFolder" : self.get_conf_dir(),
"statusFolder" : self.get_status_dir(),
"heartbeatFile" : self.get_heartbeat_file()
}
}]
fileutil.SetFileContents(self.getEnvironmentFile(),
fileutil.write_file(self.get_env_file(),
json.dumps(env))
def getTargetVersion(self):
version = self.getVersion()
updatePolicy = self.getUpgradePolicy()
if updatePolicy is None or updatePolicy.lower() != 'auto':
def get_target_version(self):
version = self.get_version()
update_policy = self.get_upgrade_policy()
if update_policy is None or update_policy.lower() != 'auto':
return version
major = version.split('.')[0]
if major is None:
raise ExtensionError("Wrong version format: {0}".format(version))
packages = filter(lambda x : x.version.startswith(major + "."),
self.packageList.versions)
packages = filter(lambda x : x.version.startswith(major + "."),
self.pkg_list.versions)
packages = sorted(packages, key=lambda x: x.version, reverse=True)
if len(packages) <= 0:
raise ExtensionError("Can't find version: {0}.*".format(major))
return packages[0].version
def getPackageUris(self):
version = self.getVersion()
packages = self.packageList.versions
def get_package_uris(self):
version = self.get_version()
packages = self.pkg_list.versions
if packages is None:
raise ExtensionError("Package uris is None.")
for package in packages:
if package.version == version:
return package.uris
raise ExtensionError("Can't get package uris for {0}.".format(version))
def getCurrOperation(self):
return self.currOperation
def getName(self):
def get_curr_op(self):
return self.curr_op
def get_name(self):
return self.extension.name
def getVersion(self):
def get_version(self):
return self.extension.properties.version
def getState(self):
def get_state(self):
return self.extension.properties.state
def getSeqNo(self):
def get_seq_no(self):
return self.settings.sequenceNumber
def getUpgradePolicy(self):
def get_upgrade_policy(self):
return self.extension.properties.upgradePolicy
def getFullName(self):
return "{0}-{1}".format(self.getName(), self.currVersion)
def getBaseDir(self):
return os.path.join(OSUtil.GetLibDir(), self.getFullName())
def get_full_name(self):
return "{0}-{1}".format(self.get_name(), self.curr_version)
def getStatusDir(self):
return os.path.join(self.getBaseDir(), "status")
def get_base_dir(self):
return os.path.join(OSUTIL.get_lib_dir(), self.get_full_name())
def getStatusFile(self):
return os.path.join(self.getStatusDir(),
def get_status_dir(self):
return os.path.join(self.get_base_dir(), "status")
def get_status_file(self):
return os.path.join(self.get_status_dir(),
"{0}.status".format(self.settings.sequenceNumber))
def getConfigDir(self):
return os.path.join(self.getBaseDir(), 'config')
def get_conf_dir(self):
return os.path.join(self.get_base_dir(), 'config')
def getSettingsFile(self):
return os.path.join(self.getConfigDir(),
def get_settings_file(self):
return os.path.join(self.get_conf_dir(),
"{0}.settings".format(self.settings.sequenceNumber))
def getHandlerStateFile(self):
return os.path.join(self.getConfigDir(), 'HandlerState')
def get_handler_state_file(self):
return os.path.join(self.get_conf_dir(), 'HandlerState')
def getHeartbeatFile(self):
return os.path.join(self.getBaseDir(), 'heartbeat.log')
def get_heartbeat_file(self):
return os.path.join(self.get_base_dir(), 'heartbeat.log')
def getManifestFile(self):
return os.path.join(self.getBaseDir(), 'HandlerManifest.json')
def get_manifest_file(self):
return os.path.join(self.get_base_dir(), 'HandlerManifest.json')
def getEnvironmentFile(self):
return os.path.join(self.getBaseDir(), 'HandlerEnvironment.json')
def get_env_file(self):
return os.path.join(self.get_base_dir(), 'HandlerEnvironment.json')
def getLogDir(self):
return os.path.join(OSUtil.GetExtLogDir(), self.getName(),
self.currVersion)
def get_log_dir(self):
return os.path.join(OSUTIL.get_ext_log_dir(), self.get_name(),
self.curr_version)
class HandlerEnvironment(object):
def __init__(self, data):
self.data = data
def getVersion(self):
def get_version(self):
return self.data["version"]
def getLogDir(self):
def get_log_dir(self):
return self.data["handlerEnvironment"]["logFolder"]
def getConfigDir(self):
def get_conf_dir(self):
return self.data["handlerEnvironment"]["configFolder"]
def getStatusDir(self):
def get_status_dir(self):
return self.data["handlerEnvironment"]["statusFolder"]
def getHeartbeatFile(self):
def get_heartbeat_file(self):
return self.data["handlerEnvironment"]["heartbeatFile"]
class HandlerManifest(object):
@@ -607,39 +607,39 @@ class HandlerManifest(object):
raise ExtensionError('Malformed manifest file.')
self.data = data
def getName(self):
def get_name(self):
return self.data["name"]
def getVersion(self):
def get_version(self):
return self.data["version"]
def getInstallCommand(self):
def get_install_command(self):
return self.data['handlerManifest']["installCommand"]
def getUninstallCommand(self):
def get_uninstall_command(self):
return self.data['handlerManifest']["uninstallCommand"]
def getUpdateCommand(self):
def get_update_command(self):
return self.data['handlerManifest']["updateCommand"]
def getEnableCommand(self):
def get_enable_command(self):
return self.data['handlerManifest']["enableCommand"]
def getDisableCommand(self):
def get_disable_command(self):
return self.data['handlerManifest']["disableCommand"]
def isRebootAfterInstall(self):
def is_reboot_after_install(self):
#TODO handle reboot after install
if "rebootAfterInstall" not in self.data['handlerManifest']:
return False
return self.data['handlerManifest']["rebootAfterInstall"]
def isReportHeartbeat(self):
def is_report_heartbeat(self):
if "reportHeartbeat" not in self.data['handlerManifest']:
return False
return self.data['handlerManifest']["reportHeartbeat"]
def isUpdateWithInstall(self):
def is_update_with_install(self):
if "updateMode" not in self.data['handlerManifest']:
return False
if "updateMode" in self.data:
@@ -17,7 +17,7 @@
# Requires Python 2.4+ and Openssl 1.0+
#
from init import InitHandler
from run import RunHandler
from run import MainHandler
from scvmm import ScvmmHandler
from dhcp import DhcpHandler
from env import EnvHandler
@@ -25,16 +25,16 @@ from provision import ProvisionHandler
from resourceDisk import ResourceDiskHandler
from extension import ExtensionsHandler
from deprovision import DeprovisionHandler
class DefaultHandlerFactory(object):
def __init__(self):
self.initHandler = InitHandler()
self.runHandler = RunHandler(self)
self.scvmmHandler = ScvmmHandler()
self.dhcpHandler = DhcpHandler()
self.envHandler = EnvHandler(self)
self.provisionHandler = ProvisionHandler()
self.resourceDiskHandler = ResourceDiskHandler()
self.extensionHandler = ExtensionsHandler()
self.deprovisionHandler = DeprovisionHandler()
self.init_handler = InitHandler()
self.main_handler = MainHandler(self)
self.scvmm_handler = ScvmmHandler()
self.dhcp_handler = DhcpHandler()
self.env_handler = EnvHandler(self)
self.provision_handler = ProvisionHandler()
self.resource_disk_handler = ResourceDiskHandler()
self.extension_handler = ExtensionsHandler()
self.deprovision_handler = DeprovisionHandler()
+11 -11
View File
@@ -20,7 +20,7 @@
import os
import azurelinuxagent.conf as conf
import azurelinuxagent.logger as logger
from azurelinuxagent.utils.osutil import OSUtil
from azurelinuxagent.utils.osutil import OSUTIL
import azurelinuxagent.utils.fileutil as fileutil
@@ -28,22 +28,22 @@ class InitHandler(object):
def init(self, verbose):
#Init stdout log
level = logger.LogLevel.VERBOSE if verbose else logger.LogLevel.INFO
logger.AddLoggerAppender(logger.AppenderType.STDOUT, level)
logger.add_logger_appender(logger.AppenderType.STDOUT, level)
#Init config
configPath = OSUtil.GetConfigurationPath()
conf.LoadConfiguration(configPath)
conf_file_path = OSUTIL.get_conf_file_path()
conf.load_conf(conf_file_path)
#Init log
verbose = verbose or conf.GetSwitch("Logs.Verbose", False)
verbose = verbose or conf.get_switch("Logs.Verbose", False)
level = logger.LogLevel.VERBOSE if verbose else logger.LogLevel.INFO
logger.AddLoggerAppender(logger.AppenderType.FILE, level,
logger.add_logger_appender(logger.AppenderType.FILE, level,
path="/var/log/waagent.log")
logger.AddLoggerAppender(logger.AppenderType.CONSOLE, level,
logger.add_logger_appender(logger.AppenderType.CONSOLE, level,
path="/dev/console")
#Create lib dir
fileutil.CreateDir(OSUtil.GetLibDir(), mode=0700)
os.chdir(OSUtil.GetLibDir())
fileutil.mkdir(OSUTIL.get_lib_dir(), mode=0700)
os.chdir(OSUTIL.get_lib_dir())
+4 -4
View File
@@ -17,12 +17,12 @@
# Requires Python 2.4+ and Openssl 1.0+
#
def GetOSUtil():
def get_osutil():
from azurelinuxagent.distro.default.osutil import DefaultOSUtil
return DefaultOSUtil()
def GetHandlers():
def get_handlers():
from azurelinuxagent.distro.default.handlerFactory import DefaultHandlerFactory
return DefaultHandlerFactory()
+305 -307
View File
@@ -18,23 +18,23 @@
import os
import re
import pwd
import shutil
import socket
import array
import struct
import fcntl
import time
import pwd
import fcntl
import azurelinuxagent.logger as logger
import azurelinuxagent.utils.fileutil as fileutil
import azurelinuxagent.utils.shellutil as shellutil
import azurelinuxagent.utils.textutil as textutil
RulesFiles = [ "/lib/udev/rules.d/75-persistent-net-generator.rules",
"/etc/udev/rules.d/70-persistent-net.rules" ]
__RULES_FILES__ = [ "/lib/udev/rules.d/75-persistent-net-generator.rules",
"/etc/udev/rules.d/70-persistent-net.rules" ]
"""
Define distro specific behavior. OSUtil class defines default behavior
Define distro specific behavior. OSUtil class defines default behavior
for all distros. Each concrete distro classes could overwrite default behavior
if needed.
"""
@@ -45,70 +45,70 @@ class OSUtilError(Exception):
class DefaultOSUtil(object):
def __init__(self):
self.libDir = "/var/lib/waagent"
self.extLogDir = "/var/log/azure"
self.dvdMountPoint = "/mnt/cdrom/secure"
self.ovfenvPathOnDvd = "/mnt/cdrom/secure/ovf-env.xml"
self.agentPidPath = "/var/run/waagent.pid"
self.passwdPath = "/etc/shadow"
self.lib_dir = "/var/lib/waagent"
self.ext_log_dir = "/var/log/azure"
self.dvd_mount_point = "/mnt/cdrom/secure"
self.ovf_env_file_path = "/mnt/cdrom/secure/ovf-env.xml"
self.agent_pid_file_path = "/var/run/waagent.pid"
self.passwd_file_path = "/etc/shadow"
self.home = '/home'
self.sshdConfigPath = '/etc/ssh/sshd_config'
self.opensslCmd = '/usr/bin/openssl'
self.configPath = '/etc/waagent.conf'
self.sshd_conf_file_path = '/etc/ssh/sshd_config'
self.openssl_cmd = '/usr/bin/openssl'
self.conf_file_path = '/etc/waagent.conf'
self.selinux=None
def GetLibDir(self):
return self.libDir
def get_lib_dir(self):
return self.lib_dir
def GetExtLogDir(self):
return self.extLogDir
def get_ext_log_dir(self):
return self.ext_log_dir
def GetDvdMountPoint(self):
return self.dvdMountPoint
def get_dvd_mount_point(self):
return self.dvd_mount_point
def GetConfigurationPath(self):
return self.configPath
def get_conf_file_path(self):
return self.conf_file_path
def GetOvfEnvPathOnDvd(self):
return self.ovfenvPathOnDvd
def get_ovf_env_file_path_on_dvd(self):
return self.ovf_env_file_path
def GetAgentPidPath(self):
return self.agentPidPath
def get_agent_pid_file_path(self):
return self.agent_pid_file_path
def GetOpensslCmd(self):
return self.opensslCmd
def get_openssl_cmd(self):
return self.openssl_cmd
def UpdateUserAccount(self, userName, password, expiration=None):
def set_user_account(self, username, password, expiration=None):
"""
Update password and ssh key for user account.
New account will be created if not exists.
"""
if userName is None:
if username is None:
raise OSUtilError("User name is empty")
if self.IsSysUser(userName):
if self.is_sys_user(username):
raise OSUtilError(("User {0} is a system user. "
"Will not set passwd.").format(userName))
"Will not set passwd.").format(username))
userentry = self.GetUserEntry(userName)
userentry = self.get_userentry(username)
if userentry is None:
self.CreateUserAccount(userName, expiration)
self.ConfigSudoer(userName, password is None)
self.useradd(username, expiration)
def GetUserEntry(self, userName):
self.conf_sudoer(username, password is None)
def get_userentry(self, username):
try:
return pwd.getpwnam(userName)
return pwd.getpwnam(username)
except KeyError:
return None
def IsSysUser(self, userName):
userentry = self.GetUserEntry(userName)
def is_sys_user(self, username):
userentry = self.get_userentry(username)
uidmin = None
try:
uidminDef = fileutil.GetLineStartingWith("UID_MIN", "/etc/login.defs")
if uidminDef is not None:
uidmin = int(uidminDef.split()[1])
uidmin_def = fileutil.get_line_startingwith("UID_MIN", "/etc/login.defs")
if uidmin_def is not None:
uidmin = int(uidmin_def.split()[1])
except IOError as e:
pass
if uidmin == None:
@@ -117,328 +117,328 @@ class DefaultOSUtil(object):
return True
else:
return False
def CreateUserAccount(self, userName, expiration=None):
def useradd(self, username, expiration=None):
if expiration is not None:
cmd = "useradd -m {0} -e {1}".format(userName, expiration)
cmd = "useradd -m {0} -e {1}".format(username, expiration)
else:
cmd = "useradd -m {0}".format(userName)
retcode, out = shellutil.RunGetOutput(cmd)
cmd = "useradd -m {0}".format(username)
retcode, out = shellutil.run_get_output(cmd)
if retcode != 0:
raise OSUtilError(("Failed to create user account:{0}, "
"retcode:{1}, "
"output:{2}").format(userName, retcode, out))
"output:{2}").format(username, retcode, out))
def ChangePassword(self, userName, password, useSalt=True, saltType=6,
saltLength=10):
passwdHash = textutil.GetPasswordHash(password, useSalt, saltType,
saltLength)
def chpasswd(self, username, password, use_salt=True, salt_type=6,
salt_len=10):
passwd_hash = textutil.gen_password_hash(password, use_salt, salt_type,
salt_len)
try:
passwdContent = fileutil.GetFileContents(self.passwdPath)
passwd = passwdContent.split("\n")
newPasswd = filter(lambda x : not x.startswith(userName), passwd)
newPasswd.append("{0}:{1}:14600::::::".format(userName, passwdHash))
fileutil.SetFileContents(self.passwdPath, "\n".join(newPasswd))
passwd_content = fileutil.read_file(self.passwd_file_path)
passwd = passwd_content.split("\n")
new_passwd = filter(lambda x : not x.startswith(username), passwd)
new_passwd.append("{0}:{1}:14600::::::".format(username, passwd_hash))
fileutil.write_file(self.passwd_file_path, "\n".join(new_passwd))
except IOError as e:
raise OSUtilError(("Failed to set password for {0}: {1}"
"").format(userName, e))
"").format(username, e))
def ConfigSudoer(self, userName, nopasswd):
def conf_sudoer(self, username, nopasswd):
# for older distros create sudoers.d
if not os.path.isdir('/etc/sudoers.d/'):
# create the /etc/sudoers.d/ directory
os.mkdir('/etc/sudoers.d/')
# add the include of sudoers.d to the /etc/sudoers
sudoers = '\n' + '#includedir /etc/sudoers.d/\n'
fileutil.AppendFileContents('/etc/sudoers', sudoers)
fileutil.append_file('/etc/sudoers', sudoers)
sudoer = None
if nopasswd:
sudoer = "{0} ALL = (ALL) NOPASSWD\n".format(userName)
sudoer = "{0} ALL = (ALL) NOPASSWD\n".format(username)
else:
sudoer = "{0} ALL = (ALL) ALL\n".format(userName)
fileutil.AppendFileContents('/etc/sudoers.d/waagent', sudoer)
fileutil.ChangeMod('/etc/sudoers.d/waagent', 0440)
sudoer = "{0} ALL = (ALL) ALL\n".format(username)
fileutil.append_file('/etc/sudoers.d/waagent', sudoer)
fileutil.chmod('/etc/sudoers.d/waagent', 0440)
def DeleteRootPassword(self):
def del_root_password(self):
try:
passwdContent = fileutil.GetFileContents(self.passwdPath)
passwd = passwdContent.split('\n')
newPasswd = filter(lambda x : not x.startswith("root:"), passwd)
newPasswd.insert(0, "root:*LOCK*:14600::::::")
fileutil.SetFileContents(self.passwdPath, "\n".join(newPasswd))
passwd_content = fileutil.read_file(self.passwd_file_path)
passwd = passwd_content.split('\n')
new_passwd = filter(lambda x : not x.startswith("root:"), passwd)
new_passwd.insert(0, "root:*LOCK*:14600::::::")
fileutil.write_file(self.passwd_file_path, "\n".join(new_passwd))
except IOError as e:
raise OSUtilError("Failed to delete root password:{0}".format(e))
def GetHome(self):
def get_home(self):
return self.home
def GetPubKeyFromPrv(self, fileName):
cmd = "{0} rsa -in {1} -pubout 2>/dev/null".format(self.opensslCmd,
fileName)
pub = shellutil.RunGetOutput(cmd)[1]
def get_pubkey_from_prv(self, file_name):
cmd = "{0} rsa -in {1} -pubout 2>/dev/null".format(self.openssl_cmd,
file_name)
pub = shellutil.run_get_output(cmd)[1]
return pub
def GetPubKeyFromCrt(self, fileName):
cmd = "{0} x509 -in {1} -pubkey -noout".format(self.opensslCmd,
fileName)
pub = shellutil.RunGetOutput(cmd)[1]
def get_pubkey_from_crt(self, file_name):
cmd = "{0} x509 -in {1} -pubkey -noout".format(self.openssl_cmd,
file_name)
pub = shellutil.run_get_output(cmd)[1]
return pub
def _NormPath(self, filepath):
home = self.GetHome()
def _norm_path(self, filepath):
home = self.get_home()
# Expand HOME variable if present in path
path = os.path.normpath(filepath.replace("$HOME", home))
return path
def GetThumbprintFromCrt(self, fileName):
cmd="{0} x509 -in {1} -fingerprint -noout".format(self.opensslCmd,
fileName)
thumbprint = shellutil.RunGetOutput(cmd)[1]
def get_thumbprint_from_crt(self, file_name):
cmd="{0} x509 -in {1} -fingerprint -noout".format(self.openssl_cmd,
file_name)
thumbprint = shellutil.run_get_output(cmd)[1]
thumbprint = thumbprint.rstrip().split('=')[1].replace(':', '').upper()
return thumbprint
def DeploySshKeyPair(self, userName, thumbprint, path):
def deploy_ssh_keypair(self, username, thumbprint, path):
"""
Deploy id_rsa and id_rsa.pub
"""
path = self._NormPath(path)
dirPath = os.path.dirname(path)
fileutil.CreateDir(dirPath, mode=0700, owner=userName)
libDir = self.GetLibDir()
prvPath = os.path.join(libDir, thumbprint + '.prv')
if not os.path.isfile(prvPath):
logger.Error("Failed to deploy key pair, thumbprint: {0}",
path = self._norm_path(path)
dir_path = os.path.dirname(path)
fileutil.mkdir(dir_path, mode=0700, owner=username)
lib_dir = self.get_lib_dir()
prv_path = os.path.join(lib_dir, thumbprint + '.prv')
if not os.path.isfile(prv_path):
logger.error("Failed to deploy key pair, thumbprint: {0}",
thumbprint)
return
shutil.copyfile(prvPath, path)
pubPath = path + '.pub'
pub = self.GetPubKeyFromPrv(prvPath)
fileutil.SetFileContents(pubPath, pub)
self.SetSelinuxContext(pubPath, 'unconfined_u:object_r:ssh_home_t:s0')
self.SetSelinuxContext(path, 'unconfined_u:object_r:ssh_home_t:s0')
shutil.copyfile(prv_path, path)
pub_path = path + '.pub'
pub = self.get_pubkey_from_prv(prv_path)
fileutil.write_file(pub_path, pub)
self.set_selinux_context(pub_path, 'unconfined_u:object_r:ssh_home_t:s0')
self.set_selinux_context(path, 'unconfined_u:object_r:ssh_home_t:s0')
os.chmod(path, 0644)
os.chmod(pubPath, 0600)
os.chmod(pub_path, 0600)
def OpenSslToOpenSsh(self, inputFile, outputFile):
shellutil.Run("ssh-keygen -i -m PKCS8 -f {0} >> {1}".format(inputFile,
outputFile))
def openssl_to_openssh(self, input_file, output_file):
shellutil.run("ssh-keygen -i -m PKCS8 -f {0} >> {1}".format(input_file,
output_file))
def DeploySshPublicKey(self, userName, thumbprint, path):
def deploy_ssh_pubkey(self, username, thumbprint, path):
"""
Deploy authorized_key
"""
path = self._NormPath(path)
dirPath = os.path.dirname(path)
fileutil.CreateDir(dirPath, mode=0700, owner=userName)
libDir = self.GetLibDir()
crtPath = os.path.join(libDir, thumbprint + '.crt')
if not os.path.isfile(crtPath):
logger.Error("Failed to deploy public key, thumbprint: {0}",
path = self._norm_path(path)
dir_path = os.path.dirname(path)
fileutil.mkdir(dir_path, mode=0700, owner=username)
lib_dir = self.get_lib_dir()
crt_path = os.path.join(lib_dir, thumbprint + '.crt')
if not os.path.isfile(crt_path):
logger.error("Failed to deploy public key, thumbprint: {0}",
thumbprint)
return
pubPath = os.path.join(libDir, thumbprint + '.pub')
pub = self.GetPubKeyFromCrt(crtPath)
fileutil.SetFileContents(pubPath, pub)
self.SetSelinuxContext(pubPath, 'unconfined_u:object_r:ssh_home_t:s0')
self.OpenSslToOpenSsh(pubPath, path)
self.SetSelinuxContext(path, 'unconfined_u:object_r:ssh_home_t:s0')
fileutil.ChangeOwner(path, userName)
fileutil.ChangeMod(path, 0644)
fileutil.ChangeMod(pubPath, 0600)
def IsSelinuxSystem(self):
pub_path = os.path.join(lib_dir, thumbprint + '.pub')
pub = self.get_pubkey_from_crt(crt_path)
fileutil.write_file(pub_path, pub)
self.set_selinux_context(pub_path, 'unconfined_u:object_r:ssh_home_t:s0')
self.openssl_to_openssh(pub_path, path)
self.set_selinux_context(path, 'unconfined_u:object_r:ssh_home_t:s0')
fileutil.chowner(path, username)
fileutil.chmod(path, 0644)
fileutil.chmod(pub_path, 0600)
def is_selinux_system(self):
"""
Checks and sets self.selinux = True if SELinux is available on system.
"""
if self.selinux == None:
if shellutil.Run("which getenforce", chk_err=False) == 0:
if shellutil.run("which getenforce", chk_err=False) == 0:
self.selinux = True
else:
self.selinux = False
return self.selinux
def IsSelinuxRunning(self):
def is_selinux_enforcing(self):
"""
Calls shell command 'getenforce' and returns True if 'Enforcing'.
"""
if self.IsSelinuxSystem():
output = shellutil.RunGetOutput("getenforce")[1]
if self.is_selinux_system():
output = shellutil.run_get_output("getenforce")[1]
return output.startswith("Enforcing")
else:
return False
def SetSelinuxEnforce(self, state):
def set_selinux_enforce(self, state):
"""
Calls shell command 'setenforce' with 'state'
Calls shell command 'setenforce' with 'state'
and returns resulting exit code.
"""
if self.IsSelinuxSystem():
if self.is_selinux_system():
if state: s = '1'
else: s='0'
return shellutil.Run("setenforce "+s)
return shellutil.run("setenforce "+s)
def SetSelinuxContext(self, path, cn):
def set_selinux_context(self, path, con):
"""
Calls shell 'chcon' with 'path' and 'cn' context.
Calls shell 'chcon' with 'path' and 'con' context.
Returns exit result.
"""
if self.IsSelinuxSystem():
return shellutil.Run('chcon ' + cn + ' ' + path)
def GetSshdConfigPath(self):
return self.sshdConfigPath
if self.is_selinux_system():
return shellutil.run('chcon ' + con + ' ' + path)
def SetSshClientAliveInterval(self):
configPath = self.GetSshdConfigPath()
config = fileutil.GetFileContents(configPath).split("\n")
textutil.SetSshConfig(config, "ClientAliveInterval", "180")
fileutil.ReplaceFileContentsAtomic(configPath, '\n'.join(config))
logger.Info("Configured SSH client probing to keep connections alive.")
def ConfigSshd(self, disablePassword):
option = "no" if disablePassword else "yes"
configPath = self.GetSshdConfigPath()
config = fileutil.GetFileContents(configPath).split("\n")
textutil.SetSshConfig(config, "PasswordAuthentication", option)
textutil.SetSshConfig(config, "ChallengeResponseAuthentication", option)
fileutil.ReplaceFileContentsAtomic(configPath, "\n".join(config))
logger.Info("Disabled SSH password-based authentication methods.")
def get_sshd_conf_file_path(self):
return self.sshd_conf_file_path
def set_ssh_client_alive_interval(self):
conf_file_path = self.get_sshd_conf_file_path()
conf = fileutil.read_file(conf_file_path).split("\n")
textutil.set_ssh_config(conf, "ClientAliveInterval", "180")
fileutil.replace_file(conf_file_path, '\n'.join(conf))
logger.info("Configured SSH client probing to keep connections alive.")
def conf_sshd(self, disable_password):
option = "no" if disable_password else "yes"
conf_file_path = self.get_sshd_conf_file_path()
conf = fileutil.read_file(conf_file_path).split("\n")
textutil.set_ssh_config(conf, "PasswordAuthentication", option)
textutil.set_ssh_config(conf, "ChallengeResponseAuthentication", option)
fileutil.replace_file(conf_file_path, "\n".join(conf))
logger.info("Disabled SSH password-based authentication methods.")
def GetDvdDevice(self, devDir='/dev'):
def get_dvd_device(self, dev_dir='/dev'):
patten=r'(sr[0-9]|hd[c-z]|cdrom[0-9])'
for dvd in [re.match(patten, dev) for dev in os.listdir(devDir)]:
for dvd in [re.match(patten, dev) for dev in os.listdir(dev_dir)]:
if dvd is not None:
return "/dev/{0}".format(dvd.group(0))
raise OSUtilError("Failed to get dvd device")
def MountDvd(self, maxRetry=6, chk_err=True):
dvd = self.GetDvdDevice()
mountPoint = self.GetDvdMountPoint()
mountlist = shellutil.RunGetOutput("mount")[1]
existing = self.GetMountPoint(mountlist, dvd)
def mount_dvd(self, max_retry=6, chk_err=True):
dvd = self.get_dvd_device()
mount_point = self.get_dvd_mount_point()
mountlist = shellutil.run_get_output("mount")[1]
existing = self.get_mount_point(mountlist, dvd)
if existing is not None: #Already mounted
logger.Info("{0} is already mounted at {1}", dvd, existing)
logger.info("{0} is already mounted at {1}", dvd, existing)
return
if not os.path.isdir(mountPoint):
os.makedirs(mountPoint)
for retry in range(0, maxRetry):
retcode = self.Mount(dvd, mountPoint, option="-o ro -t iso9660,udf",
if not os.path.isdir(mount_point):
os.makedirs(mount_point)
for retry in range(0, max_retry):
retcode = self.mount(dvd, mount_point, option="-o ro -t iso9660,udf",
chk_err=chk_err)
if retcode == 0:
logger.Info("Successfully mounted dvd")
logger.info("Successfully mounted dvd")
return
if retry < maxRetry - 1:
logger.Warn("Mount dvd failed: retry={0}, ret={1}", retry,
if retry < max_retry - 1:
logger.warn("Mount dvd failed: retry={0}, ret={1}", retry,
retcode)
time.sleep(5)
if chk_err:
raise OSUtilError("Failed to mount dvd.")
def UmountDvd(self, chk_err=True):
mountPoint = self.GetDvdMountPoint()
retcode = self.Umount(mountPoint, chk_err=chk_err)
def umount_dvd(self, chk_err=True):
mount_point = self.get_dvd_mount_point()
retcode = self.umount(mount_point, chk_err=chk_err)
if chk_err and retcode != 0:
raise OSUtilError("Failed to umount dvd.")
def LoadAtapiixModule(self):
if self.IsAtaPiixModuleLoaded():
def load_atappix_mod(self):
if self.is_atapiix_mod_loaded():
return
ret, kernVersion = shellutil.RunGetOutput("uname -r")
ret, kern_version = shellutil.run_get_output("uname -r")
if ret != 0:
raise Exception("Failed to call uname -r")
modulePath = os.path.join('/lib/modules',
kernVersion.strip('\n'),
'kernel/drivers/ata/ata_piix.ko')
if not os.path.isfile(modulePath):
raise Exception("Can't find module file:{0}".format(modulePath))
mod_path = os.path.join('/lib/modules',
kern_version.strip('\n'),
'kernel/drivers/ata/ata_piix.ko')
if not os.path.isfile(mod_path):
raise Exception("Can't find module file:{0}".format(mod_path))
ret, output = shellutil.RunGetOutput("insmod " + modulePath)
ret, output = shellutil.run_get_output("insmod " + mod_path)
if ret != 0:
raise Exception("Error calling insmod for ATAPI CD-ROM driver")
if not self.IsAtaPiixModuleLoaded(maxRetry=3):
raise Exception("Failed to load ATAPI CD-ROM driver")
if not self.is_atapiix_mod_loaded(max_retry=3):
raise Exception("Failed to load ATAPI CD-ROM driver")
def IsAtaPiixModuleLoaded(self, maxRetry=1):
for retry in range(0, maxRetry):
ret = shellutil.Run("lsmod | grep ata_piix", chk_err=False)
def is_atapiix_mod_loaded(self, max_retry=1):
for retry in range(0, max_retry):
ret = shellutil.run("lsmod | grep ata_piix", chk_err=False)
if ret == 0:
logger.Info("Module driver for ATAPI CD-ROM is already present.")
logger.info("Module driver for ATAPI CD-ROM is already present.")
return True
if retry < maxRetry - 1:
if retry < max_retry - 1:
time.sleep(1)
return False
def Mount(self, dvd, mountPoint, option="", chk_err=True):
cmd = "mount {0} {1} {2}".format(dvd, option, mountPoint)
return shellutil.RunGetOutput(cmd, chk_err)[0]
def Umount(self, mountPoint, chk_err=True):
return shellutil.Run("umount {0}".format(mountPoint), chk_err=chk_err)
def mount(self, dvd, mount_point, option="", chk_err=True):
cmd = "mount {0} {1} {2}".format(dvd, option, mount_point)
return shellutil.run_get_output(cmd, chk_err)[0]
def OpenPortForDhcp(self):
def umount(self, mount_point, chk_err=True):
return shellutil.run("umount {0}".format(mount_point), chk_err=chk_err)
def allow_dhcp_broadcast(self):
#Open DHCP port if iptables is enabled.
# We supress error logging on error.
shellutil.Run("iptables -D INPUT -p udp --dport 68 -j ACCEPT",
chk_err=False)
shellutil.Run("iptables -I INPUT -p udp --dport 68 -j ACCEPT",
shellutil.run("iptables -D INPUT -p udp --dport 68 -j ACCEPT",
chk_err=False)
shellutil.run("iptables -I INPUT -p udp --dport 68 -j ACCEPT",
chk_err=False)
def GenerateTransportCert(self):
def gen_transport_cert(self):
"""
Create ssl certificate for https communication with endpoint server.
"""
cmd = ("{0} req -x509 -nodes -subj /CN=LinuxTransport -days 32768 "
"-newkey rsa:2048 -keyout TransportPrivate.pem "
"-out TransportCert.pem").format(self.opensslCmd)
shellutil.Run(cmd)
"-out TransportCert.pem").format(self.openssl_cmd)
shellutil.run(cmd)
def RemoveRulesFiles(self, rulesFiles=RulesFiles):
libDir = self.GetLibDir()
for src in rulesFiles:
fileName = fileutil.GetLastPathElement(src)
dest = os.path.join(libDir, fileName)
def remove_rules_files(self, rules_files=__RULES_FILES__):
lib_dir = self.get_lib_dir()
for src in rules_files:
file_name = fileutil.base_name(src)
dest = os.path.join(lib_dir, file_name)
if os.path.isfile(dest):
os.remove(dest)
if os.path.isfile(src):
logger.Warn("Move rules file {0} to {1}", fileName, dest)
logger.warn("Move rules file {0} to {1}", file_name, dest)
shutil.move(src, dest)
def RestoreRulesFiles(self, rulesFiles=RulesFiles):
libDir = self.GetLibDir()
for dest in rulesFiles:
fileName = fileutil.GetLastPathElement(dest)
src = os.path.join(libDir, fileName)
def restore_rules_files(self, rules_files=__RULES_FILES__):
lib_dir = self.get_lib_dir()
for dest in rules_files:
filename = fileutil.base_name(dest)
src = os.path.join(lib_dir, filename)
if os.path.isfile(dest):
continue
if os.path.isfile(src):
logger.Warn("Move rules file {0} to {1}", fileName, dest)
logger.warn("Move rules file {0} to {1}", filename, dest)
shutil.move(src, dest)
def GetMacAddress(self):
def get_mac_addr(self):
"""
Convienience function, returns mac addr bound to
first non-loobback interface.
"""
ifname=''
while len(ifname) < 2 :
ifname=self.GetFirstActiveNetworkInterfaceNonLoopback()[0]
addr = self.GetInterfaceMac(ifname)
return textutil.HexStringToByteArray(addr)
ifname=self.get_first_if()[0]
addr = self.get_if_mac(ifname)
return textutil.hexstr_to_bytearray(addr)
def GetInterfaceMac(self, ifname):
def get_if_mac(self, ifname):
"""
Return the mac-address bound to the socket.
"""
sock = socket.socket(socket.AF_INET,
socket.SOCK_DGRAM,
sock = socket.socket(socket.AF_INET,
socket.SOCK_DGRAM,
socket.IPPROTO_UDP)
param = struct.pack('256s', (ifname[:15]+('\0'*241)).encode('latin-1'))
info = fcntl.ioctl(sock.fileno(), 0x8927, param)
return ''.join(['%02X' % textutil.Ord(char) for char in info[18:24]])
return ''.join(['%02X' % textutil.str_to_ord(char) for char in info[18:24]])
def GetFirstActiveNetworkInterfaceNonLoopback(self):
def get_first_if(self):
"""
Return the interface name, and ip addr of the
first active non-loopback interface.
@@ -446,17 +446,17 @@ class DefaultOSUtil(object):
iface=''
expected=16 # how many devices should I expect...
struct_size=40 # for 64bit the size is 40 bytes
sock = socket.socket(socket.AF_INET,
socket.SOCK_DGRAM,
sock = socket.socket(socket.AF_INET,
socket.SOCK_DGRAM,
socket.IPPROTO_UDP)
buff=array.array('B', b'\0' * (expected * struct_size))
param = struct.pack('iL',
expected*struct_size,
param = struct.pack('iL',
expected*struct_size,
buff.buffer_info()[0])
ret = fcntl.ioctl(sock.fileno(), 0x8912, param)
retsize=(struct.unpack('iL', ret)[0])
if retsize == (expected * struct_size):
logger.Warn(('SIOCGIFCONF returned more than {0} up '
logger.warn(('SIOCGIFCONF returned more than {0} up '
'network interfaces.'), expected)
sock = buff.tostring()
for i in range(0, struct_size * expected, struct_size):
@@ -467,114 +467,114 @@ class DefaultOSUtil(object):
break
return iface.decode('latin-1'), socket.inet_ntoa(sock[i+20:i+24])
def IsMissingDefaultRoute(self):
routes = shellutil.RunGetOutput("route -n")[1]
def is_missing_default_route(self):
routes = shellutil.run_get_output("route -n")[1]
for route in routes:
if route.startswith("0.0.0.0 ") or route.startswith("default "):
return False
return False
return True
def GetInterfaceName(self):
return self.GetFirstActiveNetworkInterfaceNonLoopback()[0]
def GetIpv4Address(self):
return self.GetFirstActiveNetworkInterfaceNonLoopback()[1]
def SetBroadcastRouteForDhcp(self, ifname):
return shellutil.Run("route add 255.255.255.255 dev {0}".format(ifname),
def get_if_name(self):
return self.get_first_if()[0]
def get_ip4_addr(self):
return self.get_first_if()[1]
def set_route_for_dhcp_broadcast(self, ifname):
return shellutil.run("route add 255.255.255.255 dev {0}".format(ifname),
chk_err=False)
def RemoveBroadcastRouteForDhcp(self, ifname):
shellutil.Run("route del 255.255.255.255 dev {0}".format(ifname),
def remove_route_for_dhcp_broadcast(self, ifname):
shellutil.run("route del 255.255.255.255 dev {0}".format(ifname),
chk_err=False)
def IsDhcpEnabled(self):
def is_dhcp_enabled(self):
return False
def StopDhcpService(self):
def stop_dhcp_service(self):
pass
def StartDhcpService(self):
def start_dhcp_service(self):
pass
def StartNetwork(self):
def start_network(self):
pass
def StartAgentService(self):
def start_agent_service(self):
pass
def StopAgentService(self):
def stop_agent_service(self):
pass
def RegisterAgentService(self):
def register_agent_service(self):
pass
def UnregisterAgentService(self):
def unregister_agent_service(self):
pass
def RestartSshService(self):
def restart_ssh_service(self):
pass
def RouteAdd(self, net, mask, gateway):
def route_add(self, net, mask, gateway):
"""
Add specified route using /sbin/route add -net.
"""
cmd = ("/sbin/route add -net "
"{0} netmask {1} gw {2}").format(net, mask, gateway)
return shellutil.Run(cmd, chk_err=False)
return shellutil.run(cmd, chk_err=False)
def GetDhcpProcessId(self):
ret= shellutil.RunGetOutput("pidof dhclient")
def get_dhcp_pid(self):
ret= shellutil.run_get_output("pidof dhclient")
return ret[1] if ret[0] == 0 else None
def SetHostname(self, hostname):
fileutil.SetFileContents('/etc/hostname', hostname)
shellutil.Run("hostname {0}".format(hostname), chk_err=False)
def set_hostname(self, hostname):
fileutil.write_file('/etc/hostname', hostname)
shellutil.run("hostname {0}".format(hostname), chk_err=False)
def SetDhcpHostname(self, hostname):
autoSend = r'^[^#]*?send\s*host-name.*?(<hostname>|gethostname[(,)])'
dhclientFiles = ['/etc/dhcp/dhclient.conf', '/etc/dhcp3/dhclient.conf']
for confFile in dhclientFiles:
if not os.path.isfile(confFile):
def set_dhcp_hostname(self, hostname):
autosend = r'^[^#]*?send\s*host-name.*?(<hostname>|gethostname[(,)])'
dhclient_files = ['/etc/dhcp/dhclient.conf', '/etc/dhcp3/dhclient.conf']
for conf_file in dhclient_files:
if not os.path.isfile(conf_file):
continue
if fileutil.FindStringInFile(confFile, autoSend):
if fileutil.findstr_in_file(conf_file, autosend):
#Return if auto send host-name is configured
return
fileutil.UpdateConfigFile(confFile,
fileutil.update_conf_file(conf_file,
'send host-name',
'send host-name {0}'.format(hostname))
def RestartInterface(self, ifname):
shellutil.Run("ifdown {0} && ifup {1}".format(ifname, ifname))
def restart_if(self, ifname):
shellutil.run("ifdown {0} && ifup {1}".format(ifname, ifname))
def PublishHostname(self, hostname):
self.SetDhcpHostname(hostname)
ifname = self.GetInterfaceName()
self.RestartInterface(ifname)
def publish_hostname(self, hostname):
self.set_dhcp_hostname(hostname)
ifname = self.get_if_name()
self.restart_if(ifname)
def SetScsiDiskTimeout(self, timeout):
def set_scsi_disks_timeout(self, timeout):
for dev in os.listdir("/sys/block"):
if dev.startswith('sd'):
self.SetBlockDeviceTimeout(dev, timeout)
self.set_block_device_timeout(dev, timeout)
def SetBlockDeviceTimeout(self, dev, timeout):
def set_block_device_timeout(self, dev, timeout):
if dev is not None and timeout is not None:
filePath = "/sys/block/{0}/device/timeout".format(dev)
content = fileutil.GetFileContents(filePath)
file_path = "/sys/block/{0}/device/timeout".format(dev)
content = fileutil.read_file(file_path)
original = content.splitlines()[0].rstrip()
if original != timeout:
fileutil.SetFileContents(filePath, timeout)
logger.Info("Set block dev timeout: {0} with timeout: {1}",
fileutil.write_file(file_path, timeout)
logger.info("Set block dev timeout: {0} with timeout: {1}",
dev, timeout)
def GetMountPoint(self, mountlist, device):
def get_mount_point(self, mountlist, device):
"""
Example of mountlist:
/dev/sda1 on / type ext4 (rw)
proc on /proc type proc (rw)
sysfs on /sys type sysfs (rw)
devpts on /dev/pts type devpts (rw,gid=5,mode=620)
tmpfs on /dev/shm type tmpfs
tmpfs on /dev/shm type tmpfs
(rw,rootcontext="system_u:object_r:tmpfs_t:s0")
none on /proc/sys/fs/binfmt_misc type binfmt_misc (rw)
/dev/sdb1 on /mnt/resource type ext4 (rw)
@@ -586,24 +586,23 @@ class DefaultOSUtil(object):
#Return the 3rd column of this line
return tokens[2] if len(tokens) > 2 else None
return None
def DeviceForIdePort(self, n):
def device_for_ide_port(self, port_id):
"""
Return device name attached to ide port 'n'.
"""
if n > 3:
if port_id > 3:
return None
g0 = "00000000"
if n > 1:
if port_id > 1:
g0 = "00000001"
n = n - 2
port_id = port_id - 2
device = None
path = "/sys/bus/vmbus/devices/"
for vmbus in os.listdir(path):
deviceid = fileutil.GetFileContents(os.path.join(path, vmbus,
"device_id"))
deviceid = fileutil.read_file(os.path.join(path, vmbus, "device_id"))
guid = deviceid.lstrip('{').split('-')
if guid[0] == g0 and guid[1] == "000" + str(n):
if guid[0] == g0 and guid[1] == "000" + str(port_id):
for root, dirs, files in os.walk(path + vmbus):
if root.endswith("/block"):
device = dirs[0]
@@ -616,38 +615,37 @@ class DefaultOSUtil(object):
break
return device
def DeleteAccount(self, userName):
if self.IsSysUser(userName):
logger.Error("{0} is a system user. Will not delete it.", userName)
shellutil.Run("> /var/run/utmp")
shellutil.Run("userdel -f -r " + userName)
def del_account(self, username):
if self.is_sys_user(username):
logger.error("{0} is a system user. Will not delete it.", username)
shellutil.run("> /var/run/utmp")
shellutil.run("userdel -f -r " + username)
#Remove user from suders
if os.path.isfile("/etc/suders.d/waagent"):
try:
content = fileutil.GetFileContents("/etc/sudoers.d/waagent")
content = fileutil.read_file("/etc/sudoers.d/waagent")
sudoers = content.split("\n")
sudoers = filter(lambda x : userName not in x, sudoers)
fileutil.SetFileContents("/etc/sudoers.d/waagent",
sudoers = filter(lambda x : username not in x, sudoers)
fileutil.write_file("/etc/sudoers.d/waagent",
"\n".join(sudoers))
except IOError as e:
raise OSUtilError("Failed to remove sudoer: {0}".format(e))
def TranslateCustomData(self, data):
def decode_customdata(self, data):
return data
def GetTotalMemory(self):
def get_total_mem(self):
cmd = "grep MemTotal /proc/meminfo |awk '{print $2}'"
ret = shellutil.RunGetOutput(cmd)
ret = shellutil.run_get_output(cmd)
if ret[0] == 0:
return int(ret[1])/1024
else:
raise OSUtilError("Failed to get total memory: {0}".format(ret[1]))
def GetProcessorCores(self):
ret = shellutil.RunGetOutput("grep 'processor.*:' /proc/cpuinfo |wc -l")
def get_processor_cores(self):
ret = shellutil.run_get_output("grep 'processor.*:' /proc/cpuinfo |wc -l")
if ret[0] == 0:
return int(ret[1])
else:
raise OSUtilError("Failed to get procerssor cores")
OSUtil = DefaultOSUtil
+80 -78
View File
@@ -18,133 +18,135 @@
import os
import azurelinuxagent.logger as logger
import azurelinuxagent.conf as conf
from azurelinuxagent.event import AddExtensionEvent, WALAEventOperation
from azurelinuxagent.event import add_event, WALAEventOperation
from azurelinuxagent.exception import *
from azurelinuxagent.utils.osutil import OSUtil
from azurelinuxagent.utils.osutil import OSUTIL
import azurelinuxagent.protocol as prot
import azurelinuxagent.protocol.ovfenv as ovf
import azurelinuxagent.utils.shellutil as shellutil
import azurelinuxagent.utils.fileutil as fileutil
CustomDataFile="CustomData"
CUSTOM_DATA_FILE="CustomData"
class ProvisionHandler(object):
def process(self):
#If provision is not enabled, return
if not conf.GetSwitch("Provisioning.Enabled", True):
logger.Info("Provisioning is disabled. Skip.")
if not conf.get_switch("Provisioning.Enabled", True):
logger.info("Provisioning is disabled. Skip.")
return
provisioned = os.path.join(OSUtil.GetLibDir(), "provisioned")
provisioned = os.path.join(OSUTIL.get_lib_dir(), "provisioned")
if os.path.isfile(provisioned):
return
logger.Info("Run provision handler.")
protocol = prot.Factory.getDefaultProtocol()
logger.info("run provision handler.")
protocol = prot.Factory.get_default_protocol()
try:
status = prot.ProvisionStatus(status="NotReady",
subStatus="ProvisionStatus")
protocol.reportProvisionStatus(status)
status = prot.ProvisionStatus(status="NotReady",
subStatus="Provision started")
protocol.report_provision_status(status)
self.provision()
fileutil.SetFileContents(provisioned, "")
thumbprint = self.regenerateSshHostKey()
logger.Info("Finished provisioning")
fileutil.write_file(provisioned, "")
thumbprint = self.reg_ssh_host_key()
logger.info("Finished provisioning")
status = prot.ProvisionStatus(status="Ready")
status.properties.certificateThumbprint = thumbprint
protocol.reportProvisionStatus(status)
protocol.report_provision_status(status)
AddExtensionEvent(name="WALA", isSuccess=True, message="",
add_event(name="WALA", is_success=True, message="",
op=WALAEventOperation.Provision)
except ProvisionError as e:
logger.Error("Provision failed: {0}", e)
protocol.reportProvisionStatus(status="NotReady", subStatus=str(e))
AddExtensionEvent(name="WALA", isSuccess=False, message=str(e),
logger.error("Provision failed: {0}", e)
status = prot.ProvisionStatus(status="NotReady",
subStatus= str(e))
protocol.report_provision_status(status)
add_event(name="WALA", is_success=False, message=str(e),
op=WALAEventOperation.Provision)
def regenerateSshHostKey(self):
keyPairType = conf.Get("Provisioning.SshHostKeyPairType", "rsa")
if conf.GetSwitch("Provisioning.RegenerateSshHostKeyPair"):
shellutil.Run("rm -f /etc/ssh/ssh_host_*key*")
shellutil.Run(("ssh-keygen -N '' -t {0} -f /etc/ssh/ssh_host_{1}_key"
"").format(keyPairType, keyPairType))
thumbprint = self.getSshHostKeyThumbprint(keyPairType)
def reg_ssh_host_key(self):
keypair_type = conf.get("Provisioning.SshHostKeyPairType", "rsa")
if conf.get_switch("Provisioning.RegenerateSshHostKeyPair"):
shellutil.run("rm -f /etc/ssh/ssh_host_*key*")
shellutil.run(("ssh-keygen -N '' -t {0} -f /etc/ssh/ssh_host_{1}_key"
"").format(keypair_type, keypair_type))
thumbprint = self.get_ssh_host_key_thumbprint(keypair_type)
return thumbprint
def getSshHostKeyThumbprint(self, keyPairType):
cmd = "ssh-keygen -lf /etc/ssh/ssh_host_{0}_key.pub".format(keyPairType)
ret = shellutil.RunGetOutput(cmd)
def get_ssh_host_key_thumbprint(self, keypair_type):
cmd = "ssh-keygen -lf /etc/ssh/ssh_host_{0}_key.pub".format(keypair_type)
ret = shellutil.run_get_output(cmd)
if ret[0] == 0:
return ret[1].rstrip().split()[1].replace(':', '')
else:
raise ProvisionError(("Failed to generate ssh host key: "
"ret={0}, out= {1}").format(ret[0], ret[1]))
def provision(self):
logger.Info("Copy ovf-env.xml.")
logger.info("Copy ovf-env.xml.")
try:
ovfenv = ovf.CopyOvfEnv()
ovfenv = ovf.copy_ovf_env()
except prot.ProtocolError as e:
raise ProvisionError("Failed to copy ovf-env.xml: {0}".format(e))
password = ovfenv.getUserPassword()
ovfenv.clearUserPassword()
logger.Info("Set host name.")
OSUtil.SetHostname(ovfenv.getComputerName())
logger.Info("Publish host name.")
OSUtil.PublishHostname(ovfenv.getComputerName())
logger.Info("Create user account.")
OSUtil.UpdateUserAccount(ovfenv.getUserName(), password)
password = ovfenv.get_user_password()
ovfenv.clear_user_password()
logger.info("Set host name.")
OSUTIL.set_hostname(ovfenv.get_computer_name())
logger.info("Publish host name.")
OSUTIL.publish_hostname(ovfenv.get_computer_name())
logger.info("Create user account.")
OSUTIL.set_user_account(ovfenv.get_username(), password)
if password is not None:
userSalt = conf.GetSwitch("Provision.UseSalt", True)
saltType = conf.GetSwitch("Provision.SaltType", 6)
logger.Info("Set user password.")
OSUtil.ChangePassword(ovfenv.getUserName(), password, userSalt,
saltType)
use_salt = conf.get_switch("Provision.UseSalt", True)
salt_type = conf.get_switch("Provision.SaltType", 6)
logger.info("Set user password.")
OSUTIL.chpasswd(ovfenv.get_username(), password, use_salt,
salt_type)
logger.Info("Configure sshd.")
OSUtil.ConfigSshd(ovfenv.getDisableSshPasswordAuthentication())
logger.info("Configure sshd.")
OSUTIL.conf_sshd(ovfenv.get_disable_ssh_password_auth())
#Disable selinux temporary
sel = OSUtil.IsSelinuxRunning()
sel = OSUTIL.is_selinux_enforcing()
if sel:
OSUtil.SetSelinuxEnforce(0)
self.deploySshPublicKeys(ovfenv)
self.deploySshKeyPairs(ovfenv)
self.saveCustomData(ovfenv)
OSUTIL.set_selinux_enforce(0)
self.deploy_ssh_pubkeys(ovfenv)
self.deploy_ssh_keypairs(ovfenv)
self.save_customdata(ovfenv)
if sel:
OSUtil.SetSelinuxEnforce(1)
OSUTIL.set_selinux_enforce(1)
OSUtil.RestartSshService()
OSUTIL.restart_ssh_service()
if conf.GetSwitch("Provisioning.DeleteRootPassword"):
OSUtil.DeleteRootPassword()
if conf.get_switch("Provisioning.DeleteRootPassword"):
OSUTIL.del_root_password()
def saveCustomData(self, ovfenv):
logger.Info("Save custom data")
customData = ovfenv.getCustomData()
if customData is None:
def save_customdata(self, ovfenv):
logger.info("Save custom data")
customdata = ovfenv.get_customdata()
if customdata is None:
return
libDir = OSUtil.GetLibDir()
fileutil.SetFileContents(os.path.join(libDir, CustomDataFile),
OSUtil.TranslateCustomData(customData))
lib_dir = OSUTIL.get_lib_dir()
fileutil.write_file(os.path.join(lib_dir, CUSTOM_DATA_FILE),
OSUTIL.decode_customdata(customdata))
def deploy_ssh_pubkeys(self, ovfenv):
for thumbprint, path in ovfenv.get_ssh_pubkeys():
logger.info("Deploy ssh public key.")
OSUTIL.deploy_ssh_pubkey(ovfenv.get_username(), thumbprint, path)
def deploy_ssh_keypairs(self, ovfenv):
for thumbprint, path in ovfenv.get_ssh_keypairs():
logger.info("Deploy ssh key pairs.")
OSUTIL.deploy_ssh_keypair(ovfenv.get_username(), thumbprint, path)
def deploySshPublicKeys(self, ovfenv):
for thumbprint, path in ovfenv.getSshPublicKeys():
logger.Info("Deploy ssh public key.")
OSUtil.DeploySshPublicKey(ovfenv.getUserName(), thumbprint, path)
def deploySshKeyPairs(self, ovfenv):
for thumbprint, path in ovfenv.getSshKeyPairs():
logger.Info("Deploy ssh key pairs.")
OSUtil.DeploySshKeyPair(ovfenv.getUserName(), thumbprint, path)
+90 -90
View File
@@ -22,15 +22,15 @@ import re
import threading
import azurelinuxagent.logger as logger
import azurelinuxagent.conf as conf
from azurelinuxagent.utils.osutil import OSUtil
from azurelinuxagent.event import AddExtensionEvent, WALAEventOperation
from azurelinuxagent.utils.osutil import OSUTIL
from azurelinuxagent.event import add_event, WALAEventOperation
import azurelinuxagent.utils.fileutil as fileutil
import azurelinuxagent.utils.shellutil as shellutil
from azurelinuxagent.exception import ResourceDiskError
DataLossWarningFile="DATALOSS_WARNING_README.txt"
DataLossWarning="""\
WARNING: THIS IS A TEMPORARY DISK.
DATALOSS_WARNING_FILE_NAME="DATALOSS_WARNING_README.txt"
DATA_LOSS_WARNING="""\
WARNING: THIS IS A TEMPORARY DISK.
Any data stored on this drive is SUBJECT TO LOSS and THERE IS NO WAY TO RECOVER IT.
@@ -41,126 +41,126 @@ For additional details to please refer to the MSDN documentation at : http://msd
class ResourceDiskHandler(object):
def startActivateResourceDisk(self):
diskThread = threading.Thread(target = self.run)
diskThread.start()
def start_activate_resource_disk(self):
disk_thread = threading.Thread(target = self.run)
disk_thread.start()
def run(self):
mountpoint = None
if conf.GetSwitch("ResourceDisk.Format", False):
mountpoint = self.activateResourceDisk()
if mountpoint is not None and \
conf.GetSwitch("ResourceDisk.EnableSwap", False):
self.enableSwap(mountpoint)
mount_point = None
if conf.get_switch("ResourceDisk.Format", False):
mount_point = self.activate_resource_disk()
if mount_point is not None and \
conf.get_switch("ResourceDisk.EnableSwap", False):
self.enable_swap(mount_point)
def activateResourceDisk(self):
logger.Info("Activate resource disk")
def activate_resource_disk(self):
logger.info("Activate resource disk")
try:
mountpoint = conf.Get("ResourceDisk.MountPoint", "/mnt/resource")
fs = conf.Get("ResourceDisk.Filesystem", "ext3")
mountpoint = self.mountResourceDisk(mountpoint, fs)
warningFile = os.path.join(mountpoint, DataLossWarningFile)
mount_point = conf.get("ResourceDisk.MountPoint", "/mnt/resource")
fs = conf.get("ResourceDisk.Filesystem", "ext3")
mount_point = self.mount_resource_disk(mount_point, fs)
warning_file = os.path.join(mount_point, DATALOSS_WARNING_FILE_NAME)
try:
fileutil.SetFileContents(warningFile, DataLossWarning)
fileutil.write_file(warning_file, DATA_LOSS_WARNING)
except IOError as e:
logger.Warn("Failed to write data loss warnning:{0}", e)
return mountpoint
logger.warn("Failed to write data loss warnning:{0}", e)
return mount_point
except ResourceDiskError as e:
logger.Error("Failed to mount resource disk {0}", e)
AddExtensionEvent(name="WALA", isSuccess=False, message=str(e),
logger.error("Failed to mount resource disk {0}", e)
add_event(name="WALA", is_success=False, message=str(e),
op=WALAEventOperation.ActivateResourceDisk)
def enableSwap(self, mountpoint):
logger.Info("Enable swap")
try:
sizeMB = conf.GetInt("ResourceDisk.SwapSizeMB", 0)
self.createSwapSpace(mountpoint, sizeMB)
except ResourceDiskError as e:
logger.Error("Failed to enable swap {0}", e)
def mountResourceDisk(self, mountpoint, fs):
device = OSUtil.DeviceForIdePort(1)
def enable_swap(self, mount_point):
logger.info("Enable swap")
try:
size_mb = conf.get_int("ResourceDisk.SwapSizeMB", 0)
self.create_swap_space(mount_point, size_mb)
except ResourceDiskError as e:
logger.error("Failed to enable swap {0}", e)
def mount_resource_disk(self, mount_point, fs):
device = OSUTIL.device_for_ide_port(1)
if device is None:
raise ResourceDiskError("unable to detect disk topology")
device = "/dev/" + device
mountlist = shellutil.RunGetOutput("mount")[1]
existing = OSUtil.GetMountPoint(mountlist, device)
mountlist = shellutil.run_get_output("mount")[1]
existing = OSUTIL.get_mount_point(mountlist, device)
if(existing):
logger.Info("Resource disk {0}1 is already mounted", device)
logger.info("Resource disk {0}1 is already mounted", device)
return existing
fileutil.CreateDir(mountpoint, mode=0755)
logger.Info("Detect GPT...")
fileutil.mkdir(mount_point, mode=0755)
logger.info("Detect GPT...")
partition = device + "1"
ret = shellutil.RunGetOutput("parted {0} print".format(device))
ret = shellutil.run_get_output("parted {0} print".format(device))
if ret[0]:
raise ResourceDiskError("({0}) {1}".format(device, ret[1]))
if "gpt" in ret[1]:
logger.Info("GPT detected")
logger.Info("Get GPT partitions")
parts = filter(lambda x : re.match("^\s*[0-9]+", x),
logger.info("GPT detected")
logger.info("Get GPT partitions")
parts = filter(lambda x : re.match("^\s*[0-9]+", x),
ret[1].split("\n"))
logger.Info("Found more than {0} GPT partitions.", len(parts))
logger.info("Found more than {0} GPT partitions.", len(parts))
if len(parts) > 1:
logger.Info("Remove old GPT partitions")
logger.info("Remove old GPT partitions")
for i in range(1, len(parts) + 1):
logger.Info("Remove partition: {0}", i)
shellutil.Run("parted {0} rm {1}".format(device, i))
logger.info("Remove partition: {0}", i)
shellutil.run("parted {0} rm {1}".format(device, i))
logger.Info("Create a new GPT partition using entire disk space")
shellutil.Run("parted {0} mkpart primary 0% 100%".format(device))
logger.Info("Format partition: {0} with fstype {1}",partition,fs)
shellutil.Run("mkfs." + fs + " " + partition + " -F")
logger.info("Create a new GPT partition using entire disk space")
shellutil.run("parted {0} mkpart primary 0% 100%".format(device))
logger.info("Format partition: {0} with fstype {1}",partition,fs)
shellutil.run("mkfs." + fs + " " + partition + " -F")
else:
logger.Info("GPT not detected")
logger.Info("Check fstype")
ret = shellutil.RunGetOutput("sfdisk -q -c {0} 1".format(device))
logger.info("GPT not detected")
logger.info("Check fstype")
ret = shellutil.run_get_output("sfdisk -q -c {0} 1".format(device))
if ret[1].rstrip() == "7" and fs != "ntfs":
logger.Info("The partition is formatted with ntfs")
logger.Info("Format partition: {0} with fstype {1}",partition,fs)
shellutil.Run("sfdisk -c {0} 1 83".format(device))
shellutil.Run("mkfs." + fs + " " + partition + " -F")
logger.info("The partition is formatted with ntfs")
logger.info("Format partition: {0} with fstype {1}",partition,fs)
shellutil.run("sfdisk -c {0} 1 83".format(device))
shellutil.run("mkfs." + fs + " " + partition + " -F")
logger.Info("Mount resource disk")
retCode = shellutil.Run("mount {0} {1}".format(partition, mountpoint),
logger.info("Mount resource disk")
ret = shellutil.run("mount {0} {1}".format(partition, mount_point),
chk_err=False)
if retCode:
logger.Warn("Failed to mount resource disk. Retry mounting")
shellutil.Run("mkfs." + fs + " " + partition + " -F")
retCode = shellutil.Run("mount {0} {1}".format(partition, mountpoint))
if retCode:
raise ResourceDiskError("({0}) {1}".format(partition, retCode))
if ret:
logger.warn("Failed to mount resource disk. Retry mounting")
shellutil.run("mkfs." + fs + " " + partition + " -F")
ret = shellutil.run("mount {0} {1}".format(partition, mount_point))
if ret:
raise ResourceDiskError("({0}) {1}".format(partition, ret))
logger.Info("Resource disk ({0}) is mounted at {1} with fstype {2}",
device, mountpoint, fs)
return mountpoint
logger.info("Resource disk ({0}) is mounted at {1} with fstype {2}",
device, mount_point, fs)
return mount_point
def createSwapSpace(self, mountpoint, sizeMB):
sizeKB = sizeMB * 1024
size = sizeKB * 1024
swapfile = os.path.join(mountpoint, 'swapfile')
swapList = shellutil.RunGetOutput("swapon -s")[1]
def create_swap_space(self, mount_point, size_mb):
size_kb = size_mb * 1024
size = size_kb * 1024
swapfile = os.path.join(mount_point, 'swapfile')
swaplist = shellutil.run_get_output("swapon -s")[1]
if swapfile in swapList and os.path.getsize(swapfile) == size:
logger.Info("Swap already enabled")
return
if swapfile in swaplist and os.path.getsize(swapfile) == size:
logger.info("Swap already enabled")
return
if os.path.isfile(swapfile) and os.path.getsize(swapfile) != size:
logger.Info("Remove old swap file")
shellutil.Run("swapoff -a", chk_err=False)
logger.info("Remove old swap file")
shellutil.run("swapoff -a", chk_err=False)
os.remove(swapfile)
if not os.path.isfile(swapfile):
logger.Info("Create swap file")
shellutil.Run(("dd if=/dev/zero of={0} bs=1024 "
"count={1}").format(swapfile, sizeKB))
shellutil.Run("mkswap {0}".format(swapfile))
if shellutil.Run("swapon {0}".format(swapfile)):
logger.info("Create swap file")
shellutil.run(("dd if=/dev/zero of={0} bs=1024 "
"count={1}").format(swapfile, size_kb))
shellutil.run("mkswap {0}".format(swapfile))
if shellutil.run("swapon {0}".format(swapfile)):
raise ResourceDiskError("{0}".format(swapfile))
logger.Info("Enabled {0}KB of swap at {1}".format(sizeKB, swapfile))
logger.info("Enabled {0}KB of swap at {1}".format(size_kb, swapfile))
+31 -31
View File
@@ -21,61 +21,61 @@ import os
import time
import azurelinuxagent.logger as logger
import azurelinuxagent.conf as conf
from azurelinuxagent.metadata import GuestAgentLongName, GuestAgentVersion, \
DistroName, DistroVersion, DistroFullName
from azurelinuxagent.metadata import agent_long_name, AGENT_VERSION, \
DISTRO_NAME, DISTRO_VERSION, DISTRO_FULL_NAME
import azurelinuxagent.protocol as prot
import azurelinuxagent.event as event
from azurelinuxagent.utils.osutil import OSUtil
from azurelinuxagent.utils.osutil import OSUTIL
import azurelinuxagent.utils.fileutil as fileutil
class RunHandler(object):
class MainHandler(object):
def __init__(self, handlers):
self.handlers = handlers
def run(self):
logger.Info("{0} Version:{1}", GuestAgentLongName, GuestAgentVersion)
logger.Info("OS: {0} {1}", DistroName, DistroVersion)
logger.info("{0} Version:{1}", agent_long_name, AGENT_VERSION)
logger.info("OS: {0} {1}", DISTRO_NAME, DISTRO_VERSION)
event.EnableUnhandledErrorDump("Azure Linux Agent")
fileutil.SetFileContents(OSUtil.GetAgentPidPath(),
event.enable_unhandled_err_dump("Azure Linux Agent")
fileutil.write_file(OSUTIL.get_agent_pid_file_path(),
str(os.getpid()))
if conf.GetSwitch("DetectScvmmEnv", False):
if self.handlers.scvmmHandler.detectScvmmEnv():
if conf.get_switch("DetectScvmmEnv", False):
if self.handlers.scvmm_handler.detect_scvmm_env():
return
self.handlers.dhcpHandler.probe()
prot.DetectDefaultProtocol()
event.EventMonitor().startEventsLoop()
self.handlers.dhcp_handler.probe()
self.handlers.provisionHandler.process()
prot.detect_default_protocol()
if conf.GetSwitch("ResourceDisk.Format", False):
self.handlers.resourceDiskHandler.startActivateResourceDisk()
self.handlers.envHandler.startMonitor()
event.EventMonitor().start()
protocol = prot.Factory.getDefaultProtocol()
self.handlers.provision_handler.process()
if conf.get_switch("ResourceDisk.Format", False):
self.handlers.resource_disk_handler.start_activate_resource_disk()
self.handlers.env_handler.start()
protocol = prot.Factory.get_default_protocol()
while True:
#Handle extensions
handlerStatusList = self.handlers.extensionHandler.process()
h_status_list = self.handlers.extension_handler.process()
#Report status
vmStatus = prot.VMStatus()
vmStatus.vmAgent.agentVersion = GuestAgentLongName
vmStatus.vmAgent.status = "Ready"
vmStatus.vmAgent.message = "Guest Agent is running"
for handlerStatus in handlerStatusList:
vmStatus.extensionHandlers.append(handlerStatus)
vm_status = prot.VMStatus()
vm_status.vmAgent.agentVersion = agent_long_name
vm_status.vmAgent.status = "Ready"
vm_status.vmAgent.message = "Guest Agent is running"
for h_status in h_status_list:
vm_status.extensionHandlers.append(h_status)
try:
logger.Info("Report vm status")
protocol.reportStatus(vmStatus)
logger.info("Report vm status")
protocol.report_status(vm_status)
except prot.ProtocolError as e:
logger.Error("Failed to report vm status: {0}", e)
logger.error("Failed to report vm status: {0}", e)
time.sleep(25)
+17 -17
View File
@@ -20,28 +20,28 @@
import os
import subprocess
import azurelinuxagent.logger as logger
from azurelinuxagent.utils.osutil import OSUtil
from azurelinuxagent.utils.osutil import OSUTIL
VmmConfigFileName = "linuxosconfiguration.xml"
VmmStartupScriptName= "install"
VMM_CONF_FILE_NAME = "linuxosconfiguration.xml"
VMM_STARTUP_SCRIPT_NAME= "install"
class ScvmmHandler(object):
def detectScvmmEnv(self):
logger.Info("Detecting Microsoft System Center VMM Environment")
OSUtil.MountDvd(maxRetry=1, chk_err=False)
mountPoint = OSUtil.GetDvdMountPoint()
found = os.path.isfile(os.path.join(mountPoint, VmmConfigFileName))
if found:
self.startScvmmAgent()
def detect_scvmm_env(self):
logger.info("Detecting Microsoft System Center VMM Environment")
OSUTIL.mount_dvd(max_retry=1, chk_err=False)
mount_point = OSUTIL.get_dvd_mount_point()
found = os.path.isfile(os.path.join(mount_point, VMM_CONF_FILE_NAME))
if found:
self.start_scvmm_agent()
else:
OSUtil.UmountDvd(chk_err=False)
OSUTIL.umount_dvd(chk_err=False)
return found
def startScvmmAgent(self):
logger.Info("Starting Microsoft System Center VMM Initialization "
def start_scvmm_agent(self):
logger.info("Starting Microsoft System Center VMM Initialization "
"Process")
mountPoint = OSUtil.GetDvdMountPoint()
startupScript = os.path.join(mountPoint, VmmStartupScriptName)
subprocess.Popen(["/bin/bash", startupScript, "-p " + mountPoint])
mount_point = OSUTIL.get_dvd_mount_point()
startup_script = os.path.join(mount_point, VMM_STARTUP_SCRIPT_NAME)
subprocess.Popen(["/bin/bash", startup_script, "-p " + mount_point])
+16 -16
View File
@@ -16,31 +16,31 @@
#
import azurelinuxagent.logger as logger
from azurelinuxagent.metadata import DistroName
import azurelinuxagent.distro.default.loader as defaultLoader
from azurelinuxagent.metadata import DISTRO_NAME
import azurelinuxagent.distro.default.loader as default_loader
def GetDistroLoader():
def get_distro_loader():
try:
logger.Verbose("Loading distro implemetation from: {0}", DistroName)
pkgName = "azurelinuxagent.distro.{0}.loader".format(DistroName)
return __import__(pkgName, fromlist="loader")
logger.verb("Loading distro implemetation from: {0}", DISTRO_NAME)
pkg_name = "azurelinuxagent.distro.{0}.loader".format(DISTRO_NAME)
return __import__(pkg_name, fromlist="loader")
except ImportError as e:
logger.Warn("Unable to load distro implemetation for {0}.", DistroName)
logger.Warn("Use default distro implemetation instead.")
return defaultLoader
logger.warn("Unable to load distro implemetation for {0}.", DISTRO_NAME)
logger.warn("Use default distro implemetation instead.")
return default_loader
distroLoader = GetDistroLoader()
DISTRO_LOADER = get_distro_loader()
def GetOSUtil():
def get_osutil():
try:
return distroLoader.GetOSUtil()
return DISTRO_LOADER.get_osutil()
except AttributeError:
return defaultLoader.GetOSUtil()
return default_loader.get_osutil()
def GetHandlers():
def get_handlers():
try:
return distroLoader.GetHandlers()
return DISTRO_LOADER.get_handlers()
except AttributeError:
return defaultLoader.GetHandlers()
return default_loader.get_handlers()
+4 -4
View File
@@ -17,9 +17,9 @@
# Requires Python 2.4+ and Openssl 1.0+
#
from azurelinuxagent.metadata import DistroName, DistroVersion
from azurelinuxagent.metadata import DISTRO_NAME, DISTRO_VERSION
import azurelinuxagent.distro.redhat.loader as redhat
def GetOSUtil():
return redhat.GetOSUtil()
def get_osutil():
return redhat.get_osutil()
+3 -3
View File
@@ -17,11 +17,11 @@
# Requires Python 2.4+ and Openssl 1.0+
#
from azurelinuxagent.metadata import DistroName, DistroVersion
from azurelinuxagent.metadata import DISTRO_NAME, DISTRO_VERSION
def GetOSUtil():
def get_osutil():
from azurelinuxagent.distro.redhat.osutil import Redhat6xOSUtil, RedhatOSUtil
if DistroVersion < "7":
if DISTRO_VERSION < "7":
return Redhat6xOSUtil()
else:
return RedhatOSUtil()
+46 -46
View File
@@ -30,54 +30,54 @@ import azurelinuxagent.logger as logger
import azurelinuxagent.utils.fileutil as fileutil
import azurelinuxagent.utils.shellutil as shellutil
import azurelinuxagent.utils.textutil as textutil
from azurelinuxagent.distro.default.osutil import OSUtil, OSUtilError
from azurelinuxagent.distro.default.osutil import DefaultOSUtil, OSUtilError
class Redhat6xOSUtil(OSUtil):
class Redhat6xOSUtil(DefaultOSUtil):
def __init__(self):
super(Redhat6xOSUtil, self).__init__()
self.sshdConfigPath = '/etc/ssh/sshd_config'
self.opensslCmd = '/usr/bin/openssl'
self.configPath = '/etc/waagent.conf'
self.sshd_conf_file_path = '/etc/ssh/sshd_config'
self.openssl_cmd = '/usr/bin/openssl'
self.conf_file_path = '/etc/waagent.conf'
self.selinux=None
def StartNetwork(self):
return shellutil.Run("/sbin/service networking start", chk_err=False)
def start_network(self):
return shellutil.run("/sbin/service networking start", chk_err=False)
def RestartSshService(self):
return shellutil.Run("/sbin/service sshd condrestart", chk_err=False)
def restart_ssh_service(self):
return shellutil.run("/sbin/service sshd condrestart", chk_err=False)
def StopAgentService(self):
return shellutil.Run("/sbin/service waagent stop", chk_err=False)
def stop_agent_service(self):
return shellutil.run("/sbin/service waagent stop", chk_err=False)
def StartAgentService(self):
return shellutil.Run("/sbin/service waagent start", chk_err=False)
def start_agent_service(self):
return shellutil.run("/sbin/service waagent start", chk_err=False)
def RegisterAgentService(self):
return shellutil.Run("chkconfig --add waagent", chk_err=False)
def UnregisterAgentService(self):
return shellutil.Run("chkconfig --del waagent", chk_err=False)
def register_agent_service(self):
return shellutil.run("chkconfig --add waagent", chk_err=False)
def RsaPublicKeyToSshRsa(self, publicKey):
lines = publicKey.split("\n")
def unregister_agent_service(self):
return shellutil.run("chkconfig --del waagent", chk_err=False)
def asn1_to_ssh_rsa(self, pubkey):
lines = pubkey.split("\n")
lines = filter(lambda x : not x.startswith("----"), lines)
base64Encoded = "".join(lines)
base64_encoded = "".join(lines)
try:
#TODO remove pyasn1 dependency
from pyasn1.codec.der import decoder as der_decoder
derEncoded = base64.b64decode(base64Encoded)
derEncoded = der_decoder.decode(derEncoded)[0][1]
k = der_decoder.decode(textutil.BitsToString(derEncoded))[0]
der_encoded = base64.b64decode(base64_encoded)
der_encoded = der_decoder.decode(der_encoded)[0][1]
k = der_decoder.decode(textutil.bits_to_str(der_encoded))[0]
n=k[0]
e=k[1]
keydata=""
keydata += struct.pack('>I',len("ssh-rsa"))
keydata += "ssh-rsa"
keydata += struct.pack('>I',len(textutil.NumberToBytes(e)))
keydata += textutil.NumberToBytes(e)
keydata += struct.pack('>I',len(textutil.NumberToBytes(n)) + 1)
keydata += struct.pack('>I',len(textutil.num_to_bytes(e)))
keydata += textutil.num_to_bytes(e)
keydata += struct.pack('>I',len(textutil.num_to_bytes(n)) + 1)
keydata += "\0"
keydata += textutil.NumberToBytes(n)
keydata += textutil.num_to_bytes(n)
return "ssh-rsa " + base64.b64encode(keydata) + "\n"
except ImportError as e:
raise OSUtilError("Failed to load pyasn1.codec.der")
@@ -85,37 +85,37 @@ class Redhat6xOSUtil(OSUtil):
raise OSUtilError(("Failed to convert public key: {0} {1}"
"").format(type(e).__name__, e))
def OpenSslToOpenSsh(self, inputFile, outputFile):
publicKey = fileutil.GetFileContents(inputFile)
sshRsaPublicKey = self.RsaPublicKeyToSshRsa(publicKey)
fileutil.SetFileContents(outputFile, sshRsaPublicKey)
def openssl_to_openssh(self, input_file, output_file):
pubkey = fileutil.read_file(input_file)
ssh_rsa_pubkey = self.asn1_to_ssh_rsa(pubkey)
fileutil.write_file(output_file, ssh_rsa_pubkey)
#Override
def GetDhcpProcessId(self):
ret= shellutil.RunGetOutput("pidof dhclient")
def get_dhcp_pid(self):
ret= shellutil.run_get_output("pidof dhclient")
return ret[1] if ret[0] == 0 else None
class RedhatOSUtil(Redhat6xOSUtil):
def __init__(self):
super(RedhatOSUtil, self).__init__()
def SetHostname(self, hostname):
super(RedhatOSUtil, self).SetHostname(hostname)
fileutil.UpdateConfigFile('/etc/sysconfig/network',
def set_hostname(self, hostname):
super(RedhatOSUtil, self).set_hostname(hostname)
fileutil.update_conf_file('/etc/sysconfig/network',
'HOSTNAME',
'HOSTNAME={0}'.format(hostname))
def SetDhcpHostname(self, hostname):
ifname = self.GetInterfaceName()
def set_dhcp_hostname(self, hostname):
ifname = self.get_if_name()
filepath = "/etc/sysconfig/network-scripts/ifcfg-{0}".format(ifname)
fileutil.UpdateConfigFile(filepath,
fileutil.update_conf_file(filepath,
'DHCP_HOSTNAME',
'DHCP_HOSTNAME={0}'.format(hostname))
def RegisterAgentService(self):
return shellutil.Run("systemctl enable waagent", chk_err=False)
def UnregisterAgentService(self):
return shellutil.Run("systemctl disable waagent", chk_err=False)
def register_agent_service(self):
return shellutil.run("systemctl enable waagent", chk_err=False)
def unregister_agent_service(self):
return shellutil.run("systemctl disable waagent", chk_err=False)
+4 -4
View File
@@ -17,12 +17,12 @@
# Requires Python 2.4+ and Openssl 1.0+
#
from azurelinuxagent.metadata import DistroName, DistroVersion, DistroFullName
from azurelinuxagent.metadata import DISTRO_NAME, DISTRO_VERSION, DISTRO_FULL_NAME
def GetOSUtil():
def get_osutil():
from azurelinuxagent.distro.suse.osutil import SUSE11OSUtil, SUSEOSUtil
if DistroFullName=='SUSE Linux Enterprise Server' and DistroVersion < '12' \
or DistroFullName == 'openSUSE' and DistroVersion < '13.2':
if DISTRO_FULL_NAME=='SUSE Linux Enterprise Server' and DISTRO_VERSION < '12' \
or DISTRO_FULL_NAME == 'openSUSE' and DISTRO_VERSION < '13.2':
return SUSE11OSUtil()
else:
return SUSEOSUtil()
+35 -35
View File
@@ -29,60 +29,60 @@ import azurelinuxagent.logger as logger
import azurelinuxagent.utils.fileutil as fileutil
import azurelinuxagent.utils.shellutil as shellutil
import azurelinuxagent.utils.textutil as textutil
from azurelinuxagent.metadata import DistroName, DistroVersion, DistroFullName
from azurelinuxagent.metadata import DISTRO_NAME, DISTRO_VERSION, DISTRO_FULL_NAME
from azurelinuxagent.distro.default.osutil import DefaultOSUtil
class SUSE11OSUtil(DefaultOSUtil):
def __init__(self):
super(SUSE11OSUtil, self).__init__()
self.dhcpClientName='dhcpcd'
self.dhclient_name='dhcpcd'
def SetHostname(self, hostname):
fileutil.SetFileContents('/etc/HOSTNAME', hostname)
shellutil.Run("hostname {0}".format(hostname), chk_err=False)
def set_hostname(self, hostname):
fileutil.write_file('/etc/HOSTNAME', hostname)
shellutil.run("hostname {0}".format(hostname), chk_err=False)
def GetDhcpProcessId(self):
ret= shellutil.RunGetOutput("pidof {0}".format(self.dhcpClientName))
def get_dhcp_pid(self):
ret= shellutil.run_get_output("pidof {0}".format(self.dhclient_name))
return ret[1] if ret[0] == 0 else None
def IsDhcpEnabled(self):
def is_dhcp_enabled(self):
return True
def StopDhcpService(self):
cmd = "/sbin/service {0} stop".format(self.dhcpClientName)
return shellutil.Run(cmd, chk_err=False)
def stop_dhcp_service(self):
cmd = "/sbin/service {0} stop".format(self.dhclient_name)
return shellutil.run(cmd, chk_err=False)
def StartDhcpService(self):
cmd = "/sbin/service {0} start".format(self.dhcpClientName)
return shellutil.Run(cmd, chk_err=False)
def start_dhcp_service(self):
cmd = "/sbin/service {0} start".format(self.dhclient_name)
return shellutil.run(cmd, chk_err=False)
def StartNetwork(self) :
return shellutil.Run("/sbin/service start network", chk_err=False)
def start_network(self) :
return shellutil.run("/sbin/service start network", chk_err=False)
def RestartSshService(self):
return shellutil.Run("/sbin/service sshd restart", chk_err=False)
def restart_ssh_service(self):
return shellutil.run("/sbin/service sshd restart", chk_err=False)
def StopAgentService(self):
return shellutil.Run("/sbin/service waagent stop", chk_err=False)
def stop_agent_service(self):
return shellutil.run("/sbin/service waagent stop", chk_err=False)
def StartAgentService(self):
return shellutil.Run("/sbin/service waagent start", chk_err=False)
def RegisterAgentService(self):
return shellutil.Run("/sbin/insserv waagent", chk_err=False)
def UnregisterAgentService(self):
return shellutil.Run("/sbin/insserv -r waagent", chk_err=False)
def start_agent_service(self):
return shellutil.run("/sbin/service waagent start", chk_err=False)
def register_agent_service(self):
return shellutil.run("/sbin/insserv waagent", chk_err=False)
def unregister_agent_service(self):
return shellutil.run("/sbin/insserv -r waagent", chk_err=False)
class SUSEOSUtil(SUSE11OSUtil):
def __init__(self):
super(SUSEOSUtil, self).__init__()
self.dhcpClientName = 'wickedd-dhcp4'
self.dhclient_name = 'wickedd-dhcp4'
def RegisterAgentService(self):
return shellutil.Run("systemctl enable waagent", chk_err=False)
def UnregisterAgentService(self):
return shellutil.Run("systemctl disable waagent", chk_err=False)
def register_agent_service(self):
return shellutil.run("systemctl enable waagent", chk_err=False)
def unregister_agent_service(self):
return shellutil.run("systemctl disable waagent", chk_err=False)
+8 -8
View File
@@ -22,22 +22,22 @@ import azurelinuxagent.logger as logger
import azurelinuxagent.utils.fileutil as fileutil
from azurelinuxagent.distro.default.deprovision import DeprovisionHandler, DeprovisionAction
def DeleteResolve():
def del_resolv():
if os.path.realpath('/etc/resolv.conf') != '/run/resolvconf/resolv.conf':
logger.Info("resolvconf is not configured. Removing /etc/resolv.conf")
fileutil.RemoveFiles('/etc/resolv.conf')
logger.info("resolvconf is not configured. Removing /etc/resolv.conf")
fileutil.rm_files('/etc/resolv.conf')
else:
logger.Info("resolvconf is enabled; leaving /etc/resolv.conf intact")
fileutil.RemoveFiles('/etc/resolvconf/resolv.conf.d/tail',
logger.info("resolvconf is enabled; leaving /etc/resolv.conf intact")
fileutil.rm_files('/etc/resolvconf/resolv.conf.d/tail',
'/etc/resolvconf/resolv.conf.d/originial')
class UbuntuDeprovisionHandler(DeprovisionHandler):
def setUp(self, deluser):
warnings, actions = super(UbuntuDeprovisionHandler, self).setUp(deluser)
def setup(self, deluser):
warnings, actions = super(UbuntuDeprovisionHandler, self).setup(deluser)
warnings.append("WARNING! Nameserver configuration in "
"/etc/resolvconf/resolv.conf.d/{tail,originial} "
"will be deleted.")
actions.append(DeprovisionAction(DeleteResolve))
actions.append(DeprovisionAction(del_resolv))
return warnings, actions
@@ -17,11 +17,13 @@
# Requires Python 2.4+ and Openssl 1.0+
#
from provision import UbuntuProvisionHandler
from azurelinuxagent.distro.ubuntu.provision import UbuntuProvisionHandler
from azurelinuxagent.distro.ubuntu.deprovision import UbuntuDeprovisionHandler
from azurelinuxagent.distro.default.handlerFactory import DefaultHandlerFactory
class UbuntuHandlerFactory(DefaultHandlerFactory):
def __init__(self):
super(UbuntuHandlerFactory, self).__init__()
self.provisionHandler = UbuntuProvisionHandler()
self.provision_handler = UbuntuProvisionHandler()
self.deprovision_handler = UbuntuDeprovisionHandler()
+6 -6
View File
@@ -17,20 +17,20 @@
# Requires Python 2.4+ and Openssl 1.0+
#
from azurelinuxagent.metadata import DistroName, DistroVersion
from azurelinuxagent.metadata import DISTRO_NAME, DISTRO_VERSION
def GetOSUtil():
def get_osutil():
from azurelinuxagent.distro.ubuntu.osutil import Ubuntu1204OSUtil, \
UbuntuOSUtil, \
Ubuntu14xOSUtil
if DistroVersion == "12.04":
if DISTRO_VERSION == "12.04":
return Ubuntu1204OSUtil()
elif DistroVersion == "14.04" or DistroVersion == "14.10":
elif DISTRO_VERSION == "14.04" or DISTRO_VERSION == "14.10":
return Ubuntu14xOSUtil()
else:
return UbuntuOSUtil()
def GetHandlers():
def get_handlers():
from azurelinuxagent.distro.ubuntu.handlerFactory import UbuntuHandlerFactory
return UbuntuHandlerFactory()
+14 -14
View File
@@ -35,31 +35,31 @@ class Ubuntu14xOSUtil(DefaultOSUtil):
def __init__(self):
super(Ubuntu14xOSUtil, self).__init__()
def StartNetwork(self):
return shellutil.Run("service networking start", chk_err=False)
def StopAgentService(self):
return shellutil.Run("service walinuxagent stop", chk_err=False)
def start_network(self):
return shellutil.run("service networking start", chk_err=False)
def StartAgentService(self):
return shellutil.Run("service walinuxagent start", chk_err=False)
def stop_agent_service(self):
return shellutil.run("service walinuxagent stop", chk_err=False)
def start_agent_service(self):
return shellutil.run("service walinuxagent start", chk_err=False)
class Ubuntu1204OSUtil(Ubuntu14xOSUtil):
def __init__(self):
super(Ubuntu1204OSUtil, self).__init__()
#Override
def GetDhcpProcessId(self):
ret= shellutil.RunGetOutput("pidof dhclient3")
def get_dhcp_pid(self):
ret= shellutil.run_get_output("pidof dhclient3")
return ret[1] if ret[0] == 0 else None
class UbuntuOSUtil(Ubuntu14xOSUtil):
def __init__(self):
super(UbuntuOSUtil, self).__init__()
def RegisterAgentService(self):
return shellutil.Run("systemctl unmask walinuxagent", chk_err=False)
def UnregisterAgentService(self):
return shellutil.Run("systemctl mask walinuxagent", chk_err=False)
def register_agent_service(self):
return shellutil.run("systemctl unmask walinuxagent", chk_err=False)
def unregister_agent_service(self):
return shellutil.run("systemctl mask walinuxagent", chk_err=False)
+20 -20
View File
@@ -23,7 +23,7 @@ import azurelinuxagent.logger as logger
import azurelinuxagent.conf as conf
import azurelinuxagent.protocol as prot
from azurelinuxagent.exception import *
from azurelinuxagent.utils.osutil import OSUtil
from azurelinuxagent.utils.osutil import OSUTIL
import azurelinuxagent.utils.shellutil as shellutil
import azurelinuxagent.utils.fileutil as fileutil
from azurelinuxagent.distro.default.provision import ProvisionHandler
@@ -34,38 +34,38 @@ On ubuntu image, provision could be disabled.
class UbuntuProvisionHandler(ProvisionHandler):
def process(self):
#If provision is enabled, run default provision handler
if conf.GetSwitch("Provisioning.Enabled", False):
if conf.get_switch("Provisioning.Enabled", False):
super(UbuntuProvisionHandler, self).process()
return
logger.Info("Run Ubuntu provision handler")
provisioned = os.path.join(OSUtil.GetLibDir(), "provisioned")
logger.info("run Ubuntu provision handler")
provisioned = os.path.join(OSUTIL.get_lib_dir(), "provisioned")
if os.path.isfile(provisioned):
return
logger.Info("Waiting cloud-init to finish provisioning.")
protocol = prot.Factory.getDefaultProtocol()
logger.info("Waiting cloud-init to finish provisioning.")
protocol = prot.Factory.get_default_protocol()
try:
logger.Info("Wait for ssh host key to be generated.")
thumbprint = self.waitForSshHostKey()
fileutil.SetFileContents(provisioned, "")
logger.info("Wait for ssh host key to be generated.")
thumbprint = self.wait_for_ssh_host_key()
fileutil.write_file(provisioned, "")
logger.Info("Finished provisioning")
logger.info("Finished provisioning")
status = prot.ProvisionStatus(status="Ready")
status.properties.certificateThumbprint = thumbprint
protocol.reportProvisionStatus(status)
protocol.report_provision_status(status)
except ProvisionError as e:
logger.Error("Provision failed: {0}", e)
protocol.reportProvisionStatus(status="NotReady", subStatus=str(e))
logger.error("Provision failed: {0}", e)
protocol.report_provision_status(status="NotReady", subStatus=str(e))
def waitForSshHostKey(self, maxRetry=60):
keyPairType = conf.Get("Provisioning.SshHostKeyPairType", "rsa")
path = '/etc/ssh/ssh_host_{0}_key'.format(keyPairType)
for retry in range(0, maxRetry):
def wait_for_ssh_host_key(self, max_retry=60):
kepair_type = conf.get("Provisioning.SshHostKeyPairType", "rsa")
path = '/etc/ssh/ssh_host_{0}_key'.format(kepair_type)
for retry in range(0, max_retry):
if os.path.isfile(path):
return self.getSshHostKeyThumbprint(keyPairType)
if retry < maxRetry - 1:
logger.Info("Wait for ssh host key be generated: {0}", path)
return self.get_ssh_host_key_thumbprint(kepair_type)
if retry < max_retry - 1:
logger.info("Wait for ssh host key be generated: {0}", path)
time.sleep(5)
raise ProvisionError("Ssh hsot key is not generated.")
+89 -89
View File
@@ -26,9 +26,9 @@ import threading
import platform
import azurelinuxagent.logger as logger
import azurelinuxagent.protocol as prot
from azurelinuxagent.metadata import DistroName, DistroVersion, DistroCodeName,\
GuestAgentVersion
from azurelinuxagent.utils.osutil import OSUtil
from azurelinuxagent.metadata import DISTRO_NAME, DISTRO_VERSION, DISTRO_CODE_NAME,\
AGENT_VERSION
from azurelinuxagent.utils.osutil import OSUTIL
class EventError(Exception):
pass
@@ -42,108 +42,108 @@ class WALAEventOperation:
Enable = "Enable"
Download = "Download"
Upgrade = "Upgrade"
Update = "Update"
Update = "Update"
ActivateResourceDisk="ActivateResourceDisk"
UnhandledError="UnhandledError"
class EventMonitor(object):
def __init__(self):
self.sysInfo = []
self.eventDir = os.path.join(OSUtil.GetLibDir(), "events")
self.initSystemInfo()
self.sysinfo = []
self.event_dir = os.path.join(OSUTIL.get_lib_dir(), "events")
self.init_sysinfo()
def initSystemInfo(self):
osversion = "{0}:{1}-{2}-{3}:{4}".format(platform.system(),
DistroName,
DistroVersion,
DistroCodeName,
def init_sysinfo(self):
osversion = "{0}:{1}-{2}-{3}:{4}".format(platform.system(),
DISTRO_NAME,
DISTRO_VERSION,
DISTRO_CODE_NAME,
platform.release())
self.sysInfo.append(prot.TelemetryEventParam("OSVersion", osversion))
self.sysInfo.append(prot.TelemetryEventParam("GAVersion",
GuestAgentVersion))
self.sysInfo.append(prot.TelemetryEventParam("RAM",
OSUtil.GetTotalMemory()))
self.sysInfo.append(prot.TelemetryEventParam("Processors",
OSUtil.GetProcessorCores()))
protocol = prot.Factory.getDefaultProtocol()
metadata = protocol.getInstanceMetadata()
self.sysInfo.append(prot.TelemetryEventParam("TenantName",
self.sysinfo.append(prot.TelemetryEventParam("OSVersion", osversion))
self.sysinfo.append(prot.TelemetryEventParam("GAVersion",
AGENT_VERSION))
self.sysinfo.append(prot.TelemetryEventParam("RAM",
OSUTIL.get_total_mem()))
self.sysinfo.append(prot.TelemetryEventParam("Processors",
OSUTIL.get_processor_cores()))
protocol = prot.Factory.get_default_protocol()
metadata = protocol.get_instance_metadata()
self.sysinfo.append(prot.TelemetryEventParam("TenantName",
metadata.deploymentName))
self.sysInfo.append(prot.TelemetryEventParam("RoleName",
self.sysinfo.append(prot.TelemetryEventParam("RoleName",
metadata.roleName))
self.sysInfo.append(prot.TelemetryEventParam("RoleInstanceName",
self.sysinfo.append(prot.TelemetryEventParam("RoleInstanceName",
metadata.roleInstanceId))
self.sysInfo.append(prot.TelemetryEventParam("ContainerId",
self.sysinfo.append(prot.TelemetryEventParam("ContainerId",
metadata.containerId))
def startEventsLoop(self):
eventThread = threading.Thread(target = self.eventsLoop)
eventThread.setDaemon(True)
eventThread.start()
def start(self):
event_thread = threading.Thread(target = self.run)
event_thread.setDaemon(True)
event_thread.start()
def collectEvent(self, eventFilePath):
def collect_event(self, evt_file_name):
try:
with open(eventFilePath, "rb") as hfile:
with open(evt_file_name, "rb") as evt_file:
#if fail to open or delete the file, throw exception
jsonStr = hfile.read().decode("utf-8",'ignore')
os.remove(eventFilePath)
return jsonStr
json_str = evt_file.read().decode("utf-8",'ignore')
os.remove(evt_file_name)
return json_str
except IOError as e:
msg = "Failed to process {0}, {1}".format(eventFilePath, e)
msg = "Failed to process {0}, {1}".format(evt_file_name, e)
raise EventError(msg)
def collectAndSendEvents(self):
eventList = prot.TelemetryEventList()
eventFiles = os.listdir(self.eventDir)
for eventFile in eventFiles:
if not eventFile.endswith(".tld"):
def collect_and_send_events(self):
event_list = prot.TelemetryEventList()
event_files = os.listdir(self.event_dir)
for event_file in event_files:
if not event_file.endswith(".tld"):
continue
eventFilePath = os.path.join(self.eventDir, eventFile)
event_file_path = os.path.join(self.event_dir, event_file)
try:
dataStr = self.collectEvent(eventFilePath)
data_str = self.collect_event(event_file_path)
except EventError as e:
logger.Error("{0}", e)
logger.error("{0}", e)
continue
try:
data = json.loads(dataStr)
data = json.loads(data_str)
except ValueError as e:
logger.Verbose(dataStr)
logger.Error("Failed to decode json event file{0}", e)
logger.verb(data_str)
logger.error("Failed to decode json event file{0}", e)
continue
event = prot.TelemetryEvent()
prot.set_properties(event, data)
event.parameters.extend(self.sysInfo)
eventList.events.append(event)
if len(eventList.events) == 0:
event.parameters.extend(self.sysinfo)
event_list.events.append(event)
if len(event_list.events) == 0:
return
try:
protocol = prot.Factory.getDefaultProtocol()
protocol.reportEvent(eventList)
protocol = prot.Factory.get_default_protocol()
protocol.report_event(event_list)
except prot.ProtocolError as e:
logger.Error("{0}", e)
logger.error("{0}", e)
def eventsLoop(self):
lastHeatbeat = datetime.datetime.min
def run(self):
last_heartbeat = datetime.datetime.min
period = datetime.timedelta(hours = 12)
while(True):
if (datetime.datetime.now()-lastHeatbeat) > period:
lastHeatbeat = datetime.datetime.now()
AddExtensionEvent(op=WALAEventOperation.HeartBeat,
name="WALA",isSuccess=True)
self.collectAndSendEvents()
if (datetime.datetime.now()-last_heartbeat) > period:
last_heartbeat = datetime.datetime.now()
add_event(op=WALAEventOperation.HeartBeat,
name="WALA",is_success=True)
self.collect_and_send_events()
time.sleep(60)
def SaveEvent(data):
eventfolder = os.path.join(OSUtil.GetLibDir(), 'events')
if not os.path.exists(eventfolder):
os.mkdir(eventfolder)
os.chmod(eventfolder,0700)
if len(os.listdir(eventfolder)) > 1000:
raise EventError("Too many files under: {0}", eventfolder)
filename = os.path.join(eventfolder, str(int(time.time()*1000000)))
def save_event(data):
event_dir = os.path.join(OSUTIL.get_lib_dir(), 'events')
if not os.path.exists(event_dir):
os.mkdir(event_dir)
os.chmod(event_dir,0700)
if len(os.listdir(event_dir)) > 1000:
raise EventError("Too many files under: {0}", event_dir)
filename = os.path.join(event_dir, str(int(time.time()*1000000)))
try:
with open(filename+".tmp",'wb+') as hfile:
hfile.write(data.encode("utf-8"))
@@ -152,38 +152,38 @@ def SaveEvent(data):
raise EventError("Failed to write events to file:{0}", e)
def AddExtensionEvent(name, op, isSuccess, duration=0, version="1.0",
message="", evtType="", isInternal=False):
def add_event(name, op, is_success, duration=0, version="1.0",
message="", evt_type="", is_internal=False):
event = prot.TelemetryEvent(1, "69B669B9-4AF8-4C50-BDC4-6006FA76E975")
event.parameters.append(prot.TelemetryEventParam('Name', name))
event.parameters.append(prot.TelemetryEventParam('Version', version))
event.parameters.append(prot.TelemetryEventParam('IsInternal', isInternal))
event.parameters.append(prot.TelemetryEventParam('Operation', op))
event.parameters.append(prot.TelemetryEventParam('OperationSuccess',
isSuccess))
event.parameters.append(prot.TelemetryEventParam('Message', message))
event.parameters.append(prot.TelemetryEventParam('Duration', duration))
event.parameters.append(prot.TelemetryEventParam('ExtensionType', evtType))
event.parameters.append(prot.TelemetryEventParam('Name', name))
event.parameters.append(prot.TelemetryEventParam('Version', version))
event.parameters.append(prot.TelemetryEventParam('IsInternal', is_internal))
event.parameters.append(prot.TelemetryEventParam('Operation', op))
event.parameters.append(prot.TelemetryEventParam('OperationSuccess',
is_success))
event.parameters.append(prot.TelemetryEventParam('Message', message))
event.parameters.append(prot.TelemetryEventParam('Duration', duration))
event.parameters.append(prot.TelemetryEventParam('ExtensionType', evt_type))
data = prot.get_properties(event)
try:
SaveEvent(json.dumps(data))
save_event(json.dumps(data))
except EventError as e:
logger.Error("{0}", e)
logger.error("{0}", e)
def DumpUnhandledError(name):
def dump_unhandled_err(name):
if hasattr(sys, 'last_type') and hasattr(sys, 'last_value') and \
hasattr(sys, 'last_traceback'):
last_type = getattr(sys, 'last_type')
last_value = getattr(sys, 'last_value')
last_traceback = getattr(sys, 'last_traceback')
error = traceback.format_exception(last_type, last_value,
error = traceback.format_exception(last_type, last_value,
last_traceback)
message= "".join(error)
logger.Error(message)
AddExtensionEvent(name, isSuccess=False, message=message,
logger.error(message)
add_event(name, is_success=False, message=message,
op=WALAEventOperation.UnhandledError)
def EnableUnhandledErrorDump(name):
atexit.register(DumpUnhandledError, name)
def enable_unhandled_err_dump(name):
atexit.register(dump_unhandled_err, name)
+22 -18
View File
@@ -16,46 +16,50 @@
#
# Requires Python 2.4+ and Openssl 1.0+
#
"""
Defines all exceptions
"""
"""
Base class of agent error.
"""
class AgentError(Exception):
"""
Base class of agent error.
"""
def __init__(self, errno, msg):
msg = "({0}){1}".format(errno, msg)
super(AgentError, self).__init__(msg)
"""
When configure file is not found or malformed.
"""
class AgentConfigError(AgentError):
"""
When configure file is not found or malformed.
"""
def __init__(self, msg):
super(AgentConfigError, self).__init__('000001', msg)
"""
When network is not avaiable.
"""
class AgentNetworkError(AgentError):
"""
When network is not avaiable.
"""
def __init__(self, msg):
super(AgentNetworkError, self).__init__('000002', msg)
"""
When failed to execute an extension
"""
class ExtensionError(AgentError):
"""
When failed to execute an extension
"""
def __init__(self, msg):
super(ExtensionError, self).__init__('000003', msg)
"""
When provision failed
"""
class ProvisionError(AgentError):
"""
When provision failed
"""
def __init__(self, msg):
super(ProvisionError, self).__init__('000004', msg)
"""
Mount resource disk failed
"""
class ResourceDiskError(AgentError):
"""
Mount resource disk failed
"""
def __init__(self, msg):
super(ResourceDiskError, self).__init__('000005', msg)
+4 -2
View File
@@ -18,9 +18,11 @@
#
"""
Load OSUtil implementation from azurelinuxagent.distro
Handler handles different tasks like, provisioning, deprovisioning etc.
The handlers could be extended for different distros. The default
implementation is under azurelinuxagent.distros.default
"""
import azurelinuxagent.distro.loader as loader
Handlers = loader.GetHandlers()
HANDLERS = loader.get_handlers()
+45 -53
View File
@@ -17,13 +17,18 @@
# Implements parts of RFC 2131, 1541, 1497 and
# http://msdn.microsoft.com/en-us/library/cc227282%28PROT.10%29.aspx
# http://msdn.microsoft.com/en-us/library/cc227259%28PROT.13%29.aspx
"""
Log utils
"""
import sys
import traceback
import azurelinuxagent.utils.textutil as textutil
from datetime import datetime
class Logger(object):
"""
Logger class
"""
def __init__(self, logger=None, prefix=None):
self.appenders = []
if logger is not None:
@@ -43,23 +48,24 @@ class Logger(object):
self.log(LogLevel.ERROR, msg_format, *args)
def log(self, level, msg_format, *args):
msg_format = textutil.Ascii(msg_format)
args = map(lambda x : textutil.Ascii(x), args)
if(len(args) > 0):
msg_format = textutil.ascii(msg_format)
args = map(lambda x: textutil.ascii(x), args)
if len(args) > 0:
msg = msg_format.format(*args)
else:
msg = msg_format
time = datetime.now().strftime('%Y/%m/%d %H:%M:%S.%f')
levelStr = LogLevel.STRINGS[level]
level_str = LogLevel.STRINGS[level]
if self.prefix is not None:
logItem = "{0} {1} {2} {3}".format(time, levelStr, self.prefix, msg)
log_item = "{0} {1} {2} {3}".format(time, level_str, self.prefix,
msg)
else:
logItem = "{0} {1} {2}".format(time, levelStr, msg)
log_item = "{0} {1} {2}".format(time, level_str, msg)
for appender in self.appenders:
appender.write(level, logItem)
appender.write(level, log_item)
def addLoggerAppender(self, appenderType, level, path):
appender = CreateLoggerAppender(appenderType, level, path)
def add_appender(self, appender_type, level, path):
appender = _create_logger_appender(appender_type, level, path)
self.appenders.append(appender)
class ConsoleAppender(object):
@@ -68,15 +74,15 @@ class ConsoleAppender(object):
if level >= LogLevel.INFO:
self.level = level
self.path = path
def write(self, level, msg):
if self.level <= level:
try:
with open(self.path, "w") as console :
console.write(msg.encode('ascii','ignore') + "\n")
except IOError as e:
with open(self.path, "w") as console:
console.write(msg.encode('ascii', 'ignore') + "\n")
except IOError:
pass
class FileAppender(object):
def __init__(self, level, path):
self.level = level
@@ -86,8 +92,8 @@ class FileAppender(object):
if self.level <= level:
try:
with open(self.path, "a+") as log_file:
log_file.write(msg.encode('ascii','ignore') + "\n")
except IOError as e:
log_file.write(msg.encode('ascii', 'ignore') + "\n")
except IOError:
pass
class StdoutAppender(object):
@@ -97,13 +103,13 @@ class StdoutAppender(object):
def write(self, level, msg):
if self.level <= level:
try:
sys.stdout.write(msg.encode('ascii','ignore') + "\n")
except IOError as e:
sys.stdout.write(msg.encode('ascii', 'ignore') + "\n")
except IOError:
pass
#Initialize logger instance
DefaultLogger = Logger()
default_logger = Logger()
class LogLevel(object):
VERBOSE = 0
@@ -118,49 +124,35 @@ class LogLevel(object):
]
class AppenderType(object):
FILE=0
CONSOLE=1
STDOUT=2
FILE = 0
CONSOLE = 1
STDOUT = 2
def AddLoggerAppender(appenderType, level=LogLevel.INFO, path=None):
DefaultLogger.addLoggerAppender(appenderType, level, path)
def add_logger_appender(appender_type, level=LogLevel.INFO, path=None):
default_logger.add_appender(appender_type, level, path)
def Verbose(msg_format, *args):
DefaultLogger.verbose(msg_format, *args)
def verb(msg_format, *args):
default_logger.verbose(msg_format, *args)
def Info(msg_format, *args):
DefaultLogger.info(msg_format, *args)
def info(msg_format, *args):
default_logger.info(msg_format, *args)
def Warn(msg_format, *args):
DefaultLogger.warn(msg_format, *args)
def warn(msg_format, *args):
default_logger.warn(msg_format, *args)
def Error(msg_format, *args):
DefaultLogger.error(msg_format, *args)
def error(msg_format, *args):
default_logger.error(msg_format, *args)
def Log(level, msg_format, *args):
DefaultLogger.log(level, msg_format, args)
def log(level, msg_format, *args):
default_logger.log(level, msg_format, args)
def CreateLoggerAppender(appenderType, level=LogLevel.INFO, path=None):
if appenderType == AppenderType.CONSOLE :
def _create_logger_appender(appender_type, level=LogLevel.INFO, path=None):
if appender_type == AppenderType.CONSOLE:
return ConsoleAppender(level, path)
elif appenderType == AppenderType.FILE :
elif appender_type == AppenderType.FILE:
return FileAppender(level, path)
elif appenderType == AppenderType.STDOUT :
elif appender_type == AppenderType.STDOUT:
return StdoutAppender(level)
else:
raise ValueError("Unknown appender type")
def LogError(operation):
def Decorator(func):
def Wrapper(*args, **kwargs):
try:
result = func(*args, **kwargs)
except Exception, e:
Error("Failed to {0} :{1} {2}",
operation,
e,
traceback.format_exc())
raise e
return result
return Wrapper
return Decorator
+21 -21
View File
@@ -21,40 +21,40 @@ import os
import re
import platform
def GetDistroInfo():
def get_distro():
if 'FreeBSD' in platform.system():
release = re.sub('\-.*\Z', '', str(platform.release()))
osInfo = ['freebsd', release, '', 'freebsd']
osinfo = ['freebsd', release, '', 'freebsd']
if 'linux_distribution' in dir(platform):
osInfo = list(platform.linux_distribution(full_distribution_name=0))
fullName = platform.linux_distribution()[0].strip()
osInfo.append(fullName)
osinfo = list(platform.linux_distribution(full_distribution_name=0))
full_name = platform.linux_distribution()[0].strip()
osinfo.append(full_name)
else:
osInfo = platform.dist()
osinfo = platform.dist()
#The platform.py lib has issue with detecting oracle linux distribution.
#Merge the following patch provided by oracle as a temparory fix.
if os.path.exists("/etc/oracle-release"):
osInfo[2]="oracle"
osInfo[3]="Oracle Linux"
if os.path.exists("/etc/oracle-release"):
osinfo[2] = "oracle"
osinfo[3] = "Oracle Linux"
#Remove trailing whitespace and quote in distro name
osInfo[0] = osInfo[0].strip('"').strip(' ').lower()
return osInfo
osinfo[0] = osinfo[0].strip('"').strip(' ').lower()
return osinfo
GuestAgentName = "AzureLinuxAgent"
GuestAgentLongName = "Azure Linux Agent"
GuestAgentVersion='2.1.0'
GuestAgentLongVersion = "{0}-{1}".format(GuestAgentName, GuestAgentVersion)
GuestAgentDescription = """\
AGENT_NAME = "AzureLinuxAgent"
agent_long_name = "Azure Linux Agent"
AGENT_VERSION = '2.1.0'
agent_long_version = "{0}-{1}".format(AGENT_NAME, AGENT_VERSION)
agent_description = """\
The Azure Linux Agent supports the provisioning and running of Linux
VMs in the Azure cloud. This package should be installed on Linux disk
images that are built to run in the Azure environment.
"""
__DistroInfo = GetDistroInfo()
DistroName = __DistroInfo[0]
DistroVersion = __DistroInfo[1]
DistroCodeName = __DistroInfo[2]
DistroFullName = __DistroInfo[3]
__distro__ = get_distro()
DISTRO_NAME = __distro__[0]
DISTRO_VERSION = __distro__[1]
DISTRO_CODE_NAME = __distro__[2]
DISTRO_FULL_NAME = __distro__[3]
+1 -1
View File
@@ -19,5 +19,5 @@
from azurelinuxagent.protocol.common import *
from azurelinuxagent.protocol.protocolFactory import Factory, \
DetectDefaultProtocol
detect_default_protocol
+27 -27
View File
@@ -22,7 +22,7 @@ import re
import json
import xml.dom.minidom
import azurelinuxagent.logger as logger
from azurelinuxagent.utils.textutil import GetNodeTextData
from azurelinuxagent.utils.textutil import get_node_text
import azurelinuxagent.utils.fileutil as fileutil
class ProtocolError(Exception):
@@ -31,12 +31,12 @@ class ProtocolError(Exception):
class ProtocolNotFound(Exception):
pass
def validata_param(name, val, expectedType):
def validata_param(name, val, expected_type):
if val is None:
raise ProtocolError("Param {0} is None".format(name))
if not isinstance(val, expectedType):
if not isinstance(val, expected_type):
raise ProtocolError("Param {0} type should be {1}".format(name,
expectedType))
expected_type))
def set_properties(obj, data):
validata_param("obj", obj, DataContract)
@@ -45,20 +45,20 @@ def set_properties(obj, data):
props = vars(obj)
for name, val in props.items():
try:
newVal = data[name]
new_val = data[name]
except KeyError:
continue
if isinstance(newVal, dict):
set_properties(val, newVal)
elif isinstance(newVal, list):
if isinstance(new_val, dict):
set_properties(val, new_val)
elif isinstance(new_val, list):
validata_param("list", val, DataContractList)
for dataItem in newVal:
item = val.itemType()
set_properties(item, dataItem)
for data_item in new_val:
item = val.itemType()
set_properties(item, data_item)
val.append(item)
else:
setattr(obj, name, newVal)
setattr(obj, name, new_val)
def get_properties(obj):
validata_param("obj", obj, DataContract)
@@ -86,7 +86,7 @@ class DataContractList(list):
def __init__(self, itemType):
self.itemType = itemType
class VmInfo(DataContract):
class VMInfo(DataContract):
def __init__(self, subscriptionId=None, vmName=None):
self.subscriptionId = subscriptionId
self.vmName = vmName
@@ -179,14 +179,14 @@ class ExtensionSubStatus(DataContract):
class ExtensionStatus(DataContract):
def __init__(self, name=None, configurationAppliedTime=None, operation=None,
status=None, code=None, message=None, sequenceNumber=None):
status=None, code=None, message=None, seq_no=None):
self.name = name
self.configurationAppliedTime = configurationAppliedTime
self.operation = operation
self.status = status
self.code = code
self.message = message
self.sequenceNumber = sequenceNumber
self.sequenceNumber = seq_no
self.substatusList = DataContractList(ExtensionSubStatus)
class ExtensionHandlerStatus(DataContract):
@@ -223,27 +223,27 @@ class Protocol(DataContract):
def initialize(self):
raise NotImplementedError()
def getVmInfo(self):
def get_vminfo(self):
raise NotImplementedError()
def getCerts(self):
def get_certs(self):
raise NotImplementedError()
def getExtensions(self):
raise NotImplementedError()
def getExtensionPackages(self, extension):
raise NotImplementedError()
def getInstanceMetadata(self):
def get_extensions(self):
raise NotImplementedError()
def reportProvisionStatus(self, status):
def get_extension_pkgs(self, extension):
raise NotImplementedError()
def reportStatus(self, status):
def get_instance_metadata(self):
raise NotImplementedError()
def reportEvent(self, event):
def report_provision_status(self, status):
raise NotImplementedError()
def report_status(self, status):
raise NotImplementedError()
def report_event(self, event):
raise NotImplementedError()
+86 -86
View File
@@ -19,161 +19,161 @@
from azurelinuxagent.protocol.common import *
from azurelinuxagent.utils.osutil import OSUtil, OSUtilError
from azurelinuxagent.utils.osutil import OSUTIL, OSUtilError
def GetOvfEnv():
ovfFilePath = os.path.join(OSUtil.GetLibDir(), OvfFileName)
if os.path.isfile(ovfFilePath):
xmlText = fileutil.GetFileContents(ovfFilePath)
return OvfEnv(xmlText)
def get_ovf_env():
ovf_file_path = os.path.join(OSUTIL.get_lib_dir(), OVF_FILE_NAME)
if os.path.isfile(ovf_file_path):
xml_text = fileutil.read_file(ovf_file_path)
return OvfEnv(xml_text)
else:
raise ProtocolError("ovf-env.xml is missing.")
def CopyOvfEnv():
def copy_ovf_env():
"""
Copy ovf env file from dvd to hard disk.
Copy ovf env file from dvd to hard disk.
Remove password before save it to the disk
"""
try:
OSUtil.MountDvd()
ovfFile = OSUtil.GetOvfEnvPathOnDvd()
OSUTIL.mount_dvd()
ovf_file_path_on_dvd = OSUTIL.get_ovf_env_file_path_on_dvd()
ovfxml = fileutil.GetFileContents(ovfFile, removeBom=True)
ovfxml = fileutil.read_file(ovf_file_path_on_dvd, remove_bom=True)
ovfenv = OvfEnv(ovfxml)
ovfxml = re.sub("<UserPassword>.*?<", "<UserPassword>*<", ovfxml)
ovfFilePath = os.path.join(OSUtil.GetLibDir(), OvfFileName)
fileutil.SetFileContents(ovfFilePath, ovfxml)
OSUtil.UmountDvd()
ovf_file_path = os.path.join(OSUTIL.get_lib_dir(), OVF_FILE_NAME)
fileutil.write_file(ovf_file_path, ovfxml)
OSUTIL.umount_dvd()
except IOError as e:
raise ProtocolError(str(e))
except OSUtilError as e:
raise ProtocolError(str(e))
return ovfenv
OvfFileName="ovf-env.xml"
OVF_FILE_NAME="ovf-env.xml"
class OvfEnv(object):
"""
Read, and process provisioning info from provisioning file OvfEnv.xml
"""
def __init__(self, xmlText):
if xmlText is None:
def __init__(self, xml_text):
if xml_text is None:
raise ValueError("ovf-env is None")
logger.Verbose("Load ovf-env.xml")
self.parse(xmlText)
logger.verb("Load ovf-env.xml")
self.parse(xml_text)
def reinitialize(self):
"""
Reset members.
"""
self.WaNs = "http://schemas.microsoft.com/windowsazure"
self.OvfNs = "http://schemas.dmtf.org/ovf/environment/1"
self.MajorVersion = 1
self.MinorVersion = 0
self.ComputerName = None
self.UserName = None
self.UserPassword = None
self.CustomData = None
self.DisableSshPasswordAuthentication = True
self.SshPublicKeys = []
self.SshKeyPairs = []
self.wa_ns = "http://schemas.microsoft.com/windowsazure"
self.ovf_ns = "http://schemas.dmtf.org/ovf/environment/1"
self.major_version = 1
self.minor_version = 0
self.compute_name = None
self.user_name = None
self.user_password = None
self.customdata = None
self.disable_ssh_password_auth = True
self.ssh_pubkeys = []
self.ssh_keypairs = []
def getMajorVersion(self):
return self.MajorVersion
def get_major_version(self):
return self.major_version
def getMinorVersion(self):
return self.MinorVersion
def get_minor_version(self):
return self.minor_version
def getComputerName(self):
return self.ComputerName
def get_computer_name(self):
return self.compute_name
def getUserName(self):
return self.UserName
def get_username(self):
return self.user_name
def getUserPassword(self):
return self.UserPassword
def get_user_password(self):
return self.user_password
def clearUserPassword(self):
self.UserPassword = None
def clear_user_password(self):
self.user_password = None
def getCustomData(self):
return self.CustomData
def get_customdata(self):
return self.customdata
def getDisableSshPasswordAuthentication(self):
return self.DisableSshPasswordAuthentication
def get_disable_ssh_password_auth(self):
return self.disable_ssh_password_auth
def getSshPublicKeys(self):
return self.SshPublicKeys
def get_ssh_pubkeys(self):
return self.ssh_pubkeys
def getSshKeyPairs(self):
return self.SshKeyPairs
def get_ssh_keypairs(self):
return self.ssh_keypairs
def parse(self, xmlText):
def parse(self, xml_text):
"""
Parse xml tree, retreiving user and ssh key information.
Return self.
"""
self.reinitialize()
dom = xml.dom.minidom.parseString(xmlText)
if len(dom.getElementsByTagNameNS(self.OvfNs, "Environment")) != 1:
logger.Error("Unable to parse OVF XML.")
dom = xml.dom.minidom.parseString(xml_text)
if len(dom.getElementsByTagNameNS(self.ovf_ns, "Environment")) != 1:
logger.error("Unable to parse OVF XML.")
section = None
newer = False
for p in dom.getElementsByTagNameNS(self.WaNs, "ProvisioningSection"):
for p in dom.getElementsByTagNameNS(self.wa_ns, "ProvisioningSection"):
for n in p.childNodes:
if n.localName == "Version":
verparts = GetNodeTextData(n).split('.')
verparts = get_node_text(n).split('.')
major = int(verparts[0])
minor = int(verparts[1])
if major > self.MajorVersion:
if major > self.major_version:
newer = True
if major != self.MajorVersion:
if major != self.major_version:
break
if minor > self.MinorVersion:
if minor > self.minor_version:
newer = True
section = p
if newer == True:
logger.Warn("Newer provisioning configuration detected. "
logger.warn("Newer provisioning configuration detected. "
"Please consider updating waagent.")
if section == None:
logger.Error("Could not find ProvisioningSection with "
"major version={0}", self.MajorVersion)
logger.error("Could not find ProvisioningSection with "
"major version={0}", self.major_version)
return None
self.ComputerName = GetNodeTextData(section.getElementsByTagNameNS(self.WaNs, "HostName")[0])
self.UserName = GetNodeTextData(section.getElementsByTagNameNS(self.WaNs, "UserName")[0])
self.compute_name = get_node_text(section.getElementsByTagNameNS(self.wa_ns, "HostName")[0])
self.user_name = get_node_text(section.getElementsByTagNameNS(self.wa_ns, "UserName")[0])
try:
self.UserPassword = GetNodeTextData(section.getElementsByTagNameNS(self.WaNs, "UserPassword")[0])
self.user_password = get_node_text(section.getElementsByTagNameNS(self.wa_ns, "UserPassword")[0])
except:
pass
CDSection=None
CDSection=section.getElementsByTagNameNS(self.WaNs, "CustomData")
if len(CDSection) > 0 :
self.CustomData=GetNodeTextData(CDSection[0])
disableSshPass = section.getElementsByTagNameNS(self.WaNs, "DisableSshPasswordAuthentication")
if len(disableSshPass) != 0:
self.DisableSshPasswordAuthentication = (GetNodeTextData(disableSshPass[0]).lower() == "true")
for pkey in section.getElementsByTagNameNS(self.WaNs, "PublicKey"):
logger.Verbose(repr(pkey))
cd_section=None
cd_section=section.getElementsByTagNameNS(self.wa_ns, "CustomData")
if len(cd_section) > 0 :
self.customdata=get_node_text(cd_section[0])
disable_ssh_password_auth = section.getElementsByTagNameNS(self.wa_ns, "DisableSshPasswordAuthentication")
if len(disable_ssh_password_auth) != 0:
self.disable_ssh_password_auth = (get_node_text(disable_ssh_password_auth[0]).lower() == "true")
for pkey in section.getElementsByTagNameNS(self.wa_ns, "PublicKey"):
logger.verb(repr(pkey))
fp = None
path = None
for c in pkey.childNodes:
if c.localName == "Fingerprint":
fp = GetNodeTextData(c).upper()
logger.Verbose(fp)
fp = get_node_text(c).upper()
logger.verb(fp)
if c.localName == "Path":
path = GetNodeTextData(c)
logger.Verbose(path)
self.SshPublicKeys += [[fp, path]]
for keyp in section.getElementsByTagNameNS(self.WaNs, "KeyPair"):
path = get_node_text(c)
logger.verb(path)
self.ssh_pubkeys += [[fp, path]]
for keyp in section.getElementsByTagNameNS(self.wa_ns, "KeyPair"):
fp = None
path = None
logger.Verbose(repr(keyp))
logger.verb(repr(keyp))
for c in keyp.childNodes:
if c.localName == "Fingerprint":
fp = GetNodeTextData(c).upper()
logger.Verbose(fp)
fp = get_node_text(c).upper()
logger.verb(fp)
if c.localName == "Path":
path = GetNodeTextData(c)
logger.Verbose(path)
self.SshKeyPairs += [[fp, path]]
path = get_node_text(c)
logger.verb(path)
self.ssh_keypairs += [[fp, path]]
return self
+56 -56
View File
@@ -21,19 +21,19 @@ import traceback
import threading
import azurelinuxagent.logger as logger
import azurelinuxagent.utils.fileutil as fileutil
from azurelinuxagent.utils.osutil import OSUtil
from azurelinuxagent.utils.osutil import OSUTIL
from azurelinuxagent.protocol.common import *
from azurelinuxagent.protocol.v1 import ProtocolV1
from azurelinuxagent.protocol.v2 import ProtocolV2
from azurelinuxagent.protocol.v1 import WIRE_PROTOCOL
from azurelinuxagent.protocol.v2 import MetadataProtocol
WireServerAddrFile = "WireServer"
WireProtocol = "WireProtocol"
MetaDataProtocol = "MetaDataProtocol"
WIRE_SERVER_ADDR_FILE_NAME = "WireServer"
WIRE_PROTOCOL = "WireProtocol"
METADATA_PROTOCOL = "MetaDataProtocol"
def GetWireProtocolEndpoint():
path = os.path.join(OSUtil.GetLibDir(), WireServerAddrFile)
def get_wire_protocol_endpoint():
path = os.path.join(OSUTIL.get_lib_dir(), WIRE_SERVER_ADDR_FILE_NAME)
try:
endpoint = fileutil.GetFileContents(path)
endpoint = fileutil.read_file(path)
except IOError as e:
raise ProtocolNotFound("Wire server endpoint not found: {0}".format(e))
@@ -42,84 +42,84 @@ def GetWireProtocolEndpoint():
return endpoint
def DetectV1():
endpoint = GetWireProtocolEndpoint()
def detect_wire_protocol():
endpoint = get_wire_protocol_endpoint()
OSUtil.GenerateTransportCert()
protocol = ProtocolV1(endpoint)
protocol.initialize()
logger.Info("Protocol V1 found.")
path = os.path.join(OSUtil.GetLibDir(), WireProtocol)
fileutil.SetFileContents(path, "")
return protocol
def DetectV2():
protocol = ProtocolV2()
OSUTIL.gen_transport_cert()
protocol = WIRE_PROTOCOL(endpoint)
protocol.initialize()
logger.Info("Protocol V2 found.")
path = os.path.join(OSUtil.GetLibDir(), MetaDataProtocol)
fileutil.SetFileContents(path, "")
logger.info("Protocol V1 found.")
path = os.path.join(OSUTIL.get_lib_dir(), WIRE_PROTOCOL)
fileutil.write_file(path, "")
return protocol
def DetectAvailableProtocols(probeFuncs=[DetectV1, DetectV2]):
availableProtocols = []
for probeFunc in probeFuncs:
def detect_metadata_protocol():
protocol = MetadataProtocol()
protocol.initialize()
logger.info("Protocol V2 found.")
path = os.path.join(OSUTIL.get_lib_dir(), METADATA_PROTOCOL)
fileutil.write_file(path, "")
return protocol
def detect_available_protocols(prob_funcs=[detect_wire_protocol, detect_metadata_protocol]):
available_protocols = []
for probe_func in prob_funcs:
try:
protocol = probeFunc()
availableProtocols.append(protocol)
protocol = probe_func()
available_protocols.append(protocol)
except ProtocolNotFound as e:
logger.Info(str(e))
return availableProtocols
logger.info(str(e))
return available_protocols
def DetectDefaultProtocol():
logger.Info("Detect default protocol.")
availableProtocols = DetectAvailableProtocols()
return ChooseDefaultProtocol(availableProtocols)
def detect_default_protocol():
logger.info("Detect default protocol.")
available_protocols = detect_available_protocols()
return choose_default_protocol(available_protocols)
def ChooseDefaultProtocol(availableProtocols):
if len(availableProtocols) > 0:
return availableProtocols[0]
def choose_default_protocol(protocols):
if len(protocols) > 0:
return protocols[0]
else:
raise ProtocolNotFound("No available protocol detected.")
def GetV1():
path = os.path.join(OSUtil.GetLibDir(), WireProtocol)
def get_wire_protocol():
path = os.path.join(OSUTIL.get_lib_dir(), WIRE_PROTOCOL)
if not os.path.isfile(path):
raise ProtocolNotFound("Protocol V1 not found")
endpoint = GetWireProtocolEndpoint()
return ProtocolV1(endpoint)
def GetV2():
path = os.path.join(OSUtil.GetLibDir(), MetaDataProtocol)
endpoint = get_wire_protocol_endpoint()
return WIRE_PROTOCOL(endpoint)
def get_metadata_protocol():
path = os.path.join(OSUTIL.get_lib_dir(), METADATA_PROTOCOL)
if not os.path.isfile(path):
raise ProtocolNotFound("Protocol V2 not found")
return ProtocolV2()
return MetadataProtocol()
def GetAvailableProtocols(getters=[GetV1, GetV2]):
availableProtocols = []
def get_available_protocols(getters=[get_wire_protocol, get_metadata_protocol]):
available_protocols = []
for getter in getters:
try:
protocol = getter()
availableProtocols.append(protocol)
available_protocols.append(protocol)
except ProtocolNotFound as e:
logger.Info(str(e))
return availableProtocols
logger.info(str(e))
return available_protocols
class ProtocolFactory(object):
def __init__(self):
self._protocol = None
self._lock = threading.Lock()
def getDefaultProtocol(self):
def get_default_protocol(self):
if self._protocol is None:
self._lock.acquire()
if self._protocol is None:
availableProtocols = GetAvailableProtocols()
self._protocol = ChooseDefaultProtocol(availableProtocols)
available_protocols = get_available_protocols()
self._protocol = choose_default_protocol(available_protocols)
self._lock.release()
return self._protocol
File diff suppressed because it is too large Load Diff
+43 -45
View File
@@ -21,34 +21,32 @@ import json
import azurelinuxagent.utils.restutil as restutil
from azurelinuxagent.protocol.common import *
DefaultEndpoint='169.254.169.254'
DefaultApiVersion='2015-01-01'
BaseUri = "https://{0}/Microsoft.Computer/{1}?$api-version={{{2}}}{3}"
ENDPOINT='169.254.169.254'
APIVERSION='2015-01-01'
BASE_URI = "https://{0}/Microsoft.Computer/{1}?$api-version={{{2}}}{3}"
class ProtocolV2(Protocol):
class MetadataProtocol(Protocol):
def __init__(self, apiVersion=DefaultApiVersion, endpoint=DefaultEndpoint):
self.apiVersion = apiVersion
def __init__(self, apiversion=APIVERSION, endpoint=ENDPOINT):
self.apiversion = apiversion
self.endpoint = endpoint
self.identityUri = BaseUri.format(self.endpoint, "identity",
self.apiVersion, "&expand=*")
self.certUri = BaseUri.format(self.endpoint, "certificates",
self.apiVersion, "&expand=*")
self.certUri = BaseUri.format(self.endpoint, "certificates",
self.apiVersion, "&expand=*")
self.extUri = BaseUri.format(self.endpoint, "extensionHandlers",
self.apiVersion, "&expand=*")
self.provisionStatusUri = BaseUri.format(self.endpoint,
"provisionStatus",
self.apiVersion, "")
self.statusUri = BaseUri.format(self.endpoint, "status",
self.apiVersion, "")
self.eventUri = BaseUri.format(self.endpoint, "status/telemetry",
self.apiVersion, "")
self.identity_uri = __base__uri.format(self.endpoint, "identity",
self.apiversion, "&expand=*")
self.cert_uri = __base__uri.format(self.endpoint, "certificates",
self.apiversion, "&expand=*")
self.ext_uri = __base__uri.format(self.endpoint, "extensionHandlers",
self.apiversion, "&expand=*")
self.provision_status_uri = __base__uri.format(self.endpoint,
"provisionStatus",
self.apiversion, "")
self.status_uri = __base__uri.format(self.endpoint, "status",
self.apiversion, "")
self.event_uri = __base__uri.format(self.endpoint, "status/telemetry",
self.apiversion, "")
def _getData(self, dataType, url, headers=None):
def _get_data(self, data_type, url, headers=None):
try:
resp = restutil.HttpGet(url, headers)
resp = restutil.http_get(url, headers)
except restutil.HttpError as e:
raise ProtocolError(str(e))
@@ -58,23 +56,23 @@ class ProtocolV2(Protocol):
data = json.loads(resp.read())
except ValueError as e:
raise ProtocolError(str(e))
obj = dataType()
obj = data_type()
set_properties(obj, data)
return obj
def _putData(self, url, obj, headers=None):
def _put_data(self, url, obj, headers=None):
data = get_properties(obj)
try:
resp = restutil.HttpPut(url, json.dumps(data))
resp = restutil.http_put(url, json.dumps(data))
except restutil.HttpError as e:
raise ProtocolError(str(e))
if resp.status != httplib.OK:
raise ProtocolError("{0} - PUT: {1}".format(resp.status, url))
def _postData(self, url, obj, headers=None):
def _post_data(self, url, obj, headers=None):
data = get_properties(obj)
try:
resp = restutil.HttpPost(url, json.dumps(data))
resp = restutil.http_post(url, json.dumps(data))
except restutil.HttpError as e:
raise ProtocolError(str(e))
if resp.status != httplib.CREATED:
@@ -82,27 +80,27 @@ class ProtocolV2(Protocol):
def initialize(self):
pass
def getVmInfo(self):
return self._getData(VmInfo, self.identityUri)
def getCerts(self):
certs = self._getData(CertList, self.certUri)
def get_vminfo(self):
return self._get_data(VMInfo, self.identity_uri)
def get_certs(self):
certs = self._get_data(CertList, self.cert_uri)
#TODO download pfx and convert to pem
return certs
def getExtensions(self):
return self._getData(ExtensionList, self.extUri)
def get_extensions(self):
return self._get_data(ExtensionList, self.ext_uri)
def reportProvisionStatus(self, status):
def report_provision_status(self, status):
validata_param('status', status, ProvisionStatus)
self._putData(self.provisionStatusUri, status)
def reportStatus(self, status):
validata_param('status', status, VMStatus)
self._putData(self.statusUri, status)
def reportEvent(self, events):
validata_param('events', events, TelemetryEventList)
self._postData(self.eventUri, events)
self._put_data(self.provision_status_uri, status)
def report_status(self, status):
validata_param('status', status, VMStatus)
self._put_data(self.status_uri, status)
def report_event(self, events):
validata_param('events', events, TelemetryEventList)
self._post_data(self.event_uri, events)
+70 -73
View File
@@ -17,163 +17,160 @@
# Requires Python 2.4+ and Openssl 1.0+
#
import platform
"""
File operation util functions
"""
import os
import re
import shutil
import pwd
import tempfile
import subprocess
import azurelinuxagent.logger as logger
import azurelinuxagent.utils.textutil as textutil
"""
File operation util functions
"""
def GetFileContents(filepath, asbin=False, removeBom=False):
def read_file(filepath, asbin=False, remove_bom=False):
"""
Read and return contents of 'filepath'.
"""
mode='r'
mode = 'r'
if asbin:
mode+='b'
with open(filepath, mode) as F :
c=F.read()
if (not asbin) and removeBom:
c = textutil.RemoveBom(c)
return c
mode += 'b'
with open(filepath, mode) as in_file:
contents = in_file.read()
if (not asbin) and remove_bom:
contents = textutil.remove_bom(contents)
return contents
def SetFileContents(filepath, contents):
def write_file(filepath, contents):
"""
Write 'contents' to 'filepath'.
"""
#if type(contents) == str :
#contents=contents.encode('latin-1', 'ignore')
with open(filepath, "wb+") as F :
F.write(contents)
with open(filepath, "wb") as out_file:
out_file.write(contents)
def AppendFileContents(filepath, contents):
def append_file(filepath, contents):
"""
Append 'contents' to 'filepath'.
"""
#if type(contents) == str :
#contents=contents.encode('latin-1')
with open(filepath, "a+") as F :
F.write(contents)
with open(filepath, "a+") as out_file:
out_file.write(contents)
def ReplaceFileContentsAtomic(filepath, contents):
def replace_file(filepath, contents):
"""
Write 'contents' to 'filepath' by creating a temp file, and replacing original.
Write 'contents' to 'filepath' by creating a temp file,
and replacing original.
"""
handle, temp = tempfile.mkstemp(dir = os.path.dirname(filepath))
#if type(contents) == str :
handle, temp = tempfile.mkstemp(dir=os.path.dirname(filepath))
#if type(contents) == str:
#contents=contents.encode('latin-1')
try:
os.write(handle, contents)
except IOError, e:
logger.Error('Write to file {0}, Exception is {1}', filepath, e)
except IOError as err:
logger.error('Write to file {0}, Exception is {1}', filepath, err)
return 1
finally:
os.close(handle)
try:
os.rename(temp, filepath)
except IOError, e:
logger.Info('Rename {0} to {1}, Exception is {2}',temp, filepath, e)
logger.Info('Remove original file and retry')
except IOError as err:
logger.info('Rename {0} to {1}, Exception is {2}', temp, filepath, err)
logger.info('Remove original file and retry')
try:
os.remove(filepath)
except IOError, e:
logger.Error('Remove {0}, Exception is {1}',temp, filepath, e)
except IOError as err:
logger.error('Remove {0}, Exception is {1}', temp, filepath, err)
try:
os.rename(temp, filepath)
except IOError, e:
logger.Error('Rename {0} to {1}, Exception is {2}',temp, filepath, e)
except IOError, err:
logger.error('Rename {0} to {1}, Exception is {2}', temp, filepath,
err)
return 1
return 0
def GetLastPathElement(path):
def base_name(path):
head, tail = os.path.split(path)
return tail
def GetLineStartingWith(prefix, filepath):
def get_line_startingwith(prefix, filepath):
"""
Return line from 'filepath' if the line startswith 'prefix'
"""
for line in GetFileContents(filepath).split('\n'):
for line in read_file(filepath).split('\n'):
if line.startswith(prefix):
return line
return None
#End File operation util functions
def CreateDir(dirpath, mode=None, owner=None):
def mkdir(dirpath, mode=None, owner=None):
if not os.path.isdir(dirpath):
os.makedirs(dirpath)
if mode is not None:
ChangeMod(dirpath, mode)
chmod(dirpath, mode)
if owner is not None:
ChangeOwner(dirpath, owner)
chowner(dirpath, owner)
def ChangeOwner(path, owner):
ownerInfo = pwd.getpwnam(owner)
os.chown(path, ownerInfo[2], ownerInfo[3])
def chowner(path, owner):
owner_info = pwd.getpwnam(owner)
os.chown(path, owner_info[2], owner_info[3])
def ChangeMod(path, mode):
def chmod(path, mode):
os.chmod(path, mode)
def RemoveFiles(*args, **kwargs):
def rm_files(*args):
for path in args:
if os.path.isfile(path):
os.remove(path)
def CleanupDirs(*args, **kwargs):
def rm_dirs(*args):
"""
Remove all the contents under the directry
"""
for dirName in args:
if os.path.isdir(dirName):
for item in os.listdir(dirName):
path = os.path.join(dirName, item)
for dir_name in args:
if os.path.isdir(dir_name):
for item in os.listdir(dir_name):
path = os.path.join(dir_name, item)
if os.path.isfile(path):
os.remove(path)
elif os.path.isdir(path):
shutil.rmtree(path)
def UpdateConfigFile(path, lineStart, val, chk_err=False):
config = []
def update_conf_file(path, line_start, val, chk_err=False):
conf = []
if not os.path.isfile(path) and chk_err:
raise Exception("Can't find config file:{0}".format(path))
config = GetFileContents(path).split('\n')
config = filter(lambda x : not x.startswith(lineStart), config)
config.append(val)
ReplaceFileContentsAtomic(path, '\n'.join(config))
conf = read_file(path).split('\n')
conf = filter(lambda x: not x.startswith(line_start), conf)
conf.append(val)
replace_file(path, '\n'.join(conf))
def SearchForFile(dirName, fileName):
for root, dirs, files in os.walk(dirName):
for f in files:
if f == fileName:
return os.path.join(root, f)
def search_file(target_dir_name, target_file_name):
for root, dirs, files in os.walk(target_dir_name):
for file_name in files:
if file_name == target_file_name:
return os.path.join(root, file_name)
return None
def ChangeTreeMod(path, mode):
def chmod_tree(path, mode):
for root, dirs, files in os.walk(path):
for f in files:
os.chmod(os.path.join(root, f), mode)
for file_name in files:
os.chmod(os.path.join(root, file_name), mode)
def FindStringInFile(fname, matchs):
def findstr_in_file(file_path, pattern_str):
"""
Return match object if found in file.
"""
try:
ms=re.compile(matchs)
for l in (open(fname,'r')).readlines():
m=re.search(ms,l)
if m:
return m
pattern = re.compile(pattern_str)
for line in (open(file_path, 'r')).readlines():
match = re.search(pattern, line)
if match:
return match
except:
raise
return None
+1 -1
View File
@@ -23,5 +23,5 @@ Load OSUtil implementation from azurelinuxagent.distro
from azurelinuxagent.distro.default.osutil import OSUtilError
import azurelinuxagent.distro.loader as loader
OSUtil = loader.GetOSUtil()
OSUTIL = loader.get_osutil()
+56 -56
View File
@@ -29,54 +29,54 @@ from urlparse import urlparse
"""
REST api util functions
"""
__RetryWaitingInterval=10
RETRY_WAITING_INTERVAL = 10
class HttpError(Exception):
pass
def _ParseUrl(url):
def _parse_url(url):
o = urlparse(url)
relativeUrl = o.path
rel_uri = o.path
if o.fragment:
relativeUrl = "{0}#{1}".format(relativeUrl, o.fragment)
rel_uri = "{0}#{1}".format(rel_uri, o.fragment)
if o.query:
relativeUrl = "{0}?{1}".format(relativeUrl, o.query)
rel_uri = "{0}?{1}".format(rel_uri, o.query)
secure = False
if o.scheme.lower() == "https":
secure = True
return o.hostname, o.port, secure, relativeUrl
return o.hostname, o.port, secure, rel_uri
def GetHttpProxy():
def get_http_proxy():
"""
Get http_proxy and https_proxy from environment variables.
Username and password is not supported now.
"""
host = conf.Get("HttpProxy.Host", None)
port = conf.Get("HttpProxy.Port", None)
return (host, port)
host = conf.get("HttpProxy.Host", None)
port = conf.get("HttpProxy.Port", None)
return (host, port)
def _HttpRequest(method, host, relativeUrl, port=None, data=None, secure=False,
headers=None, proxyHost=None, proxyPort=None):
def _http_request(method, host, rel_uri, port=None, data=None, secure=False,
headers=None, proxy_host=None, proxy_port=None):
url, conn = None, None
if secure:
port = 443 if port is None else port
if proxyHost is not None and proxyPort is not None:
conn = httplib.HTTPSConnection(proxyHost, proxyPort)
if proxy_host is not None and proxy_port is not None:
conn = httplib.HTTPSConnection(proxy_host, proxy_port)
conn.set_tunnel(host, port)
#If proxy is used, full url is needed.
url = "https://{0}:{1}{2}".format(host, port, relativeUrl)
url = "https://{0}:{1}{2}".format(host, port, rel_uri)
else:
conn = httplib.HTTPSConnection(host, port)
url = relativeUrl
url = rel_uri
else:
port = 80 if port is None else port
if proxyHost is not None and proxyPort is not None:
conn = httplib.HTTPConnection(proxyHost, proxyPort)
if proxy_host is not None and proxy_port is not None:
conn = httplib.HTTPConnection(proxy_host, proxy_port)
#If proxy is used, full url is needed.
url = "http://{0}:{1}{2}".format(host, port, relativeUrl)
url = "http://{0}:{1}{2}".format(host, port, rel_uri)
else:
conn = httplib.HTTPConnection(host, port)
url = relativeUrl
url = rel_uri
if headers == None:
conn.request(method, url, data)
else:
@@ -84,65 +84,65 @@ def _HttpRequest(method, host, relativeUrl, port=None, data=None, secure=False,
resp = conn.getresponse()
return resp
def HttpRequest(method, url, data, headers=None, maxRetry=3, chkProxy=False):
def http_request(method, url, data, headers=None, max_retry=3, chk_proxy=False):
"""
Sending http request to server
On error, sleep 10 and retry maxRetry times.
On error, sleep 10 and retry max_retry times.
"""
logger.Verbose("HTTP Req: {0} {1}", method, url)
logger.Verbose(" Data={0}", data)
logger.Verbose(" Header={0}", headers)
host, port, secure, relativeUrl = _ParseUrl(url)
logger.verb("HTTP Req: {0} {1}", method, url)
logger.verb(" Data={0}", data)
logger.verb(" Header={0}", headers)
host, port, secure, rel_uri = _parse_url(url)
#Check proxy
proxyHost, proxyPort = (None, None)
if chkProxy:
proxyHost, proxyPort = GetHttpProxy()
proxy_host, proxy_port = (None, None)
if chk_proxy:
proxy_host, proxy_port = get_http_proxy()
#If httplib module is not built with ssl support. Fallback to http
if secure and not hasattr(httplib, "HTTPSConnection"):
logger.Warn("httplib is not built with ssl support")
logger.warn("httplib is not built with ssl support")
secure = False
#If httplib module doesn't support https tunnelling. Fallback to http
if secure and \
proxyHost is not None and \
proxyPort is not None and \
proxy_host is not None and \
proxy_port is not None and \
not hasattr(httplib.HTTPSConnection, "set_tunnel"):
logger.Warn("httplib doesn't support https tunnelling(new in python 2.7)")
logger.warn("httplib doesn't support https tunnelling(new in python 2.7)")
secure = False
for retry in range(0, maxRetry):
for retry in range(0, max_retry):
try:
resp = _HttpRequest(method, host, relativeUrl, port, data,
secure, headers, proxyHost, proxyPort)
logger.Verbose("HTTP Resp: Status={0}", resp.status)
logger.Verbose(" Header={0}", resp.getheaders())
resp = _http_request(method, host, rel_uri, port=port, data=data, secure=secure,
headers=headers, proxy_host=proxy_host, proxy_port=proxy_port)
logger.verb("HTTP Resp: Status={0}", resp.status)
logger.verb(" Header={0}", resp.getheaders())
return resp
except httplib.HTTPException as e:
logger.Warn('HTTPException {0}, args:{1}', e, repr(e.args))
logger.warn('HTTPException {0}, args:{1}', e, repr(e.args))
except IOError as e:
logger.Warn('Socket IOError {0}, args:{1}', e, repr(e.args))
logger.warn('Socket IOError {0}, args:{1}', e, repr(e.args))
if retry < maxRetry - 1:
logger.Info("Retry={0}, {1} {2}", retry, method, url)
time.sleep(__RetryWaitingInterval)
if retry < max_retry - 1:
logger.info("Retry={0}, {1} {2}", retry, method, url)
time.sleep(RETRY_WAITING_INTERVAL)
raise HttpError("HTTP Err: {0} {1}".format(method, url))
def HttpGet(url, headers=None, maxRetry=3, chkProxy=False):
return HttpRequest("GET", url, None, headers, maxRetry, chkProxy)
def HttpHead(url, headers=None, maxRetry=3, chkProxy=False):
return HttpRequest("HEAD", url, None, headers, maxRetry, chkProxy)
def HttpPost(url, data, headers=None, maxRetry=3, chkProxy=False):
return HttpRequest("POST", url, data, headers, maxRetry, chkProxy)
def http_get(url, headers=None, max_retry=3, chk_proxy=False):
return http_request("GET", url, data=None, headers=headers, max_retry=max_retry, chk_proxy=chk_proxy)
def HttpPut(url, data, headers=None, maxRetry=3, chkProxy=False):
return HttpRequest("PUT", url, data, headers, maxRetry, chkProxy)
def http_head(url, headers=None, max_retry=3, chk_proxy=False):
return http_request("HEAD", url, None, headers=headers, max_retry=max_retry, chk_proxy=chk_proxy)
def HttpDelete(url, headers=None, maxRetry=3, chkProxy=False):
return HttpRequest("DELETE", url, None, headers, maxRetry, chkProxy)
def http_post(url, data, headers=None, max_retry=3, chk_proxy=False):
return http_request("POST", url, data, headers=headers, max_retry=max_retry, chk_proxy=chk_proxy)
def http_put(url, data, headers=None, max_retry=3, chk_proxy=False):
return http_request("PUT", url, data, headers=headers, max_retry=max_retry, chk_proxy=chk_proxy)
def http_delete(url, headers=None, max_retry=3, chk_proxy=False):
return http_request("DELETE", url, None, headers=headers, max_retry=max_retry, chk_proxy=chk_proxy)
#End REST api util functions
+10 -34
View File
@@ -50,59 +50,35 @@ if not hasattr(subprocess,'check_output'):
subprocess.check_output=check_output
subprocess.CalledProcessError=CalledProcessError
"""
Shell command util functions
"""
def Run(cmd, chk_err=True):
def run(cmd, chk_err=True):
"""
Calls RunGetOutput on 'cmd', returning only the return code.
Calls run_get_output on 'cmd', returning only the return code.
If chk_err=True then errors will be reported in the log.
If chk_err=False then errors will be suppressed from the log.
"""
retcode,out=RunGetOutput(cmd,chk_err)
retcode,out=run_get_output(cmd,chk_err)
return retcode
def RunGetOutput(cmd, chk_err=True):
def run_get_output(cmd, chk_err=True):
"""
Wrapper for subprocess.check_output.
Execute 'cmd'. Returns return code and STDOUT, trapping expected exceptions.
Reports exceptions to Error if chk_err parameter is True
"""
logger.Verbose("Run cmd '{0}'", cmd)
try:
logger.verb("run cmd '{0}'", cmd)
try:
output=subprocess.check_output(cmd,stderr=subprocess.STDOUT,shell=True)
except subprocess.CalledProcessError,e :
if chk_err :
logger.Error("Run cmd '{0}' failed", e.cmd)
logger.Error("Error Code:{0}", e.returncode)
logger.Error("Result:{0}", e.output[:-1].decode('latin-1'))
logger.error("run cmd '{0}' failed", e.cmd)
logger.error("Error Code:{0}", e.returncode)
logger.error("Result:{0}", e.output[:-1].decode('latin-1'))
return e.returncode, e.output.decode('latin-1')
return 0, output
def RunSendStdin(cmd, input, chk_err=True):
"""
Wrapper for subprocess.Popen.
Execute 'cmd', sending 'input' to STDIN of 'cmd'.
Returns return code and STDOUT, trapping expected exceptions.
Reports exceptions to Error if chk_err parameter is True
"""
logger.Verbose("Run cmd '{0}'", cmd)
try:
me=subprocess.Popen([cmd], shell=True, stdin=subprocess.PIPE,
stderr=subprocess.STDOUT,stdout=subprocess.PIPE)
output=me.communicate(input)
except OSError , e :
if chk_err :
logger.Error("Run cmd '{0}' failed", e.cmd)
logger.Error("Error Code:{0}", e.returncode)
logger.Error("Result:{0}", e.output[:-1].decode('latin-1'))
return e.returncode, e.output.decode('latin-1')
if me.returncode is not 0 and chk_err is True:
logger.Error("Run cmd '{0}' failed", cmd)
logger.Error("Error Code:{0}", me.returncode)
logger.Error("Result:{0}", output[0].decode('latin-1'))
return me.returncode, output[0].decode('latin-1')
#End shell command util functions
+44 -43
View File
@@ -15,21 +15,22 @@
# limitations under the License.
#
# Requires Python 2.4+ and Openssl 1.0+
import crypt
import random
import string
import struct
def FindFirstNode(xmlDoc, selector):
nodes = FindAllNodes(xmlDoc, selector)
def find_first_node(xml_doc, selector):
nodes = find_all_nodes(xml_doc, selector)
if len(nodes) > 0:
return nodes[0]
def FindAllNodes(xmlDoc, selector):
nodes = xmlDoc.findall(selector)
def find_all_nodes(xml_doc, selector):
nodes = xml_doc.findall(selector)
return nodes
def GetNodeTextData(a):
def get_node_text(a):
"""
Filter non-text nodes from DOM tree
"""
@@ -37,54 +38,54 @@ def GetNodeTextData(a):
if b.nodeType == b.TEXT_NODE:
return b.data
def Unpack(buf, offset, range):
def unpack(buf, offset, range):
"""
Unpack bytes into python values.
"""
result = 0
for i in range:
result = (result << 8) | Ord(buf[offset + i])
result = (result << 8) | str_to_ord(buf[offset + i])
return result
def UnpackLittleEndian(buf, offset, length):
def unpack_little_endian(buf, offset, length):
"""
Unpack little endian bytes into python values.
"""
return Unpack(buf, offset, list(range(length - 1, -1, -1)))
return unpack(buf, offset, list(range(length - 1, -1, -1)))
def UnpackBigEndian(buf, offset, length):
def unpack_big_endian(buf, offset, length):
"""
Unpack big endian bytes into python values.
"""
return Unpack(buf, offset, list(range(0, length)))
return unpack(buf, offset, list(range(0, length)))
def HexDump3(buf, offset, length):
def hex_dump3(buf, offset, length):
"""
Dump range of buf in formatted hex.
"""
return ''.join(['%02X' % Ord(char) for char in buf[offset:offset + length]])
return ''.join(['%02X' % str_to_ord(char) for char in buf[offset:offset + length]])
def HexDump2(buf):
def hex_dump2(buf):
"""
Dump buf in formatted hex.
"""
return HexDump3(buf, 0, len(buf))
return hex_dump3(buf, 0, len(buf))
def IsInRangeInclusive(a, low, high):
def is_in_range(a, low, high):
"""
Return True if 'a' in 'low' <= a >= 'high'
"""
return (a >= low and a <= high)
def IsPrintable(ch):
def is_printable(ch):
"""
Return True if character is displayable.
"""
return (IsInRangeInclusive(ch, Ord('A'), Ord('Z'))
or IsInRangeInclusive(ch, Ord('a'), Ord('z'))
or IsInRangeInclusive(ch, Ord('0'), Ord('9')))
return (is_in_range(ch, str_to_ord('A'), str_to_ord('Z'))
or is_in_range(ch, str_to_ord('a'), str_to_ord('z'))
or is_in_range(ch, str_to_ord('0'), str_to_ord('9')))
def HexDump(buffer, size):
def hex_dump(buffer, size):
"""
Return Hex formated dump of a 'buffer' of 'size'.
"""
@@ -111,16 +112,16 @@ def HexDump(buffer, size):
for j in range(i - (i % 16), i + 1):
byte=buffer[j]
if type(byte) == str:
byte = ord(byte.decode('latin1'))
byte = str_to_ord(byte.decode('latin1'))
k = '.'
if IsPrintable(byte):
if is_printable(byte):
k = chr(byte)
result += k
if (i + 1) != size:
result += "\n"
return result
def Ord(a):
def str_to_ord(a):
"""
Allows indexing into a string or an array of integers transparently.
Generic utility function.
@@ -129,25 +130,25 @@ def Ord(a):
a = ord(a)
return a
def CompareBytes(a, b, start, length):
def compare_bytes(a, b, start, length):
for offset in range(start, start + length):
if Ord(a[offset]) != Ord(b[offset]):
if str_to_ord(a[offset]) != str_to_ord(b[offset]):
return False
return True
def IntegerToIpAddressV4String(a):
def int_to_ip4_addr(a):
"""
Build DHCP request string.
"""
return "%u.%u.%u.%u" % ((a >> 24) & 0xFF,
(a >> 16) & 0xFF,
(a >> 8) & 0xFF,
return "%u.%u.%u.%u" % ((a >> 24) & 0xFF,
(a >> 16) & 0xFF,
(a >> 8) & 0xFF,
(a) & 0xFF)
def Ascii(val):
def ascii(val):
uni = None
if type(val) == str:
uni = unicode(val, 'utf-8', errors='ignore')
uni = unicode(val, 'utf-8', errors='ignore')
else:
uni = unicode(val)
if uni is None:
@@ -155,7 +156,7 @@ def Ascii(val):
else:
return uni.encode('ascii', 'backslashreplace')
def HexStringToByteArray(a):
def hexstr_to_bytearray(a):
"""
Return hex string packed into a binary struct.
"""
@@ -164,7 +165,7 @@ def HexStringToByteArray(a):
b += struct.pack("B", int(a[c * 2:c * 2 + 2], 16))
return b
def SetSshConfig(config, name, val):
def set_ssh_config(config, name, val):
notfound = True
for i in range(0, len(config)):
if config[i].startswith(name):
@@ -177,20 +178,20 @@ def SetSshConfig(config, name, val):
config.insert(i, "{0} {1}".format(name, val))
return config
def RemoveBom(c):
if ord(c[0]) > 128 and ord(c[1]) > 128 and ord(c[2]) > 128:
def remove_bom(c):
if str_to_ord(c[0]) > 128 and str_to_ord(c[1]) > 128 and str_to_ord(c[2]) > 128:
c = c[3:]
return c
def GetPasswordHash(password, useSalt, saltType, saltLength):
salt="$6$"
if useSalt:
def gen_password_hash(password, use_salt, salt_type, salt_len):
salt="$6$"
if use_salt:
collection = string.ascii_letters + string.digits
salt = ''.join(random.choice(collection) for _ in range(saltLength))
salt = "${0}${1}".format(saltType, salt)
salt = ''.join(random.choice(collection) for _ in range(salt_len))
salt = "${0}${1}".format(salt_type, salt)
return crypt.crypt(password, salt)
def NumberToBytes(i):
def num_to_bytes(i):
"""
Pack number into bytes. Retun as string.
"""
@@ -201,7 +202,7 @@ def NumberToBytes(i):
result.reverse()
return ''.join(result)
def BitsToString(a):
def bits_to_str(a):
"""
Return string representation of bits in a.
"""
+627 -214
View File
File diff suppressed because it is too large Load Diff
+11 -11
View File
@@ -18,11 +18,11 @@
#
import os
from azurelinuxagent.metadata import GuestAgentName, GuestAgentVersion, \
GuestAgentDescription, \
DistroName, DistroVersion, DistroFullName
from azurelinuxagent.metadata import AGENT_NAME, AGENT_VERSION, \
agent_description, \
DISTRO_NAME, DISTRO_VERSION, DISTRO_FULL_NAME
from azurelinuxagent.utils.osutil import OSUtil
from azurelinuxagent.utils.osutil import OSUTIL
import azurelinuxagent.agent as agent
import setuptools
from setuptools import find_packages
@@ -103,9 +103,9 @@ class install(_install):
def initialize_options(self):
_install.initialize_options(self)
self.lnx_distro = DistroName
self.lnx_distro_version = DistroVersion
self.lnx_distro_fullname = DistroFullName
self.lnx_distro = DISTRO_NAME
self.lnx_distro_version = DISTRO_VERSION
self.lnx_distro_fullname = DISTRO_FULL_NAME
self.register_service = False
def finalize_options(self):
@@ -118,11 +118,11 @@ class install(_install):
def run(self):
_install.run(self)
if self.register_service:
agent.RegisterService()
agent.register_service()
setuptools.setup(name=GuestAgentName,
version=GuestAgentVersion,
long_description=GuestAgentDescription,
setuptools.setup(name=AGENT_NAME,
version=AGENT_VERSION,
long_description=agent_description,
author= 'Yue Zhang, Stephen Zarkos, Eric Gable',
author_email = 'walinuxagent@microsoft.com',
platforms = 'Linux',
-75
View File
@@ -1,75 +0,0 @@
# Copyright 2014 Microsoft Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Requires Python 2.4+ and Openssl 1.0+
#
# Implements parts of RFC 2131, 1541, 1497 and
# http://msdn.microsoft.com/en-us/library/cc227282%28PROT.10%29.aspx
# http://msdn.microsoft.com/en-us/library/cc227259%28PROT.13%29.aspx
import env
import tests.tools as tools
from tools import *
import uuid
import shutil
import unittest
import os
import azurelinuxagent.logger as logger
import azurelinuxagent.utils.shellutil as shellutil
import azurelinuxagent.utils.fileutil as fileutil
import azurelinuxagent.utils.textutil as textutil
from azurelinuxagent.utils.osutil import CurrOSInfo, CurrOS
import test
"""
OS related test. Need to run with root privilege.
CAUSION: during the test, user account and system config may be changed
"""
class TestUserOperation(unittest.TestCase):
def test_sysuser(self):
self.assertTrue(CurrOS.IsSysUser('root'))
def test_update_user_account(self):
userName="nobodywillusethisname"
shellutil.Run('userdel -f -r ' + userName)
self.assertFalse(tools.simple_file_grep('/etc/passwd', userName))
self.assertFalse(os.path.isdir(os.path.join(CurrOS.GetHome(), userName)))
CurrOS.UpdateUserAccount(userName, "User@123")
self.assertTrue(tools.simple_file_grep('/etc/passwd', userName))
self.assertTrue(tools.simple_file_grep('/etc/sudoers.d/waagent',
userName))
self.assertTrue(os.path.isdir(os.path.join(CurrOS.GetHome(), userName)))
class TestSshOperation(unittest.TestCase):
def _setUp(self):
logger.AddLoggerAppender(logger.AppenderConfig({
"type":"CONSOLE",
"level":"VERBOSE",
"console_path":"/dev/stdout"
}))
def test_regen_ssh_host_key(self):
oldKey = fileutil.GetFileContents('/etc/ssh/ssh_host_rsa_key')
CurrOS.RegenerateSshHostkey('rsa')
newKey = fileutil.GetFileContents('/etc/ssh/ssh_host_rsa_key')
self.assertNotEquals(oldKey, newKey)
#TODO test dvd mount
#TODO test set scsi
#TODO test set/publish hostname
if __name__ == '__main__':
unittest.main()
-38
View File
@@ -1,38 +0,0 @@
#!/usr/bin/env python
#
# Windows Azure Linux Agent
#
# Copyright 2014 Microsoft Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Requires Python 2.4+ and Openssl 1.0+
#
# Implements parts of RFC 2131, 1541, 1497 and
# http://msdn.microsoft.com/en-us/library/cc227282%28PROT.10%29.aspx
# http://msdn.microsoft.com/en-us/library/cc227259%28PROT.13%29.aspx
#
import os
import sys
test_root = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(test_root)
sys.path.insert(0, project_root)
import azurelinuxagent.agent as agent
if __name__ == '__main__':
agent.Main()
+2 -2
View File
@@ -29,14 +29,14 @@ import azurelinuxagent.agent as agent
class TestAgent(unittest.TestCase):
def test_parse_args(self):
cmd, force, verbose = agent.ParseArgs(["deprovision+user",
cmd, force, verbose = agent.parse_args(["deprovision+user",
"-force",
"/verbose"])
self.assertEquals("deprovision+user", cmd)
self.assertTrue(force)
self.assertTrue(verbose)
cmd, force, verbose = agent.ParseArgs(["wrong cmd"])
cmd, force, verbose = agent.parse_args(["wrong cmd"])
self.assertEquals("help", cmd)
self.assertFalse(force)
self.assertFalse(verbose)
+10 -11
View File
@@ -27,7 +27,7 @@ import json
import azurelinuxagent.utils.fileutil as fileutil
import azurelinuxagent.protocol.v1 as v1
CertificatesSample="""\
certs_sample="""\
<?xml version="1.0" encoding="utf-8"?>
<CertificateFile xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:noNamespaceSchemaLocation="certificates10.xsd">
<Version>2012-11-30</Version>
@@ -111,7 +111,7 @@ h+249Wj0Bw==
</CertificateFile>
"""
TransportCert="""\
transport_cert="""\
-----BEGIN CERTIFICATE-----
MIIDBzCCAe+gAwIBAgIJANujJuVt5eC8MA0GCSqGSIb3DQEBCwUAMBkxFzAVBgNV
BAMMDkxpbnV4VHJhbnNwb3J0MCAXDTE0MTAyNDA3MjgwN1oYDzIxMDQwNzEyMDcy
@@ -133,7 +133,7 @@ DsfY6XGSEIhZnA4=
-----END CERTIFICATE-----
"""
TransportPrivate="""\
transport_private="""\
-----BEGIN PRIVATE KEY-----
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDT3CQJHeleTXqI
EioyHk1zlguccyk8riT7eGGTFUqLFNJSOmtMB6foNnz9ts61bMFvkCc+gFVqIyBK
@@ -180,19 +180,18 @@ class TestCertificates(unittest.TestCase):
os.remove(crt2)
if os.path.isfile(prv2):
os.remove(prv2)
fileutil.SetFileContents(os.path.join('/tmp', v1.TransportCertFile),
TransportCert)
fileutil.SetFileContents(os.path.join('/tmp', v1.TransportPrivateFile),
TransportPrivate)
config = v1.Certificates(CertificatesSample)
fileutil.write_file(os.path.join('/tmp', "TransportCert.pem"),
transport_cert)
fileutil.write_file(os.path.join('/tmp', "TransportPrivate.pem"),
transport_private)
config = v1.Certificates(certs_sample)
self.assertNotEquals(None, config)
self.assertTrue(os.path.isfile(crt1))
self.assertTrue(os.path.isfile(crt2))
self.assertTrue(os.path.isfile(prv2))
certs = config.getCerts()
self.assertNotEquals(0, len(certs.certificates))
cert = certs.certificates[0]
self.assertNotEquals(0, len(config.cert_list.certificates))
cert = config.cert_list.certificates[0]
self.assertNotEquals(None, cert.thumbprint)
if __name__ == '__main__':
+9 -9
View File
@@ -42,15 +42,15 @@ class TestConfiguration(unittest.TestCase):
def test_parse_conf(self):
config = conf.ConfigurationProvider()
config.load(TestConf)
self.assertEquals(True, config.getSwitch("foo.bar.switch"))
self.assertEquals(False, config.getSwitch("foo.bar.switch2"))
self.assertEquals(False, config.getSwitch("foo.bar.switch3"))
self.assertEquals(True, config.getSwitch("foo.bar.switch4", True))
self.assertEquals(True, config.get_switch("foo.bar.switch"))
self.assertEquals(False, config.get_switch("foo.bar.switch2"))
self.assertEquals(False, config.get_switch("foo.bar.switch3"))
self.assertEquals(True, config.get_switch("foo.bar.switch4", True))
self.assertEquals("foobar", config.get("foo.bar.str"))
self.assertEquals("foobar1", config.get("foo.bar.str1", "foobar1"))
self.assertEquals(300, config.getInt("foo.bar.int"))
self.assertEquals(-1, config.getInt("foo.bar.int2"))
self.assertEquals(-1, config.getInt("foo.bar.str"))
self.assertEquals(300, config.get_int("foo.bar.int"))
self.assertEquals(-1, config.get_int("foo.bar.int2"))
self.assertEquals(-1, config.get_int("foo.bar.str"))
def test_parse_malformed_conf(self):
with self.assertRaises(AgentConfigError) as cm:
@@ -63,8 +63,8 @@ class TestConfiguration(unittest.TestCase):
F.close()
config = conf.ConfigurationProvider()
conf.LoadConfiguration('/tmp/test_conf', conf=config)
self.assertEquals(True, config.getSwitch("foo.bar.switch"), False)
conf.load_conf('/tmp/test_conf', conf=config)
self.assertEquals(True, config.get_switch("foo.bar.switch"), False)
if __name__ == '__main__':
unittest.main()
+8 -8
View File
@@ -21,7 +21,7 @@
import env
from tests.tools import *
import unittest
import azurelinuxagent.distro.default.deprovision as deprovisionHandler
import azurelinuxagent.distro.default.deprovision as deprovision_handler
def MockAction(param):
#print param
@@ -30,24 +30,24 @@ def MockAction(param):
def MockSetup(self, deluser):
warnings = ["Print warning to console"]
actions = [
deprovisionHandler.DeprovisionAction(MockAction, ['Take action'])
deprovision_handler.DeprovisionAction(MockAction, ['Take action'])
]
return warnings, actions
class TestDeprovisionHandler(unittest.TestCase):
def test_setUp(self):
handler = deprovisionHandler.DeprovisionHandler()
warnings, actions = handler.setUp(False)
def test_setup(self):
handler = deprovision_handler.DeprovisionHandler()
warnings, actions = handler.setup(False)
self.assertNotEquals(None, warnings)
self.assertNotEquals(0, len(warnings))
self.assertNotEquals(None, actions)
self.assertNotEquals(0, len(actions))
self.assertEquals(deprovisionHandler.DeprovisionAction, type(actions[0]))
self.assertEquals(deprovision_handler.DeprovisionAction, type(actions[0]))
@Mockup(deprovisionHandler.DeprovisionHandler, 'setUp', MockSetup)
@mock(deprovision_handler.DeprovisionHandler, 'setup', MockSetup)
def test_deprovision(self):
handler = deprovisionHandler.DeprovisionHandler()
handler = deprovision_handler.DeprovisionHandler()
handler.deprovision(force=True)
if __name__ == '__main__':
+13 -13
View File
@@ -25,34 +25,34 @@ import unittest
import os
import json
import azurelinuxagent.utils.fileutil as fileutil
import azurelinuxagent.distro.default.dhcp as dhcpHandler
import azurelinuxagent.distro.default.dhcp as dhcp_handler
SampleDhcpResponse = None
with open(os.path.join(env.test_root, "dhcp")) as F:
SampleDhcpResponse = F.read()
MockSocketSend = MockFunc('SocketSend', SampleDhcpResponse)
MockGenTransactionId = MockFunc('GenTransactionId', "\xC6\xAA\xD1\x5D")
MockGetMacAddress = MockFunc('GetMacAddress', "\x00\x15\x5D\x38\xAA\x38")
mock_socket_send = MockFunc('socket_send', SampleDhcpResponse)
mock_gen_trans_id = MockFunc('gen_trans_id', "\xC6\xAA\xD1\x5D")
mock_get_mac_addr = MockFunc('get_mac_addr', "\x00\x15\x5D\x38\xAA\x38")
class TestdhcpHandler(unittest.TestCase):
def test_build_dhcp_req(self):
req = dhcpHandler.BuildDhcpRequest(MockGetMacAddress())
req = dhcp_handler.build_dhcp_request(mock_get_mac_addr())
self.assertNotEquals(None, req)
@Mockup(dhcpHandler, "GenTransactionId", MockGenTransactionId)
@Mockup(dhcpHandler, "SocketSend", MockSocketSend)
@mock(dhcp_handler, "gen_trans_id", mock_gen_trans_id)
@mock(dhcp_handler, "socket_send", mock_socket_send)
def test_send_dhcp_req(self):
req = dhcpHandler.BuildDhcpRequest(MockGetMacAddress())
resp = dhcpHandler.SendDhcpRequest(req)
req = dhcp_handler.build_dhcp_request(mock_get_mac_addr())
resp = dhcp_handler.send_dhcp_request(req)
self.assertNotEquals(None, resp)
@Mockup(dhcpHandler, "SocketSend", MockSocketSend)
@Mockup(dhcpHandler, "GenTransactionId", MockGenTransactionId)
@Mockup(dhcpHandler.OSUtil, "GetMacAddress", MockGetMacAddress)
@mock(dhcp_handler, "socket_send", mock_socket_send)
@mock(dhcp_handler, "gen_trans_id", mock_gen_trans_id)
@mock(dhcp_handler.OSUtil, "get_mac_addr", mock_get_mac_addr)
def test_handle_dhcp(self):
dh = dhcpHandler.DhcpHandler()
dh = dhcp_handler.DhcpHandler()
dh.probe()
self.assertEquals("10.62.144.1", dh.gateway)
self.assertEquals("10.62.144.140", dh.endpoint)
+12 -12
View File
@@ -21,22 +21,22 @@
import env
from tests.tools import *
import unittest
from azurelinuxagent.utils.osutil import OSUtil, OSUtilError
from azurelinuxagent.handler import Handlers
from azurelinuxagent.utils.osutil import OSUTIL, OSUtilError
from azurelinuxagent.handler import HANDLERS
import azurelinuxagent.distro.default.osutil as osutil
class TestDistroLoader(unittest.TestCase):
def test_loader(self):
self.assertNotEquals(osutil.DefaultOSUtil, type(OSUtil))
self.assertNotEquals(None, Handlers.initHandler)
self.assertNotEquals(None, Handlers.runHandler)
self.assertNotEquals(None, Handlers.scvmmHandler)
self.assertNotEquals(None, Handlers.dhcpHandler)
self.assertNotEquals(None, Handlers.envHandler)
self.assertNotEquals(None, Handlers.provisionHandler)
self.assertNotEquals(None, Handlers.resourceDiskHandler)
self.assertNotEquals(None, Handlers.envHandler)
self.assertNotEquals(None, Handlers.deprovisionHandler)
self.assertNotEquals(osutil.DefaultOSUtil, type(OSUTIL))
self.assertNotEquals(None, HANDLERS.init_handler)
self.assertNotEquals(None, HANDLERS.main_handler)
self.assertNotEquals(None, HANDLERS.scvmm_handler)
self.assertNotEquals(None, HANDLERS.dhcp_handler)
self.assertNotEquals(None, HANDLERS.env_handler)
self.assertNotEquals(None, HANDLERS.provision_handler)
self.assertNotEquals(None, HANDLERS.resource_disk_handler)
self.assertNotEquals(None, HANDLERS.env_handler)
self.assertNotEquals(None, HANDLERS.deprovision_handler)
if __name__ == '__main__':
unittest.main()
+12 -10
View File
@@ -22,28 +22,30 @@ import env
from tests.tools import *
import unittest
import time
from azurelinuxagent.utils.osutil import OSUtil
from azurelinuxagent.utils.osutil import OSUTIL
from azurelinuxagent.distro.default.env import EnvMonitor
class MockDhcpHandler(object):
def configRoutes(self):
def conf_routes(self):
pass
MockDhcpProcessIdNotChange = MockFunc(retval="1234")
def MockDhcpProcessIdChange():
def mock_get_dhcp_pid():
return "1234"
def mock_dhcp_pid_change():
return str(time.time())
class TestEnvMonitor(unittest.TestCase):
@Mockup(OSUtil, 'GetDhcpProcessId', MockDhcpProcessIdNotChange)
def test_dhcpProcessIdNotChange(self):
@mock(OSUTIL, 'get_dhcp_pid', mock_get_dhcp_pid)
def test_dhcp_pid_not_change(self):
monitor = EnvMonitor(MockDhcpHandler())
monitor.handleDhcpClientRestart()
monitor.handle_dhclient_restart()
@Mockup(OSUtil, 'GetDhcpProcessId', MockDhcpProcessIdChange)
def test_dhcpProcessIdChange(self):
@mock(OSUTIL, 'get_dhcp_pid', mock_dhcp_pid_change)
def test_dhcp_pid_change(self):
monitor = EnvMonitor(MockDhcpHandler())
monitor.handleDhcpClientRestart()
monitor.handle_dhclient_restart()
if __name__ == '__main__':
unittest.main()
+7 -7
View File
@@ -29,25 +29,25 @@ import azurelinuxagent.event as evt
import azurelinuxagent.protocol as prot
class MockProtocol(object):
def getInstanceMetadata(self):
def get_instance_metadata(self):
return prot.InstanceMetadata(deploymentName='foo', roleName='foo',
roleInstanceId='foo', containerId='foo')
def reportEvent(self, data): pass
def report_event(self, data): pass
class TestEvent(unittest.TestCase):
def test_save(self):
if not os.path.exists("/tmp/events"):
os.mkdir("/tmp/events")
evt.AddExtensionEvent("Test", "Test", True)
evt.add_event("Test", "Test", True)
eventsFile = os.listdir("/tmp/events")
self.assertNotEquals(0, len(eventsFile))
shutil.rmtree("/tmp/events")
@Mockup(evt.prot.Factory, 'getDefaultProtocol',
MockFunc(retval=MockProtocol()))
def test_initSystemInfo(self):
@mock(evt.prot.Factory, 'get_default_protocol',
MockFunc(retval=MockProtocol()))
def test_init_sys_info(self):
monitor = evt.EventMonitor()
self.assertNotEquals(0, len(monitor.sysInfo))
self.assertNotEquals(0, len(monitor.sysinfo))
if __name__ == '__main__':
unittest.main()
+72 -72
View File
@@ -25,12 +25,12 @@ import unittest
import os
import json
import azurelinuxagent.logger as logger
from azurelinuxagent.utils.osutil import OSUtil
from azurelinuxagent.utils.osutil import OSUTIL
import azurelinuxagent.utils.fileutil as fileutil
import azurelinuxagent.protocol as prot
import azurelinuxagent.distro.default.extension as ext
extensionData = {
ext_sample_json = {
"name":"TestExt",
"properties":{
"version":"2.0",
@@ -51,10 +51,10 @@ extensionData = {
}]
}
}
extension = prot.Extension()
prot.set_properties(extension, extensionData)
ext_sample = prot.Extension()
prot.set_properties(ext_sample, ext_sample_json)
packageListData={
pkd_list_sample_str={
"versions": [{
"version": "2.0",
"uris":[{
@@ -67,10 +67,10 @@ packageListData={
}]
}]
}
packageList = prot.ExtensionPackageList()
prot.set_properties(packageList, packageListData)
pkg_list_sample = prot.ExtensionPackageList()
prot.set_properties(pkg_list_sample, pkd_list_sample_str)
manJson = {
manifest_sample_str = {
"handlerManifest":{
"installCommand": "echo 'install'",
"uninstallCommand": "echo 'uninstall'",
@@ -79,99 +79,99 @@ manJson = {
"disableCommand": "echo 'disable'",
}
}
man = ext.HandlerManifest(manJson)
manifest_sample = ext.HandlerManifest(manifest_sample_str)
def MockLoadManifest(self):
return man
def mock_load_manifest(self):
return manifest_sample
MockLaunchCommand = MockFunc()
MockSetHandlerStatus = MockFunc()
mock_launch_command = MockFunc()
mock_set_handler_status = MockFunc()
def MockDownload(self):
fileutil.CreateDir(self.getBaseDir())
fileutil.SetFileContents(self.getManifestFile(), json.dumps(manJson))
def mock_download(self):
fileutil.mkdir(self.get_base_dir())
fileutil.write_file(self.get_manifest_file(), json.dumps(manifest_sample_str))
#logger.LoggerInit("/dev/null", "/dev/stdout")
class TestExtensions(unittest.TestCase):
def test_load_ext(self):
libDir = OSUtil.GetLibDir()
testExt1 = os.path.join(libDir, 'TestExt-1.0')
testExt2 = os.path.join(libDir, 'TestExt-2.0')
testExt2 = os.path.join(libDir, 'TestExt-2.1')
for path in [testExt1, testExt2]:
libDir = OSUTIL.get_lib_dir()
test_ext1 = os.path.join(libDir, 'TestExt-1.0')
test_ext2 = os.path.join(libDir, 'TestExt-2.0')
test_ext2 = os.path.join(libDir, 'TestExt-2.1')
for path in [test_ext1, test_ext2]:
if not os.path.isdir(path):
os.mkdir(path)
testExt = ext.GetInstalledExtensionVersion('TestExt')
self.assertEqual('2.1', testExt)
test_ext = ext.get_installed_version('TestExt')
self.assertEqual('2.1', test_ext)
def test_getters(self):
testExt = ext.ExtensionInstance(extension, packageList,
extension.properties.version, False)
self.assertEqual("/tmp/TestExt-2.0", testExt.getBaseDir())
self.assertEqual("/tmp/TestExt-2.0/status", testExt.getStatusDir())
test_ext = ext.ExtensionInstance(ext_sample, pkg_list_sample,
ext_sample.properties.version, False)
self.assertEqual("/tmp/TestExt-2.0", test_ext.get_base_dir())
self.assertEqual("/tmp/TestExt-2.0/status", test_ext.get_status_dir())
self.assertEqual("/tmp/TestExt-2.0/status/0.status",
testExt.getStatusFile())
test_ext.get_status_file())
self.assertEqual("/tmp/TestExt-2.0/config/HandlerState",
testExt.getHandlerStateFile())
self.assertEqual("/tmp/TestExt-2.0/config", testExt.getConfigDir())
test_ext.get_handler_state_file())
self.assertEqual("/tmp/TestExt-2.0/config", test_ext.get_conf_dir())
self.assertEqual("/tmp/TestExt-2.0/config/0.settings",
testExt.getSettingsFile())
test_ext.get_settings_file())
self.assertEqual("/tmp/TestExt-2.0/heartbeat.log",
testExt.getHeartbeatFile())
test_ext.get_heartbeat_file())
self.assertEqual("/tmp/TestExt-2.0/HandlerManifest.json",
testExt.getManifestFile())
test_ext.get_manifest_file())
self.assertEqual("/tmp/TestExt-2.0/HandlerEnvironment.json",
testExt.getEnvironmentFile())
self.assertEqual("/tmp/log/TestExt/2.0", testExt.getLogDir())
test_ext.get_env_file())
self.assertEqual("/tmp/log/TestExt/2.0", test_ext.get_log_dir())
testExt = ext.ExtensionInstance(extension, packageList, "2.1", False)
self.assertEqual("/tmp/TestExt-2.1", testExt.getBaseDir())
self.assertEqual("2.1", testExt.getTargetVersion())
test_ext = ext.ExtensionInstance(ext_sample, pkg_list_sample, "2.1", False)
self.assertEqual("/tmp/TestExt-2.1", test_ext.get_base_dir())
self.assertEqual("2.1", test_ext.get_target_version())
@Mockup(ext.ExtensionInstance, 'loadManifest', MockLoadManifest)
@Mockup(ext.ExtensionInstance, 'launchCommand', MockLaunchCommand)
@Mockup(ext.ExtensionInstance, 'setHandlerStatus', MockSetHandlerStatus)
@mock(ext.ExtensionInstance, 'load_manifest', mock_load_manifest)
@mock(ext.ExtensionInstance, 'launch_command', mock_launch_command)
@mock(ext.ExtensionInstance, 'set_handler_status', mock_set_handler_status)
def test_handle_uninstall(self):
MockLaunchCommand.args = None
MockSetHandlerStatus.args = None
testExt = ext.ExtensionInstance(extension, packageList,
extension.properties.version, False)
testExt.handleUninstall()
self.assertEqual(None, MockLaunchCommand.args)
self.assertEqual(None, MockSetHandlerStatus.args)
self.assertEqual(None, testExt.getCurrOperation())
mock_launch_command.args = None
mock_set_handler_status.args = None
test_ext = ext.ExtensionInstance(ext_sample, pkg_list_sample,
ext_sample.properties.version, False)
test_ext.handle_uninstall()
self.assertEqual(None, mock_launch_command.args)
self.assertEqual(None, mock_set_handler_status.args)
self.assertEqual(None, test_ext.get_curr_op())
testExt = ext.ExtensionInstance(extension, packageList,
extension.properties.version, True)
testExt.handleUninstall()
self.assertEqual(man.getUninstallCommand(), MockLaunchCommand.args[0])
self.assertEqual("UnInstall", testExt.getCurrOperation())
self.assertEqual("NotReady", MockSetHandlerStatus.args[0])
test_ext = ext.ExtensionInstance(ext_sample, pkg_list_sample,
ext_sample.properties.version, True)
test_ext.handle_uninstall()
self.assertEqual(manifest_sample.get_uninstall_command(), mock_launch_command.args[0])
self.assertEqual("UnInstall", test_ext.get_curr_op())
self.assertEqual("NotReady", mock_set_handler_status.args[0])
@Mockup(ext.ExtensionInstance, 'loadManifest', MockLoadManifest)
@Mockup(ext.ExtensionInstance, 'launchCommand', MockLaunchCommand)
@Mockup(ext.ExtensionInstance, 'download', MockDownload)
@Mockup(ext.ExtensionInstance, 'getHandlerStatus', MockFunc(retval="enabled"))
@Mockup(ext.ExtensionInstance, 'setHandlerStatus', MockSetHandlerStatus)
@mock(ext.ExtensionInstance, 'load_manifest', mock_load_manifest)
@mock(ext.ExtensionInstance, 'launch_command', mock_launch_command)
@mock(ext.ExtensionInstance, 'download', mock_download)
@mock(ext.ExtensionInstance, 'get_handler_status', MockFunc(retval="enabled"))
@mock(ext.ExtensionInstance, 'set_handler_status', mock_set_handler_status)
def test_handle(self):
#Test enable
testExt = ext.ExtensionInstance(extension, packageList,
extension.properties.version, False)
testExt.initLog()
self.assertEqual(1, len(testExt.logger.appenders) - len(logger.DefaultLogger.appenders))
testExt.handle()
test_ext = ext.ExtensionInstance(ext_sample, pkg_list_sample,
ext_sample.properties.version, False)
test_ext.init_logger()
self.assertEqual(1, len(test_ext.logger.appenders) - len(logger.default_logger.appenders))
test_ext.handle()
#Test upgrade
testExt = ext.ExtensionInstance(extension, packageList,
extension.properties.version, False)
testExt.initLog()
self.assertEqual(1, len(testExt.logger.appenders) - len(logger.DefaultLogger.appenders))
testExt.handle()
test_ext = ext.ExtensionInstance(ext_sample, pkg_list_sample,
ext_sample.properties.version, False)
test_ext.init_logger()
self.assertEqual(1, len(test_ext.logger.appenders) - len(logger.default_logger.appenders))
test_ext.handle()
def test_status_convert(self):
extStatus = json.loads('[{"status": {"status": "success", "formattedMessage": {"lang": "en-US", "message": "Script is finished"}, "operation": "Enable", "code": "0", "name": "Microsoft.OSTCExtensions.CustomScriptForLinux"}, "version": "1.0", "timestampUTC": "2015-06-27T08:34:50Z"}]')
ext.extension_status_to_v2(extStatus[0], 0)
ext_status = json.loads('[{"status": {"status": "success", "formattedMessage": {"lang": "en-US", "message": "Script is finished"}, "operation": "Enable", "code": "0", "name": "Microsoft.OSTCExtensions.CustomScriptForLinux"}, "version": "1.0", "timestampUTC": "2015-06-27T08:34:50Z"}]')
ext.ext_status_to_v2(ext_status[0], 0)
if __name__ == '__main__':
+8 -8
View File
@@ -26,7 +26,7 @@ import os
import json
import azurelinuxagent.protocol.v1 as v1
ExtensionsConfigSample="""\
ext_conf_sample="""\
<Extensions version="1.0.0.0" goalStateIncarnation="9"><GuestAgentExtension xmlns:i="http://www.w3.org/2001/XMLSchema-instance">
<GAFamilies>
<GAFamily>
@@ -68,13 +68,13 @@ ExtensionsConfigSample="""\
</Plugins>
<PluginSettings>
<Plugin name="OSTCExtensions.ExampleHandlerLinux" version="1.4">
<RuntimeSettings seqNo="6">{"runtimeSettings":[{"handlerSettings":{"protectedSettingsCertThumbprint":"4037FBF5F1F3014F99B5D6C7799E9B20E6871CB3","protectedSettings":"MIICWgYJK","publicSettings":{"foo":"bar"}}}]}</RuntimeSettings>
<runtimeSettings seqNo="6">{"runtimeSettings":[{"handlerSettings":{"protectedSettingsCertThumbprint":"4037FBF5F1F3014F99B5D6C7799E9B20E6871CB3","protectedSettings":"MIICWgYJK","publicSettings":{"foo":"bar"}}}]}</runtimeSettings>
</Plugin>
</PluginSettings>
<StatusUploadBlob>https://yuezhatest.blob.core.windows.net/vhds/test-cs12.test-cs12.test-cs12.status?sr=b&amp;sp=rw&amp;se=9999-01-01&amp;sk=key1&amp;sv=2014-02-14&amp;sig=hfRh7gzUE7sUtYwke78IOlZOrTRCYvkec4hGZ9zZzXo%3D</StatusUploadBlob></Extensions>
"""
ManifestSample="""\
manifest_sample="""\
<?xml version="1.0" encoding="utf-8"?>
<PluginVersionManifest xmlns:i="http://www.w3.org/2001/XMLSchema-instance">
<Plugins>
@@ -113,8 +113,8 @@ EmptySettings="""\
class TestExtensionsConfig(unittest.TestCase):
def test_extensions_config(self):
config = v1.ExtensionsConfig(ExtensionsConfigSample)
extensions = config.extList.extensions
config = v1.ExtensionsConfig(ext_conf_sample)
extensions = config.ext_list.extensions
self.assertNotEquals(None, extensions)
self.assertEquals(1, len(extensions))
self.assertNotEquals(None, extensions[0])
@@ -131,9 +131,9 @@ class TestExtensionsConfig(unittest.TestCase):
self.assertEquals(json.loads('{"foo":"bar"}'),
settings.publicSettings)
man = v1.ExtensionManifest(ManifestSample)
self.assertNotEquals(None, man.packageList)
self.assertEquals(3, len(man.packageList.versions))
man = v1.ExtensionManifest(manifest_sample)
self.assertNotEquals(None, man.pkg_list)
self.assertEquals(3, len(man.pkg_list.versions))
def test_empty_settings(self):
config = v1.ExtensionsConfig(EmptySettings)
+9 -9
View File
@@ -30,36 +30,36 @@ class TestFileOperations(unittest.TestCase):
def test_get_set_file_contents(self):
test_file='/tmp/test_file'
content = str(uuid.uuid4())
fileutil.SetFileContents(test_file, content)
fileutil.write_file(test_file, content)
self.assertTrue(tools.simple_file_grep(test_file, content))
self.assertEquals(content, fileutil.GetFileContents('/tmp/test_file'))
self.assertEquals(content, fileutil.read_file('/tmp/test_file'))
os.remove(test_file)
def test_append_file(self):
test_file='/tmp/test_file2'
content = str(uuid.uuid4())
fileutil.AppendFileContents(test_file, content)
fileutil.append_file(test_file, content)
self.assertTrue(tools.simple_file_grep(test_file, content))
os.remove(test_file)
def test_replace_file(self):
test_file='/tmp/test_file3'
contentOld = str(uuid.uuid4())
old_content = str(uuid.uuid4())
content = str(uuid.uuid4())
with open(test_file, "a+") as F:
F.write(contentOld)
fileutil.ReplaceFileContentsAtomic(test_file, content)
self.assertFalse(tools.simple_file_grep(test_file, contentOld))
F.write(old_content)
fileutil.replace_file(test_file, content)
self.assertFalse(tools.simple_file_grep(test_file, old_content))
self.assertTrue(tools.simple_file_grep(test_file, content))
os.remove(test_file)
def test_get_last_path_element(self):
filepath = '/tmp/abc.def'
filename = fileutil.GetLastPathElement(filepath)
filename = fileutil.base_name(filepath)
self.assertEquals('abc.def', filename)
filepath = '/tmp/abc'
filename = fileutil.GetLastPathElement(filepath)
filename = fileutil.base_name(filepath)
self.assertEquals('abc', filename)
if __name__ == '__main__':
+10 -10
View File
@@ -26,7 +26,7 @@ import os
import test
import azurelinuxagent.protocol.v1 as v1
GoalStateSample="""
goal_state_sample="""
<?xml version="1.0" encoding="utf-8"?>
<GoalState xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:noNamespaceSchemaLocation="goalstate10.xsd">
<Version>2010-12-15</Version>
@@ -58,15 +58,15 @@ GoalStateSample="""
class TestGoalState(unittest.TestCase):
def test_goal_state(self):
goalState = v1.GoalState(GoalStateSample)
self.assertEquals('1', goalState.getIncarnation())
self.assertNotEquals(None, goalState.getExpectedState())
self.assertNotEquals(None, goalState.getHostingEnvUri())
self.assertNotEquals(None, goalState.getSharedConfigUri())
self.assertNotEquals(None, goalState.getCertificatesUri())
self.assertNotEquals(None, goalState.getExtensionsUri())
self.assertNotEquals(None, goalState.getRoleInstanceId())
self.assertNotEquals(None, goalState.getContainerId())
goal_state = v1.GoalState(goal_state_sample)
self.assertEquals('1', goal_state.incarnation)
self.assertNotEquals(None, goal_state.expected_state)
self.assertNotEquals(None, goal_state.hosting_env_uri)
self.assertNotEquals(None, goal_state.shared_conf_uri)
self.assertNotEquals(None, goal_state.certs_uri)
self.assertNotEquals(None, goal_state.ext_uri)
self.assertNotEquals(None, goal_state.role_instance_id)
self.assertNotEquals(None, goal_state.container_id)
if __name__ == '__main__':
unittest.main()
+8 -8
View File
@@ -25,7 +25,7 @@ import unittest
import os
import azurelinuxagent.protocol.v1 as v1
HostingEnvSample="""
hosting_env_sample="""
<HostingEnvironmentConfig version="1.0.0.0" goalStateIncarnation="1">
<StoredCertificates>
<StoredCertificate name="Stored0Microsoft.WindowsAzure.Plugins.RemoteAccess.PasswordEncryption" certificateId="sha1:C093FA5CD3AAE057CB7C4E04532B2E16E07C26CA" storeName="My" configurationLevel="System" />
@@ -36,7 +36,7 @@ HostingEnvSample="""
</Deployment>
<Incarnation number="1" instance="MachineRole_IN_0" guid="{a0faca35-52e5-4ec7-8fd1-63d2bc107d9b}" />
<Role guid="{73d95f1c-6472-e58e-7a1a-523554e11d46}" name="MachineRole" hostingEnvironmentVersion="1" software="" softwareType="ApplicationPackage" entryPoint="" parameters="" settleTimeSeconds="10" />
<HostingEnvironmentSettings name="full" Runtime="rd_fabric_stable.110217-1402.RuntimePackage_1.0.0.8.zip">
<HostingEnvironmentSettings name="full" runtime="rd_fabric_stable.110217-1402.runtimePackage_1.0.0.8.zip">
<CAS mode="full" />
<PrivilegeLevel mode="max" />
<AdditionalProperties><CgiHandlers></CgiHandlers></AdditionalProperties>
@@ -53,13 +53,13 @@ HostingEnvSample="""
"""
class TestHostingEvn(unittest.TestCase):
def test_hostingenv(self):
hostingenv = v1.HostingEnv(HostingEnvSample)
self.assertNotEquals(None, hostingenv)
self.assertEquals("MachineRole_IN_0", hostingenv.getVmName())
self.assertEquals("MachineRole", hostingenv.getRoleName())
def test_hosting_env(self):
hosting_env = v1.HostingEnv(hosting_env_sample)
self.assertNotEquals(None, hosting_env)
self.assertEquals("MachineRole_IN_0", hosting_env.vm_name)
self.assertEquals("MachineRole", hosting_env.role_name)
self.assertEquals("db00a7755a5e4e8a8fe4b19bc3b330c3",
hostingenv.getDeploymentName())
hosting_env.deployment_name)
if __name__ == '__main__':
+3 -3
View File
@@ -51,7 +51,7 @@ class TestLogger(unittest.TestCase):
def test_file_appender(self):
_logger = logger.Logger()
_logger.addLoggerAppender(logger.AppenderType.FILE,
_logger.add_appender(logger.AppenderType.FILE,
logger.LogLevel.INFO,
'/tmp/testlog')
@@ -69,14 +69,14 @@ class TestLogger(unittest.TestCase):
def test_log_to_non_exists_dev(self):
_logger = logger.Logger()
_logger.addLoggerAppender(logger.AppenderType.CONSOLE,
_logger.add_appender(logger.AppenderType.CONSOLE,
logger.LogLevel.INFO,
'/dev/nonexists')
_logger.info("something")
def test_log_to_non_exists_file(self):
_logger = logger.Logger()
_logger.addLoggerAppender(logger.AppenderType.FILE,
_logger.add_appender(logger.AppenderType.FILE,
logger.LogLevel.INFO,
'/tmp/nonexists')
_logger.info("something")
+7 -7
View File
@@ -21,16 +21,16 @@
import env
from tests.tools import *
import unittest
from azurelinuxagent.metadata import GuestAgentName, GuestAgentVersion, \
DistroName, DistroVersion, DistroCodeName, \
DistroFullName
from azurelinuxagent.metadata import AGENT_NAME, AGENT_VERSION, \
DISTRO_NAME, DISTRO_VERSION, DISTRO_CODE_NAME, \
DISTRO_FULL_NAME
class TestOSInfo(unittest.TestCase):
def test_curr_os_info(self):
self.assertNotEquals(None, DistroName)
self.assertNotEquals(None, DistroVersion)
self.assertNotEquals(None, DistroCodeName)
self.assertNotEquals(None, DistroFullName)
self.assertNotEquals(None, DISTRO_NAME)
self.assertNotEquals(None, DISTRO_VERSION)
self.assertNotEquals(None, DISTRO_CODE_NAME)
self.assertNotEquals(None, DISTRO_FULL_NAME)
if __name__ == '__main__':
unittest.main()
+71 -71
View File
@@ -28,14 +28,14 @@ import time
import azurelinuxagent.utils.fileutil as fileutil
import azurelinuxagent.utils.shellutil as shellutil
import azurelinuxagent.conf as conf
from azurelinuxagent.utils.osutil import OSUtil
from azurelinuxagent.utils.osutil import OSUTIL
import test
class TestOSUtil(unittest.TestCase):
def test_current_distro(self):
self.assertNotEquals(None, OSUtil)
self.assertNotEquals(None, OSUTIL)
MountlistSample="""\
mount_list_sample="""\
/dev/sda1 on / type ext4 (rw)
proc on /proc type proc (rw)
sysfs on /sys type sysfs (rw)
@@ -48,29 +48,29 @@ none on /proc/sys/fs/binfmt_misc type binfmt_misc (rw)
class TestCurrOS(unittest.TestCase):
#class TestCurrOS(object):
def test_get_paths(self):
self.assertNotEquals(None, OSUtil.GetHome())
self.assertNotEquals(None, OSUtil.GetLibDir())
self.assertNotEquals(None, OSUtil.GetAgentPidPath())
self.assertNotEquals(None, OSUtil.GetConfigurationPath())
self.assertNotEquals(None, OSUtil.GetDvdMountPoint())
self.assertNotEquals(None, OSUtil.GetOvfEnvPathOnDvd())
self.assertNotEquals(None, OSUTIL.get_home())
self.assertNotEquals(None, OSUTIL.get_lib_dir())
self.assertNotEquals(None, OSUTIL.get_agent_pid_file_path())
self.assertNotEquals(None, OSUTIL.get_conf_file_path())
self.assertNotEquals(None, OSUTIL.get_dvd_mount_point())
self.assertNotEquals(None, OSUTIL.get_ovf_env_file_path_on_dvd())
@Mockup(shellutil, 'RunGetOutput', MockFunc(retval=[0, '']))
@Mockup(shellutil, 'RunSendStdin', MockFunc(retval=[0, '']))
@Mockup(fileutil, 'SetFileContents', MockFunc())
@Mockup(fileutil, 'AppendFileContents', MockFunc())
@Mockup(fileutil, 'GetFileContents', MockFunc(retval=''))
@Mockup(fileutil, 'ChangeMod', MockFunc())
@mock(fileutil, 'write_file', MockFunc())
@mock(fileutil, 'append_file', MockFunc())
@mock(fileutil, 'chmod', MockFunc())
@mock(fileutil, 'read_file', MockFunc(retval=''))
@mock(shellutil, 'run', MockFunc())
@mock(shellutil, 'run_get_output', MockFunc(retval=[0, '']))
def test_update_user_account(self):
OSUtil.UpdateUserAccount('api', 'api')
OSUtil.DeleteAccount('api')
OSUTIL.set_user_account('api', 'api')
OSUTIL.del_account('api')
@Mockup(fileutil, 'GetFileContents', MockFunc(retval='root::::'))
@Mockup(fileutil, 'SetFileContents', MockFunc())
@mock(fileutil, 'read_file', MockFunc(retval='root::::'))
@mock(fileutil, 'write_file', MockFunc())
def test_delete_root_password(self):
OSUtil.DeleteRootPassword()
OSUTIL.del_root_password()
self.assertEquals('root:*LOCK*:14600::::::',
fileutil.SetFileContents.args[1])
fileutil.write_file.args[1])
def test_cert_operation(self):
if os.path.isfile('/tmp/test.prv'):
@@ -81,84 +81,84 @@ class TestCurrOS(unittest.TestCase):
os.remove('/tmp/test.crt')
shutil.copyfile(os.path.join(env.test_root, 'test.crt'),
'/tmp/test.crt')
pub1 = OSUtil.GetPubKeyFromPrv('/tmp/test.prv')
pub2 = OSUtil.GetPubKeyFromCrt('/tmp/test.crt')
pub1 = OSUTIL.get_pubkey_from_prv('/tmp/test.prv')
pub2 = OSUTIL.get_pubkey_from_crt('/tmp/test.crt')
self.assertEquals(pub1, pub2)
thumbprint = OSUtil.GetThumbprintFromCrt('/tmp/test.crt')
thumbprint = OSUTIL.get_thumbprint_from_crt('/tmp/test.crt')
self.assertEquals('33B0ABCE4673538650971C10F7D7397E71561F35', thumbprint)
def test_selinux(self):
if not OSUtil.IsSelinuxSystem():
if not OSUTIL.is_selinux_system():
return
isRunning = OSUtil.IsSelinuxRunning()
if not OSUtil.IsSelinuxRunning():
OSUtil.SetSelinuxEnforce(0)
self.assertEquals(False, OSUtil.IsSelinuxRunning())
OSUtil.SetSelinuxEnforce(1)
self.assertEquals(True, OSUtil.IsSelinuxRunning())
isrunning = OSUTIL.is_selinux_enforcing()
if not OSUTIL.is_selinux_enforcing():
OSUTIL.set_selinux_enforce(0)
self.assertEquals(False, OSUTIL.is_selinux_enforcing())
OSUTIL.set_selinux_enforce(1)
self.assertEquals(True, OSUTIL.is_selinux_enforcing())
if os.path.isfile('/tmp/abc'):
os.remove('/tmp/abc')
fileutil.SetFileContents('/tmp/abc', '')
OSUtil.SetSelinuxContext('/tmp/abc','unconfined_u:object_r:ssh_home_t:s')
OSUtil.SetSelinuxEnforce(1 if isRunning else 0)
fileutil.write_file('/tmp/abc', '')
OSUTIL.set_selinux_context('/tmp/abc','unconfined_u:object_r:ssh_home_t:s')
OSUTIL.set_selinux_enforce(1 if isrunning else 0)
@Mockup(shellutil, 'RunGetOutput', MockFunc(retval=[0, '']))
@Mockup(fileutil, 'SetFileContents', MockFunc())
@mock(shellutil, 'run_get_output', MockFunc(retval=[0, '']))
@mock(fileutil, 'write_file', MockFunc())
def test_network_operation(self):
OSUtil.StartNetwork()
OSUtil.OpenPortForDhcp()
OSUtil.GenerateTransportCert()
mac = OSUtil.GetMacAddress()
OSUTIL.start_network()
OSUTIL.allow_dhcp_broadcast()
OSUTIL.gen_transport_cert()
mac = OSUTIL.get_mac_addr()
self.assertNotEquals(None, mac)
OSUtil.IsMissingDefaultRoute()
OSUtil.SetBroadcastRouteForDhcp('api')
OSUtil.RemoveBroadcastRouteForDhcp('api')
OSUtil.RouteAdd('', '', '')
OSUtil.GetDhcpProcessId()
OSUtil.SetHostname('api')
OSUtil.PublishHostname('api')
OSUTIL.is_missing_default_route()
OSUTIL.set_route_for_dhcp_broadcast('api')
OSUTIL.remove_route_for_dhcp_broadcast('api')
OSUTIL.route_add('', '', '')
OSUTIL.get_dhcp_pid()
OSUTIL.set_hostname('api')
OSUTIL.publish_hostname('api')
@Mockup(OSUtil, 'GetHome', MockFunc(retval='/tmp/home'))
@mock(OSUTIL, 'get_home', MockFunc(retval='/tmp/home'))
def test_deploy_key(self):
if os.path.isdir('/tmp/home'):
shutil.rmtree('/tmp/home')
user = shellutil.RunGetOutput('whoami')[1].strip()
OSUtil.DeploySshKeyPair(user, 'test', '$HOME/.ssh/id_rsa')
OSUtil.DeploySshPublicKey(user, 'test', '$HOME/.ssh/authorized_keys')
user = shellutil.run_get_output('whoami')[1].strip()
OSUTIL.deploy_ssh_keypair(user, 'test', '$HOME/.ssh/id_rsa')
OSUTIL.deploy_ssh_pubkey(user, 'test', '$HOME/.ssh/authorized_keys')
self.assertTrue(os.path.isfile('/tmp/home/.ssh/id_rsa'))
self.assertTrue(os.path.isfile('/tmp/home/.ssh/id_rsa.pub'))
self.assertTrue(os.path.isfile('/tmp/home/.ssh/authorized_keys'))
@Mockup(shellutil, 'RunGetOutput', MockFunc(retval=[0, '']))
@Mockup(OSUtil, 'GetSshdConfigPath', MockFunc(retval='/tmp/sshd_config'))
@mock(shellutil, 'run_get_output', MockFunc(retval=[0, '']))
@mock(OSUTIL, 'get_sshd_conf_file_path', MockFunc(retval='/tmp/sshd_config'))
def test_ssh_operation(self):
shellutil.RunGetOutput.retval=[0,
shellutil.run_get_output.retval=[0,
'2048 f1:fe:14:66:9d:46:9a:60:8b:8c:'
'80:43:39:1c:20:9e root@api (RSA)']
sshdConfig = OSUtil.GetSshdConfigPath()
self.assertEquals('/tmp/sshd_config', sshdConfig)
if os.path.isfile(sshdConfig):
os.remove(sshdConfig)
shutil.copyfile(os.path.join(env.test_root, 'sshd_config'), sshdConfig)
OSUtil.SetSshClientAliveInterval()
OSUtil.ConfigSshd(True)
self.assertTrue(simple_file_grep(sshdConfig,
sshd_conf = OSUTIL.get_sshd_conf_file_path()
self.assertEquals('/tmp/sshd_config', sshd_conf)
if os.path.isfile(sshd_conf):
os.remove(sshd_conf)
shutil.copyfile(os.path.join(env.test_root, 'sshd_config'), sshd_conf)
OSUTIL.set_ssh_client_alive_interval()
OSUTIL.conf_sshd(True)
self.assertTrue(simple_file_grep(sshd_conf,
'PasswordAuthentication no'))
self.assertTrue(simple_file_grep(sshdConfig,
self.assertTrue(simple_file_grep(sshd_conf,
'ChallengeResponseAuthentication no'))
self.assertTrue(simple_file_grep(sshdConfig,
self.assertTrue(simple_file_grep(sshd_conf,
'ClientAliveInterval 180'))
@Mockup(shellutil, 'RunGetOutput', MockFunc(retval=[0, '']))
@Mockup(OSUtil, 'GetDvdMountPoint', MockFunc(retval='/tmp/cdrom'))
@mock(shellutil, 'run_get_output', MockFunc(retval=[0, '']))
@mock(OSUTIL, 'get_mount_point', MockFunc(retval='/tmp/cdrom'))
def test_mount(self):
OSUtil.MountDvd()
OSUtil.UmountDvd()
mountPoint = OSUtil.GetMountPoint(MountlistSample, '/dev/sda')
self.assertNotEquals(None, mountPoint)
OSUTIL.mount_dvd()
OSUTIL.umount_dvd()
mount_point = OSUTIL.get_mount_point(mount_list_sample, '/dev/sda')
self.assertNotEquals(None, mount_point)
def test_getdvd(self):
OSUtil.GetDvdDevice()
OSUTIL.get_dvd_device()
if __name__ == '__main__':
unittest.main()
+11 -11
View File
@@ -59,17 +59,17 @@ ExtensionsConfigSample="""
class TestOvf(unittest.TestCase):
def test_ovf(self):
config = ovfenv.OvfEnv(ExtensionsConfigSample)
self.assertEquals(1, config.getMajorVersion())
self.assertEquals(0, config.getMinorVersion())
self.assertEquals("HostName", config.getComputerName())
self.assertEquals("UserName", config.getUserName())
self.assertEquals("UserPassword", config.getUserPassword())
self.assertEquals(False, config.getDisableSshPasswordAuthentication())
self.assertEquals("CustomData", config.getCustomData())
self.assertNotEquals(None, config.getSshPublicKeys())
self.assertEquals(1, len(config.getSshPublicKeys()))
self.assertNotEquals(None, config.getSshKeyPairs())
self.assertEquals(1, len(config.getSshKeyPairs()))
self.assertEquals(1, config.get_major_version())
self.assertEquals(0, config.get_minor_version())
self.assertEquals("HostName", config.get_computer_name())
self.assertEquals("UserName", config.get_username())
self.assertEquals("UserPassword", config.get_user_password())
self.assertEquals(False, config.get_disable_ssh_password_auth())
self.assertEquals("CustomData", config.get_customdata())
self.assertNotEquals(None, config.get_ssh_pubkeys())
self.assertEquals(1, len(config.get_ssh_pubkeys()))
self.assertNotEquals(None, config.get_ssh_keypairs())
self.assertEquals(1, len(config.get_ssh_keypairs()))
if __name__ == '__main__':
unittest.main()
+1 -1
View File
@@ -62,7 +62,7 @@ extensionDataStr = """
class TestProtocolContract(unittest.TestCase):
def test_get_properties(self):
data = get_properties(VmInfo())
data = get_properties(VMInfo())
data = get_properties(Cert())
data = get_properties(ExtensionPackageList())
data = get_properties(InstanceMetadata())
+2 -2
View File
@@ -29,11 +29,11 @@ import azurelinuxagent.protocol.protocolFactory as protocolFactory
class TestWireProtocolEndpoint(unittest.TestCase):
def test_get_available_protocol(self):
with self.assertRaises(protocol.ProtocolNotFound):
protocol.Factory.getDefaultProtocol()
protocol.Factory.get_default_protocol()
def test_get_available_protocols(self):
mockGetV1 = MockFunc(retval="Mock protocol")
protocols = protocolFactory.GetAvailableProtocols([mockGetV1])
protocols = protocolFactory.get_available_protocols([mockGetV1])
self.assertNotEquals(None, protocols)
self.assertNotEquals(0, len(protocols))
+4 -4
View File
@@ -23,7 +23,7 @@ from tests.tools import *
import unittest
from azurelinuxagent.distro.redhat.osutil import RedhatOSUtil
TestPublicKey="""\
test_pubkey="""\
-----BEGIN PUBLIC KEY-----
MIIBIDANBgkqhkiG9w0BAQEFAAOCAQ0AMIIBCAKCAQEA2wo22vf1N8NWE+5lLfit
T7uzkfwqdw0IAoHZ0l2BtP0ajy6f835HCR3w3zLWw5ut7Xvyo26x1OMOzjo5lqtM
@@ -35,15 +35,15 @@ fQIBIw==
-----END PUBLIC KEY-----
"""
Expected="""\
expected_ssh_rsa_pubkey="""\
ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAQEA2wo22vf1N8NWE+5lLfitT7uzkfwqdw0IAoHZ0l2BtP0ajy6f835HCR3w3zLWw5ut7Xvyo26x1OMOzjo5lqtMh8iyQwfHtWf6Cekxfkf+6Pca99bNuDgwRopOTOyoVgwDzJB0+slpn/sJjeGbhxJlToT8tNPLrBmnnpaMZLMIANcPQtTRCQcV/ycv+/omKXFB+zULYkN8v22o5mysoCuQfzXiJP3Mlnf+V2XMl1WAJylhOJif04K8j+G8oF5ECBIQiph4ZLQS1yTYlozPXU8k8vB6A5+UiOGxBnOQYnp42cS5d4qSQ8LORCRGXrCj4DCP+lvkUDLUHx2WN+1ivZkOfQ==
"""
class TestRedhat(unittest.TestCase):
def test_RsaPublicKeyToSshRsa(self):
OSUtil = RedhatOSUtil()
sshRsaPublicKey = OSUtil.RsaPublicKeyToSshRsa(TestPublicKey)
self.assertEquals(Expected, sshRsaPublicKey)
ssh_rsa_pubkey = OSUtil.asn1_to_ssh_rsa(test_pubkey)
self.assertEquals(expected_ssh_rsa_pubkey, ssh_rsa_pubkey)
if __name__ == '__main__':
unittest.main()
+12 -12
View File
@@ -23,11 +23,11 @@ from tests.tools import *
import unittest
import azurelinuxagent.distro.default.resourceDisk as rdh
import azurelinuxagent.logger as logger
from azurelinuxagent.utils.osutil import OSUtil
from azurelinuxagent.utils.osutil import OSUTIL
#logger.LoggerInit("/dev/null", "/dev/stdout")
MockGPTOutput="""
gpt_output_sample="""
Model: Msft Virtual Disk (scsi)
Disk /dev/sda: 32.2GB
Sector size (logical/physical): 512B/4096B
@@ -40,24 +40,24 @@ Number Start End Size Type File system Flags
class TestResourceDisk(unittest.TestCase):
@Mockup(rdh.OSUtil, 'DeviceForIdePort', MockFunc(retval='foo'))
@Mockup(rdh.shellutil, 'RunGetOutput', MockFunc(retval=(0, MockGPTOutput)))
@Mockup(rdh.shellutil, 'Run', MockFunc(retval=0))
@mock(rdh.OSUtil, 'device_for_ide_port', MockFunc(retval='foo'))
@mock(rdh.shellutil, 'run_get_output', MockFunc(retval=(0, gpt_output_sample)))
@mock(rdh.shellutil, 'run', MockFunc(retval=0))
def test_mountGPT(self):
handler = rdh.ResourceDiskHandler()
handler.mountResourceDisk('/tmp/foo', 'ext4')
handler.mount_resource_disk('/tmp/foo', 'ext4')
@Mockup(rdh.OSUtil, 'DeviceForIdePort', MockFunc(retval='foo'))
@Mockup(rdh.shellutil, 'RunGetOutput', MockFunc(retval=(0, "")))
@Mockup(rdh.shellutil, 'Run', MockFunc(retval=0))
@mock(rdh.OSUtil, 'device_for_ide_port', MockFunc(retval='foo'))
@mock(rdh.shellutil, 'run_get_output', MockFunc(retval=(0, "")))
@mock(rdh.shellutil, 'run', MockFunc(retval=0))
def test_mountMBR(self):
handler = rdh.ResourceDiskHandler()
handler.mountResourceDisk('/tmp/foo', 'ext4')
handler.mount_resource_disk('/tmp/foo', 'ext4')
@Mockup(rdh.shellutil, 'Run', MockFunc(retval=0))
@mock(rdh.shellutil, 'run', MockFunc(retval=0))
def test_createSwapSpace(self):
handler = rdh.ResourceDiskHandler()
handler.createSwapSpace('/tmp/foo', 512)
handler.create_swap_space('/tmp/foo', 512)
if __name__ == '__main__':
unittest.main()
+9 -9
View File
@@ -33,32 +33,32 @@ import azurelinuxagent.logger as logger
class TestHttpOperations(unittest.TestCase):
def test_parse_url(self):
host, port, secure, relativeUrl = restutil._ParseUrl("http://abc.def/ghi#hash?jkl=mn")
host, port, secure, rel_uri = restutil._parse_url("http://abc.def/ghi#hash?jkl=mn")
self.assertEquals("abc.def", host)
self.assertEquals("/ghi#hash?jkl=mn", relativeUrl)
self.assertEquals("/ghi#hash?jkl=mn", rel_uri)
host, port, secure, relativeUrl = restutil._ParseUrl("http://abc.def/")
host, port, secure, rel_uri = restutil._parse_url("http://abc.def/")
self.assertEquals("abc.def", host)
self.assertEquals("/", relativeUrl)
self.assertEquals("/", rel_uri)
self.assertEquals(False, secure)
host, port, secure, relativeUrl = restutil._ParseUrl("https://abc.def/ghi?jkl=mn")
host, port, secure, rel_uri = restutil._parse_url("https://abc.def/ghi?jkl=mn")
self.assertEquals(True, secure)
host, port, secure, relativeUrl = restutil._ParseUrl("http://abc.def:80/")
host, port, secure, rel_uri = restutil._parse_url("http://abc.def:80/")
self.assertEquals("abc.def", host)
def _test_http_get(self):
resp = restutil.HttpGet("http://httpbin.org/get").read()
resp = restutil.http_get("http://httpbin.org/get").read()
self.assertNotEquals(None, resp)
msg = str(uuid.uuid4())
resp = restutil.HttpGet("http://httpbin.org/get", {"x-abc":msg}).read()
resp = restutil.http_get("http://httpbin.org/get", {"x-abc":msg}).read()
self.assertNotEquals(None, resp)
self.assertTrue(msg in resp)
def _test_https_get(self):
resp = restutil.HttpGet("https://httpbin.org/get").read()
resp = restutil.http_get("https://httpbin.org/get").read()
self.assertNotEquals(None, resp)
if __name__ == '__main__':
+2 -2
View File
@@ -25,7 +25,7 @@ import unittest
import os
import azurelinuxagent.protocol.v1 as v1
SharedConfigSample="""
shared_config_sample="""
<SharedConfig version="1.0.0.0" goalStateIncarnation="1">
<Deployment name="db00a7755a5e4e8a8fe4b19bc3b330c3" guid="{ce5a036f-5c93-40e7-8adf-2613631008ab}" incarnation="2">
@@ -73,7 +73,7 @@ SharedConfigSample="""
class TestSharedConfig(unittest.TestCase):
def test_sharedconfig(self):
sharedconfig = v1.SharedConfig(SharedConfigSample)
shared_conf = v1.SharedConfig(shared_config_sample)
if __name__ == '__main__':
unittest.main()
+4 -13
View File
@@ -26,23 +26,14 @@ import os
import azurelinuxagent.utils.shellutil as shellutil
import test
class TestRunCmd(unittest.TestCase):
class TestrunCmd(unittest.TestCase):
def test_run_get_output(self):
output = shellutil.RunGetOutput("ls /")
output = shellutil.run_get_output("ls /")
self.assertNotEquals(None, output)
self.assertEquals(0, output[0])
err = shellutil.RunGetOutput("ls /not-exists")
err = shellutil.run_get_output("ls /not-exists")
self.assertNotEquals(0, err[0])
def test_run_send_stdin(self):
test_sh = os.path.join(env.test_root, "read_stdin.sh")
output = shellutil.RunSendStdin(test_sh, "y")
self.assertEquals(0, output[0])
output = shellutil.RunSendStdin(test_sh, "n")
self.assertEquals(1, output[0])
if __name__ == '__main__':
unittest.main()
+3 -3
View File
@@ -27,9 +27,9 @@ import azurelinuxagent.utils.textutil as textutil
import test
class TestTextUtil(unittest.TestCase):
def test_GetPasswordHash(self):
passwdHash = textutil.GetPasswordHash("asdf", True, 6, 10)
self.assertNotEquals(None, passwdHash)
def test_get_password_hash(self):
password_hash = textutil.gen_password_hash("asdf", True, 6, 10)
self.assertNotEquals(None, password_hash)
if __name__ == '__main__':
unittest.main()
+71 -71
View File
@@ -29,96 +29,96 @@ import httplib
import azurelinuxagent.logger as logger
import azurelinuxagent.protocol.v1 as v1
from test_version import VersionInfoSample
from test_goalstate import GoalStateSample
from test_hostingenv import HostingEnvSample
from test_sharedconfig import SharedConfigSample
from test_certificates import CertificatesSample, TransportCert
from test_extensionsconfig import ExtensionsConfigSample, ManifestSample
from test_goalstate import goal_state_sample
from test_hostingenv import hosting_env_sample
from test_sharedconfig import shared_config_sample
from test_certificates import certs_sample, transport_cert
from test_extensionsconfig import ext_conf_sample, manifest_sample
#logger.LoggerInit("/dev/stdout", "/dev/null", verbose=True)
#logger.LoggerInit("/dev/stdout", "/dev/null", verbose=False)
def MockFetchUri(url, headers=None, chkProxy=False):
def mock_fetch_uri(url, headers=None, chk_proxy=False):
content = None
if "versions" in url:
content = VersionInfoSample
elif "goalstate" in url:
content = GoalStateSample
content = goal_state_sample
elif "hostingenvuri" in url:
content = HostingEnvSample
content = hosting_env_sample
elif "sharedconfiguri" in url:
content = SharedConfigSample
content = shared_config_sample
elif "certificatesuri" in url:
content = CertificatesSample
content = certs_sample
elif "extensionsconfiguri" in url:
content = ExtensionsConfigSample
content = ext_conf_sample
elif "manifest.xml" in url:
content = ManifestSample
content = manifest_sample
else:
raise Exception("Bad url {0}".format(url))
return content
def MockFetchManifest(uris):
return ManifestSample
def mock_fetch_manifest(uris):
return manifest_sample
def MockFetchCache(filePath):
def mock_fetch_cache(file_path):
content = None
if "Incarnation" in filePath:
if "Incarnation" in file_path:
content = 1
elif "GoalState" in filePath:
content = GoalStateSample
elif "HostingEnvironmentConfig" in filePath:
content = HostingEnvSample
elif "SharedConfig" in filePath:
content = SharedConfigSample
elif "Certificates" in filePath:
content = CertificatesSample
elif "TransportCert" in filePath:
content = TransportCert
elif "ExtensionsConfig" in filePath:
content = ExtensionsConfigSample
elif "manifest" in filePath:
content = ManifestSample
elif "GoalState" in file_path:
content = goal_state_sample
elif "HostingEnvironmentConfig" in file_path:
content = hosting_env_sample
elif "SharedConfig" in file_path:
content = shared_config_sample
elif "Certificates" in file_path:
content = certs_sample
elif "TransportCert" in file_path:
content = transport_cert
elif "ExtensionsConfig" in file_path:
content = ext_conf_sample
elif "manifest" in file_path:
content = manifest_sample
else:
raise Exception("Bad filepath {0}".format(filePath))
raise Exception("Bad filepath {0}".format(file_path))
return content
class TestWireClint(unittest.TestCase):
@Mockup(v1, '_fetchCache', MockFetchCache)
def testGet(self):
@mock(v1, '_fetch_cache', mock_fetch_cache)
def test_get(self):
os.chdir('/tmp')
client = v1.WireClient("foobar")
goalState = client.getGoalState()
goalState = client.get_goal_state()
self.assertNotEquals(None, goalState)
hostingEnv = client.getHostingEnv()
hostingEnv = client.get_hosting_env()
self.assertNotEquals(None, hostingEnv)
sharedConfig = client.getSharedConfig()
sharedConfig = client.get_shared_conf()
self.assertNotEquals(None, sharedConfig)
extensionsConfig = client.getExtensionsConfig()
extensionsConfig = client.get_ext_conf()
self.assertNotEquals(None, extensionsConfig)
@Mockup(v1, '_fetchCache', MockFetchCache)
def testGetHeaderWithCert(self):
@mock(v1, '_fetch_cache', mock_fetch_cache)
def test_get_head_for_cert(self):
client = v1.WireClient("foobar")
header = client.getHeaderWithCert()
header = client.get_header_for_cert()
self.assertNotEquals(None, header)
@Mockup(v1.WireClient, 'getHeaderWithCert', MockFunc())
@Mockup(v1, '_fetchUri', MockFetchUri)
@Mockup(v1.fileutil, 'SetFileContents', MockFunc())
def testUpdateGoalState(self):
@mock(v1.WireClient, 'get_header_for_cert', MockFunc())
@mock(v1, '_fetch_uri', mock_fetch_uri)
@mock(v1.fileutil, 'write_file', MockFunc())
def test_update_goal_state(self):
client = v1.WireClient("foobar")
client.updateGoalState()
goalState = client.getGoalState()
self.assertNotEquals(None, goalState)
hostingEnv = client.getHostingEnv()
self.assertNotEquals(None, hostingEnv)
sharedConfig = client.getSharedConfig()
self.assertNotEquals(None, sharedConfig)
extensionsConfig = client.getExtensionsConfig()
self.assertNotEquals(None, extensionsConfig)
client.update_goal_state()
goal_state = client.get_goal_state()
self.assertNotEquals(None, goal_state)
hosting_env = client.get_hosting_env()
self.assertNotEquals(None, hosting_env)
shared_config = client.get_shared_conf()
self.assertNotEquals(None, shared_config)
ext_conf = client.get_ext_conf()
self.assertNotEquals(None, ext_conf)
class MockResp(object):
def __init__(self, status):
@@ -126,40 +126,40 @@ class MockResp(object):
class TestStatusBlob(unittest.TestCase):
def testToJson(self):
vmStatus = v1.VMStatus()
statusBlob = v1.StatusBlob(vmStatus)
self.assertNotEquals(None, statusBlob.toJson())
vm_status = v1.VMStatus()
status_blob = v1.StatusBlob(vm_status)
self.assertNotEquals(None, status_blob.toJson())
@Mockup(v1.restutil, 'HttpPut', MockFunc(retval=MockResp(httplib.CREATED)))
@Mockup(v1.restutil, 'HttpHead', MockFunc(retval=MockResp(httplib.OK)))
@mock(v1.restutil, 'http_put', MockFunc(retval=MockResp(httplib.CREATED)))
@mock(v1.restutil, 'http_head', MockFunc(retval=MockResp(httplib.OK)))
def test_put_page_blob(self):
vmStatus = v1.VMStatus()
statusBlob = v1.StatusBlob(vmStatus)
vm_status = v1.VMStatus()
status_blob = v1.StatusBlob(vm_status)
data = ['a'] * 100
statusBlob.putPageBlob("http://foo.bar", data)
status_blob.put_page_blob("http://foo.bar", data)
class TestConvert(unittest.TestCase):
def test_status(self):
vmStatus = v1.VMStatus()
handlerStatus = v1.ExtensionHandlerStatus()
vm_status = v1.VMStatus()
handler_status = v1.ExtensionHandlerStatus()
substatus = v1.ExtensionSubStatus()
extStatus = v1.ExtensionStatus()
ext_status = v1.ExtensionStatus()
vmStatus.extensionHandlers.append(handlerStatus)
v1.vm_status_to_v1(vmStatus)
vm_status.extensionHandlers.append(handler_status)
v1.vm_status_to_v1(vm_status)
handlerStatus.extensionStatusList.append(extStatus)
v1.vm_status_to_v1(vmStatus)
handler_status.extensionStatusList.append(ext_status)
v1.vm_status_to_v1(vm_status)
extStatus.substatusList.append(substatus)
v1.vm_status_to_v1(vmStatus)
ext_status.substatusList.append(substatus)
v1.vm_status_to_v1(vm_status)
def test_param(self):
param = v1.TelemetryEventParam()
event = v1.TelemetryEvent()
event.parameters.append(param)
v1.event_to_xml(event)
v1.event_to_v1(event)
if __name__ == '__main__':
unittest.main()
+5 -5
View File
@@ -42,11 +42,11 @@ VersionInfoSample="""\
class TestVersionInfo(unittest.TestCase):
def test_version_info(self):
config = v1.VersionInfo(VersionInfoSample)
self.assertEquals("2012-11-30", config.getPreferred())
self.assertNotEquals(None, config.getSupported())
self.assertEquals(2, len(config.getSupported()))
self.assertEquals("2010-12-15", config.getSupported()[0])
self.assertEquals("2010-28-10", config.getSupported()[1])
self.assertEquals("2012-11-30", config.get_preferred())
self.assertNotEquals(None, config.get_supported())
self.assertEquals(2, len(config.get_supported()))
self.assertEquals("2010-12-15", config.get_supported()[0])
self.assertEquals("2010-28-10", config.get_supported()[1])
if __name__ == '__main__':
unittest.main()
+8 -10
View File
@@ -21,7 +21,7 @@
import os
import sys
from azurelinuxagent.utils.osutil import OSUtil
from azurelinuxagent.utils.osutil import OSUTIL
parent = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent)
@@ -31,9 +31,9 @@ def simple_file_grep(file_path, search_str):
if search_str in line:
return line
def Mockup(target, name, mock):
def Decorator(func):
def Wrapper(*args, **kwargs):
def mock(target, name, mock):
def decorator(func):
def wrapper(*args, **kwargs):
origin = getattr(target, name)
setattr(target, name, mock)
try:
@@ -43,8 +43,8 @@ def Mockup(target, name, mock):
finally:
setattr(target, name, origin)
return result
return Wrapper
return Decorator
return wrapper
return decorator
class MockFunc(object):
def __init__(self, name='', retval=None):
@@ -57,9 +57,7 @@ class MockFunc(object):
self.kwargs = kwargs
return self.retval
def Dummy():
pass
#Mock osutil so that the test of other part will be os unrelated
OSUtil.GetLibDir = MockFunc(retval='/tmp')
OSUtil.GetExtLogDir = MockFunc(retval='/tmp/log')
OSUTIL.get_lib_dir = MockFunc(retval='/tmp')
OSUTIL.get_ext_log_dir = MockFunc(retval='/tmp/log')