رفتن به نوشته‌ها

آموزش تنسورفلو قسمت پنجم (Estimator API)

در قسمت‌های قبلی با برخی مفاهیم اولیه در تنسورفلو آشنا شدیم و دیدیم که چطوری میشه یک مدل رو پیاده‌سازی کرد و در نهایت آن را سرو کنیم! خوشبختانه دعوای بین گوگل و فیسبوک به نفع ما تموم میشه و هرکدوم سعی میکنن ویژگی‌های بهتر و جذاب‌تری رو ارائه بدن تا رقابت رو از دیگری ببرند. یکی از چیزهایی که تقریبا در راستای همین دعوا ارائه شده است Estimator API نام دارد. گوگل این ویژگی‌ رو ارائه کرد تا علاوه بر اینکه کد زدن رو راحت‌تر کنه تو حوزه رایانش ابری و تی‌پی‌یو برگ برنده‌ای نسبت به فیسبوک داشته باشه.

۱) Estimator چیست؟

Estimator یک API سطح بالا در تنسورفلو است که از سطح انتزاع (abstraction) بیشتری برخوردار است. یعنی چی؟ یعنی لازم نیست خیلی قسمت‌ها رو از صفر پیاده‌سازی کنیم بلکه این API، آنها پیاده‌سازی کرده است و در اختیار ما قرار می‌دهد. به شکل زیر، که معماری کلی تنسورفلو را نشان می‌دهد، نگاه کنید:

یک معماری کلی از تنسورفلو (عکس از تنسورفلو)

همانطور که در شکل بالا مشاهده می‌شود، Estimator هم سطح Keras است. یعنی انتزاعی که در Estimator وجود دارد تقریبا معادل همون چیزی است که ما در کراس داریم. از آن جالب‌تر اینکه هر دو (کراس و Estimator) روی tf.keras.layers پیاده‌سازی شده‌اند و تقریبا میشه گفت این Estimator چیز جدیدی ندارد و با ارائه ماژول Estimator خواستند از برخی جهات به توسعه‌دهنده‌ها کمک کنند.

چرا از همون کراس استفاده نکردند؟ راستش این واسه خودمم سواله! ولی تا جایی میدونم این است که کراس عملکرد جالبی روی سیستم‌‌های توزیع شده گوگل (Google cloud) و Multi-TPU نداره و بخاطر همین گوگل این API را ارائه کرده است.

خب بگذریم… این Estimator چیکار قرار است بکنه؟ اگر چندتا پروژه دیپ‌لرنیگ رو از صفر در تنسورفلو پیاده‌سازی کنید بعد از مدتی متوجه میشید که خیلی از قسمت‌هایی که کد زدید در پروژه‌های بعدی هم عینا تکرار میشه مثلا: قسمتهای مربوط به تنسوربورد (summaryها)، قسمت‌های مربوط به ذخیره و بازیابی مدل (save & restore)، حلقه اصلی آموزش برنامه و … هدف Estimator این است که این قسمت‌ها را به حداقل برساند (اصلاحا قسمت‌های boilerplate رو کم کنه).

اگر به شکل بالا (معماری تنسورفلو) نگاه کنید بالاتر از Estimator و کراس یک ماژول به نام Pre-made Estimator داریم که در واقع این ماژول‌ خواسته کار ما را حتی از این هم راحت‌تر کند و یک سری مدل از پیش تعریف شده ارائه شده است و فقط کافی است پارامترها را تعریف کنیم و دیتای خودمون رو بهش بدیم. در Pre-made estimatorها حتی معماری شبکه هم بر اساس معماری معروف پیاده‌سازی شده است و لازم نیست شما از صفر بشینید کد بزنید. لیست Pre-made estimatorها رو میتونید از اینجا مشاهده کنید که از معروف‌ترین آنها میشه به DNNClassifier و ‌BoostedTreeRegressor و… اشاره کرد. مثلا برای DNNClassifier معماری زیر ارائه شده است که شما کافی است بگید ورودی من چیه، تعداد لایه‌های مخفی (hidden layer) چیه و چندتا کلاس خروجی دارم.

معماری شبکه DNNClassifier (عکس از تنسورفلو)

که اگر بخواهیم شبکه بالا را در Estimator پیاده کنیم، میتونیم بگیم:

همانطور که مشاهده می‌کنید شما لازم نیست شبکه رو از صفر تعریف کنید یا اینکه نگران سازگاری اندازه لایه‌هاتون در لایه‌های متوالی باشید. یه شی از DNNClassifier درست می‌کنید و تمام!

اگر با scikit-learn کار کرده باشید یه شباهت‌هایی بین اینترفیس Estimator و scikit پی می‌برید. توی scikit هم ما یه چیزی به نام BaseEstimator داشتیم که مثلا میخواستیم مدل خودمون رو از صفر تعریف کنیم ازش ارث‌بری میکردیم یا اگرم میخواستیم از مدل‌های خود scikit استفاده کنیم کافی بود فقط پارامتر‌هاشو بهش بدیم و fit و transform رو براش صدا بزنیم.

مثلا فرض کنید روی دیتاست iris بخواهیم از scikit استفاده کنیم یه همچین چیزی مینوشتیم (iris یک دیتاست برای دسته‌بندی گلها است):

حالا فرض کنید همین کد رو با pre-estimatorها پیاده کنیم چیزی شبیه زیر خواهیم داشت:

اگر یه نگاه کلی به کد بالا بندازید متوجه یه شباهت‌هایی با scikit یه شباهت‌هایی میشیم. مشابه scikit ابتدا مدل رو تعریف میکنیم و بعدش train و evaluate رو صدا می‌زنیم که تقریبا معادل fit و predict توی scikit است (feature_columns رو فعلا نادیده بگیرید). همچنین مثل scikit دیگه لازم نیست از جزئیات داخل مدل سر دربیاریم، کافیه تعداد لایه‌ها hidden_units و اندازه‌اشون و n_classes رو بهش بدیم و دیتا‌هامون رو به خورد مدل بدیم (feed).

بیایید یکم دقیق‌تر به کد بالا نگاه کنیم، همانطور که در مدل بالا و توی سایر کدهای مربوط به Estimatorها هم خواهید دید، ورودی‌های (ارگومان‌های) توابع معمولا با چیزی که توی کتابخانه‌های معروف مثل scikit کار کردیم متفاوت است. یعنی ما توی scikit دیتا رو یه پیش‌پردازش می‌کردیم و مستقیم list یا np.array رو برای fit میفرستیم. ولی اگر در کد بالا نگاه کنید، classifier.train برای ورودی input_fn میخواد که خب شاید یکم گیج کننده باشد. حالا این input_fn چی هست؟

… A function that provides input data for training as minibatches

reference

همانطور که می‌بینید طبق تعریف بالا که از سایت تنسورفلو آورده شده اورده شده است input_fn یک تابع است که فیچرها و برچسب‌های ما را در قالب دسته‌هایی مثلا 128 یا 256 یا … به مدل مدل‌امون میدهد. به به عبارت دیگه معمولا input_fn ها یه همچین چیزی خواهند بود:

همانطور که می‌بینید این input_fn تهش قرار است به ما لیستی به اندازه batch_sizeامون از جفت (ویژگی – برچسب) برگرداند. پس input_fn یک تابع است نه یک لیست یا ارایه!

خب کد بالا رو یکبار دیگه نگاه کنید ما اولش آمدیم یه چیزی به نام feature_columns تعریف کردیم که معمولا همچین قرتی‌بازی‌هایی توی scikit نداشتیم. این چیه؟

همانطور که احتمالا میدونید در دیتاست iris ما چهار نوع ویژگی (feature) داریم: SepalWidth , SepalLength, PetalWidth و PetalLength
که معادل فارسی‌اشون ظاهرا میشه طول و عرض کاسبرگ و گلبرگ!

کاسبرگ و گلبرگ چیه!

ما با استفاده از feature_columns تعریف می‌کنیم که نوع داده‌امون چیه (int یا float یا…)، اسمش feature چیه، اندازه (shape) فیچرهامون چطوریه، مقدار پیشفرض برای فیچرهامون چیه، در صورتی که لازم باشه نرمالایز اشون هم میتونیم بکنیم. در واقع اگر feature_columns که در بالا داشتیم رو چاپ کنیم یه چیزی شبیه زیر خواهید داشت:

پس در واقع این feature_columns به Estimator میگه داده‌هایی که قرار است شبکه رو باهاش آموزش بدی چه شکلی هستند.

حالا اگر کد بالا رو اجرا کنیم (کد کامل رو میتونید از اینجا مشاهده کنید) یه اتفاق خیلی جالب برای میافته و اونم این است که Estimator خودش فایل‌های مربوط Tensorboard و Checkpoint را برای شما می‌سازد یعنی شما عمل خیلی از کد زدن‌هاتون کم می‌کنه:

دایرکتوری ما بعد از اجرا DNNClassifier

حالا اگر دوباره کد رو اجرا کنید متوجه می‌شید که دیگه نمیاد از دوباره آموزش بده! خودش میره فایلهای checkpoint شما رو نگاه میکنه و مدل‌اتون رو بازیابی (restore) میکنه و از روی آن evaluation را صدا می‌زند. بنظرم که خیلی جذابه!

و اگر دستور زیر رو ترمینال یا کامند اجرا کنید می‌بینید تنسوربورد اجرا میشه و میتونید گراف محاسباتی و خطای مدل‌ را مشاهده کنیم:

tensorboard --logdir ./tmp

دیگه خبری از optimizer و loss و training loop و … نیست! همانطور که می‌بینید در عرض چند دقیقه میتونید یک مدل دیپ رو درست کنید، گراف محاسباتی‌اشو ببنید و حتی اون رو آماده سرو کنید!

ولی خب این همه انتزاع مشکلات و دردسر‌های خودشو داره و ما خیلی اوقات دوست داریم که شبکه‌امون رو خودمون بسازیم. طبق این مقاله‌ای که برای Estimator در اینجا منتشر شده است، گفته که گوگل تو ۴۲٪ موارد از Estimatorهای شخصی‌سازی شده یا customize استفاده می‌کنه:

۴۲٪ درصد از مدلهایی که در گوگل استفاده می‌شوند از Custom Estimator استفاده می‌کنند.

خب پس بیایید ببینیم چطوری میشه از Estimatorهایی استفاده کنیم که خودمون شبکه، تابع خطا و … رو میتونیم تعریف کنیم.

۲) استفاده از Estimatorهای شخصی‌سازی شده

استفاده از Estimatorهای شخصی شده خیلی راحت است و برای اینکه بخواهیم Estimator خودمون رو بسازیم باید از کلاس tf.estimator.Estimator یک نمونه درست کنیم. ولی قبل از اینکه به پیاده‌سازی آن بپردازیم بد نیست یک نگاه به سازنده (constructor) آن داشته باشیم:

آرگومان‌های سازنده tf.estimator.Estimator

توضیحاتی که تنسورفلو در اینجا قرار داده است میگه:

The Estimator object wraps a model which is specified by a model_fn, which, given inputs and a number of other parameters, returns the ops necessary to perform training, evaluation, or predictions.

یعنی چی؟ این کلاس tf.estimator.Estimator یک wrapper روی model_fn است. همانطور که قبلا داشتیم و احتمالا از اسمش هم حدس میزنید، model_fn یک تابع است که قراره مدل‌امون رو داخل آن تعریف کنیم.

پارامتر بعدی که در سازنده می‌بینید model_dir که مثل قبل آدرسی که تمام فایل‌های checkpoint و event رو باید ذخیره کنیم نشان می‌دهد. پارامتر سوم، config است که میتونید بگید هرچند وقت یکبار مدل‌امون رو ذخیره کن یا لاگ رو هرچند وقت یکبار نشون بده یا GPU رو چطوری تخصیص بده. ضمنا میتونید model_dir رو در config هم تنظیم کنید. با پارامترهای بعدی هم فعلا کاری نداریم.

خب بیایید یکم بیشتر راجع به model_fn صحبت کنیم. تابعی که به عنوان model_fn تعریف میکنیم باید (signature) شبیه زیر داشته باشه:

که از این بین mode، config و params اختیاری هستند و model_fn میتونه این پارامترها رو نداشته باشد ولی features و labels اجباری است. خب هرکدام از اینها چی هستند؟

  • features اولین آیتمی که از input_fn برمیگردد و مستقیم به train، eval یا predict پاس داده می‌شود.
  • labels دومین آیتمی است که از input_fn برمیگردد و مانند قبل به eval, train یا predict فرستاده می‌شود. لازم به ذکر است که وقتی labels به predict فرستاده می‌شود labels=None به آن فرستاده میشود.
  • mode مشخص می‌کند که فاز ما چیست! یعنی توی train هستیم یا evaluation یا predict! این mode سه حالت هم بیشتر نمی‌تونه داشته باشه: TRAIN, PREDICT, EVAL
  • params در واقع هایپرپارامترهای ماست که به مدل ارسال می‌شود. این آرگومان وقتی میخواهیم هایپرپارامترها را تنظیم کنیم (hyperparameters tuning) بسیار بدرد میخوره
  • config هم یک شی از جنس tf.estimator.RunConfig است که مسیر ذخیره checkpoints‌ها، اینکه هر چند بار از مدل checkpoint بگیریم یا اینکه چندتا GPU داریم و… را مشخص می‌کند.

خب حالا سوالی که پیش میاد این است که این model_fn اینا رو میگیره چی باید برگردونه؟ یک شی از نوع tf.estimator.EstimatorSpec. این چیه دیگه؟ بزارید یه کد تقریبی برای model_fn رو ببینیم:

ما توی model_fn چندتا کار اصلی باید انجام بدیم تا در نهایت tf.estimator.EstimatorSpec رو برگردانیم. این کار شامل زیر می‌شود:

  • تعریف مدل (تعریف logic مدل)
  • تعریف تابع خطا (loss)
  • تعریف بهینه کننده (optimizer)
  • تعریف معیارهای ارزیابی (evaluation metrics)

در کد بالا هم دقیقا همین کار رو کردیم یعنی ابتدا با استفاده از تابع neural_net_model یک شبکه عصبی تعریف کردیم که خروجی شبکه عصبی رو برگرداندیم (logit). بعد class_prediction رو تعریف کردیم. یعنی چی؟ logit خروجی خام شبکه را به ما می‌دهد یعنی نمیگه کدام کلاس را پیش‌بینی کرده است ولی ما با استفاده tf.argmax محمتل‌ترین پیش‌بینی را به عنوان class_prediction انتخاب می‌کنیم.

در قسمت بعدی آمدیم loss و optimizer رو تعریف کردیم که شاید بپرسید در حالت evaluation ما loss و optimizer نداریم! بعله درسته! بخاطر همین به طور پیش‌فرض با None مقداردهی کردیم و چندتا if هم گذاشتیم.

در نهایت هم متریک ‌امون رو تعریف می‌کنیم (tf.metrics.accuracy) که من در اینجا فرض کردم تنها یک متریک دارم و اونم دقت (acc) است.

در پایان هم همه این چیزایی که حساب کردم رو میریزیم توی یک شی به نام tf.estimator.EstimatorSpec که کار باهاش تر و تمیز باشه! و تمام!

تقریبا تموم شد! قسمت اصلی کد زدن با استفاده از Estimator API تعریف model_fn است.

فقط یک چیز موند و اونم تعریف مدل (neural_net_model) است. فرض کنید تسک‌امون digit recognition روی داده‌های mnist است و از یک کانولوشون بخواهیم استفاده کنیم که کدمون برای این قسمت چیزی شبیه زیر خواهد بود:

خب مرحله آخر میرسه به فراخوانی توابع train و evaluation که این رو توی DNNClassifier داشتیم پس داریم:

و تمام! شما الان تونستید یک مدل دلخواه را با استفاده از Estimator API پیاده‌سازی کنید! کد کامل این قسمت را میتونید از اینجا مشاهده کنید.

که اگر بخواهیم برای قسمت بالا یک دیاگرام هم بکشیم کاری که کردیم شبیه زیر است:

یک دیاگرام کلی برای مدلی که آموزش دادیم

۳) تبدیل کراس به Estimator

همانطور که اشاره شد Estimator مزیت‌های زیادی دارد که کراس نداره! بخاطر همین کراس این قابلیت را در اختیار ما قرار داده است تا مدل‌هایی که در کراس آموزش میدیم را به Estimator پورت کنیم. این کار با استفاده از تابعی به نام tf.keras.estimator.model_to_estimator انجام می‌گیرد و کافی است مدلی که در کراس درست کردید را به model_to_estimator بفرستید و بعد توابع train و evaluate را برای estimator صدا بزنید.

۴) استفاده از چند GPU

اگر چندتا GPU دارید و میخواهید مدل‌اتون روی چندتا GPU آموزش بدید به راحتی میتونید اندکی config‌ای که به Estimator پاس می‌دید تغییر بدید تا حداکثر استفاده را ببرید بدون اینکه لازم باشه تغییری در کدتون ایجاد کنید:

خب فکر کنم تا همینجا برای مقدمه کافی باشه و مابقی رو خودتون به راحتی میتونید تو مستندات تنسورفلو بگردید و پیدا کنید. اگر شما تجربه‌ای درباره کار کردن با Estimator دارید خوشحال میشم با من به اشتراک بزارید.

منتشر شده در تنسورفلویادگیری عمیقیادگیری ماشین

نظر

  1. ممنون از اطلاعات بسیار بسیار ارزشمند شما

    • هادیفر هادیفر

      مرسی از لطف شما 🙂

  2. عارفه عارفه

    خیلی وقت بود همچین مقاله فارسی با این حجم از کیفیت نخوونده بودم ! عالی بود.

    • هادیفر هادیفر

      ممنون از لطفتون

  3. mahsa mahsa

    سلام واقعا عالی بود ممنون.چطور میتونم از خروجی tf.predictکه در کد من بردارهای ویژگی لایه آخر رو پیش بینی میکنه برای کلاس بندی استفاده کنم؟

پاسخ دادن به mahsa لغو پاسخ

نشانی ایمیل شما منتشر نخواهد شد. بخش‌های موردنیاز علامت‌گذاری شده‌اند *